JSON type fix in Client and and typing fix for /chat endpoint in gr.ChatInterface (#10193)

* fix

* add changeset

* add changeset

* fix

* chat interface fixes

* rename

* add changeset

* format

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-12-12 17:58:03 -08:00 committed by GitHub
parent 5e6e234cba
commit 424365bdbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 6 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface`

View File

@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
type_ = get_type(schema) type_ = get_type(schema)
if type_ == {}: if type_ == {}:
if "json" in schema.get("description", {}): if "json" in schema.get("description", {}):
return "Dict[Any, Any]" return "str | float | bool | list | dict"
else: else:
return "Any" return "Any"
elif type_ == "$ref": elif type_ == "$ref":

View File

@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema):
elif schema == "FileSerializable": elif schema == "FileSerializable":
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]" answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
elif schema == "JSONSerializable": elif schema == "JSONSerializable":
answer = "Dict[Any, Any]" answer = "str | float | bool | list | dict"
elif schema == "GallerySerializable": elif schema == "GallerySerializable":
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]" answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
elif schema == "SingleFileSerializable": elif schema == "SingleFileSerializable":

View File

@ -18,6 +18,7 @@ from gradio_client.documentation import document
from gradio import utils from gradio import utils
from gradio.blocks import Blocks from gradio.blocks import Blocks
from gradio.components import ( from gradio.components import (
JSON,
Button, Button,
Chatbot, Chatbot,
Component, Component,
@ -283,7 +284,7 @@ class ChatInterface(Blocks):
self.textbox.stop_btn = False self.textbox.stop_btn = False
self.fake_api_btn = Button("Fake API", visible=False) self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox( self.api_response = JSON(
label="Response", visible=False label="Response", visible=False
) # Used to store the response from the API call ) # Used to store the response from the API call
@ -311,6 +312,7 @@ class ChatInterface(Blocks):
input_component.render() input_component.render()
self.saved_input = State() # Stores the most recent user message self.saved_input = State() # Stores the most recent user message
self.null_component = State() # Used to discard unneeded values
self.chatbot_state = ( self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([]) State(self.chatbot.value) if self.chatbot.value else State([])
) )
@ -357,8 +359,7 @@ class ChatInterface(Blocks):
submit_fn_kwargs = { submit_fn_kwargs = {
"fn": submit_fn, "fn": submit_fn,
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs, "inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
"outputs": [self.fake_response_textbox, self.chatbot] "outputs": [self.null_component, self.chatbot] + self.additional_outputs,
+ self.additional_outputs,
"show_api": False, "show_api": False,
"concurrency_limit": cast( "concurrency_limit": cast(
Union[int, Literal["default"], None], self.concurrency_limit Union[int, Literal["default"], None], self.concurrency_limit
@ -395,11 +396,12 @@ class ChatInterface(Blocks):
self.fake_api_btn.click( self.fake_api_btn.click(
submit_fn, submit_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs, [self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs, [self.api_response, self.chatbot_state] + self.additional_outputs,
api_name=cast(Union[str, Literal[False]], self.api_name), api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast( concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit Union[int, Literal["default"], None], self.concurrency_limit
), ),
postprocess=False,
) )
if ( if (