2022-01-26 04:32:32 +08:00
|
|
|
"""Contains tests for networking.py and app.py"""
|
2022-09-24 04:01:44 +08:00
|
|
|
import json
|
2022-01-26 04:32:32 +08:00
|
|
|
import os
|
2022-09-24 04:01:44 +08:00
|
|
|
import sys
|
2023-02-06 10:04:26 +08:00
|
|
|
import tempfile
|
2023-05-04 06:30:38 +08:00
|
|
|
from pathlib import Path
|
2022-09-24 04:01:44 +08:00
|
|
|
from unittest.mock import patch
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-12-14 07:01:27 +08:00
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
2022-09-24 04:01:44 +08:00
|
|
|
import pytest
|
2022-10-19 04:12:51 +08:00
|
|
|
import starlette.routing
|
2022-09-24 04:01:44 +08:00
|
|
|
import websockets
|
2022-08-11 06:29:14 +08:00
|
|
|
from fastapi import FastAPI
|
2022-01-26 04:32:32 +08:00
|
|
|
from fastapi.testclient import TestClient
|
2023-04-14 07:20:33 +08:00
|
|
|
from gradio_client import media_data
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-10-25 07:32:37 +08:00
|
|
|
import gradio as gr
|
2023-01-05 01:18:04 +08:00
|
|
|
from gradio import (
|
|
|
|
Blocks,
|
|
|
|
Button,
|
|
|
|
Interface,
|
|
|
|
Number,
|
|
|
|
Textbox,
|
|
|
|
close_all,
|
|
|
|
routes,
|
|
|
|
)
|
2022-01-26 04:32:32 +08:00
|
|
|
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
@pytest.fixture()
|
|
|
|
def test_client():
|
|
|
|
io = Interface(lambda x: x + x, "text", "text")
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
test_client = TestClient(app)
|
|
|
|
yield test_client
|
|
|
|
io.close()
|
|
|
|
close_all()
|
2022-01-26 04:32:32 +08:00
|
|
|
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestRoutes:
|
|
|
|
def test_get_main_route(self, test_client):
|
|
|
|
response = test_client.get("/")
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
|
|
def test_static_files_served_safely(self, test_client):
|
2022-01-26 04:32:32 +08:00
|
|
|
# Make sure things outside the static folder are not accessible
|
2022-11-08 08:37:55 +08:00
|
|
|
response = test_client.get(r"/static/..%2findex.html")
|
2023-04-21 03:18:33 +08:00
|
|
|
assert response.status_code == 403
|
2022-11-08 08:37:55 +08:00
|
|
|
response = test_client.get(r"/static/..%2f..%2fapi_docs.html")
|
2023-04-21 03:18:33 +08:00
|
|
|
assert response.status_code == 403
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
def test_get_config_route(self, test_client):
|
|
|
|
response = test_client.get("/config/")
|
|
|
|
assert response.status_code == 200
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
def test_upload_path(self, test_client):
|
2023-05-05 10:54:23 +08:00
|
|
|
with open("test/test_files/alphabet.txt") as f:
|
|
|
|
response = test_client.post("/upload", files={"files": f})
|
2023-02-18 07:31:02 +08:00
|
|
|
assert response.status_code == 200
|
|
|
|
file = response.json()[0]
|
|
|
|
assert "alphabet" in file
|
|
|
|
assert file.endswith(".txt")
|
|
|
|
with open(file) as saved_file:
|
|
|
|
assert saved_file.read() == "abcdefghijklmnopqrstuvwxyz"
|
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
def test_custom_upload_path(self):
|
|
|
|
os.environ["GRADIO_TEMP_DIR"] = str(Path(tempfile.gettempdir()) / "gradio-test")
|
|
|
|
io = Interface(lambda x: x + x, "text", "text")
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
test_client = TestClient(app)
|
|
|
|
try:
|
2023-05-05 10:54:23 +08:00
|
|
|
with open("test/test_files/alphabet.txt") as f:
|
|
|
|
response = test_client.post("/upload", files={"files": f})
|
2023-05-04 06:30:38 +08:00
|
|
|
assert response.status_code == 200
|
|
|
|
file = response.json()[0]
|
|
|
|
assert "alphabet" in file
|
|
|
|
assert file.startswith(str(Path(tempfile.gettempdir()) / "gradio-test"))
|
|
|
|
assert file.endswith(".txt")
|
|
|
|
with open(file) as saved_file:
|
|
|
|
assert saved_file.read() == "abcdefghijklmnopqrstuvwxyz"
|
|
|
|
finally:
|
|
|
|
os.environ["GRADIO_TEMP_DIR"] = ""
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
def test_predict_route(self, test_client):
|
|
|
|
response = test_client.post(
|
2022-05-10 09:05:30 +08:00
|
|
|
"/api/predict/", json={"data": ["test"], "fn_index": 0}
|
2022-03-24 06:50:10 +08:00
|
|
|
)
|
2022-11-08 08:37:55 +08:00
|
|
|
assert response.status_code == 200
|
2022-01-26 04:32:32 +08:00
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == ["testtest"]
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-09-02 19:10:20 +08:00
|
|
|
def test_named_predict_route(self):
|
|
|
|
with Blocks() as demo:
|
|
|
|
i = Textbox()
|
|
|
|
o = Textbox()
|
2023-04-28 04:09:50 +08:00
|
|
|
i.change(lambda x: f"{x}1", i, o, api_name="p")
|
|
|
|
i.change(lambda x: f"{x}2", i, o, api_name="q")
|
2022-09-02 19:10:20 +08:00
|
|
|
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post("/api/p/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test1"]
|
|
|
|
|
|
|
|
response = client.post("/api/q/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test2"]
|
|
|
|
|
|
|
|
def test_same_named_predict_route(self):
|
|
|
|
with Blocks() as demo:
|
|
|
|
i = Textbox()
|
|
|
|
o = Textbox()
|
2023-04-28 04:09:50 +08:00
|
|
|
i.change(lambda x: f"{x}0", i, o, api_name="p")
|
|
|
|
i.change(lambda x: f"{x}1", i, o, api_name="p")
|
2022-09-02 19:10:20 +08:00
|
|
|
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post("/api/p/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test0"]
|
|
|
|
|
|
|
|
response = client.post("/api/p_1/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test1"]
|
|
|
|
|
|
|
|
def test_multiple_renamed(self):
|
|
|
|
with Blocks() as demo:
|
|
|
|
i = Textbox()
|
|
|
|
o = Textbox()
|
2023-04-28 04:09:50 +08:00
|
|
|
i.change(lambda x: f"{x}0", i, o, api_name="p")
|
|
|
|
i.change(lambda x: f"{x}1", i, o, api_name="p")
|
|
|
|
i.change(lambda x: f"{x}2", i, o, api_name="p_1")
|
2022-09-02 19:10:20 +08:00
|
|
|
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post("/api/p/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test0"]
|
|
|
|
|
|
|
|
response = client.post("/api/p_1/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test1"]
|
|
|
|
|
|
|
|
response = client.post("/api/p_1_1/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test2"]
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
def test_predict_route_without_fn_index(self, test_client):
|
|
|
|
response = test_client.post("/api/predict/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
2022-05-10 09:05:30 +08:00
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == ["testtest"]
|
2022-05-10 09:05:30 +08:00
|
|
|
|
2022-10-25 07:32:37 +08:00
|
|
|
def test_predict_route_batching(self):
|
|
|
|
def batch_fn(x):
|
|
|
|
results = []
|
|
|
|
for word in x:
|
2023-04-28 04:09:50 +08:00
|
|
|
results.append(f"Hello {word}")
|
2022-10-25 07:32:37 +08:00
|
|
|
return (results,)
|
|
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
text = gr.Textbox()
|
|
|
|
btn = gr.Button()
|
|
|
|
btn.click(batch_fn, inputs=text, outputs=text, batch=True, api_name="pred")
|
|
|
|
|
|
|
|
demo.queue()
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post("/api/pred/", json={"data": ["test"]})
|
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == ["Hello test"]
|
2022-10-25 07:32:37 +08:00
|
|
|
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post(
|
|
|
|
"/api/pred/", json={"data": [["test", "test2"]], "batched": True}
|
|
|
|
)
|
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == [["Hello test", "Hello test2"]]
|
2022-10-25 07:32:37 +08:00
|
|
|
|
2022-03-04 13:24:17 +08:00
|
|
|
def test_state(self):
|
2022-04-19 17:06:38 +08:00
|
|
|
def predict(input, history):
|
|
|
|
if history is None:
|
|
|
|
history = ""
|
2022-03-04 13:24:17 +08:00
|
|
|
history += input
|
|
|
|
return history, history
|
|
|
|
|
|
|
|
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
2022-04-19 16:48:33 +08:00
|
|
|
response = client.post(
|
2022-04-19 17:15:53 +08:00
|
|
|
"/api/predict/",
|
2022-05-10 09:05:30 +08:00
|
|
|
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
2022-04-19 16:48:33 +08:00
|
|
|
)
|
2022-03-04 13:24:17 +08:00
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == ["test", None]
|
2022-04-19 16:48:33 +08:00
|
|
|
response = client.post(
|
2022-04-19 17:15:53 +08:00
|
|
|
"/api/predict/",
|
2022-05-10 09:05:30 +08:00
|
|
|
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
2022-04-19 16:48:33 +08:00
|
|
|
)
|
2022-03-04 13:24:17 +08:00
|
|
|
output = dict(response.json())
|
2022-11-08 08:37:55 +08:00
|
|
|
assert output["data"] == ["testtest", None]
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
def test_get_allowed_paths(self):
|
2023-02-06 10:04:26 +08:00
|
|
|
allowed_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
|
|
|
|
allowed_file.write(media_data.BASE64_IMAGE)
|
|
|
|
allowed_file.flush()
|
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
2023-02-06 10:04:26 +08:00
|
|
|
client = TestClient(app)
|
2023-04-21 03:18:33 +08:00
|
|
|
file_response = client.get(f"/file={allowed_file.name}")
|
|
|
|
assert file_response.status_code == 403
|
2023-05-04 06:30:38 +08:00
|
|
|
io.close()
|
2023-02-06 10:04:26 +08:00
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
2023-02-06 10:04:26 +08:00
|
|
|
prevent_thread_lock=True,
|
2023-05-04 06:30:38 +08:00
|
|
|
allowed_paths=[os.path.dirname(allowed_file.name)],
|
2023-02-06 10:04:26 +08:00
|
|
|
)
|
|
|
|
client = TestClient(app)
|
2023-05-04 06:30:38 +08:00
|
|
|
file_response = client.get(f"/file={allowed_file.name}")
|
|
|
|
assert file_response.status_code == 200
|
|
|
|
assert len(file_response.text) == len(media_data.BASE64_IMAGE)
|
|
|
|
io.close()
|
2023-02-06 10:04:26 +08:00
|
|
|
|
2023-05-04 06:30:38 +08:00
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
allowed_paths=[os.path.abspath(allowed_file.name)],
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
2023-02-06 10:04:26 +08:00
|
|
|
file_response = client.get(f"/file={allowed_file.name}")
|
|
|
|
assert file_response.status_code == 200
|
|
|
|
assert len(file_response.text) == len(media_data.BASE64_IMAGE)
|
2023-05-04 06:30:38 +08:00
|
|
|
io.close()
|
|
|
|
|
|
|
|
def test_get_blocked_paths(self):
|
|
|
|
# Test that blocking a default Gradio file path works
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
|
|
dir=".", suffix=".jpg", delete=False
|
|
|
|
) as tmp_file:
|
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
file_response = client.get(f"/file={tmp_file.name}")
|
|
|
|
assert file_response.status_code == 200
|
|
|
|
io.close()
|
|
|
|
os.remove(tmp_file.name)
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
|
|
dir=".", suffix=".jpg", delete=False
|
|
|
|
) as tmp_file:
|
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
|
|
|
prevent_thread_lock=True, blocked_paths=[os.path.abspath(tmp_file.name)]
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
file_response = client.get(f"/file={tmp_file.name}")
|
|
|
|
assert file_response.status_code == 403
|
|
|
|
io.close()
|
|
|
|
os.remove(tmp_file.name)
|
|
|
|
|
|
|
|
# Test that blocking a default Gradio directory works
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
|
|
dir=".", suffix=".jpg", delete=False
|
|
|
|
) as tmp_file:
|
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
|
|
|
prevent_thread_lock=True, blocked_paths=[os.path.abspath(tmp_file.name)]
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
file_response = client.get(f"/file={tmp_file.name}")
|
|
|
|
assert file_response.status_code == 403
|
|
|
|
io.close()
|
|
|
|
os.remove(tmp_file.name)
|
|
|
|
|
|
|
|
# Test that blocking a directory works even if it's also allowed
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
|
|
|
|
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
|
|
|
|
app, _, _ = io.launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
allowed_paths=[os.path.dirname(tmp_file.name)],
|
|
|
|
blocked_paths=[os.path.dirname(tmp_file.name)],
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
file_response = client.get(f"/file={tmp_file.name}")
|
|
|
|
assert file_response.status_code == 403
|
|
|
|
io.close()
|
|
|
|
os.remove(tmp_file.name)
|
2023-02-06 10:04:26 +08:00
|
|
|
|
2023-01-05 01:18:04 +08:00
|
|
|
def test_get_file_created_by_app(self):
|
|
|
|
app, _, _ = gr.Interface(lambda s: s.name, gr.File(), gr.File()).launch(
|
|
|
|
prevent_thread_lock=True
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.post(
|
|
|
|
"/api/predict/",
|
|
|
|
json={
|
|
|
|
"data": [
|
|
|
|
{
|
|
|
|
"data": media_data.BASE64_IMAGE,
|
|
|
|
"name": "bus.png",
|
|
|
|
"size": len(media_data.BASE64_IMAGE),
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"fn_index": 0,
|
|
|
|
"session_hash": "_",
|
|
|
|
},
|
|
|
|
).json()
|
|
|
|
created_file = response["data"][0]["name"]
|
|
|
|
file_response = client.get(f"/file={created_file}")
|
|
|
|
assert file_response.is_success
|
|
|
|
|
2023-01-19 02:13:29 +08:00
|
|
|
backwards_compatible_file_response = client.get(f"/file/{created_file}")
|
|
|
|
assert backwards_compatible_file_response.is_success
|
|
|
|
|
2023-01-16 11:54:09 +08:00
|
|
|
file_response_with_full_range = client.get(
|
|
|
|
f"/file={created_file}", headers={"Range": "bytes=0-"}
|
|
|
|
)
|
|
|
|
assert file_response_with_full_range.is_success
|
|
|
|
assert file_response.text == file_response_with_full_range.text
|
|
|
|
|
|
|
|
file_response_with_partial_range = client.get(
|
|
|
|
f"/file={created_file}", headers={"Range": "bytes=0-10"}
|
|
|
|
)
|
|
|
|
assert file_response_with_partial_range.is_success
|
|
|
|
assert len(file_response_with_partial_range.text) == 11
|
|
|
|
|
2023-01-08 18:42:57 +08:00
|
|
|
def test_mount_gradio_app(self):
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
demo = gr.Interface(
|
|
|
|
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
|
|
|
|
).queue()
|
|
|
|
demo1 = gr.Interface(
|
|
|
|
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
|
|
|
|
).queue()
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/ps")
|
|
|
|
app = gr.mount_gradio_app(app, demo1, path="/py")
|
|
|
|
|
|
|
|
# Use context manager to trigger start up events
|
|
|
|
with TestClient(app) as client:
|
|
|
|
assert client.get("/ps").is_success
|
|
|
|
assert client.get("/py").is_success
|
|
|
|
|
2023-04-21 03:18:33 +08:00
|
|
|
def test_static_file_missing(self, test_client):
|
|
|
|
response = test_client.get(r"/static/not-here.js")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
|
|
def test_asset_file_missing(self, test_client):
|
|
|
|
response = test_client.get(r"/assets/not-here.js")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
|
|
def test_dynamic_file_missing(self, test_client):
|
|
|
|
response = test_client.get(r"/file=not-here.js")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
|
|
def test_dynamic_file_directory(self, test_client):
|
|
|
|
response = test_client.get(r"/file=gradio")
|
|
|
|
assert response.status_code == 403
|
|
|
|
|
2023-03-28 09:01:02 +08:00
|
|
|
def test_mount_gradio_app_raises_error_if_event_queued_but_queue_disabled(self):
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
with gr.Row():
|
|
|
|
with gr.Column():
|
|
|
|
input_ = gr.Textbox()
|
|
|
|
btn = gr.Button("Greet")
|
|
|
|
with gr.Column():
|
|
|
|
output = gr.Textbox()
|
|
|
|
btn.click(
|
|
|
|
lambda x: f"Hello, {x}",
|
|
|
|
inputs=input_,
|
|
|
|
outputs=output,
|
|
|
|
queue=True,
|
|
|
|
api_name="greet",
|
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="The queue is enabled for event greet"):
|
|
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
|
|
|
|
demo.close()
|
|
|
|
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
class TestApp:
|
|
|
|
def test_create_app(self):
|
|
|
|
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
|
|
|
|
assert isinstance(app, FastAPI)
|
|
|
|
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestAuthenticatedRoutes:
|
|
|
|
def test_post_login(self):
|
|
|
|
io = Interface(lambda x: x, "text", "text")
|
|
|
|
app, _, _ = io.launch(
|
2022-09-08 22:35:31 +08:00
|
|
|
auth=("test", "correct_password"),
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
enable_queue=False,
|
2022-01-26 04:32:32 +08:00
|
|
|
)
|
2022-11-08 08:37:55 +08:00
|
|
|
client = TestClient(app)
|
2022-01-26 04:32:32 +08:00
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
response = client.post(
|
2022-11-15 03:31:03 +08:00
|
|
|
"/login",
|
2023-04-29 05:59:42 +08:00
|
|
|
data={"username": "test", "password": "correct_password"},
|
2022-01-26 04:32:32 +08:00
|
|
|
)
|
2023-01-04 02:13:11 +08:00
|
|
|
assert response.status_code == 200
|
2022-11-08 08:37:55 +08:00
|
|
|
response = client.post(
|
2022-11-15 03:31:03 +08:00
|
|
|
"/login",
|
2023-04-29 05:59:42 +08:00
|
|
|
data={"username": "test", "password": "incorrect_password"},
|
2022-01-26 04:32:32 +08:00
|
|
|
)
|
2022-11-08 08:37:55 +08:00
|
|
|
assert response.status_code == 400
|
2022-01-26 04:32:32 +08:00
|
|
|
|
|
|
|
|
2022-10-25 07:32:37 +08:00
|
|
|
class TestQueueRoutes:
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
sys.version_info < (3, 8),
|
|
|
|
reason="Mocks don't work with async context managers in 3.7",
|
|
|
|
)
|
|
|
|
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
|
|
|
|
async def test_queue_join_routes_sets_url_if_none_set(self, mock_get_url):
|
|
|
|
io = Interface(lambda x: x, "text", "text").queue()
|
|
|
|
io.launch(prevent_thread_lock=True)
|
|
|
|
io._queue.server_path = None
|
|
|
|
async with websockets.connect(
|
|
|
|
f"{io.local_url.replace('http', 'ws')}queue/join"
|
|
|
|
) as ws:
|
|
|
|
completed = False
|
|
|
|
while not completed:
|
|
|
|
msg = json.loads(await ws.recv())
|
|
|
|
if msg["msg"] == "send_data":
|
|
|
|
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
|
|
|
|
if msg["msg"] == "send_hash":
|
|
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
|
|
completed = msg["msg"] == "process_completed"
|
|
|
|
assert io._queue.server_path == "foo_url"
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"ws_url,answer",
|
|
|
|
[
|
|
|
|
("ws://127.0.0.1:7861/queue/join", "http://127.0.0.1:7861/"),
|
|
|
|
(
|
|
|
|
"ws://127.0.0.1:7861/gradio/gradio/gradio/queue/join",
|
|
|
|
"http://127.0.0.1:7861/gradio/gradio/gradio/",
|
|
|
|
),
|
|
|
|
(
|
2022-11-01 11:29:13 +08:00
|
|
|
"wss://gradio-titanic-survival.hf.space/queue/join",
|
|
|
|
"https://gradio-titanic-survival.hf.space/",
|
2022-10-25 07:32:37 +08:00
|
|
|
),
|
|
|
|
],
|
2022-10-19 04:12:51 +08:00
|
|
|
)
|
2022-10-25 07:32:37 +08:00
|
|
|
def test_get_server_url_from_ws_url(self, ws_url, answer):
|
|
|
|
assert routes.get_server_url_from_ws_url(ws_url) == answer
|
|
|
|
|
|
|
|
|
|
|
|
class TestDevMode:
|
|
|
|
def test_mount_gradio_app_set_dev_mode_false(self):
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
def read_main():
|
|
|
|
return {"message": "Hello!"}
|
|
|
|
|
|
|
|
with gr.Blocks() as blocks:
|
|
|
|
gr.Textbox("Hello from gradio!")
|
|
|
|
|
|
|
|
app = routes.mount_gradio_app(app, blocks, path="/gradio")
|
|
|
|
gradio_fast_api = next(
|
|
|
|
route for route in app.routes if isinstance(route, starlette.routing.Mount)
|
|
|
|
)
|
|
|
|
assert not gradio_fast_api.app.blocks.dev_mode
|
2022-10-19 04:12:51 +08:00
|
|
|
|
|
|
|
|
2022-11-19 16:52:06 +08:00
|
|
|
class TestPassingRequest:
|
|
|
|
def test_request_included_with_regular_function(self):
|
|
|
|
def identity(name, request: gr.Request):
|
|
|
|
assert isinstance(request.client.host, str)
|
|
|
|
return name
|
|
|
|
|
|
|
|
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
response = client.post("/api/predict/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
2023-02-25 00:40:34 +08:00
|
|
|
assert output["data"] == ["test"]
|
|
|
|
|
2023-03-15 06:15:12 +08:00
|
|
|
def test_request_get_headers(self):
|
|
|
|
def identity(name, request: gr.Request):
|
|
|
|
assert isinstance(request.headers["user-agent"], str)
|
|
|
|
assert isinstance(request.headers.items(), list)
|
|
|
|
assert isinstance(request.headers.keys(), list)
|
|
|
|
assert isinstance(request.headers.values(), list)
|
|
|
|
assert isinstance(dict(request.headers), dict)
|
|
|
|
user_agent = request.headers["user-agent"]
|
|
|
|
assert "testclient" in user_agent
|
|
|
|
return name
|
|
|
|
|
|
|
|
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
response = client.post("/api/predict/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test"]
|
|
|
|
|
2023-02-25 00:40:34 +08:00
|
|
|
def test_request_includes_username_as_none_if_no_auth(self):
|
|
|
|
def identity(name, request: gr.Request):
|
|
|
|
assert request.username is None
|
|
|
|
return name
|
|
|
|
|
|
|
|
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
|
|
|
|
prevent_thread_lock=True,
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
response = client.post("/api/predict/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
|
|
|
assert output["data"] == ["test"]
|
|
|
|
|
|
|
|
def test_request_includes_username_with_auth(self):
|
|
|
|
def identity(name, request: gr.Request):
|
|
|
|
assert request.username == "admin"
|
|
|
|
return name
|
|
|
|
|
|
|
|
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
|
|
|
|
prevent_thread_lock=True, auth=("admin", "password")
|
|
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
client.post(
|
|
|
|
"/login",
|
2023-04-29 05:59:42 +08:00
|
|
|
data={"username": "admin", "password": "password"},
|
2023-02-25 00:40:34 +08:00
|
|
|
)
|
|
|
|
response = client.post("/api/predict/", json={"data": ["test"]})
|
|
|
|
assert response.status_code == 200
|
|
|
|
output = dict(response.json())
|
2022-11-19 16:52:06 +08:00
|
|
|
assert output["data"] == ["test"]
|
|
|
|
|
|
|
|
|
2022-10-25 23:45:15 +08:00
|
|
|
def test_predict_route_is_blocked_if_api_open_false():
|
|
|
|
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
|
|
|
|
api_open=False
|
|
|
|
)
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
assert not io.show_api
|
|
|
|
client = TestClient(app)
|
|
|
|
result = client.post(
|
|
|
|
"/api/predict", json={"fn_index": 0, "data": [5], "session_hash": "foo"}
|
|
|
|
)
|
|
|
|
assert result.status_code == 401
|
|
|
|
|
|
|
|
|
|
|
|
def test_predict_route_not_blocked_if_queue_disabled():
|
|
|
|
with Blocks() as demo:
|
|
|
|
input = Textbox()
|
|
|
|
output = Textbox()
|
|
|
|
number = Number()
|
|
|
|
button = Button()
|
|
|
|
button.click(
|
|
|
|
lambda x: f"Hello, {x}!", input, output, queue=False, api_name="not_blocked"
|
|
|
|
)
|
|
|
|
button.click(lambda: 42, None, number, queue=True, api_name="blocked")
|
|
|
|
app, _, _ = demo.queue(api_open=False).launch(
|
|
|
|
prevent_thread_lock=True, show_api=True
|
|
|
|
)
|
|
|
|
assert not demo.show_api
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
result = client.post("/api/blocked", json={"data": [], "session_hash": "foo"})
|
|
|
|
assert result.status_code == 401
|
|
|
|
result = client.post(
|
|
|
|
"/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
|
|
|
|
)
|
|
|
|
assert result.status_code == 200
|
|
|
|
assert result.json()["data"] == ["Hello, freddy!"]
|
|
|
|
|
|
|
|
|
|
|
|
def test_predict_route_not_blocked_if_routes_open():
|
|
|
|
with Blocks() as demo:
|
|
|
|
input = Textbox()
|
|
|
|
output = Textbox()
|
|
|
|
button = Button()
|
|
|
|
button.click(
|
|
|
|
lambda x: f"Hello, {x}!", input, output, queue=True, api_name="not_blocked"
|
|
|
|
)
|
|
|
|
app, _, _ = demo.queue(api_open=True).launch(
|
|
|
|
prevent_thread_lock=True, show_api=False
|
|
|
|
)
|
|
|
|
assert demo.show_api
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
result = client.post(
|
|
|
|
"/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
|
|
|
|
)
|
|
|
|
assert result.status_code == 200
|
|
|
|
assert result.json()["data"] == ["Hello, freddy!"]
|
|
|
|
|
|
|
|
demo.close()
|
|
|
|
demo.queue(api_open=False).launch(prevent_thread_lock=True, show_api=False)
|
|
|
|
assert not demo.show_api
|
|
|
|
|
|
|
|
|
|
|
|
def test_show_api_queue_not_enabled():
|
|
|
|
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]])
|
|
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
assert io.show_api
|
|
|
|
io.close()
|
|
|
|
io.launch(prevent_thread_lock=True, show_api=False)
|
|
|
|
assert not io.show_api
|
2022-12-14 07:01:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_orjson_serialization():
|
|
|
|
df = pd.DataFrame(
|
|
|
|
{
|
|
|
|
"date_1": pd.date_range("2021-01-01", periods=2),
|
|
|
|
"date_2": pd.date_range("2022-02-15", periods=2).strftime("%B %d, %Y, %r"),
|
|
|
|
"number": np.array([0.2233, 0.57281]),
|
|
|
|
"number_2": np.array([84, 23]).astype(np.int64),
|
|
|
|
"bool": [True, False],
|
|
|
|
"markdown": ["# Hello", "# Goodbye"],
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
gr.DataFrame(df)
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
|
|
test_client = TestClient(app)
|
|
|
|
response = test_client.get("/")
|
|
|
|
assert response.status_code == 200
|
|
|
|
demo.close()
|