mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
fixed tests based on Ali's feedback
This commit is contained in:
parent
424d390319
commit
d09b24ae5f
@ -71,7 +71,7 @@ class Interface:
|
||||
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
allow_screenshot=True, allow_flagging=None, flagging_options=None, encrypt=False,
|
||||
show_tips=False, flagging_dir="flagged", analytics_enabled=None, enable_queue=False, api_mode=False):
|
||||
"""
|
||||
Parameters:
|
||||
@ -171,20 +171,13 @@ class Interface:
|
||||
self.server_port = server_port
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
# If parameter is provided, use that; otherwise, environmet variable; otherwise, True
|
||||
if allow_flagging is None:
|
||||
self.allow_flagging = bool(os.getenv("GRADIO_FLAGGING") or True)
|
||||
else:
|
||||
self.allow_flagging = allow_flagging
|
||||
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
|
||||
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True"
|
||||
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_dir = flagging_dir
|
||||
self.encrypt = encrypt
|
||||
Interface.instances.add(self)
|
||||
# If parameter is provided, use that; otherwise, environmet variable; otherwise, True
|
||||
if analytics_enabled is None:
|
||||
self.analytics_enabled = bool(os.getenv("GRADIO_ANALYTICS") or True)
|
||||
else:
|
||||
self.analytics_enabled = analytics_enabled
|
||||
self.save_to = None
|
||||
self.share = None
|
||||
self.share_url = None
|
||||
|
@ -19,7 +19,7 @@ TIMEOUT = 10
|
||||
|
||||
GAP_TO_SCREENSHOT = 2
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
def wait_for_url(url):
|
||||
|
@ -150,25 +150,25 @@ class TestLoadInterface(unittest.TestCase):
|
||||
|
||||
def test_sentiment_model(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
io = gr.Interface(**interface_info, analytics_enabled=False)
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['Positive'], 0.5)
|
||||
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
io = gr.Interface(**interface_info, analytics_enabled=False)
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/test_data/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
|
||||
def test_translation_model(self):
|
||||
interface_info = gr.external.load_interface("models/t5-base")
|
||||
io = gr.Interface(**interface_info, analytics_enabled=False)
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
|
||||
io = gr.Interface(**interface_info, analytics_enabled=False)
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("male", 77, 10)
|
||||
self.assertLess(output['Survives'], 0.5)
|
||||
|
||||
@ -178,7 +178,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
raise AssertionError("File does not exist: %s" % str(path))
|
||||
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
|
||||
io = gr.Interface(**interface_info, analytics_enabled=False)
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/test_data/lion.jpg")
|
||||
assertIsFile(output)
|
||||
|
||||
|
@ -57,10 +57,10 @@ class TestTextbox(unittest.TestCase):
|
||||
self.assertIsInstance(text_input.generate_sample(), str)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox", analytics_enabled=False)
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
iface = gr.Interface(lambda sentence: max([len(word) for word in sentence.split()]), gr.inputs.Textbox(),
|
||||
gr.outputs.Textbox(), interpretation="default", analytics_enabled=False)
|
||||
gr.outputs.Textbox(), interpretation="default")
|
||||
scores, alternative_outputs = iface.interpret(["Return the length of the longest word in this sentence"])
|
||||
self.assertEqual(scores, [[('Return', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('length', 0.0), (' ', 0),
|
||||
('of', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('longest', 0.0), (' ', 0),
|
||||
@ -87,9 +87,9 @@ class TestNumber(unittest.TestCase):
|
||||
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox", analytics_enabled=False)
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ['4.0'])
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default", analytics_enabled=False)
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default")
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(scores, [[(1.94, -0.23640000000000017), (1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012), [2, None], (2.02, 0.08040000000000003),
|
||||
@ -122,9 +122,9 @@ class TestSlider(unittest.TestCase):
|
||||
})
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox", analytics_enabled=False)
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ['4'])
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default", analytics_enabled=False)
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default")
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(scores, [[-4.0, 200.08163265306123, 812.3265306122449, 1832.7346938775513, 3261.3061224489797,
|
||||
5098.040816326531, 7342.938775510205, 9996.0]])
|
||||
|
@ -5,7 +5,7 @@ import requests
|
||||
import os
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
|
@ -6,7 +6,7 @@ from gradio import Interface
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
class TestDefault(unittest.TestCase):
|
||||
def test_default_text(self):
|
||||
|
@ -4,7 +4,7 @@ from gradio import mix
|
||||
import os
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestSeries(unittest.TestCase):
|
||||
|
@ -10,7 +10,7 @@ from unittest.mock import ANY
|
||||
import urllib.request
|
||||
import os
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "True" # Enables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestUser(unittest.TestCase):
|
||||
@ -59,7 +59,7 @@ class TestPort(unittest.TestCase):
|
||||
|
||||
class TestFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
@ -106,7 +106,7 @@ class TestFlaskRoutes(unittest.TestCase):
|
||||
|
||||
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
@ -126,7 +126,7 @@ class TestAuthenticatedFlaskRoutes(unittest.TestCase):
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
response = client.post('/api/predict/', json={"data": [0]})
|
||||
@ -136,13 +136,13 @@ class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
|
||||
def test_feature_logging(self):
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number")
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
|
||||
io.launch(show_error=True, prevent_thread_lock=True)
|
||||
networking.log_feature_analytics("test_feature")
|
||||
mock_post.assert_called_with(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
|
||||
io.close()
|
||||
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number")
|
||||
print(io.analytics_enabled)
|
||||
io.launch(show_error=True, prevent_thread_lock=True)
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
@ -152,7 +152,7 @@ class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
def test_num_rows_written(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
io.launch(prevent_thread_lock=True)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
row_count = networking.flag_data(["test"], ["test"], flag_path=tmpdirname)
|
||||
@ -164,7 +164,7 @@ class TestFlagging(unittest.TestCase):
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("gradio.networking.flag_data")
|
||||
def test_flagging_analytics(self, mock_flag, mock_post):
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=True)
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
|
||||
@ -176,7 +176,7 @@ class TestFlagging(unittest.TestCase):
|
||||
@mock.patch("requests.post")
|
||||
class TestInterpretation(unittest.TestCase):
|
||||
def test_interpretation(self, mock_post):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default")
|
||||
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
io.interpret = mock.MagicMock(return_value=(None, None))
|
||||
@ -187,13 +187,13 @@ class TestInterpretation(unittest.TestCase):
|
||||
|
||||
class TestState(unittest.TestCase):
|
||||
def test_state_initialization(self):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: len(x), "text", "label")
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
with app.test_request_context():
|
||||
self.assertIsNone(networking.get_state())
|
||||
|
||||
def test_state_value(self):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: len(x), "text", "label")
|
||||
io.launch(prevent_thread_lock=True)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
with app.test_request_context():
|
||||
@ -220,7 +220,7 @@ class TestURLs(unittest.TestCase):
|
||||
|
||||
class TestQueuing(unittest.TestCase):
|
||||
def test_queueing(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
# mock queue methods and post method
|
||||
|
@ -7,7 +7,7 @@ import tempfile
|
||||
import os
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class OutputComponent(unittest.TestCase):
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
import tempfile
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class ImagePreprocessing(unittest.TestCase):
|
||||
|
@ -9,7 +9,7 @@ import paramiko
|
||||
import os
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestTunneling(unittest.TestCase):
|
||||
|
@ -7,7 +7,7 @@ import requests
|
||||
import os
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "" # Disables analytics
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user