mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
JSON type fix in Client and and typing fix for /chat
endpoint in gr.ChatInterface
(#10193)
* fix * add changeset * add changeset * fix * chat interface fixes * rename * add changeset * format * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
5e6e234cba
commit
424365bdbd
6
.changeset/warm-dragons-carry.md
Normal file
6
.changeset/warm-dragons-carry.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
"gradio": patch
|
||||||
|
"gradio_client": patch
|
||||||
|
---
|
||||||
|
|
||||||
|
feat:JSON type fix in Client and and typing fix for `/chat` endpoint in `gr.ChatInterface`
|
@ -926,7 +926,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
|
|||||||
type_ = get_type(schema)
|
type_ = get_type(schema)
|
||||||
if type_ == {}:
|
if type_ == {}:
|
||||||
if "json" in schema.get("description", {}):
|
if "json" in schema.get("description", {}):
|
||||||
return "Dict[Any, Any]"
|
return "str | float | bool | list | dict"
|
||||||
else:
|
else:
|
||||||
return "Any"
|
return "Any"
|
||||||
elif type_ == "$ref":
|
elif type_ == "$ref":
|
||||||
|
@ -170,7 +170,7 @@ def test_json_schema_to_python_type(schema):
|
|||||||
elif schema == "FileSerializable":
|
elif schema == "FileSerializable":
|
||||||
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
|
answer = "str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)) | List[str | Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file))]"
|
||||||
elif schema == "JSONSerializable":
|
elif schema == "JSONSerializable":
|
||||||
answer = "Dict[Any, Any]"
|
answer = "str | float | bool | list | dict"
|
||||||
elif schema == "GallerySerializable":
|
elif schema == "GallerySerializable":
|
||||||
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
|
answer = "Tuple[Dict(name: str (name of file), data: str (base64 representation of file), size: int (size of image in bytes), is_file: bool (true if the file has been uploaded to the server), orig_name: str (original name of the file)), str | None]"
|
||||||
elif schema == "SingleFileSerializable":
|
elif schema == "SingleFileSerializable":
|
||||||
|
@ -18,6 +18,7 @@ from gradio_client.documentation import document
|
|||||||
from gradio import utils
|
from gradio import utils
|
||||||
from gradio.blocks import Blocks
|
from gradio.blocks import Blocks
|
||||||
from gradio.components import (
|
from gradio.components import (
|
||||||
|
JSON,
|
||||||
Button,
|
Button,
|
||||||
Chatbot,
|
Chatbot,
|
||||||
Component,
|
Component,
|
||||||
@ -283,7 +284,7 @@ class ChatInterface(Blocks):
|
|||||||
self.textbox.stop_btn = False
|
self.textbox.stop_btn = False
|
||||||
|
|
||||||
self.fake_api_btn = Button("Fake API", visible=False)
|
self.fake_api_btn = Button("Fake API", visible=False)
|
||||||
self.fake_response_textbox = Textbox(
|
self.api_response = JSON(
|
||||||
label="Response", visible=False
|
label="Response", visible=False
|
||||||
) # Used to store the response from the API call
|
) # Used to store the response from the API call
|
||||||
|
|
||||||
@ -311,6 +312,7 @@ class ChatInterface(Blocks):
|
|||||||
input_component.render()
|
input_component.render()
|
||||||
|
|
||||||
self.saved_input = State() # Stores the most recent user message
|
self.saved_input = State() # Stores the most recent user message
|
||||||
|
self.null_component = State() # Used to discard unneeded values
|
||||||
self.chatbot_state = (
|
self.chatbot_state = (
|
||||||
State(self.chatbot.value) if self.chatbot.value else State([])
|
State(self.chatbot.value) if self.chatbot.value else State([])
|
||||||
)
|
)
|
||||||
@ -357,8 +359,7 @@ class ChatInterface(Blocks):
|
|||||||
submit_fn_kwargs = {
|
submit_fn_kwargs = {
|
||||||
"fn": submit_fn,
|
"fn": submit_fn,
|
||||||
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
|
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
|
||||||
"outputs": [self.fake_response_textbox, self.chatbot]
|
"outputs": [self.null_component, self.chatbot] + self.additional_outputs,
|
||||||
+ self.additional_outputs,
|
|
||||||
"show_api": False,
|
"show_api": False,
|
||||||
"concurrency_limit": cast(
|
"concurrency_limit": cast(
|
||||||
Union[int, Literal["default"], None], self.concurrency_limit
|
Union[int, Literal["default"], None], self.concurrency_limit
|
||||||
@ -395,11 +396,12 @@ class ChatInterface(Blocks):
|
|||||||
self.fake_api_btn.click(
|
self.fake_api_btn.click(
|
||||||
submit_fn,
|
submit_fn,
|
||||||
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
||||||
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs,
|
[self.api_response, self.chatbot_state] + self.additional_outputs,
|
||||||
api_name=cast(Union[str, Literal[False]], self.api_name),
|
api_name=cast(Union[str, Literal[False]], self.api_name),
|
||||||
concurrency_limit=cast(
|
concurrency_limit=cast(
|
||||||
Union[int, Literal["default"], None], self.concurrency_limit
|
Union[int, Literal["default"], None], self.concurrency_limit
|
||||||
),
|
),
|
||||||
|
postprocess=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user