diff --git a/gradio/interface.py b/gradio/interface.py index ba5f3ac6d9..cd825e5f90 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -3,17 +3,15 @@ This is the core file in the `gradio` package, and defines the Interface class, interface using the input and output types. """ -import gradio +from gradio import networking, strings, utils, encryptor, queue from gradio.inputs import get_input_instance from gradio.outputs import get_output_instance -from gradio import networking, strings, utils, encryptor, queue from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value from gradio.external import load_interface import copy import csv import getpass import inspect -import json import markdown2 import numpy as np import os @@ -28,7 +26,6 @@ import weakref ip_address = networking.get_local_ip_address() -JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") class Interface: @@ -144,6 +141,7 @@ class Interface: article = utils.readme_to_html(article) article = markdown2.markdown( article, extras=["fenced-code-blocks"]) + self.article = article self.thumbnail = thumbnail theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default") @@ -180,7 +178,6 @@ class Interface: self.flagging_options = flagging_options self.flagging_dir = flagging_dir self.encrypt = encrypt - Interface.instances.add(self) self.save_to = None self.share = None self.share_url = None @@ -231,6 +228,7 @@ class Interface: # Alert user if a more recent version of the library exists utils.version_check() + Interface.instances.add(self) def __call__(self, *params): if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API @@ -606,7 +604,7 @@ class Interface: self.show_error = show_error # Count number of launches - launch_counter() + utils.launch_counter() # If running in a colab or not able to access localhost, automatically create a shareable link is_colab = utils.colab_check() @@ -684,7 +682,7 @@ class Interface: if self.analytics_enabled: utils.launch_analytics(data) - show_tip(self) + utils.show_tip(self) # Run server perpetually under certain circumstances if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1: @@ -753,31 +751,11 @@ class Interface: utils.integration_analytics(data) -def show_tip(io): - # Only show tip every other use. - if io.show_tips and random.random() < 0.5: - print(random.choice(strings.en.TIPS)) - - -def launch_counter(): - try: - if not os.path.exists(JSON_PATH): - launches = {"launches": 1} - with open(JSON_PATH, "w+") as j: - json.dump(launches, j) - else: - with open(JSON_PATH) as j: - launches = json.load(j) - launches["launches"] += 1 - if launches["launches"] in [25, 50]: - print(strings.en["BETA_INVITE"]) - with open(JSON_PATH, "w") as j: - j.write(json.dumps(launches)) - except: - pass - - def close_all(): for io in Interface.get_instances(): io.close() + +def reset_all(): + warnings.warn("The `reset_all()` method has been renamed to `close_all()`. Please use `close_all()` instead.") + close_all() diff --git a/gradio/mix.py b/gradio/mix.py index 09bdc14484..4564ecfc1d 100644 --- a/gradio/mix.py +++ b/gradio/mix.py @@ -25,10 +25,10 @@ class Parallel(gradio.Interface): "inputs": interfaces[0].input_components, "outputs": outputs, "repeat_outputs_per_model": False, - "api_mode": interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute } kwargs.update(options) super().__init__(**kwargs) + self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute class Series(gradio.Interface): @@ -67,8 +67,8 @@ class Series(gradio.Interface): "fn": connected_fn, "inputs": interfaces[0].input_components, "outputs": interfaces[-1].output_components, - "api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces } kwargs.update(options) super().__init__(**kwargs) + self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute diff --git a/gradio/utils.py b/gradio/utils.py index 5d2c7b7937..930407d954 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1,9 +1,13 @@ +import gradio import analytics -import json.decoder -import warnings import requests -import pkg_resources from distutils.version import StrictVersion +import json +import json.decoder +import os +import pkg_resources +import warnings +import random from socket import gaierror from urllib3.exceptions import MaxRetryError @@ -11,6 +15,7 @@ from urllib3.exceptions import MaxRetryError analytics_url = 'https://api.gradio.app/' PKG_VERSION_URL = "https://api.gradio.app/pkg-version" analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" +JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") def version_check(): @@ -110,3 +115,29 @@ def readme_to_html(article): except requests.exceptions.RequestException: pass return article + + +def show_tip(io): + # Only show tip every other use. + if io.show_tips and random.random() < 0.5: + print(random.choice(gradio.strings.en.TIPS)) + + +def launch_counter(): + try: + if not os.path.exists(JSON_PATH): + launches = {"launches": 1} + with open(JSON_PATH, "w+") as j: + json.dump(launches, j) + else: + with open(JSON_PATH) as j: + launches = json.load(j) + launches["launches"] += 1 + if launches["launches"] in [25, 50]: + print(gradio.strings.en["BETA_INVITE"]) + with open(JSON_PATH, "w") as j: + j.write(json.dumps(launches)) + except: + pass + + diff --git a/test/test_external.py b/test/test_external.py index 70c3756770..7bc4e39428 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -151,24 +151,28 @@ class TestLoadInterface(unittest.TestCase): def test_sentiment_model(self): interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier") io = gr.Interface(**interface_info) + io.api_mode = True output = io("I am happy, I love you.") self.assertGreater(output['Positive'], 0.5) def test_image_classification_model(self): interface_info = gr.external.load_interface("models/google/vit-base-patch16-224") io = gr.Interface(**interface_info) + io.api_mode = True output = io("test/test_data/lion.jpg") self.assertGreater(output['lion'], 0.5) def test_translation_model(self): interface_info = gr.external.load_interface("models/t5-base") io = gr.Interface(**interface_info) + io.api_mode = True output = io("My name is Sarah and I live in London") self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London') def test_numerical_to_label_space(self): interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival") io = gr.Interface(**interface_info) + io.api_mode = True output = io("male", 77, 10) self.assertLess(output['Survives'], 0.5) @@ -179,6 +183,7 @@ class TestLoadInterface(unittest.TestCase): interface_info = gr.external.load_interface("spaces/abidlabs/image-identity") io = gr.Interface(**interface_info) + io.api_mode = True output = io("test/test_data/lion.jpg") assertIsFile(output) diff --git a/test/test_interfaces.py b/test/test_interfaces.py index 8207c6d6d0..0dd0b6c111 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -24,43 +24,6 @@ def captured_output(): sys.stdout, sys.stderr = old_out, old_err class TestInterface(unittest.TestCase): - # send_error_analytics should probably actually be a method of Interface - # (so it doesn't have to take the 'enabled' argument) - # and since it's specific to the launch method, it should probably be - # renamed to send_launch_error_analytics. - # these tests test its current behavior - @mock.patch("requests.post") - def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post): - mock_post.side_effect = requests.ConnectionError() - send_error_analytics(True) - mock_post.assert_called() - - @mock.patch("requests.post") - def test_error_analytics_doesnt_post_if_not_enabled(self, mock_post): - send_error_analytics(False) - mock_post.assert_not_called() - - @mock.patch("requests.post") - def test_error_analytics_successful(self, mock_post): - send_error_analytics(True) - mock_post.assert_called() - - # as above, send_launch_analytics should probably be a method of Interface - @mock.patch("requests.post") - def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post): - mock_post.side_effect = requests.ConnectionError() - send_launch_analytics(analytics_enabled=True, - inbrowser=True, is_colab="is_colab", - share="share", share_url="share_url") - mock_post.assert_called() - - @mock.patch("requests.post") - def test_launch_analytics_doesnt_post_if_not_enabled(self, mock_post): - send_launch_analytics(analytics_enabled=False, - inbrowser=True, is_colab="is_colab", - share="share", share_url="share_url") - mock_post.assert_not_called() - def test_reset_all(self): interface = Interface(lambda input: None, "textbox", "label") interface.close = mock.MagicMock() @@ -160,19 +123,6 @@ class TestInterface(unittest.TestCase): self.assertEqual(len(interface.examples[0]), 1) interface.close() - def test_launch_counter(self): - with tempfile.NamedTemporaryFile() as tmp: - with mock.patch('gradio.interface.JSON_PATH', tmp.name): - interface = Interface(lambda x: x, "textbox", "label") - os.remove(tmp.name) - interface.launch(prevent_thread_lock=True) - with open(tmp.name) as j: - self.assertEqual(json.load(j)['launches'], 1) - interface.launch(prevent_thread_lock=True) - with open(tmp.name) as j: - self.assertEqual(json.load(j)['launches'], 2) - interface.close() - @mock.patch('IPython.display.display') def test_inline_display(self, mock_display): interface = Interface(lambda x: x, "textbox", "label") diff --git a/test/test_utils.py b/test/test_utils.py index 5947572b93..9c83e8293d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,10 +1,12 @@ -from gradio.utils import * -import unittest +import os import pkg_resources +import requests +import tempfile +import unittest import unittest.mock as mock import warnings -import requests -import os +import gradio +from gradio.utils import * os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -63,8 +65,13 @@ class TestUtils(unittest.TestCase): def test_error_analytics_successful(self, mock_post): error_analytics("placeholder") mock_post.assert_called() - - + + @mock.patch("requests.post") + def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post): + mock_post.side_effect = requests.ConnectionError() + launch_analytics(data={}) + mock_post.assert_called() + @mock.patch("IPython.get_ipython") @mock.patch("gradio.utils.error_analytics") def test_colab_check_sends_analytics_on_import_fail(self, mock_error_analytics, mock_get_ipython): @@ -96,6 +103,18 @@ class TestUtils(unittest.TestCase): def test_readme_to_html_correct_parse(self): readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md") + def test_launch_counter(self): + with tempfile.NamedTemporaryFile() as tmp: + with mock.patch('gradio.utils.JSON_PATH', tmp.name): + interface = gradio.Interface(lambda x: x, "textbox", "label") + os.remove(tmp.name) + interface.launch(prevent_thread_lock=True) + with open(tmp.name) as j: + self.assertEqual(json.load(j)['launches'], 1) + interface.launch(prevent_thread_lock=True) + with open(tmp.name) as j: + self.assertEqual(json.load(j)['launches'], 2) + if __name__ == '__main__': unittest.main()