mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
significant cleanup and test fixes
This commit is contained in:
parent
97fc1d1dfd
commit
8dc11093b9
@ -3,17 +3,15 @@ This is the core file in the `gradio` package, and defines the Interface class,
|
||||
interface using the input and output types.
|
||||
"""
|
||||
|
||||
import gradio
|
||||
from gradio import networking, strings, utils, encryptor, queue
|
||||
from gradio.inputs import get_input_instance
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio import networking, strings, utils, encryptor, queue
|
||||
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
|
||||
from gradio.external import load_interface
|
||||
import copy
|
||||
import csv
|
||||
import getpass
|
||||
import inspect
|
||||
import json
|
||||
import markdown2
|
||||
import numpy as np
|
||||
import os
|
||||
@ -28,7 +26,6 @@ import weakref
|
||||
|
||||
|
||||
ip_address = networking.get_local_ip_address()
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
|
||||
|
||||
class Interface:
|
||||
@ -144,6 +141,7 @@ class Interface:
|
||||
article = utils.readme_to_html(article)
|
||||
article = markdown2.markdown(
|
||||
article, extras=["fenced-code-blocks"])
|
||||
|
||||
self.article = article
|
||||
self.thumbnail = thumbnail
|
||||
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
|
||||
@ -180,7 +178,6 @@ class Interface:
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_dir = flagging_dir
|
||||
self.encrypt = encrypt
|
||||
Interface.instances.add(self)
|
||||
self.save_to = None
|
||||
self.share = None
|
||||
self.share_url = None
|
||||
@ -231,6 +228,7 @@ class Interface:
|
||||
|
||||
# Alert user if a more recent version of the library exists
|
||||
utils.version_check()
|
||||
Interface.instances.add(self)
|
||||
|
||||
def __call__(self, *params):
|
||||
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
|
||||
@ -606,7 +604,7 @@ class Interface:
|
||||
self.show_error = show_error
|
||||
|
||||
# Count number of launches
|
||||
launch_counter()
|
||||
utils.launch_counter()
|
||||
|
||||
# If running in a colab or not able to access localhost, automatically create a shareable link
|
||||
is_colab = utils.colab_check()
|
||||
@ -684,7 +682,7 @@ class Interface:
|
||||
if self.analytics_enabled:
|
||||
utils.launch_analytics(data)
|
||||
|
||||
show_tip(self)
|
||||
utils.show_tip(self)
|
||||
|
||||
# Run server perpetually under certain circumstances
|
||||
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
|
||||
@ -753,31 +751,11 @@ class Interface:
|
||||
utils.integration_analytics(data)
|
||||
|
||||
|
||||
def show_tip(io):
|
||||
# Only show tip every other use.
|
||||
if io.show_tips and random.random() < 0.5:
|
||||
print(random.choice(strings.en.TIPS))
|
||||
|
||||
|
||||
def launch_counter():
|
||||
try:
|
||||
if not os.path.exists(JSON_PATH):
|
||||
launches = {"launches": 1}
|
||||
with open(JSON_PATH, "w+") as j:
|
||||
json.dump(launches, j)
|
||||
else:
|
||||
with open(JSON_PATH) as j:
|
||||
launches = json.load(j)
|
||||
launches["launches"] += 1
|
||||
if launches["launches"] in [25, 50]:
|
||||
print(strings.en["BETA_INVITE"])
|
||||
with open(JSON_PATH, "w") as j:
|
||||
j.write(json.dumps(launches))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def close_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
||||
|
||||
def reset_all():
|
||||
warnings.warn("The `reset_all()` method has been renamed to `close_all()`. Please use `close_all()` instead.")
|
||||
close_all()
|
||||
|
@ -25,10 +25,10 @@ class Parallel(gradio.Interface):
|
||||
"inputs": interfaces[0].input_components,
|
||||
"outputs": outputs,
|
||||
"repeat_outputs_per_model": False,
|
||||
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
||||
|
||||
class Series(gradio.Interface):
|
||||
@ -67,8 +67,8 @@ class Series(gradio.Interface):
|
||||
"fn": connected_fn,
|
||||
"inputs": interfaces[0].input_components,
|
||||
"outputs": interfaces[-1].output_components,
|
||||
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
||||
|
@ -1,9 +1,13 @@
|
||||
import gradio
|
||||
import analytics
|
||||
import json.decoder
|
||||
import warnings
|
||||
import requests
|
||||
import pkg_resources
|
||||
from distutils.version import StrictVersion
|
||||
import json
|
||||
import json.decoder
|
||||
import os
|
||||
import pkg_resources
|
||||
import warnings
|
||||
import random
|
||||
from socket import gaierror
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
@ -11,6 +15,7 @@ from urllib3.exceptions import MaxRetryError
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
|
||||
|
||||
def version_check():
|
||||
@ -110,3 +115,29 @@ def readme_to_html(article):
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
return article
|
||||
|
||||
|
||||
def show_tip(io):
|
||||
# Only show tip every other use.
|
||||
if io.show_tips and random.random() < 0.5:
|
||||
print(random.choice(gradio.strings.en.TIPS))
|
||||
|
||||
|
||||
def launch_counter():
|
||||
try:
|
||||
if not os.path.exists(JSON_PATH):
|
||||
launches = {"launches": 1}
|
||||
with open(JSON_PATH, "w+") as j:
|
||||
json.dump(launches, j)
|
||||
else:
|
||||
with open(JSON_PATH) as j:
|
||||
launches = json.load(j)
|
||||
launches["launches"] += 1
|
||||
if launches["launches"] in [25, 50]:
|
||||
print(gradio.strings.en["BETA_INVITE"])
|
||||
with open(JSON_PATH, "w") as j:
|
||||
j.write(json.dumps(launches))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -151,24 +151,28 @@ class TestLoadInterface(unittest.TestCase):
|
||||
def test_sentiment_model(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['Positive'], 0.5)
|
||||
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("test/test_data/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
|
||||
def test_translation_model(self):
|
||||
interface_info = gr.external.load_interface("models/t5-base")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("male", 77, 10)
|
||||
self.assertLess(output['Survives'], 0.5)
|
||||
|
||||
@ -179,6 +183,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("test/test_data/lion.jpg")
|
||||
assertIsFile(output)
|
||||
|
||||
|
@ -24,43 +24,6 @@ def captured_output():
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
# send_error_analytics should probably actually be a method of Interface
|
||||
# (so it doesn't have to take the 'enabled' argument)
|
||||
# and since it's specific to the launch method, it should probably be
|
||||
# renamed to send_launch_error_analytics.
|
||||
# these tests test its current behavior
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
send_error_analytics(True)
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_post_if_not_enabled(self, mock_post):
|
||||
send_error_analytics(False)
|
||||
mock_post.assert_not_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_successful(self, mock_post):
|
||||
send_error_analytics(True)
|
||||
mock_post.assert_called()
|
||||
|
||||
# as above, send_launch_analytics should probably be a method of Interface
|
||||
@mock.patch("requests.post")
|
||||
def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
send_launch_analytics(analytics_enabled=True,
|
||||
inbrowser=True, is_colab="is_colab",
|
||||
share="share", share_url="share_url")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_launch_analytics_doesnt_post_if_not_enabled(self, mock_post):
|
||||
send_launch_analytics(analytics_enabled=False,
|
||||
inbrowser=True, is_colab="is_colab",
|
||||
share="share", share_url="share_url")
|
||||
mock_post.assert_not_called()
|
||||
|
||||
def test_reset_all(self):
|
||||
interface = Interface(lambda input: None, "textbox", "label")
|
||||
interface.close = mock.MagicMock()
|
||||
@ -160,19 +123,6 @@ class TestInterface(unittest.TestCase):
|
||||
self.assertEqual(len(interface.examples[0]), 1)
|
||||
interface.close()
|
||||
|
||||
def test_launch_counter(self):
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
with mock.patch('gradio.interface.JSON_PATH', tmp.name):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
os.remove(tmp.name)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
with open(tmp.name) as j:
|
||||
self.assertEqual(json.load(j)['launches'], 1)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
with open(tmp.name) as j:
|
||||
self.assertEqual(json.load(j)['launches'], 2)
|
||||
interface.close()
|
||||
|
||||
@mock.patch('IPython.display.display')
|
||||
def test_inline_display(self, mock_display):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
|
@ -1,10 +1,12 @@
|
||||
from gradio.utils import *
|
||||
import unittest
|
||||
import os
|
||||
import pkg_resources
|
||||
import requests
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
import requests
|
||||
import os
|
||||
import gradio
|
||||
from gradio.utils import *
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -63,8 +65,13 @@ class TestUtils(unittest.TestCase):
|
||||
def test_error_analytics_successful(self, mock_post):
|
||||
error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
launch_analytics(data={})
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
@mock.patch("gradio.utils.error_analytics")
|
||||
def test_colab_check_sends_analytics_on_import_fail(self, mock_error_analytics, mock_get_ipython):
|
||||
@ -96,6 +103,18 @@ class TestUtils(unittest.TestCase):
|
||||
def test_readme_to_html_correct_parse(self):
|
||||
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
|
||||
|
||||
def test_launch_counter(self):
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
with mock.patch('gradio.utils.JSON_PATH', tmp.name):
|
||||
interface = gradio.Interface(lambda x: x, "textbox", "label")
|
||||
os.remove(tmp.name)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
with open(tmp.name) as j:
|
||||
self.assertEqual(json.load(j)['launches'], 1)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
with open(tmp.name) as j:
|
||||
self.assertEqual(json.load(j)['launches'], 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user