JSON type fix in Client and and typing fix for /chat endpoint in gr.ChatInterface (#10193)

* fix

* add changeset

* add changeset

* fix

* chat interface fixes

* rename

* add changeset

* format

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-12-12 17:58:03 -08:00 committed by GitHub
parent 5e6e234cba
commit 424365bdbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 6 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface`

View File

@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
type_ = get_type(schema)
if type_ == {}:
if "json" in schema.get("description", {}):
return "Dict[Any, Any]"
return "str | float | bool | list | dict"
else:
return "Any"
elif type_ == "$ref":

View File

@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema):
elif schema == "FileSerializable":
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
elif schema == "JSONSerializable":
answer = "Dict[Any, Any]"
answer = "str | float | bool | list | dict"
elif schema == "GallerySerializable":
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
elif schema == "SingleFileSerializable":

View File

@ -18,6 +18,7 @@ from gradio_client.documentation import document
from gradio import utils
from gradio.blocks import Blocks
from gradio.components import (
JSON,
Button,
Chatbot,
Component,
@ -283,7 +284,7 @@ class ChatInterface(Blocks):
self.textbox.stop_btn = False
self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
self.api_response = JSON(
label="Response", visible=False
) # Used to store the response from the API call
@ -311,6 +312,7 @@ class ChatInterface(Blocks):
input_component.render()
self.saved_input = State() # Stores the most recent user message
self.null_component = State() # Used to discard unneeded values
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
@ -357,8 +359,7 @@ class ChatInterface(Blocks):
submit_fn_kwargs = {
"fn": submit_fn,
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
"outputs": [self.fake_response_textbox, self.chatbot]
+ self.additional_outputs,
"outputs": [self.null_component, self.chatbot] + self.additional_outputs,
"show_api": False,
"concurrency_limit": cast(
Union[int, Literal["default"], None], self.concurrency_limit
@ -395,11 +396,12 @@ class ChatInterface(Blocks):
self.fake_api_btn.click(
submit_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs,
[self.api_response, self.chatbot_state] + self.additional_outputs,
api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
postprocess=False,
)
if (