mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-05 11:10:03 +08:00
Fix serialization error in curl api (#9189)
* Fix code * add changeset * Fix code * fix * add another test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
4a85559219
commit
ab142ee13d
5
.changeset/moody-dogs-search.md
Normal file
5
.changeset/moody-dogs-search.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix serialization error in curl api
|
@ -891,12 +891,13 @@ class App(FastAPI):
|
||||
event_id: str,
|
||||
):
|
||||
def process_msg(message: EventMessage) -> str | None:
|
||||
msg = message.model_dump()
|
||||
if isinstance(message, ProcessCompletedMessage):
|
||||
event = "complete" if message.success else "error"
|
||||
data = message.output.get("data")
|
||||
data = msg["output"].get("data")
|
||||
elif isinstance(message, ProcessGeneratingMessage):
|
||||
event = "generating" if message.success else "error"
|
||||
data = message.output.get("data")
|
||||
data = msg["output"].get("data")
|
||||
elif isinstance(message, HeartbeatMessage):
|
||||
event = "heartbeat"
|
||||
data = None
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
@ -1454,3 +1455,35 @@ def test_file_access():
|
||||
demo.close()
|
||||
not_allowed_file.unlink()
|
||||
allowed_file.unlink()
|
||||
|
||||
|
||||
def test_bash_api_serialization():
|
||||
demo = gr.Interface(lambda x: x, "json", "json")
|
||||
|
||||
app, _, _ = demo.launch(prevent_thread_lock=True)
|
||||
test_client = TestClient(app)
|
||||
|
||||
with test_client:
|
||||
submit = test_client.post("/call/predict", json={"data": [{"a": 1}]})
|
||||
event_id = submit.json()["event_id"]
|
||||
response = test_client.get(f"/call/predict/{event_id}")
|
||||
assert response.status_code == 200
|
||||
assert "event: complete\ndata:" in response.text
|
||||
assert json.dumps({"a": 1}) in response.text
|
||||
|
||||
|
||||
def test_bash_api_multiple_inputs_outputs():
|
||||
demo = gr.Interface(
|
||||
lambda x, y: (y, x), ["textbox", "number"], ["number", "textbox"]
|
||||
)
|
||||
|
||||
app, _, _ = demo.launch(prevent_thread_lock=True)
|
||||
test_client = TestClient(app)
|
||||
|
||||
with test_client:
|
||||
submit = test_client.post("/call/predict", json={"data": ["abc", 123]})
|
||||
event_id = submit.json()["event_id"]
|
||||
response = test_client.get(f"/call/predict/{event_id}")
|
||||
assert response.status_code == 200
|
||||
assert "event: complete\ndata:" in response.text
|
||||
assert json.dumps([123, "abc"]) in response.text
|
||||
|
Loading…
Reference in New Issue
Block a user