significant cleanup and test fixes

This commit is contained in:
Abubakar Abid 2021-12-16 09:43:31 -06:00
parent 97fc1d1dfd
commit 8dc11093b9
6 changed files with 75 additions and 92 deletions

View File

@ -3,17 +3,15 @@ This is the core file in the `gradio` package, and defines the Interface class,
interface using the input and output types.
"""
import gradio
from gradio import networking, strings, utils, encryptor, queue
from gradio.inputs import get_input_instance
from gradio.outputs import get_output_instance
from gradio import networking, strings, utils, encryptor, queue
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.external import load_interface
import copy
import csv
import getpass
import inspect
import json
import markdown2
import numpy as np
import os
@ -28,7 +26,6 @@ import weakref
ip_address = networking.get_local_ip_address()
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
class Interface:
@ -144,6 +141,7 @@ class Interface:
article = utils.readme_to_html(article)
article = markdown2.markdown(
article, extras=["fenced-code-blocks"])
self.article = article
self.thumbnail = thumbnail
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
@ -180,7 +178,6 @@ class Interface:
self.flagging_options = flagging_options
self.flagging_dir = flagging_dir
self.encrypt = encrypt
Interface.instances.add(self)
self.save_to = None
self.share = None
self.share_url = None
@ -231,6 +228,7 @@ class Interface:
# Alert user if a more recent version of the library exists
utils.version_check()
Interface.instances.add(self)
def __call__(self, *params):
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
@ -606,7 +604,7 @@ class Interface:
self.show_error = show_error
# Count number of launches
launch_counter()
utils.launch_counter()
# If running in a colab or not able to access localhost, automatically create a shareable link
is_colab = utils.colab_check()
@ -684,7 +682,7 @@ class Interface:
if self.analytics_enabled:
utils.launch_analytics(data)
show_tip(self)
utils.show_tip(self)
# Run server perpetually under certain circumstances
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
@ -753,31 +751,11 @@ class Interface:
utils.integration_analytics(data)
def show_tip(io):
# Only show tip every other use.
if io.show_tips and random.random() < 0.5:
print(random.choice(strings.en.TIPS))
def launch_counter():
try:
if not os.path.exists(JSON_PATH):
launches = {"launches": 1}
with open(JSON_PATH, "w+") as j:
json.dump(launches, j)
else:
with open(JSON_PATH) as j:
launches = json.load(j)
launches["launches"] += 1
if launches["launches"] in [25, 50]:
print(strings.en["BETA_INVITE"])
with open(JSON_PATH, "w") as j:
j.write(json.dumps(launches))
except:
pass
def close_all():
for io in Interface.get_instances():
io.close()
def reset_all():
warnings.warn("The `reset_all()` method has been renamed to `close_all()`. Please use `close_all()` instead.")
close_all()

View File

@ -25,10 +25,10 @@ class Parallel(gradio.Interface):
"inputs": interfaces[0].input_components,
"outputs": outputs,
"repeat_outputs_per_model": False,
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
class Series(gradio.Interface):
@ -67,8 +67,8 @@ class Series(gradio.Interface):
"fn": connected_fn,
"inputs": interfaces[0].input_components,
"outputs": interfaces[-1].output_components,
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute

View File

@ -1,9 +1,13 @@
import gradio
import analytics
import json.decoder
import warnings
import requests
import pkg_resources
from distutils.version import StrictVersion
import json
import json.decoder
import os
import pkg_resources
import warnings
import random
from socket import gaierror
from urllib3.exceptions import MaxRetryError
@ -11,6 +15,7 @@ from urllib3.exceptions import MaxRetryError
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
def version_check():
@ -110,3 +115,29 @@ def readme_to_html(article):
except requests.exceptions.RequestException:
pass
return article
def show_tip(io):
# Only show tip every other use.
if io.show_tips and random.random() < 0.5:
print(random.choice(gradio.strings.en.TIPS))
def launch_counter():
try:
if not os.path.exists(JSON_PATH):
launches = {"launches": 1}
with open(JSON_PATH, "w+") as j:
json.dump(launches, j)
else:
with open(JSON_PATH) as j:
launches = json.load(j)
launches["launches"] += 1
if launches["launches"] in [25, 50]:
print(gradio.strings.en["BETA_INVITE"])
with open(JSON_PATH, "w") as j:
j.write(json.dumps(launches))
except:
pass

View File

@ -151,24 +151,28 @@ class TestLoadInterface(unittest.TestCase):
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("I am happy, I love you.")
self.assertGreater(output['Positive'], 0.5)
def test_image_classification_model(self):
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
def test_translation_model(self):
interface_info = gr.external.load_interface("models/t5-base")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("My name is Sarah and I live in London")
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("male", 77, 10)
self.assertLess(output['Survives'], 0.5)
@ -179,6 +183,7 @@ class TestLoadInterface(unittest.TestCase):
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("test/test_data/lion.jpg")
assertIsFile(output)

View File

@ -24,43 +24,6 @@ def captured_output():
sys.stdout, sys.stderr = old_out, old_err
class TestInterface(unittest.TestCase):
# send_error_analytics should probably actually be a method of Interface
# (so it doesn't have to take the 'enabled' argument)
# and since it's specific to the launch method, it should probably be
# renamed to send_launch_error_analytics.
# these tests test its current behavior
@mock.patch("requests.post")
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
mock_post.side_effect = requests.ConnectionError()
send_error_analytics(True)
mock_post.assert_called()
@mock.patch("requests.post")
def test_error_analytics_doesnt_post_if_not_enabled(self, mock_post):
send_error_analytics(False)
mock_post.assert_not_called()
@mock.patch("requests.post")
def test_error_analytics_successful(self, mock_post):
send_error_analytics(True)
mock_post.assert_called()
# as above, send_launch_analytics should probably be a method of Interface
@mock.patch("requests.post")
def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post):
mock_post.side_effect = requests.ConnectionError()
send_launch_analytics(analytics_enabled=True,
inbrowser=True, is_colab="is_colab",
share="share", share_url="share_url")
mock_post.assert_called()
@mock.patch("requests.post")
def test_launch_analytics_doesnt_post_if_not_enabled(self, mock_post):
send_launch_analytics(analytics_enabled=False,
inbrowser=True, is_colab="is_colab",
share="share", share_url="share_url")
mock_post.assert_not_called()
def test_reset_all(self):
interface = Interface(lambda input: None, "textbox", "label")
interface.close = mock.MagicMock()
@ -160,19 +123,6 @@ class TestInterface(unittest.TestCase):
self.assertEqual(len(interface.examples[0]), 1)
interface.close()
def test_launch_counter(self):
with tempfile.NamedTemporaryFile() as tmp:
with mock.patch('gradio.interface.JSON_PATH', tmp.name):
interface = Interface(lambda x: x, "textbox", "label")
os.remove(tmp.name)
interface.launch(prevent_thread_lock=True)
with open(tmp.name) as j:
self.assertEqual(json.load(j)['launches'], 1)
interface.launch(prevent_thread_lock=True)
with open(tmp.name) as j:
self.assertEqual(json.load(j)['launches'], 2)
interface.close()
@mock.patch('IPython.display.display')
def test_inline_display(self, mock_display):
interface = Interface(lambda x: x, "textbox", "label")

View File

@ -1,10 +1,12 @@
from gradio.utils import *
import unittest
import os
import pkg_resources
import requests
import tempfile
import unittest
import unittest.mock as mock
import warnings
import requests
import os
import gradio
from gradio.utils import *
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -63,8 +65,13 @@ class TestUtils(unittest.TestCase):
def test_error_analytics_successful(self, mock_post):
error_analytics("placeholder")
mock_post.assert_called()
@mock.patch("requests.post")
def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post):
mock_post.side_effect = requests.ConnectionError()
launch_analytics(data={})
mock_post.assert_called()
@mock.patch("IPython.get_ipython")
@mock.patch("gradio.utils.error_analytics")
def test_colab_check_sends_analytics_on_import_fail(self, mock_error_analytics, mock_get_ipython):
@ -96,6 +103,18 @@ class TestUtils(unittest.TestCase):
def test_readme_to_html_correct_parse(self):
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
def test_launch_counter(self):
with tempfile.NamedTemporaryFile() as tmp:
with mock.patch('gradio.utils.JSON_PATH', tmp.name):
interface = gradio.Interface(lambda x: x, "textbox", "label")
os.remove(tmp.name)
interface.launch(prevent_thread_lock=True)
with open(tmp.name) as j:
self.assertEqual(json.load(j)['launches'], 1)
interface.launch(prevent_thread_lock=True)
with open(tmp.name) as j:
self.assertEqual(json.load(j)['launches'], 2)
if __name__ == '__main__':
unittest.main()