diff --git a/app/components/realtime-chat/realtime-chat.module.scss b/app/components/realtime-chat/realtime-chat.module.scss index 6151eb8aa..2d803a13d 100644 --- a/app/components/realtime-chat/realtime-chat.module.scss +++ b/app/components/realtime-chat/realtime-chat.module.scss @@ -24,12 +24,19 @@ .bottom-icons { display: flex; justify-content: space-between; + align-items: center; width: 100%; position: absolute; bottom: 20px; box-sizing: border-box; padding: 0 20px; } + .icon-center { + display: flex; + justify-content: center; + align-items: center; + gap: 4px; + } .icon-left, .icon-right { diff --git a/app/components/realtime-chat/realtime-chat.tsx b/app/components/realtime-chat/realtime-chat.tsx index 3278b5eec..43d76c643 100644 --- a/app/components/realtime-chat/realtime-chat.tsx +++ b/app/components/realtime-chat/realtime-chat.tsx @@ -1,34 +1,220 @@ import VoiceIcon from "@/app/icons/voice.svg"; import VoiceOffIcon from "@/app/icons/voice-off.svg"; import Close24Icon from "@/app/icons/close-24.svg"; +import PowerIcon from "@/app/icons/power.svg"; + import styles from "./realtime-chat.module.scss"; import clsx from "clsx"; -import { useState, useRef, useCallback } from "react"; +import { useState, useRef, useCallback, useEffect } from "react"; import { useAccessStore, useChatStore, ChatMessage } from "@/app/store"; +import { IconButton } from "@/app/components/button"; + +import { + Modality, + RTClient, + RTInputAudioItem, + RTResponse, + TurnDetection, +} from "rt-client"; +import { AudioHandler } from "@/app/lib/audio"; + interface RealtimeChatProps { onClose?: () => void; onStartVoice?: () => void; onPausedVoice?: () => void; - sampleRate?: number; } export function RealtimeChat({ onClose, onStartVoice, onPausedVoice, - sampleRate = 24000, }: RealtimeChatProps) { - const [isVoicePaused, setIsVoicePaused] = useState(true); - const clientRef = useRef(null); const currentItemId = useRef(""); const currentBotMessage = useRef(); const currentUserMessage = useRef(); const accessStore = useAccessStore.getState(); const chatStore = useChatStore(); + const [isRecording, setIsRecording] = useState(false); + const [isConnected, setIsConnected] = useState(false); + const [isConnecting, setIsConnecting] = useState(false); + const [modality, setModality] = useState("audio"); + const [isAzure, setIsAzure] = useState(false); + const [endpoint, setEndpoint] = useState(""); + const [deployment, setDeployment] = useState(""); + const [useVAD, setUseVAD] = useState(true); + + const clientRef = useRef(null); + const audioHandlerRef = useRef(null); + + const apiKey = accessStore.openaiApiKey; + + const handleConnect = async () => { + if (!isConnected) { + try { + setIsConnecting(true); + clientRef.current = isAzure + ? new RTClient(new URL(endpoint), { key: apiKey }, { deployment }) + : new RTClient( + { key: apiKey }, + { model: "gpt-4o-realtime-preview-2024-10-01" }, + ); + const modalities: Modality[] = + modality === "audio" ? ["text", "audio"] : ["text"]; + const turnDetection: TurnDetection = useVAD + ? { type: "server_vad" } + : null; + clientRef.current.configure({ + instructions: "Hi", + input_audio_transcription: { model: "whisper-1" }, + turn_detection: turnDetection, + tools: [], + temperature: 0.9, + modalities, + }); + startResponseListener(); + + setIsConnected(true); + } catch (error) { + console.error("Connection failed:", error); + } finally { + setIsConnecting(false); + } + } else { + await disconnect(); + } + }; + + const disconnect = async () => { + if (clientRef.current) { + try { + await clientRef.current.close(); + clientRef.current = null; + setIsConnected(false); + } catch (error) { + console.error("Disconnect failed:", error); + } + } + }; + + const startResponseListener = async () => { + if (!clientRef.current) return; + + try { + for await (const serverEvent of clientRef.current.events()) { + if (serverEvent.type === "response") { + await handleResponse(serverEvent); + } else if (serverEvent.type === "input_audio") { + await handleInputAudio(serverEvent); + } + } + } catch (error) { + if (clientRef.current) { + console.error("Response iteration error:", error); + } + } + }; + + const handleResponse = async (response: RTResponse) => { + for await (const item of response) { + if (item.type === "message" && item.role === "assistant") { + const message = { + type: item.role, + content: "", + }; + // setMessages((prevMessages) => [...prevMessages, message]); + for await (const content of item) { + if (content.type === "text") { + for await (const text of content.textChunks()) { + message.content += text; + // setMessages((prevMessages) => { + // prevMessages[prevMessages.length - 1].content = message.content; + // return [...prevMessages]; + // }); + } + } else if (content.type === "audio") { + const textTask = async () => { + for await (const text of content.transcriptChunks()) { + message.content += text; + // setMessages((prevMessages) => { + // prevMessages[prevMessages.length - 1].content = + // message.content; + // return [...prevMessages]; + // }); + } + }; + const audioTask = async () => { + audioHandlerRef.current?.startStreamingPlayback(); + for await (const audio of content.audioChunks()) { + audioHandlerRef.current?.playChunk(audio); + } + }; + await Promise.all([textTask(), audioTask()]); + } + } + } + } + }; + + const handleInputAudio = async (item: RTInputAudioItem) => { + audioHandlerRef.current?.stopStreamingPlayback(); + await item.waitForCompletion(); + // setMessages((prevMessages) => [ + // ...prevMessages, + // { + // type: "user", + // content: item.transcription || "", + // }, + // ]); + }; + + const toggleRecording = async () => { + if (!isRecording && clientRef.current) { + try { + if (!audioHandlerRef.current) { + audioHandlerRef.current = new AudioHandler(); + await audioHandlerRef.current.initialize(); + } + await audioHandlerRef.current.startRecording(async (chunk) => { + await clientRef.current?.sendAudio(chunk); + }); + setIsRecording(true); + } catch (error) { + console.error("Failed to start recording:", error); + } + } else if (audioHandlerRef.current) { + try { + audioHandlerRef.current.stopRecording(); + if (!useVAD) { + const inputAudio = await clientRef.current?.commitAudio(); + await handleInputAudio(inputAudio!); + await clientRef.current?.generateResponse(); + } + setIsRecording(false); + } catch (error) { + console.error("Failed to stop recording:", error); + } + } + }; + + useEffect(() => { + const initAudioHandler = async () => { + const handler = new AudioHandler(); + await handler.initialize(); + audioHandlerRef.current = handler; + }; + + initAudioHandler().catch(console.error); + + return () => { + disconnect(); + audioHandlerRef.current?.close().catch(console.error); + }; + }, []); + // useEffect(() => { // if ( // clientRef.current?.getTurnDetectionType() === "server_vad" && @@ -223,12 +409,16 @@ export function RealtimeChat({ const handleStartVoice = useCallback(() => { onStartVoice?.(); - setIsVoicePaused(false); + handleConnect(); }, []); const handlePausedVoice = () => { onPausedVoice?.(); - setIsVoicePaused(true); + }; + + const handleClose = () => { + onClose?.(); + disconnect(); }; return ( @@ -241,15 +431,39 @@ export function RealtimeChat({
-
- {isVoicePaused ? ( - - ) : ( - - )} +
+ : } + onClick={toggleRecording} + disabled={!isConnected} + bordered + shadow + />
-
- +
+ } + text={ + isConnecting + ? "Connecting..." + : isConnected + ? "Disconnect" + : "Connect" + } + onClick={handleConnect} + disabled={isConnecting} + bordered + shadow + /> +
+
+ } + onClick={handleClose} + disabled={!isConnected} + bordered + shadow + />
diff --git a/app/icons/power.svg b/app/icons/power.svg new file mode 100644 index 000000000..f60fc4266 --- /dev/null +++ b/app/icons/power.svg @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/app/lib/audio.ts b/app/lib/audio.ts new file mode 100644 index 000000000..b3674a2c5 --- /dev/null +++ b/app/lib/audio.ts @@ -0,0 +1,134 @@ +export class AudioHandler { + private context: AudioContext; + private workletNode: AudioWorkletNode | null = null; + private stream: MediaStream | null = null; + private source: MediaStreamAudioSourceNode | null = null; + private readonly sampleRate = 24000; + + private nextPlayTime: number = 0; + private isPlaying: boolean = false; + private playbackQueue: AudioBufferSourceNode[] = []; + + constructor() { + this.context = new AudioContext({ sampleRate: this.sampleRate }); + } + + async initialize() { + await this.context.audioWorklet.addModule("/audio-processor.js"); + } + + async startRecording(onChunk: (chunk: Uint8Array) => void) { + try { + if (!this.workletNode) { + await this.initialize(); + } + + this.stream = await navigator.mediaDevices.getUserMedia({ + audio: { + channelCount: 1, + sampleRate: this.sampleRate, + echoCancellation: true, + noiseSuppression: true, + }, + }); + + await this.context.resume(); + this.source = this.context.createMediaStreamSource(this.stream); + this.workletNode = new AudioWorkletNode( + this.context, + "audio-recorder-processor", + ); + + this.workletNode.port.onmessage = (event) => { + if (event.data.eventType === "audio") { + const float32Data = event.data.audioData; + const int16Data = new Int16Array(float32Data.length); + + for (let i = 0; i < float32Data.length; i++) { + const s = Math.max(-1, Math.min(1, float32Data[i])); + int16Data[i] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + + const uint8Data = new Uint8Array(int16Data.buffer); + onChunk(uint8Data); + } + }; + + this.source.connect(this.workletNode); + this.workletNode.connect(this.context.destination); + + this.workletNode.port.postMessage({ command: "START_RECORDING" }); + } catch (error) { + console.error("Error starting recording:", error); + throw error; + } + } + + stopRecording() { + if (!this.workletNode || !this.source || !this.stream) { + throw new Error("Recording not started"); + } + + this.workletNode.port.postMessage({ command: "STOP_RECORDING" }); + + this.workletNode.disconnect(); + this.source.disconnect(); + this.stream.getTracks().forEach((track) => track.stop()); + } + startStreamingPlayback() { + this.isPlaying = true; + this.nextPlayTime = this.context.currentTime; + } + + stopStreamingPlayback() { + this.isPlaying = false; + this.playbackQueue.forEach((source) => source.stop()); + this.playbackQueue = []; + } + + playChunk(chunk: Uint8Array) { + if (!this.isPlaying) return; + + const int16Data = new Int16Array(chunk.buffer); + + const float32Data = new Float32Array(int16Data.length); + for (let i = 0; i < int16Data.length; i++) { + float32Data[i] = int16Data[i] / (int16Data[i] < 0 ? 0x8000 : 0x7fff); + } + + const audioBuffer = this.context.createBuffer( + 1, + float32Data.length, + this.sampleRate, + ); + audioBuffer.getChannelData(0).set(float32Data); + + const source = this.context.createBufferSource(); + source.buffer = audioBuffer; + source.connect(this.context.destination); + + const chunkDuration = audioBuffer.length / this.sampleRate; + + source.start(this.nextPlayTime); + + this.playbackQueue.push(source); + source.onended = () => { + const index = this.playbackQueue.indexOf(source); + if (index > -1) { + this.playbackQueue.splice(index, 1); + } + }; + + this.nextPlayTime += chunkDuration; + + if (this.nextPlayTime < this.context.currentTime) { + this.nextPlayTime = this.context.currentTime; + } + } + async close() { + this.workletNode?.disconnect(); + this.source?.disconnect(); + this.stream?.getTracks().forEach((track) => track.stop()); + await this.context.close(); + } +} diff --git a/package.json b/package.json index c49a84d42..235652c39 100644 --- a/package.json +++ b/package.json @@ -52,7 +52,8 @@ "sass": "^1.59.2", "spark-md5": "^3.0.2", "use-debounce": "^9.0.4", - "zustand": "^4.3.8" + "zustand": "^4.3.8", + "rt-client": "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz" }, "devDependencies": { "@tauri-apps/api": "^1.6.0", diff --git a/public/audio-processor.js b/public/audio-processor.js new file mode 100644 index 000000000..4fae6ea1a --- /dev/null +++ b/public/audio-processor.js @@ -0,0 +1,48 @@ +// @ts-nocheck +class AudioRecorderProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.isRecording = false; + this.bufferSize = 2400; // 100ms at 24kHz + this.currentBuffer = []; + + this.port.onmessage = (event) => { + if (event.data.command === "START_RECORDING") { + this.isRecording = true; + } else if (event.data.command === "STOP_RECORDING") { + this.isRecording = false; + + if (this.currentBuffer.length > 0) { + this.sendBuffer(); + } + } + }; + } + + sendBuffer() { + if (this.currentBuffer.length > 0) { + const audioData = new Float32Array(this.currentBuffer); + this.port.postMessage({ + eventType: "audio", + audioData: audioData, + }); + this.currentBuffer = []; + } + } + + process(inputs) { + const input = inputs[0]; + if (input.length > 0 && this.isRecording) { + const audioData = input[0]; + + this.currentBuffer.push(...audioData); + + if (this.currentBuffer.length >= this.bufferSize) { + this.sendBuffer(); + } + } + return true; + } +} + +registerProcessor("audio-recorder-processor", AudioRecorderProcessor); diff --git a/yarn.lock b/yarn.lock index 5c2dfe4ed..4f92858f3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7455,6 +7455,12 @@ robust-predicates@^3.0.0: resolved "https://registry.npmmirror.com/robust-predicates/-/robust-predicates-3.0.1.tgz#ecde075044f7f30118682bd9fb3f123109577f9a" integrity sha512-ndEIpszUHiG4HtDsQLeIuMvRsDnn8c8rYStabochtUeCvfuvNptb5TUbVD68LRAILPX7p9nqQGh4xJgn3EHS/g== +"rt-client@https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz": + version "0.5.0" + resolved "https://github.com/Azure-Samples/aoai-realtime-audio-sdk/releases/download/js/v0.5.0/rt-client-0.5.0.tgz#abf2e9a850201e3571b8d36830f77bc52af3de9b" + dependencies: + ws "^8.18.0" + run-parallel@^1.1.9: version "1.2.0" resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" @@ -8498,9 +8504,9 @@ write-file-atomic@^4.0.2: imurmurhash "^0.1.4" signal-exit "^3.0.7" -ws@^8.11.0: +ws@^8.11.0, ws@^8.18.0: version "8.18.0" - resolved "https://registry.npmmirror.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc" + resolved "https://registry.yarnpkg.com/ws/-/ws-8.18.0.tgz#0d7505a6eafe2b0e712d232b42279f53bc289bbc" integrity sha512-8VbfWfHLbbwu3+N6OKsOMpBdT4kXPDDB9cJk2bJ6mh9ucxdlnNvH1e+roYkKmN9Nxw2yjz7VzeO9oOz2zJ04Pw== xml-name-validator@^4.0.0: