fixed tests

This commit is contained in:
Abubakar Abid 2021-11-20 17:45:13 -06:00
parent aedbe9b990
commit b44d41ece8
2 changed files with 19 additions and 21 deletions

View File

@ -277,7 +277,7 @@ def log_feature_analytics(feature):
def flag():
log_feature_analytics('flag')
data = request.json['data']
app.interface.flagging_handler.flag(app.interface, data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
app.interface.flagging_callback.flag(app.interface, data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
current_user.id if current_user.is_authenticated else None)
return jsonify(success=True)

View File

@ -1,14 +1,12 @@
from gradio import networking
import gradio as gr
from gradio import networking, Interface, reset_all, flagging
import unittest
import unittest.mock as mock
import ipaddress
import requests
import warnings
from unittest.mock import ANY
from unittest.mock import ANY, MagicMock
import urllib.request
import os
import tempfile
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -60,7 +58,7 @@ class TestPort(unittest.TestCase):
class TestFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text")
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = self.app.test_client()
@ -102,12 +100,12 @@ class TestFlaskRoutes(unittest.TestCase):
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
reset_all()
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text")
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
self.client = self.app.test_client()
@ -123,12 +121,12 @@ class TestAuthenticatedFlaskRoutes(unittest.TestCase):
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
reset_all()
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
io = gr.Interface(lambda x: 1/x, "number", "number")
io = Interface(lambda x: 1/x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/predict/', json={"data": [0]})
@ -138,14 +136,13 @@ class TestInterfaceCustomParameters(unittest.TestCase):
def test_feature_logging(self):
with mock.patch('requests.post') as mock_post:
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
io = Interface(lambda x: 1/x, "number", "number", analytics_enabled=True)
io.launch(show_error=True, prevent_thread_lock=True)
networking.log_feature_analytics("test_feature")
mock_post.assert_called_with(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
io.close()
io = gr.Interface(lambda x: 1/x, "number", "number")
print(io.analytics_enabled)
io = Interface(lambda x: 1/x, "number", "number")
io.launch(show_error=True, prevent_thread_lock=True)
with mock.patch('requests.post') as mock_post:
networking.log_feature_analytics("test_feature")
@ -155,14 +152,15 @@ class TestInterfaceCustomParameters(unittest.TestCase):
class TestFlagging(unittest.TestCase):
@mock.patch("requests.post")
@mock.patch("gradio.flagging.CSVLogger.flag")
def test_flagging_analytics(self, mock_flag, mock_post):
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=True)
def test_flagging_analytics(self, mock_post):
callback = flagging.CSVLogger()
callback.flag = mock.MagicMock()
io = Interface(lambda x: x, "text", "text", analytics_enabled=True, flagging_callback=callback)
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
mock_post.assert_any_call(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
mock_flag.assert_called_once()
callback.flag.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
@ -170,7 +168,7 @@ class TestFlagging(unittest.TestCase):
@mock.patch("requests.post")
class TestInterpretation(unittest.TestCase):
def test_interpretation(self, mock_post):
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
io = Interface(lambda x: len(x), "text", "label", interpretation="default", analytics_enabled=True)
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
io.interpret = mock.MagicMock(return_value=(None, None))
@ -181,13 +179,13 @@ class TestInterpretation(unittest.TestCase):
class TestState(unittest.TestCase):
def test_state_initialization(self):
io = gr.Interface(lambda x: len(x), "text", "label")
io = Interface(lambda x: len(x), "text", "label")
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
self.assertIsNone(networking.get_state())
def test_state_value(self):
io = gr.Interface(lambda x: len(x), "text", "label")
io = Interface(lambda x: len(x), "text", "label")
io.launch(prevent_thread_lock=True)
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
@ -214,7 +212,7 @@ class TestURLs(unittest.TestCase):
class TestQueuing(unittest.TestCase):
def test_queueing(self):
io = gr.Interface(lambda x: x, "text", "text")
io = Interface(lambda x: x, "text", "text")
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
# mock queue methods and post method