gradio/test/test_networking.py
Hoàng Tùng Lâm (Linus) 66ef85b874
Fixed misleading log when server_name is '0.0.0.0' (#2176)
* Fixed misleading log when server_name is '0.0.0.0'

* some cleanups

* formatting

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2022-09-04 18:48:08 -07:00

138 lines
4.6 KiB
Python

"""Contains tests for networking.py and app.py"""
import os
import unittest
import unittest.mock as mock
import urllib
import warnings
from fastapi.testclient import TestClient
import gradio as gr
from gradio import Interface, networking
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
class TestInterfaceErrors(unittest.TestCase):
def test_processing_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0], "fn_index": 0})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
io.close()
def test_validation_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": [0]})
self.assertEqual(response.status_code, 422)
io.close()
class TestStartServer(unittest.TestCase):
def test_start_server(self):
io = Interface(lambda x: x, "number", "number")
io.favicon_path = None
io.config = io.get_config_file()
io.show_error = True
io.flagging_callback.setup(gr.Number(), io.flagging_dir)
io.auth = None
port = networking.get_first_available_port(
networking.INITIAL_PORT_VALUE,
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
)
io.enable_queue = False
_, _, local_path, _, server = networking.start_server(io, server_port=port)
url = urllib.parse.urlparse(local_path)
self.assertEquals(url.scheme, "http")
self.assertEquals(url.port, port)
server.close()
# class TestFlagging(unittest.TestCase):
# def test_flagging_analytics(self):
# callback = flagging.CSVLogger()
# callback.flag = mock.MagicMock()
# aiohttp.ClientSession.post = mock.MagicMock()
# aiohttp.ClientSession.post.__aenter__ = None
# aiohttp.ClientSession.post.__aexit__ = None
# io = Interface(
# lambda x: x,
# "text",
# "text",
# analytics_enabled=True,
# flagging_callback=callback,
# )
# app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
# client = TestClient(app)
# response = client.post(
# "/api/flag/",
# json={"data": {"input_data": ["test"], "output_data": ["test"]}},
# )
# aiohttp.ClientSession.post.assert_called()
# callback.flag.assert_called_once()
# self.assertEqual(response.status_code, 200)
# io.close()
# class TestInterpretation(unittest.TestCase):
# def test_interpretation(self):
# io = Interface(
# lambda x: len(x),
# "text",
# "label",
# interpretation="default",
# analytics_enabled=True,
# )
# app, _, _ = io.launch(prevent_thread_lock=True)
# client = TestClient(app)
# aiohttp.ClientSession.post = mock.MagicMock()
# aiohttp.ClientSession.post.__aenter__ = None
# aiohttp.ClientSession.post.__aexit__ = None
# io.interpret = mock.MagicMock(return_value=(None, None))
# response = client.post("/api/interpret/", json={"data": ["test test"]})
# aiohttp.ClientSession.post.assert_called()
# self.assertEqual(response.status_code, 200)
# io.close()
class TestURLs(unittest.TestCase):
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
self.assertEqual(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
if __name__ == "__main__":
unittest.main()