Fix serialization error in curl api (#9189)

* Fix code

* add changeset

* Fix code

* fix

* add another test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-08-27 18:49:36 -04:00 committed by GitHub
parent 4a85559219
commit ab142ee13d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 2 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fix serialization error in curl api

View File

@ -891,12 +891,13 @@ class App(FastAPI):
event_id: str,
):
def process_msg(message: EventMessage) -> str | None:
msg = message.model_dump()
if isinstance(message, ProcessCompletedMessage):
event = "complete" if message.success else "error"
data = message.output.get("data")
data = msg["output"].get("data")
elif isinstance(message, ProcessGeneratingMessage):
event = "generating" if message.success else "error"
data = message.output.get("data")
data = msg["output"].get("data")
elif isinstance(message, HeartbeatMessage):
event = "heartbeat"
data = None

View File

@ -1,6 +1,7 @@
"""Contains tests for networking.py and app.py"""
import functools
import json
import os
import pickle
import tempfile
@ -1454,3 +1455,35 @@ def test_file_access():
demo.close()
not_allowed_file.unlink()
allowed_file.unlink()
def test_bash_api_serialization():
demo = gr.Interface(lambda x: x, "json", "json")
app, _, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)
with test_client:
submit = test_client.post("/call/predict", json={"data": [{"a": 1}]})
event_id = submit.json()["event_id"]
response = test_client.get(f"/call/predict/{event_id}")
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps({"a": 1}) in response.text
def test_bash_api_multiple_inputs_outputs():
demo = gr.Interface(
lambda x, y: (y, x), ["textbox", "number"], ["number", "textbox"]
)
app, _, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)
with test_client:
submit = test_client.post("/call/predict", json={"data": ["abc", 123]})
event_id = submit.json()["event_id"]
response = test_client.get(f"/call/predict/{event_id}")
assert response.status_code == 200
assert "event: complete\ndata:" in response.text
assert json.dumps([123, "abc"]) in response.text