diff --git a/.changeset/weak-bugs-itch.md b/.changeset/weak-bugs-itch.md new file mode 100644 index 0000000000..5b5b895f62 --- /dev/null +++ b/.changeset/weak-bugs-itch.md @@ -0,0 +1,7 @@ +--- +"@gradio/app": minor +"@gradio/client": minor +"gradio": minor +--- + +feat:Client fixes diff --git a/client/js/src/client.ts b/client/js/src/client.ts index c6cf920e23..71c82f5a84 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -139,18 +139,24 @@ export class Client { if (config) { this.config = config; if (this.config && this.config.connect_heartbeat) { - // connect to the heartbeat endpoint via GET request - const heartbeat_url = new URL( - `${this.config.root}/heartbeat/${this.session_hash}` - ); - this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 - if (this.config.space_id && this.options.hf_token) { this.jwt = await get_jwt( this.config.space_id, this.options.hf_token ); } + + // connect to the heartbeat endpoint via GET request + const heartbeat_url = new URL( + `${this.config.root}/heartbeat/${this.session_hash}` + ); + + // if the jwt is available, add it to the query params + if (this.jwt) { + heartbeat_url.searchParams.set("jwt", this.jwt); + } + + this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 } } }); diff --git a/client/js/src/helpers/api_info.ts b/client/js/src/helpers/api_info.ts index 32015f1939..9ac0add27c 100644 --- a/client/js/src/helpers/api_info.ts +++ b/client/js/src/helpers/api_info.ts @@ -77,7 +77,11 @@ export function transform_api_info( Object.entries(api_info[category]).forEach( ([endpoint, { parameters, returns }]) => { const dependencyIndex = - config.dependencies.findIndex((dep) => dep.api_name === endpoint) || + config.dependencies.findIndex( + (dep) => + dep.api_name === endpoint || + dep.api_name === endpoint.replace("/", "") + ) || api_map[endpoint.replace("/", "")] || -1; @@ -86,6 +90,34 @@ export function transform_api_info( ? config.dependencies[dependencyIndex].types : { continuous: false, generator: false }; + if ( + dependencyIndex !== -1 && + config.dependencies[dependencyIndex]?.inputs?.length !== + parameters.length + ) { + const components = config.dependencies[dependencyIndex].inputs.map( + (input) => config.components.find((c) => c.id === input)?.type + ); + + try { + components.forEach((comp, idx) => { + if (comp === "state") { + const new_param = { + component: "state", + example: null, + parameter_default: null, + parameter_has_default: true, + parameter_name: null, + hidden: true + }; + + // @ts-ignore + parameters.splice(idx, 0, new_param); + } + }); + } catch (e) {} + } + const transform_type = ( data: ApiData, component: string, @@ -93,17 +125,17 @@ export function transform_api_info( signature_type: "return" | "parameter" ): JsApiData => ({ ...data, - description: get_description(data.type, serializer), + description: get_description(data?.type, serializer), type: - get_type(data.type, component, serializer, signature_type) || "" + get_type(data?.type, component, serializer, signature_type) || "" }); transformed_info[category][endpoint] = { parameters: parameters.map((p: ApiData) => - transform_type(p, p.component, p.serializer, "parameter") + transform_type(p, p?.component, p?.serializer, "parameter") ), returns: returns.map((r: ApiData) => - transform_type(r, r.component, r.serializer, "return") + transform_type(r, r?.component, r?.serializer, "return") ), type: dependencyTypes }; @@ -121,7 +153,7 @@ export function get_type( serializer: string, signature_type: "return" | "parameter" ): string | undefined { - switch (type.type) { + switch (type?.type) { case "string": return "string"; case "boolean": @@ -166,7 +198,7 @@ export function get_description( } else if (serializer === "FileSerializable") { return "array of files or single file"; } - return type.description; + return type?.description; } export function handle_message( diff --git a/demo/chatinterface_multimodal/run.ipynb b/demo/chatinterface_multimodal/run.ipynb index 0ab1249612..9582013a99 100644 --- a/demo/chatinterface_multimodal/run.ipynb +++ b/demo/chatinterface_multimodal/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "demo = gr.ChatInterface(fn=echo, examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}], title=\"Echo Bot\", multimodal=True)\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=echo,\n", " examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}],\n", " title=\"Echo Bot\",\n", " multimodal=True,\n", ")\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_multimodal/run.py b/demo/chatinterface_multimodal/run.py index ff797547fb..bc73bf4820 100644 --- a/demo/chatinterface_multimodal/run.py +++ b/demo/chatinterface_multimodal/run.py @@ -1,7 +1,14 @@ import gradio as gr + def echo(message, history): return message["text"] -demo = gr.ChatInterface(fn=echo, examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}], title="Echo Bot", multimodal=True) + +demo = gr.ChatInterface( + fn=echo, + examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}], + title="Echo Bot", + multimodal=True, +) demo.launch() diff --git a/demo/chatinterface_system_prompt/run.ipynb b/demo/chatinterface_system_prompt/run.ipynb index 7ad525d719..6bef666380 100644 --- a/demo/chatinterface_system_prompt/run.ipynb +++ b/demo/chatinterface_system_prompt/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " echo,\n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"),\n", " gr.Slider(10, 100),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_system_prompt/run.py b/demo/chatinterface_system_prompt/run.py index e8b1422c4f..8731cab83f 100644 --- a/demo/chatinterface_system_prompt/run.py +++ b/demo/chatinterface_system_prompt/run.py @@ -1,18 +1,21 @@ import gradio as gr import time + def echo(message, history, system_prompt, tokens): response = f"System prompt: {system_prompt}\n Message: {message}." for i in range(min(len(response), int(tokens))): time.sleep(0.05) - yield response[: i+1] + yield response[: i + 1] -demo = gr.ChatInterface(echo, - additional_inputs=[ - gr.Textbox("You are helpful AI.", label="System Prompt"), - gr.Slider(10, 100) - ] - ) + +demo = gr.ChatInterface( + echo, + additional_inputs=[ + gr.Textbox("You are helpful AI.", label="System Prompt"), + gr.Slider(10, 100), + ], +) if __name__ == "__main__": - demo.queue().launch() \ No newline at end of file + demo.queue().launch() diff --git a/js/app/src/api_docs/CodeSnippet.svelte b/js/app/src/api_docs/CodeSnippet.svelte index cc94b83ba4..684b15226e 100644 --- a/js/app/src/api_docs/CodeSnippet.svelte +++ b/js/app/src/api_docs/CodeSnippet.svelte @@ -78,7 +78,7 @@ result = client.predict
import { Client } from "@gradio/client"; -{#each blob_examples as { label, type, python_type, component, example_input, serializer }, i} const response_{i} = await fetch("{example_input.url}"); const example{component} = await response_{i}.blob(); @@ -88,11 +88,12 @@ const client = await Client.connect("{root}"); const result = await client.predict({#if named}"/{dependency.api_name}"{:else}{dependency_index}{/if}, { {#each endpoint_parameters as { label, type, python_type, component, example_input, serializer }, i}{#each endpoint_parameters as { label, parameter_name, type, python_type, component, example_input, serializer }, i}{#if blob_components.includes(component)} {label}: example{component}{parameter_name}: example{component}, {:else} {label}: {represent_value( + >{parameter_name}: {represent_value( example_input, python_type.type, "js"