diff --git a/gradio.egg-info/PKG-INFO b/gradio.egg-info/PKG-INFO index 1cee009b41..7b170030b5 100644 --- a/gradio.egg-info/PKG-INFO +++ b/gradio.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: gradio -Version: 2.9b23 +Version: 2.9b25 Summary: Python library for easily interacting with trained machine learning models Home-page: https://github.com/gradio-app/gradio-UI Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir diff --git a/gradio/interface.py b/gradio/interface.py index db3dabb7a5..5344d6837d 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -450,9 +450,14 @@ class Interface(Blocks): cache_interface_examples(self) if self.allow_flagging != "never": - self.flagging_callback.setup( - self.input_components + self.output_components, self.flagging_dir - ) + if self.interface_type == self.InterfaceTypes.UNIFIED: + self.flagging_callback.setup(self.input_components, self.flagging_dir) + elif self.interface_type == self.InterfaceTypes.INPUT_ONLY: + pass + else: + self.flagging_callback.setup( + self.input_components + self.output_components, self.flagging_dir + ) with self: if self.title: @@ -498,15 +503,6 @@ class Interface(Blocks): submit_btn = Button("Submit") if self.allow_flagging == "manual": flag_btn = Button("Flag", variant="secondary") - flag_btn.click( - lambda *flag_data: self.flagging_callback.flag( - flag_data - ), - inputs=self.input_components, - outputs=[], - _preprocess=False, - queue=False, - ) if self.interface_type in [ self.InterfaceTypes.STANDARD, @@ -523,15 +519,6 @@ class Interface(Blocks): submit_btn = Button("Generate") if self.allow_flagging == "manual": flag_btn = Button("Flag", variant="secondary") - flag_btn.click( - lambda *flag_data: self.flagging_callback.flag( - flag_data - ), - inputs=self.input_components - + self.output_components, - outputs=[], - _preprocess=False, - ) if self.interpretation: interpretation_btn = Button( "Interpret", variant="secondary" @@ -587,6 +574,27 @@ class Interface(Blocks): )} """, ) + if self.allow_flagging == "manual": + if self.interface_type in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.OUTPUT_ONLY, + ]: + flag_btn.click( + lambda *flag_data: self.flagging_callback.flag(flag_data), + inputs=self.input_components + self.output_components, + outputs=[], + _preprocess=False, + queue=False, + ) + elif self.interface_type == self.InterfaceTypes.UNIFIED: + flag_btn.click( + lambda *flag_data: self.flagging_callback.flag(flag_data), + inputs=self.input_components, + outputs=[], + _preprocess=False, + queue=False, + ) + if self.examples: non_state_inputs = [ c for c in self.input_components if not isinstance(c, Variable) diff --git a/gradio/routes.py b/gradio/routes.py index ca79ab0b5e..018b2b53f5 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -70,7 +70,7 @@ class QueuePushBody(BaseModel): class PredictBody(BaseModel): session_hash: Optional[str] data: Any - fn_index: int + fn_index: int = 0 ########### diff --git a/gradio/version.txt b/gradio/version.txt index c9c02e33da..9ea98528a7 100644 --- a/gradio/version.txt +++ b/gradio/version.txt @@ -1 +1 @@ -2.9b23 \ No newline at end of file +2.9b25 \ No newline at end of file diff --git a/setup.py b/setup.py index 46c3b588e9..0e03d0f186 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ long_description = (this_directory / "README.md").read_text() setup( name="gradio", - version="2.9b23", + version="2.9b25", include_package_data=True, description="Python library for easily interacting with trained machine learning models", long_description=long_description, diff --git a/test/test_networking.py b/test/test_networking.py index deeecc07d7..456cbfa182 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -6,11 +6,10 @@ import unittest.mock as mock import urllib import warnings -import aiohttp from fastapi.testclient import TestClient import gradio as gr -from gradio import Interface, flagging, networking +from gradio import Interface, networking os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -41,7 +40,7 @@ class TestInterfaceErrors(unittest.TestCase): io = Interface(lambda x: 1 / x, "number", "number") app, _, _ = io.launch(show_error=True, prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"data": [0], "fn_index": 1}) + response = client.post("/api/predict/", json={"data": [0], "fn_index": 0}) self.assertEqual(response.status_code, 500) self.assertTrue("error" in response.json()) io.close() @@ -50,7 +49,7 @@ class TestInterfaceErrors(unittest.TestCase): io = Interface(lambda x: 1 / x, "number", "number") app, _, _ = io.launch(show_error=True, prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"data": [0]}) + response = client.post("/api/predict/", json={"fn_index": [0]}) self.assertEqual(response.status_code, 422) io.close() diff --git a/test/test_routes.py b/test/test_routes.py index 6a2b526ace..1eced1e897 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -38,12 +38,18 @@ class TestRoutes(unittest.TestCase): def test_predict_route(self): response = self.client.post( - "/api/predict/", json={"data": ["test"], "fn_index": 1} + "/api/predict/", json={"data": ["test"], "fn_index": 0} ) self.assertEqual(response.status_code, 200) output = dict(response.json()) self.assertEqual(output["data"], ["testtest"]) + def test_predict_route_without_fn_index(self): + response = self.client.post("/api/predict/", json={"data": ["test"]}) + self.assertEqual(response.status_code, 200) + output = dict(response.json()) + self.assertEqual(output["data"], ["testtest"]) + def test_state(self): def predict(input, history): if history is None: @@ -56,14 +62,14 @@ class TestRoutes(unittest.TestCase): client = TestClient(app) response = client.post( "/api/predict/", - json={"data": ["test", None], "fn_index": 1, "session_hash": "_"}, + json={"data": ["test", None], "fn_index": 0, "session_hash": "_"}, ) output = dict(response.json()) print("output", output) self.assertEqual(output["data"], ["test", None]) response = client.post( "/api/predict/", - json={"data": ["test", None], "fn_index": 1, "session_hash": "_"}, + json={"data": ["test", None], "fn_index": 0, "session_hash": "_"}, ) output = dict(response.json()) self.assertEqual(output["data"], ["testtest", None])