mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
fixed tests
This commit is contained in:
parent
aedbe9b990
commit
b44d41ece8
@ -277,7 +277,7 @@ def log_feature_analytics(feature):
|
||||
def flag():
|
||||
log_feature_analytics('flag')
|
||||
data = request.json['data']
|
||||
app.interface.flagging_handler.flag(app.interface, data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
|
||||
app.interface.flagging_callback.flag(app.interface, data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
|
||||
current_user.id if current_user.is_authenticated else None)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
@ -1,14 +1,12 @@
|
||||
from gradio import networking
|
||||
import gradio as gr
|
||||
from gradio import networking, Interface, reset_all, flagging
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import ipaddress
|
||||
import requests
|
||||
import warnings
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import ANY, MagicMock
|
||||
import urllib.request
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -60,7 +58,7 @@ class TestPort(unittest.TestCase):
|
||||
|
||||
class TestFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
@ -102,12 +100,12 @@ class TestFlaskRoutes(unittest.TestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
gr.reset_all()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = gr.Interface(lambda x: x, "text", "text")
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
|
||||
self.client = self.app.test_client()
|
||||
|
||||
@ -123,12 +121,12 @@ class TestAuthenticatedFlaskRoutes(unittest.TestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
gr.reset_all()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number")
|
||||
io = Interface(lambda x: 1/x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
response = client.post('/api/predict/', json={"data": [0]})
|
||||
@ -138,14 +136,13 @@ class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
|
||||
def test_feature_logging(self):
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
|
||||
io = Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
|
||||
io.launch(show_error=True, prevent_thread_lock=True)
|
||||
networking.log_feature_analytics("test_feature")
|
||||
mock_post.assert_called_with(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
|
||||
io.close()
|
||||
|
||||
io = gr.Interface(lambda x: 1/x, "number", "number")
|
||||
print(io.analytics_enabled)
|
||||
io = Interface(lambda x: 1/x, "number", "number")
|
||||
io.launch(show_error=True, prevent_thread_lock=True)
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
networking.log_feature_analytics("test_feature")
|
||||
@ -155,14 +152,15 @@ class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
@mock.patch("requests.post")
|
||||
@mock.patch("gradio.flagging.CSVLogger.flag")
|
||||
def test_flagging_analytics(self, mock_flag, mock_post):
|
||||
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=True)
|
||||
def test_flagging_analytics(self, mock_post):
|
||||
callback = flagging.CSVLogger()
|
||||
callback.flag = mock.MagicMock()
|
||||
io = Interface(lambda x: x, "text", "text", analytics_enabled=True, flagging_callback=callback)
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
|
||||
mock_post.assert_any_call(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
|
||||
mock_flag.assert_called_once()
|
||||
callback.flag.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
io.close()
|
||||
|
||||
@ -170,7 +168,7 @@ class TestFlagging(unittest.TestCase):
|
||||
@mock.patch("requests.post")
|
||||
class TestInterpretation(unittest.TestCase):
|
||||
def test_interpretation(self, mock_post):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
|
||||
io = Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
io.interpret = mock.MagicMock(return_value=(None, None))
|
||||
@ -181,13 +179,13 @@ class TestInterpretation(unittest.TestCase):
|
||||
|
||||
class TestState(unittest.TestCase):
|
||||
def test_state_initialization(self):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label")
|
||||
io = Interface(lambda x: len(x), "text", "label")
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
with app.test_request_context():
|
||||
self.assertIsNone(networking.get_state())
|
||||
|
||||
def test_state_value(self):
|
||||
io = gr.Interface(lambda x: len(x), "text", "label")
|
||||
io = Interface(lambda x: len(x), "text", "label")
|
||||
io.launch(prevent_thread_lock=True)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
with app.test_request_context():
|
||||
@ -214,7 +212,7 @@ class TestURLs(unittest.TestCase):
|
||||
|
||||
class TestQueuing(unittest.TestCase):
|
||||
def test_queueing(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
io = Interface(lambda x: x, "text", "text")
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
# mock queue methods and post method
|
||||
|
Loading…
Reference in New Issue
Block a user