Restored /api/predict/ endpoint for Interfaces (#1199)

* updated PyPi version to 2.9b25

* added /api/predict reverse compatibility

* fixed flagging

* formatting

* fixed networking tests

* added queue false
This commit is contained in:
Abubakar Abid 2022-05-09 18:05:30 -07:00 committed by GitHub
parent 962a254a6b
commit 871c9713b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 45 additions and 32 deletions

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: gradio
Version: 2.9b23
Version: 2.9b25
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir

View File

@ -450,9 +450,14 @@ class Interface(Blocks):
cache_interface_examples(self)
if self.allow_flagging != "never":
self.flagging_callback.setup(
self.input_components + self.output_components, self.flagging_dir
)
if self.interface_type == self.InterfaceTypes.UNIFIED:
self.flagging_callback.setup(self.input_components, self.flagging_dir)
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
pass
else:
self.flagging_callback.setup(
self.input_components + self.output_components, self.flagging_dir
)
with self:
if self.title:
@ -498,15 +503,6 @@ class Interface(Blocks):
submit_btn = Button("Submit")
if self.allow_flagging == "manual":
flag_btn = Button("Flag", variant="secondary")
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components,
outputs=[],
_preprocess=False,
queue=False,
)
if self.interface_type in [
self.InterfaceTypes.STANDARD,
@ -523,15 +519,6 @@ class Interface(Blocks):
submit_btn = Button("Generate")
if self.allow_flagging == "manual":
flag_btn = Button("Flag", variant="secondary")
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components
+ self.output_components,
outputs=[],
_preprocess=False,
)
if self.interpretation:
interpretation_btn = Button(
"Interpret", variant="secondary"
@ -587,6 +574,27 @@ class Interface(Blocks):
)}
""",
)
if self.allow_flagging == "manual":
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
]:
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components + self.output_components,
outputs=[],
_preprocess=False,
queue=False,
)
elif self.interface_type == self.InterfaceTypes.UNIFIED:
flag_btn.click(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components,
outputs=[],
_preprocess=False,
queue=False,
)
if self.examples:
non_state_inputs = [
c for c in self.input_components if not isinstance(c, Variable)

View File

@ -70,7 +70,7 @@ class QueuePushBody(BaseModel):
class PredictBody(BaseModel):
session_hash: Optional[str]
data: Any
fn_index: int
fn_index: int = 0
###########

View File

@ -1 +1 @@
2.9b23
2.9b25

View File

@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text()
setup(
name="gradio",
version="2.9b23",
version="2.9b25",
include_package_data=True,
description="Python library for easily interacting with trained machine learning models",
long_description=long_description,

View File

@ -6,11 +6,10 @@ import unittest.mock as mock
import urllib
import warnings
import aiohttp
from fastapi.testclient import TestClient
import gradio as gr
from gradio import Interface, flagging, networking
from gradio import Interface, networking
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -41,7 +40,7 @@ class TestInterfaceErrors(unittest.TestCase):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0], "fn_index": 1})
response = client.post("/api/predict/", json={"data": [0], "fn_index": 0})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
io.close()
@ -50,7 +49,7 @@ class TestInterfaceErrors(unittest.TestCase):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0]})
response = client.post("/api/predict/", json={"fn_index": [0]})
self.assertEqual(response.status_code, 422)
io.close()

View File

@ -38,12 +38,18 @@ class TestRoutes(unittest.TestCase):
def test_predict_route(self):
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 1}
"/api/predict/", json={"data": ["test"], "fn_index": 0}
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
def test_predict_route_without_fn_index(self):
response = self.client.post("/api/predict/", json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
def test_state(self):
def predict(input, history):
if history is None:
@ -56,14 +62,14 @@ class TestRoutes(unittest.TestCase):
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])