mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Client fixes (#8272)
* fix param name * fix hidden state variable * pass jwt to heartbeat event * notebooks * format * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
30463c5e15
commit
fbf4edde7c
7
.changeset/weak-bugs-itch.md
Normal file
7
.changeset/weak-bugs-itch.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Client fixes
|
@ -139,18 +139,24 @@ export class Client {
|
||||
if (config) {
|
||||
this.config = config;
|
||||
if (this.config && this.config.connect_heartbeat) {
|
||||
// connect to the heartbeat endpoint via GET request
|
||||
const heartbeat_url = new URL(
|
||||
`${this.config.root}/heartbeat/${this.session_hash}`
|
||||
);
|
||||
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
|
||||
|
||||
if (this.config.space_id && this.options.hf_token) {
|
||||
this.jwt = await get_jwt(
|
||||
this.config.space_id,
|
||||
this.options.hf_token
|
||||
);
|
||||
}
|
||||
|
||||
// connect to the heartbeat endpoint via GET request
|
||||
const heartbeat_url = new URL(
|
||||
`${this.config.root}/heartbeat/${this.session_hash}`
|
||||
);
|
||||
|
||||
// if the jwt is available, add it to the query params
|
||||
if (this.jwt) {
|
||||
heartbeat_url.searchParams.set("jwt", this.jwt);
|
||||
}
|
||||
|
||||
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -77,7 +77,11 @@ export function transform_api_info(
|
||||
Object.entries(api_info[category]).forEach(
|
||||
([endpoint, { parameters, returns }]) => {
|
||||
const dependencyIndex =
|
||||
config.dependencies.findIndex((dep) => dep.api_name === endpoint) ||
|
||||
config.dependencies.findIndex(
|
||||
(dep) =>
|
||||
dep.api_name === endpoint ||
|
||||
dep.api_name === endpoint.replace("/", "")
|
||||
) ||
|
||||
api_map[endpoint.replace("/", "")] ||
|
||||
-1;
|
||||
|
||||
@ -86,6 +90,34 @@ export function transform_api_info(
|
||||
? config.dependencies[dependencyIndex].types
|
||||
: { continuous: false, generator: false };
|
||||
|
||||
if (
|
||||
dependencyIndex !== -1 &&
|
||||
config.dependencies[dependencyIndex]?.inputs?.length !==
|
||||
parameters.length
|
||||
) {
|
||||
const components = config.dependencies[dependencyIndex].inputs.map(
|
||||
(input) => config.components.find((c) => c.id === input)?.type
|
||||
);
|
||||
|
||||
try {
|
||||
components.forEach((comp, idx) => {
|
||||
if (comp === "state") {
|
||||
const new_param = {
|
||||
component: "state",
|
||||
example: null,
|
||||
parameter_default: null,
|
||||
parameter_has_default: true,
|
||||
parameter_name: null,
|
||||
hidden: true
|
||||
};
|
||||
|
||||
// @ts-ignore
|
||||
parameters.splice(idx, 0, new_param);
|
||||
}
|
||||
});
|
||||
} catch (e) {}
|
||||
}
|
||||
|
||||
const transform_type = (
|
||||
data: ApiData,
|
||||
component: string,
|
||||
@ -93,17 +125,17 @@ export function transform_api_info(
|
||||
signature_type: "return" | "parameter"
|
||||
): JsApiData => ({
|
||||
...data,
|
||||
description: get_description(data.type, serializer),
|
||||
description: get_description(data?.type, serializer),
|
||||
type:
|
||||
get_type(data.type, component, serializer, signature_type) || ""
|
||||
get_type(data?.type, component, serializer, signature_type) || ""
|
||||
});
|
||||
|
||||
transformed_info[category][endpoint] = {
|
||||
parameters: parameters.map((p: ApiData) =>
|
||||
transform_type(p, p.component, p.serializer, "parameter")
|
||||
transform_type(p, p?.component, p?.serializer, "parameter")
|
||||
),
|
||||
returns: returns.map((r: ApiData) =>
|
||||
transform_type(r, r.component, r.serializer, "return")
|
||||
transform_type(r, r?.component, r?.serializer, "return")
|
||||
),
|
||||
type: dependencyTypes
|
||||
};
|
||||
@ -121,7 +153,7 @@ export function get_type(
|
||||
serializer: string,
|
||||
signature_type: "return" | "parameter"
|
||||
): string | undefined {
|
||||
switch (type.type) {
|
||||
switch (type?.type) {
|
||||
case "string":
|
||||
return "string";
|
||||
case "boolean":
|
||||
@ -166,7 +198,7 @@ export function get_description(
|
||||
} else if (serializer === "FileSerializable") {
|
||||
return "array of files or single file";
|
||||
}
|
||||
return type.description;
|
||||
return type?.description;
|
||||
}
|
||||
|
||||
export function handle_message(
|
||||
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "demo = gr.ChatInterface(fn=echo, examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}], title=\"Echo Bot\", multimodal=True)\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=echo,\n", " examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}],\n", " title=\"Echo Bot\",\n", " multimodal=True,\n", ")\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -1,7 +1,14 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def echo(message, history):
|
||||
return message["text"]
|
||||
|
||||
demo = gr.ChatInterface(fn=echo, examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}], title="Echo Bot", multimodal=True)
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
fn=echo,
|
||||
examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}],
|
||||
title="Echo Bot",
|
||||
multimodal=True,
|
||||
)
|
||||
demo.launch()
|
||||
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(\n", " echo,\n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"),\n", " gr.Slider(10, 100),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -1,18 +1,21 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
|
||||
def echo(message, history, system_prompt, tokens):
|
||||
response = f"System prompt: {system_prompt}\n Message: {message}."
|
||||
for i in range(min(len(response), int(tokens))):
|
||||
time.sleep(0.05)
|
||||
yield response[: i+1]
|
||||
yield response[: i + 1]
|
||||
|
||||
demo = gr.ChatInterface(echo,
|
||||
additional_inputs=[
|
||||
gr.Textbox("You are helpful AI.", label="System Prompt"),
|
||||
gr.Slider(10, 100)
|
||||
]
|
||||
)
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
echo,
|
||||
additional_inputs=[
|
||||
gr.Textbox("You are helpful AI.", label="System Prompt"),
|
||||
gr.Slider(10, 100),
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue().launch()
|
||||
demo.queue().launch()
|
||||
|
@ -78,7 +78,7 @@ result = client.<span class="highlight">predict</span
|
||||
</div>
|
||||
<div bind:this={js_code}>
|
||||
<pre>import { Client } from "@gradio/client";
|
||||
{#each blob_examples as { label, type, python_type, component, example_input, serializer }, i}<!--
|
||||
{#each blob_examples as { component, example_input }, i}<!--
|
||||
-->
|
||||
const response_{i} = await fetch("{example_input.url}");
|
||||
const example{component} = await response_{i}.blob();
|
||||
@ -88,11 +88,12 @@ const client = await Client.connect(<span class="token string">"{root}"</span>);
|
||||
const result = await client.predict({#if named}<span class="api-name"
|
||||
>"/{dependency.api_name}"</span
|
||||
>{:else}{dependency_index}{/if}, { <!--
|
||||
-->{#each endpoint_parameters as { label, type, python_type, component, example_input, serializer }, i}<!--
|
||||
-->{#each endpoint_parameters as { label, parameter_name, type, python_type, component, example_input, serializer }, i}<!--
|
||||
-->{#if blob_components.includes(component)}<!--
|
||||
-->
|
||||
<span
|
||||
class="example-inputs">{label}: example{component}</span
|
||||
class="example-inputs"
|
||||
>{parameter_name}: example{component}</span
|
||||
>, <!--
|
||||
--><span class="desc"
|
||||
><!--
|
||||
@ -104,7 +105,7 @@ const result = await client.predict({#if named}<span class="api-name"
|
||||
-->{:else}<!--
|
||||
-->
|
||||
<span class="example-inputs"
|
||||
>{label}: {represent_value(
|
||||
>{parameter_name}: {represent_value(
|
||||
example_input,
|
||||
python_type.type,
|
||||
"js"
|
||||
|
Loading…
x
Reference in New Issue
Block a user