This commit is contained in:
Dogtiti 2024-11-06 22:07:33 +08:00
parent d544eead38
commit f6e1f8398b
7 changed files with 435 additions and 18 deletions

View File

@ -24,12 +24,19 @@
.bottom-icons {
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
position: absolute;
bottom: 20px;
box-sizing: border-box;
padding: 0 20px;
}
.icon-center {
display: flex;
justify-content: center;
align-items: center;
gap: 4px;
}
.icon-left,
.icon-right {

View File

@ -1,34 +1,220 @@
import VoiceIcon from "@/app/icons/voice.svg";
import VoiceOffIcon from "@/app/icons/voice-off.svg";
import Close24Icon from "@/app/icons/close-24.svg";
import PowerIcon from "@/app/icons/power.svg";
import styles from "./realtime-chat.module.scss";
import clsx from "clsx";
import { useState, useRef, useCallback } from "react";
import { useState, useRef, useCallback, useEffect } from "react";
import { useAccessStore, useChatStore, ChatMessage } from "@/app/store";
import { IconButton } from "@/app/components/button";
import {
Modality,
RTClient,
RTInputAudioItem,
RTResponse,
TurnDetection,
} from "rt-client";
import { AudioHandler } from "@/app/lib/audio";
interface RealtimeChatProps {
onClose?: () => void;
onStartVoice?: () => void;
onPausedVoice?: () => void;
sampleRate?: number;
}
export function RealtimeChat({
onClose,
onStartVoice,
onPausedVoice,
sampleRate = 24000,
}: RealtimeChatProps) {
const [isVoicePaused, setIsVoicePaused] = useState(true);
const clientRef = useRef<null>(null);
const currentItemId = useRef<string>("");
const currentBotMessage = useRef<ChatMessage | null>();
const currentUserMessage = useRef<ChatMessage | null>();
const accessStore = useAccessStore.getState();
const chatStore = useChatStore();
const [isRecording, setIsRecording] = useState(false);
const [isConnected, setIsConnected] = useState(false);
const [isConnecting, setIsConnecting] = useState(false);
const [modality, setModality] = useState("audio");
const [isAzure, setIsAzure] = useState(false);
const [endpoint, setEndpoint] = useState("");
const [deployment, setDeployment] = useState("");
const [useVAD, setUseVAD] = useState(true);
const clientRef = useRef<RTClient | null>(null);
const audioHandlerRef = useRef<AudioHandler | null>(null);
const apiKey = accessStore.openaiApiKey;
const handleConnect = async () => {
if (!isConnected) {
try {
setIsConnecting(true);
clientRef.current = isAzure
? new RTClient(new URL(endpoint), { key: apiKey }, { deployment })
: new RTClient(
{ key: apiKey },
{ model: "gpt-4o-realtime-preview-2024-10-01" },
);
const modalities: Modality[] =
modality === "audio" ? ["text", "audio"] : ["text"];
const turnDetection: TurnDetection = useVAD
? { type: "server_vad" }
: null;
clientRef.current.configure({
instructions: "Hi",
input_audio_transcription: { model: "whisper-1" },
turn_detection: turnDetection,
tools: [],
temperature: 0.9,
modalities,
});
startResponseListener();
setIsConnected(true);
} catch (error) {
console.error("Connection failed:", error);
} finally {
setIsConnecting(false);
}
} else {
await disconnect();
}
};
const disconnect = async () => {
if (clientRef.current) {
try {
await clientRef.current.close();
clientRef.current = null;
setIsConnected(false);
} catch (error) {
console.error("Disconnect failed:", error);
}
}
};
const startResponseListener = async () => {
if (!clientRef.current) return;
try {
for await (const serverEvent of clientRef.current.events()) {
if (serverEvent.type === "response") {
await handleResponse(serverEvent);
} else if (serverEvent.type === "input_audio") {
await handleInputAudio(serverEvent);
}
}
} catch (error) {
if (clientRef.current) {
console.error("Response iteration error:", error);
}
}
};
const handleResponse = async (response: RTResponse) => {
for await (const item of response) {
if (item.type === "message" && item.role === "assistant") {
const message = {
type: item.role,
content: "",
};
// setMessages((prevMessages) => [...prevMessages, message]);
for await (const content of item) {
if (content.type === "text") {
for await (const text of content.textChunks()) {
message.content += text;
// setMessages((prevMessages) => {
// prevMessages[prevMessages.length - 1].content = message.content;
// return [...prevMessages];
// });
}
} else if (content.type === "audio") {
const textTask = async () => {
for await (const text of content.transcriptChunks()) {
message.content += text;
// setMessages((prevMessages) => {
// prevMessages[prevMessages.length - 1].content =
// message.content;
// return [...prevMessages];
// });
}
};
const audioTask = async () => {
audioHandlerRef.current?.startStreamingPlayback();
for await (const audio of content.audioChunks()) {
audioHandlerRef.current?.playChunk(audio);
}
};
await Promise.all([textTask(), audioTask()]);
}
}
}
}
};
const handleInputAudio = async (item: RTInputAudioItem) => {
audioHandlerRef.current?.stopStreamingPlayback();
await item.waitForCompletion();
// setMessages((prevMessages) => [
// ...prevMessages,
// {
// type: "user",
// content: item.transcription || "",
// },
// ]);
};
const toggleRecording = async () => {
if (!isRecording && clientRef.current) {
try {
if (!audioHandlerRef.current) {
audioHandlerRef.current = new AudioHandler();
await audioHandlerRef.current.initialize();
}
await audioHandlerRef.current.startRecording(async (chunk) => {
await clientRef.current?.sendAudio(chunk);
});
setIsRecording(true);
} catch (error) {
console.error("Failed to start recording:", error);
}
} else if (audioHandlerRef.current) {
try {
audioHandlerRef.current.stopRecording();
if (!useVAD) {
const inputAudio = await clientRef.current?.commitAudio();
await handleInputAudio(inputAudio!);
await clientRef.current?.generateResponse();
}
setIsRecording(false);
} catch (error) {
console.error("Failed to stop recording:", error);
}
}
};
useEffect(() => {
const initAudioHandler = async () => {
const handler = new AudioHandler();
await handler.initialize();
audioHandlerRef.current = handler;
};
initAudioHandler().catch(console.error);
return () => {
disconnect();
audioHandlerRef.current?.close().catch(console.error);
};
}, []);
// useEffect(() => {
// if (
// clientRef.current?.getTurnDetectionType() === "server_vad" &&
@ -223,12 +409,16 @@ export function RealtimeChat({
const handleStartVoice = useCallback(() => {
onStartVoice?.();
setIsVoicePaused(false);
handleConnect();
}, []);
const handlePausedVoice = () => {
onPausedVoice?.();
setIsVoicePaused(true);
};
const handleClose = () => {
onClose?.();
disconnect();
};
return (
@ -241,15 +431,39 @@ export function RealtimeChat({
<div className={styles["icon-center"]}></div>
</div>
<div className={styles["bottom-icons"]}>
<div className={styles["icon-left"]}>
{isVoicePaused ? (
<VoiceOffIcon onClick={handleStartVoice} />
) : (
<VoiceIcon onClick={handlePausedVoice} />
)}
<div>
<IconButton
icon={isRecording ? <VoiceOffIcon /> : <VoiceIcon />}
onClick={toggleRecording}
disabled={!isConnected}
bordered
shadow
/>
</div>
<div className={styles["icon-right"]} onClick={onClose}>
<Close24Icon />
<div className={styles["icon-center"]}>
<IconButton
icon={<PowerIcon />}
text={
isConnecting
? "Connecting..."
: isConnected
? "Disconnect"
: "Connect"
}
onClick={handleConnect}
disabled={isConnecting}
bordered
shadow
/>
</div>
<div onClick={handleClose}>
<IconButton
icon={<Close24Icon />}
onClick={handleClose}
disabled={!isConnected}
bordered
shadow
/>
</div>
</div>
</div>

7
app/icons/power.svg Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="24" height="24" viewBox="0 0 48 48" fill="none" xmlns="http://www.w3.org/2000/svg">
<path
d="M14.5 8C13.8406 8.37652 13.2062 8.79103 12.6 9.24051C11.5625 10.0097 10.6074 10.8814 9.75 11.8402C6.79377 15.1463 5 19.4891 5 24.2455C5 34.6033 13.5066 43 24 43C34.4934 43 43 34.6033 43 24.2455C43 19.4891 41.2062 15.1463 38.25 11.8402C37.3926 10.8814 36.4375 10.0097 35.4 9.24051C34.7938 8.79103 34.1594 8.37652 33.5 8"
stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round" />
<path d="M24 4V24" stroke="#333" stroke-width="4" stroke-linecap="round" stroke-linejoin="round" />
</svg>

After

Width:  |  Height:  |  Size: 675 B

134
app/lib/audio.ts Normal file
View File

@ -0,0 +1,134 @@
export class AudioHandler {
private context: AudioContext;
private workletNode: AudioWorkletNode | null = null;
private stream: MediaStream | null = null;
private source: MediaStreamAudioSourceNode | null = null;
private readonly sampleRate = 24000;
private nextPlayTime: number = 0;
private isPlaying: boolean = false;
private playbackQueue: AudioBufferSourceNode[] = [];
constructor() {
this.context = new AudioContext({ sampleRate: this.sampleRate });
}
async initialize() {
await this.context.audioWorklet.addModule("/audio-processor.js");
}
async startRecording(onChunk: (chunk: Uint8Array) => void) {
try {
if (!this.workletNode) {
await this.initialize();
}
this.stream = await navigator.mediaDevices.getUserMedia({
audio: {
channelCount: 1,
sampleRate: this.sampleRate,
echoCancellation: true,
noiseSuppression: true,
},
});
await this.context.resume();
this.source = this.context.createMediaStreamSource(this.stream);
this.workletNode = new AudioWorkletNode(
this.context,
"audio-recorder-processor",
);
this.workletNode.port.onmessage = (event) => {
if (event.data.eventType === "audio") {
const float32Data = event.data.audioData;
const int16Data = new Int16Array(float32Data.length);
for (let i = 0; i < float32Data.length; i++) {
const s = Math.max(-1, Math.min(1, float32Data[i]));
int16Data[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
}
const uint8Data = new Uint8Array(int16Data.buffer);
onChunk(uint8Data);
}
};
this.source.connect(this.workletNode);
this.workletNode.connect(this.context.destination);
this.workletNode.port.postMessage({ command: "START_RECORDING" });
} catch (error) {
console.error("Error starting recording:", error);
throw error;
}
}
stopRecording() {
if (!this.workletNode || !this.source || !this.stream) {
throw new Error("Recording not started");
}
this.workletNode.port.postMessage({ command: "STOP_RECORDING" });
this.workletNode.disconnect();
this.source.disconnect();
this.stream.getTracks().forEach((track) => track.stop());
}
startStreamingPlayback() {
this.isPlaying = true;
this.nextPlayTime = this.context.currentTime;
}
stopStreamingPlayback() {
this.isPlaying = false;
this.playbackQueue.forEach((source) => source.stop());
this.playbackQueue = [];
}
playChunk(chunk: Uint8Array) {
if (!this.isPlaying) return;
const int16Data = new Int16Array(chunk.buffer);
const float32Data = new Float32Array(int16Data.length);
for (let i = 0; i < int16Data.length; i++) {
float32Data[i] = int16Data[i] / (int16Data[i] < 0 ? 0x8000 : 0x7fff);
}
const audioBuffer = this.context.createBuffer(
1,
float32Data.length,
this.sampleRate,
);
audioBuffer.getChannelData(0).set(float32Data);
const source = this.context.createBufferSource();
source.buffer = audioBuffer;
source.connect(this.context.destination);
const chunkDuration = audioBuffer.length / this.sampleRate;
source.start(this.nextPlayTime);
this.playbackQueue.push(source);
source.onended = () => {
const index = this.playbackQueue.indexOf(source);
if (index > -1) {
this.playbackQueue.splice(index, 1);
}
};
this.nextPlayTime += chunkDuration;
if (this.nextPlayTime < this.context.currentTime) {
this.nextPlayTime = this.context.currentTime;
}
}
async close() {
this.workletNode?.disconnect();
this.source?.disconnect();
this.stream?.getTracks().forEach((track) => track.stop());
await this.context.close();
}
}

View File

@ -52,7 +52,8 @@
"sass": "^1.59.2",
"spark-md5": "^3.0.2",
"use-debounce": "^9.0.4",
"zustand": "^4.3.8"
"zustand": "^4.3.8",
"rt-client": "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz"
},
"devDependencies": {
"@tauri-apps/api": "^1.6.0",

48
public/audio-processor.js Normal file
View File

@ -0,0 +1,48 @@
// @ts-nocheck
class AudioRecorderProcessor extends AudioWorkletProcessor {
constructor() {
super();
this.isRecording = false;
this.bufferSize = 2400; // 100ms at 24kHz
this.currentBuffer = [];
this.port.onmessage = (event) => {
if (event.data.command === "START_RECORDING") {
this.isRecording = true;
} else if (event.data.command === "STOP_RECORDING") {
this.isRecording = false;
if (this.currentBuffer.length > 0) {
this.sendBuffer();
}
}
};
}
sendBuffer() {
if (this.currentBuffer.length > 0) {
const audioData = new Float32Array(this.currentBuffer);
this.port.postMessage({
eventType: "audio",
audioData: audioData,
});
this.currentBuffer = [];
}
}
process(inputs) {
const input = inputs[0];
if (input.length > 0 && this.isRecording) {
const audioData = input[0];
this.currentBuffer.push(...audioData);
if (this.currentBuffer.length >= this.bufferSize) {
this.sendBuffer();
}
}
return true;
}
}
registerProcessor("audio-recorder-processor", AudioRecorderProcessor);

View File

@ -7455,6 +7455,12 @@ robust-predicates@^3.0.0:
resolved "https://registry.npmmirror.com/robust-predicates/-/robust-predicates-3.0.1.tgz#ecde075044f7f30118682bd9fb3f123109577f9a"
integrity sha512-ndEIpszUHiG4HtDsQLeIuMvRsDnn8c8rYStabochtUeCvfuvNptb5TUbVD68LRAILPX7p9nqQGh4xJgn3EHS/g==
"rt-client@https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz":
version "0.5.0"
resolved "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz#abf2e9a850201e3571b8d36830f77bc52af3de9b"
dependencies:
ws "^8.18.0"
run-parallel@^1.1.9:
version "1.2.0"
resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee"
@ -8498,9 +8504,9 @@ write-file-atomic@^4.0.2:
imurmurhash "^0.1.4"
signal-exit "^3.0.7"
ws@^8.11.0:
ws@^8.11.0, ws@^8.18.0:
version "8.18.0"
resolved "https://registry.npmmirror.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc"
resolved "https://registry.yarnpkg.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc"
integrity sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw==
xml-name-validator@^4.0.0: