fixed tests based on Ali's feedback

This commit is contained in:
Abubakar Abid 2021-11-09 12:30:59 -06:00
parent 424d390319
commit d09b24ae5f
12 changed files with 35 additions and 42 deletions

View File

@ -71,7 +71,7 @@ class Interface:
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
allow_screenshot=True, allow_flagging=None, flagging_options=None, encrypt=False,
show_tips=False, flagging_dir="flagged", analytics_enabled=None, enable_queue=False, api_mode=False):
"""
Parameters:
@ -171,20 +171,13 @@ class Interface:
self.server_port = server_port
self.simple_server = None
self.allow_screenshot = allow_screenshot
# If parameter is provided, use that; otherwise, environmet variable; otherwise, True
if allow_flagging is None:
self.allow_flagging = bool(os.getenv("GRADIO_FLAGGING") or True)
else:
self.allow_flagging = allow_flagging
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True"
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
self.flagging_options = flagging_options
self.flagging_dir = flagging_dir
self.encrypt = encrypt
Interface.instances.add(self)
# If parameter is provided, use that; otherwise, environmet variable; otherwise, True
if analytics_enabled is None:
self.analytics_enabled = bool(os.getenv("GRADIO_ANALYTICS") or True)
else:
self.analytics_enabled = analytics_enabled
self.save_to = None
self.share = None
self.share_url = None

View File

@ -19,7 +19,7 @@ TIMEOUT = 10
GAP_TO_SCREENSHOT = 2
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
def wait_for_url(url):

View File

@ -150,25 +150,25 @@ class TestLoadInterface(unittest.TestCase):
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = gr.Interface(**interface_info, analytics_enabled=False)
io = gr.Interface(**interface_info)
output = io("I am happy, I love you.")
self.assertGreater(output['Positive'], 0.5)
def test_image_classification_model(self):
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
io = gr.Interface(**interface_info, analytics_enabled=False)
io = gr.Interface(**interface_info)
output = io("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
def test_translation_model(self):
interface_info = gr.external.load_interface("models/t5-base")
io = gr.Interface(**interface_info, analytics_enabled=False)
io = gr.Interface(**interface_info)
output = io("My name is Sarah and I live in London")
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
io = gr.Interface(**interface_info, analytics_enabled=False)
io = gr.Interface(**interface_info)
output = io("male", 77, 10)
self.assertLess(output['Survives'], 0.5)
@ -178,7 +178,7 @@ class TestLoadInterface(unittest.TestCase):
raise AssertionError("File does not exist: %s" % str(path))
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
io = gr.Interface(**interface_info, analytics_enabled=False)
io = gr.Interface(**interface_info)
output = io("test/test_data/lion.jpg")
assertIsFile(output)

View File

@ -57,10 +57,10 @@ class TestTextbox(unittest.TestCase):
self.assertIsInstance(text_input.generate_sample(), str)
def test_in_interface(self):
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox", analytics_enabled=False)
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
iface = gr.Interface(lambda sentence: max([len(word) for word in sentence.split()]), gr.inputs.Textbox(),
gr.outputs.Textbox(), interpretation="default", analytics_enabled=False)
gr.outputs.Textbox(), interpretation="default")
scores, alternative_outputs = iface.interpret(["Return the length of the longest word in this sentence"])
self.assertEqual(scores, [[('Return', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('length', 0.0), (' ', 0),
('of', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('longest', 0.0), (' ', 0),
@ -87,9 +87,9 @@ class TestNumber(unittest.TestCase):
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "number", "textbox", analytics_enabled=False)
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ['4.0'])
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default", analytics_enabled=False)
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default")
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(scores, [[(1.94, -0.23640000000000017), (1.96, -0.15840000000000032),
(1.98, -0.07960000000000012), [2, None], (2.02, 0.08040000000000003),
@ -122,9 +122,9 @@ class TestSlider(unittest.TestCase):
})
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "slider", "textbox", analytics_enabled=False)
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ['4'])
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default", analytics_enabled=False)
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default")
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(scores, [[-4.0, 200.08163265306123, 812.3265306122449, 1832.7346938775513, 3261.3061224489797,
5098.040816326531, 7342.938775510205, 9996.0]])

View File

@ -5,7 +5,7 @@ import requests
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestInterface(unittest.TestCase):

View File

@ -6,7 +6,7 @@ from gradio import Interface
import numpy as np
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestDefault(unittest.TestCase):
def test_default_text(self):

View File

@ -4,7 +4,7 @@ from gradio import mix
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestSeries(unittest.TestCase):

View File

@ -10,7 +10,7 @@ from unittest.mock import ANY
import urllib.request
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "True" # Enables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestUser(unittest.TestCase):
@ -59,7 +59,7 @@ class TestPort(unittest.TestCase):
class TestFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
self.io = gr.Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = self.app.test_client()
@ -106,7 +106,7 @@ class TestFlaskRoutes(unittest.TestCase):
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
self.io = gr.Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
self.client = self.app.test_client()
@ -126,7 +126,7 @@ class TestAuthenticatedFlaskRoutes(unittest.TestCase):
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
io = gr.Interface(lambda x: 1/x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/predict/', json={"data": [0]})
@ -136,13 +136,13 @@ class TestInterfaceCustomParameters(unittest.TestCase):
def test_feature_logging(self):
with mock.patch('requests.post') as mock_post:
io = gr.Interface(lambda x: 1/x, "number", "number")
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
io.launch(show_error=True, prevent_thread_lock=True)
networking.log_feature_analytics("test_feature")
mock_post.assert_called_with(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
io.close()
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
io = gr.Interface(lambda x: 1/x, "number", "number")
print(io.analytics_enabled)
io.launch(show_error=True, prevent_thread_lock=True)
with mock.patch('requests.post') as mock_post:
@ -152,7 +152,7 @@ class TestInterfaceCustomParameters(unittest.TestCase):
class TestFlagging(unittest.TestCase):
def test_num_rows_written(self):
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
io = gr.Interface(lambda x: x, "text", "text")
io.launch(prevent_thread_lock=True)
with tempfile.TemporaryDirectory() as tmpdirname:
row_count = networking.flag_data(["test"], ["test"], flag_path=tmpdirname)
@ -164,7 +164,7 @@ class TestFlagging(unittest.TestCase):
@mock.patch("requests.post")
@mock.patch("gradio.networking.flag_data")
def test_flagging_analytics(self, mock_flag, mock_post):
io = gr.Interface(lambda x: x, "text", "text")
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=True)
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
@ -176,7 +176,7 @@ class TestFlagging(unittest.TestCase):
@mock.patch("requests.post")
class TestInterpretation(unittest.TestCase):
def test_interpretation(self, mock_post):
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default")
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
io.interpret = mock.MagicMock(return_value=(None, None))
@ -187,13 +187,13 @@ class TestInterpretation(unittest.TestCase):
class TestState(unittest.TestCase):
def test_state_initialization(self):
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
io = gr.Interface(lambda x: len(x), "text", "label")
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
self.assertIsNone(networking.get_state())
def test_state_value(self):
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
io = gr.Interface(lambda x: len(x), "text", "label")
io.launch(prevent_thread_lock=True)
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
@ -220,7 +220,7 @@ class TestURLs(unittest.TestCase):
class TestQueuing(unittest.TestCase):
def test_queueing(self):
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
io = gr.Interface(lambda x: x, "text", "text")
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
# mock queue methods and post method

View File

@ -7,7 +7,7 @@ import tempfile
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class OutputComponent(unittest.TestCase):

View File

@ -8,7 +8,7 @@ import os
import tempfile
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class ImagePreprocessing(unittest.TestCase):

View File

@ -9,7 +9,7 @@ import paramiko
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestTunneling(unittest.TestCase):

View File

@ -7,7 +7,7 @@ import requests
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestUtils(unittest.TestCase):