mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Restored /api/predict/ endpoint for Interfaces (#1199)
* updated PyPi version to 2.9b25 * added /api/predict reverse compatibility * fixed flagging * formatting * fixed networking tests * added queue false
This commit is contained in:
parent
962a254a6b
commit
871c9713b4
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: gradio
|
||||
Version: 2.9b23
|
||||
Version: 2.9b25
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir
|
||||
|
@ -450,9 +450,14 @@ class Interface(Blocks):
|
||||
cache_interface_examples(self)
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
self.flagging_callback.setup(
|
||||
self.input_components + self.output_components, self.flagging_dir
|
||||
)
|
||||
if self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir)
|
||||
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
|
||||
pass
|
||||
else:
|
||||
self.flagging_callback.setup(
|
||||
self.input_components + self.output_components, self.flagging_dir
|
||||
)
|
||||
|
||||
with self:
|
||||
if self.title:
|
||||
@ -498,15 +503,6 @@ class Interface(Blocks):
|
||||
submit_btn = Button("Submit")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag", variant="secondary")
|
||||
flag_btn.click(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components,
|
||||
outputs=[],
|
||||
_preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
@ -523,15 +519,6 @@ class Interface(Blocks):
|
||||
submit_btn = Button("Generate")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag", variant="secondary")
|
||||
flag_btn.click(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components
|
||||
+ self.output_components,
|
||||
outputs=[],
|
||||
_preprocess=False,
|
||||
)
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button(
|
||||
"Interpret", variant="secondary"
|
||||
@ -587,6 +574,27 @@ class Interface(Blocks):
|
||||
)}
|
||||
""",
|
||||
)
|
||||
if self.allow_flagging == "manual":
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
]:
|
||||
flag_btn.click(
|
||||
lambda *flag_data: self.flagging_callback.flag(flag_data),
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=[],
|
||||
_preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
elif self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
flag_btn.click(
|
||||
lambda *flag_data: self.flagging_callback.flag(flag_data),
|
||||
inputs=self.input_components,
|
||||
outputs=[],
|
||||
_preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if self.examples:
|
||||
non_state_inputs = [
|
||||
c for c in self.input_components if not isinstance(c, Variable)
|
||||
|
@ -70,7 +70,7 @@ class QueuePushBody(BaseModel):
|
||||
class PredictBody(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
data: Any
|
||||
fn_index: int
|
||||
fn_index: int = 0
|
||||
|
||||
|
||||
###########
|
||||
|
@ -1 +1 @@
|
||||
2.9b23
|
||||
2.9b25
|
2
setup.py
2
setup.py
@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text()
|
||||
|
||||
setup(
|
||||
name="gradio",
|
||||
version="2.9b23",
|
||||
version="2.9b25",
|
||||
include_package_data=True,
|
||||
description="Python library for easily interacting with trained machine learning models",
|
||||
long_description=long_description,
|
||||
|
@ -6,11 +6,10 @@ import unittest.mock as mock
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import gradio as gr
|
||||
from gradio import Interface, flagging, networking
|
||||
from gradio import Interface, networking
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -41,7 +40,7 @@ class TestInterfaceErrors(unittest.TestCase):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": [0], "fn_index": 1})
|
||||
response = client.post("/api/predict/", json={"data": [0], "fn_index": 0})
|
||||
self.assertEqual(response.status_code, 500)
|
||||
self.assertTrue("error" in response.json())
|
||||
io.close()
|
||||
@ -50,7 +49,7 @@ class TestInterfaceErrors(unittest.TestCase):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": [0]})
|
||||
response = client.post("/api/predict/", json={"fn_index": [0]})
|
||||
self.assertEqual(response.status_code, 422)
|
||||
io.close()
|
||||
|
||||
|
@ -38,12 +38,18 @@ class TestRoutes(unittest.TestCase):
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post(
|
||||
"/api/predict/", json={"data": ["test"], "fn_index": 1}
|
||||
"/api/predict/", json={"data": ["test"], "fn_index": 0}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest"])
|
||||
|
||||
def test_predict_route_without_fn_index(self):
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest"])
|
||||
|
||||
def test_state(self):
|
||||
def predict(input, history):
|
||||
if history is None:
|
||||
@ -56,14 +62,14 @@ class TestRoutes(unittest.TestCase):
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
|
||||
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
print("output", output)
|
||||
self.assertEqual(output["data"], ["test", None])
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
|
||||
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest", None])
|
||||
|
Loading…
Reference in New Issue
Block a user