mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-03 01:50:59 +08:00
9b15e9a1f8
* Add note to guide * Add more explanation * Undo changes * CHANGELOG * Make work with queue + fix docs * Changelog fix * Add unit test --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
1441 lines
49 KiB
Python
1441 lines
49 KiB
Python
import asyncio
|
|
import copy
|
|
import io
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import random
|
|
import sys
|
|
import time
|
|
import unittest.mock as mock
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
from string import capwords
|
|
from unittest.mock import patch
|
|
|
|
import mlflow
|
|
import pytest
|
|
import uvicorn
|
|
import wandb
|
|
import websockets
|
|
from fastapi.testclient import TestClient
|
|
|
|
import gradio as gr
|
|
from gradio.exceptions import DuplicateBlockError
|
|
from gradio.networking import Server, get_first_available_port
|
|
from gradio.test_data.blocks_configs import XRAY_CONFIG
|
|
from gradio.utils import assert_configs_are_equivalent_besides_ids
|
|
|
|
pytest_plugins = ("pytest_asyncio",)
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
|
@contextmanager
|
|
def captured_output():
|
|
new_out, new_err = io.StringIO(), io.StringIO()
|
|
old_out, old_err = sys.stdout, sys.stderr
|
|
try:
|
|
sys.stdout, sys.stderr = new_out, new_err
|
|
yield sys.stdout, sys.stderr
|
|
finally:
|
|
sys.stdout, sys.stderr = old_out, old_err
|
|
|
|
|
|
class TestBlocksMethods:
|
|
maxDiff = None
|
|
|
|
def test_set_share(self):
|
|
with gr.Blocks() as demo:
|
|
# self.share is False when instantiating the class
|
|
assert not demo.share
|
|
# default is False, if share is None
|
|
demo.share = None
|
|
assert not demo.share
|
|
# if set to True, it doesn't change
|
|
demo.share = True
|
|
assert demo.share
|
|
|
|
@patch("gradio.networking.setup_tunnel")
|
|
@patch("gradio.utils.colab_check")
|
|
def test_set_share_in_colab(self, mock_colab_check, mock_setup_tunnel):
|
|
mock_colab_check.return_value = True
|
|
mock_setup_tunnel.return_value = "http://localhost:7860/"
|
|
with gr.Blocks() as demo:
|
|
# self.share is False when instantiating the class
|
|
assert not demo.share
|
|
# share default is False, if share is None in colab and no queueing
|
|
demo.launch(prevent_thread_lock=True)
|
|
assert not demo.share
|
|
demo.close()
|
|
# share becomes true, if share is None in colab with queueing
|
|
demo.queue()
|
|
demo.launch(prevent_thread_lock=True)
|
|
assert demo.share
|
|
demo.close()
|
|
|
|
def test_default_enabled_deprecated(self):
|
|
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox())
|
|
with pytest.warns(
|
|
UserWarning, match="The default_enabled parameter of queue has no effect"
|
|
):
|
|
io.queue(default_enabled=True)
|
|
|
|
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox())
|
|
with warnings.catch_warnings(record=True) as record:
|
|
warnings.simplefilter("always")
|
|
io.queue()
|
|
for warning in record:
|
|
assert "default_enabled" not in str(warning.message)
|
|
|
|
def test_xray(self):
|
|
def fake_func():
|
|
return "Hello There"
|
|
|
|
xray_model = lambda diseases, img: {
|
|
disease: random.random() for disease in diseases
|
|
}
|
|
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(
|
|
"""
|
|
# Detect Disease From Scan
|
|
With this model you can lorem ipsum
|
|
- ipsum 1
|
|
- ipsum 2
|
|
"""
|
|
)
|
|
disease = gr.CheckboxGroup(
|
|
choices=["Covid", "Malaria", "Lung Cancer"], label="Disease to Scan For"
|
|
)
|
|
|
|
with gr.Tabs():
|
|
with gr.TabItem("X-ray"):
|
|
with gr.Row():
|
|
xray_scan = gr.Image()
|
|
xray_results = gr.JSON()
|
|
xray_run = gr.Button("Run")
|
|
xray_run.click(
|
|
xray_model, inputs=[disease, xray_scan], outputs=xray_results
|
|
)
|
|
|
|
with gr.TabItem("CT Scan"):
|
|
with gr.Row():
|
|
ct_scan = gr.Image()
|
|
ct_results = gr.JSON()
|
|
ct_run = gr.Button("Run")
|
|
ct_run.click(
|
|
ct_model, inputs=[disease, ct_scan], outputs=ct_results
|
|
)
|
|
textbox = gr.Textbox()
|
|
demo.load(fake_func, [], [textbox])
|
|
|
|
config = demo.get_config_file()
|
|
assert assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, config)
|
|
assert config["show_api"] is True
|
|
_ = demo.launch(prevent_thread_lock=True, show_api=False)
|
|
assert demo.config["show_api"] is False
|
|
|
|
def test_load_from_config(self):
|
|
def update(name):
|
|
return f"Welcome to Gradio, {name}!"
|
|
|
|
with gr.Blocks() as demo1:
|
|
inp = gr.Textbox(placeholder="What is your name?")
|
|
out = gr.Textbox()
|
|
|
|
inp.submit(fn=update, inputs=inp, outputs=out, api_name="greet")
|
|
|
|
gr.Image().style(height=54, width=240)
|
|
|
|
config1 = demo1.get_config_file()
|
|
demo2 = gr.Blocks.from_config(config1, [update])
|
|
config2 = demo2.get_config_file()
|
|
assert assert_configs_are_equivalent_besides_ids(config1, config2)
|
|
|
|
def test_partial_fn_in_config(self):
|
|
def greet(name, formatter):
|
|
return formatter("Hello " + name + "!")
|
|
|
|
greet_upper_case = partial(greet, formatter=capwords)
|
|
with gr.Blocks() as demo:
|
|
t = gr.Textbox()
|
|
o = gr.Textbox()
|
|
t.change(greet_upper_case, t, o)
|
|
|
|
assert len(demo.fns) == 1
|
|
assert "fn" in str(demo.fns[0])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dict_inputs_in_config(self):
|
|
with gr.Blocks() as demo:
|
|
first = gr.Textbox()
|
|
last = gr.Textbox()
|
|
btn = gr.Button()
|
|
greeting = gr.Textbox()
|
|
|
|
def greet(data):
|
|
return f"Hello {data[first]} {data[last]}"
|
|
|
|
btn.click(greet, {first, last}, greeting)
|
|
|
|
result = await demo.process_api(inputs=["huggy", "face"], fn_index=0, state={})
|
|
assert result["data"] == ["Hello huggy face"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_function(self):
|
|
async def wait(x):
|
|
await asyncio.sleep(0.01)
|
|
return x
|
|
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
button = gr.Button()
|
|
button.click(wait, [text], [text])
|
|
|
|
start = time.time()
|
|
result = await demo.process_api(inputs=[1], fn_index=0, state={})
|
|
end = time.time()
|
|
difference = end - start
|
|
assert difference >= 0.01
|
|
assert result
|
|
|
|
def test_integration_wandb(self):
|
|
with captured_output() as (out, err):
|
|
wandb.log = mock.MagicMock()
|
|
wandb.Html = mock.MagicMock()
|
|
demo = gr.Blocks()
|
|
with demo:
|
|
gr.Textbox("Hi there!")
|
|
demo.integrate(wandb=wandb)
|
|
|
|
assert (
|
|
out.getvalue().strip()
|
|
== "The WandB integration requires you to `launch(share=True)` first."
|
|
)
|
|
demo.share_url = "tmp"
|
|
demo.integrate(wandb=wandb)
|
|
wandb.log.assert_called_once()
|
|
|
|
@mock.patch("comet_ml.Experiment")
|
|
def test_integration_comet(self, mock_experiment):
|
|
experiment = mock_experiment()
|
|
experiment.log_text = mock.MagicMock()
|
|
experiment.log_other = mock.MagicMock()
|
|
|
|
demo = gr.Blocks()
|
|
with demo:
|
|
gr.Textbox("Hi there!")
|
|
|
|
demo.launch(prevent_thread_lock=True)
|
|
demo.integrate(comet_ml=experiment)
|
|
experiment.log_text.assert_called_with("gradio: " + demo.local_url)
|
|
demo.share_url = "tmp" # used to avoid creating real share links.
|
|
demo.integrate(comet_ml=experiment)
|
|
experiment.log_text.assert_called_with("gradio: " + demo.share_url)
|
|
assert experiment.log_other.call_count == 2
|
|
demo.share_url = None
|
|
demo.close()
|
|
|
|
def test_integration_mlflow(self):
|
|
mlflow.log_param = mock.MagicMock()
|
|
demo = gr.Blocks()
|
|
with demo:
|
|
gr.Textbox("Hi there!")
|
|
|
|
demo.launch(prevent_thread_lock=True)
|
|
demo.integrate(mlflow=mlflow)
|
|
mlflow.log_param.assert_called_with(
|
|
"Gradio Interface Local Link", demo.local_url
|
|
)
|
|
demo.share_url = "tmp" # used to avoid creating real share links.
|
|
demo.integrate(mlflow=mlflow)
|
|
mlflow.log_param.assert_called_with(
|
|
"Gradio Interface Share Link", demo.share_url
|
|
)
|
|
demo.share_url = None
|
|
demo.close()
|
|
|
|
@mock.patch("requests.post")
|
|
def test_initiated_analytics(self, mock_post):
|
|
with gr.Blocks(analytics_enabled=True):
|
|
pass
|
|
mock_post.assert_called_once()
|
|
|
|
def test_show_error(self):
|
|
with gr.Blocks() as demo:
|
|
pass
|
|
|
|
assert demo.show_error
|
|
demo.launch(prevent_thread_lock=True)
|
|
assert not demo.show_error
|
|
demo.close()
|
|
demo.launch(show_error=True, prevent_thread_lock=True)
|
|
assert demo.show_error
|
|
demo.close()
|
|
|
|
def test_custom_css(self):
|
|
css = """
|
|
.gr-button {
|
|
color: white;
|
|
border-color: black;
|
|
background: black;
|
|
}
|
|
"""
|
|
css = css * 5 # simulate a long css string
|
|
block = gr.Blocks(css=css)
|
|
|
|
assert block.css == css
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restart_after_close(self):
|
|
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox()).queue()
|
|
io.launch(prevent_thread_lock=True)
|
|
|
|
async with websockets.connect(
|
|
f"{io.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
if msg["msg"] == "process_completed":
|
|
completed = True
|
|
assert msg["output"]["data"][0] == "freddy"
|
|
|
|
io.close()
|
|
io.launch(prevent_thread_lock=True)
|
|
|
|
async with websockets.connect(
|
|
f"{io.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": ["Victor"], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
if msg["msg"] == "process_completed":
|
|
completed = True
|
|
assert msg["output"]["data"][0] == "Victor"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_without_launching(self):
|
|
"""Test that we can start the app and use queue without calling .launch().
|
|
|
|
This is essentially what the 'gradio' reload mode does
|
|
"""
|
|
|
|
port = get_first_available_port(7860, 7870)
|
|
|
|
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox()).queue()
|
|
|
|
config = uvicorn.Config(app=io.app, port=port, log_level="warning")
|
|
|
|
server = Server(config=config)
|
|
server.run_in_thread()
|
|
|
|
try:
|
|
async with websockets.connect(f"ws://localhost:{port}/queue/join") as ws:
|
|
completed = False
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": ["Victor"], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(
|
|
json.dumps({"fn_index": 0, "session_hash": "shdce"})
|
|
)
|
|
if msg["msg"] == "process_completed":
|
|
completed = True
|
|
assert msg["output"]["data"][0] == "Victor"
|
|
finally:
|
|
server.close()
|
|
|
|
|
|
class TestComponentsInBlocks:
|
|
def test_slider_random_value_config(self):
|
|
with gr.Blocks() as demo:
|
|
gr.Slider(
|
|
value=11.2,
|
|
minimum=-10.2,
|
|
maximum=15,
|
|
label="Non-random Slider (Static)",
|
|
)
|
|
gr.Slider(
|
|
randomize=True,
|
|
minimum=100,
|
|
maximum=200,
|
|
label="Random Slider (Input 1)",
|
|
)
|
|
gr.Slider(
|
|
randomize=True,
|
|
minimum=10,
|
|
maximum=23.2,
|
|
label="Random Slider (Input 2)",
|
|
)
|
|
for component in demo.blocks.values():
|
|
if isinstance(component, gr.components.IOComponent):
|
|
if "Non-random" in component.label:
|
|
assert not component.load_event
|
|
else:
|
|
assert component.load_event
|
|
dependencies_on_load = [
|
|
dep["trigger"] == "load" for dep in demo.config["dependencies"]
|
|
]
|
|
assert all(dependencies_on_load)
|
|
assert len(dependencies_on_load) == 2
|
|
assert not any([dep["queue"] for dep in demo.config["dependencies"]])
|
|
|
|
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
|
|
io_components = [comp for comp in io_components if comp not in [gr.State]]
|
|
interface = gr.Interface(
|
|
lambda *args: None,
|
|
inputs=[comp(value=lambda: None, every=1) for comp in io_components],
|
|
outputs=None,
|
|
)
|
|
|
|
dependencies_on_load = [
|
|
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
|
|
]
|
|
assert len(dependencies_on_load) == len(io_components)
|
|
assert all([dep["every"] == 1 for dep in dependencies_on_load])
|
|
|
|
def test_get_load_events(self, io_components):
|
|
components = []
|
|
with gr.Blocks() as demo:
|
|
for component in io_components:
|
|
components.append(component(value=lambda: None, every=1))
|
|
assert [comp.load_event for comp in components] == demo.dependencies
|
|
|
|
def test_blocks_do_not_filter_none_values_from_updates(self, io_components):
|
|
io_components = [
|
|
c()
|
|
for c in io_components
|
|
if c not in [gr.State, gr.Button, gr.ScatterPlot, gr.LinePlot]
|
|
]
|
|
with gr.Blocks() as demo:
|
|
for component in io_components:
|
|
component.render()
|
|
btn = gr.Button(value="Reset")
|
|
btn.click(
|
|
lambda: [gr.update(value=None) for _ in io_components],
|
|
inputs=[],
|
|
outputs=io_components,
|
|
)
|
|
|
|
output = demo.postprocess_data(
|
|
0, [gr.update(value=None) for _ in io_components], state={}
|
|
)
|
|
assert all(
|
|
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
|
|
)
|
|
|
|
def test_blocks_does_not_replace_keyword_literal(self):
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
btn = gr.Button(value="Reset")
|
|
btn.click(
|
|
lambda: gr.update(value="NO_VALUE"),
|
|
inputs=[],
|
|
outputs=text,
|
|
)
|
|
|
|
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state={})
|
|
assert output[0]["value"] == "NO_VALUE"
|
|
|
|
def test_blocks_returns_correct_output_dict_single_key(self):
|
|
with gr.Blocks() as demo:
|
|
num = gr.Number()
|
|
num2 = gr.Number()
|
|
update = gr.Button(value="update")
|
|
|
|
def update_values(val):
|
|
return {num2: gr.Number.update(value=42)}
|
|
|
|
update.click(update_values, inputs=[num], outputs=[num2])
|
|
|
|
output = demo.postprocess_data(0, {num2: gr.Number.update(value=42)}, state={})
|
|
assert output[0]["value"] == 42
|
|
|
|
output = demo.postprocess_data(0, {num2: 23}, state={})
|
|
assert output[0] == 23
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocks_update_dict_without_postprocessing(self):
|
|
def infer(x):
|
|
return gr.media_data.BASE64_IMAGE, gr.update(visible=True)
|
|
|
|
with gr.Blocks() as demo:
|
|
prompt = gr.Textbox()
|
|
image = gr.Image()
|
|
run_button = gr.Button()
|
|
share_button = gr.Button("share", visible=False)
|
|
run_button.click(infer, prompt, [image, share_button], postprocess=False)
|
|
|
|
output = await demo.process_api(0, ["test"], state={})
|
|
assert output["data"][0] == gr.media_data.BASE64_IMAGE
|
|
assert output["data"][1] == {"__type__": "update", "visible": True}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocks_update_dict_does_not_postprocess_value_if_postprocessing_false(
|
|
self,
|
|
):
|
|
def infer(x):
|
|
return gr.Image.update(value=gr.media_data.BASE64_IMAGE)
|
|
|
|
with gr.Blocks() as demo:
|
|
prompt = gr.Textbox()
|
|
image = gr.Image()
|
|
run_button = gr.Button()
|
|
run_button.click(infer, [prompt], [image], postprocess=False)
|
|
|
|
output = await demo.process_api(0, ["test"], state={})
|
|
assert output["data"][0] == {
|
|
"__type__": "update",
|
|
"value": gr.media_data.BASE64_IMAGE,
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocks_update_interactive(
|
|
self,
|
|
):
|
|
def specific_update():
|
|
return [
|
|
gr.Image.update(interactive=True),
|
|
gr.Textbox.update(interactive=True),
|
|
]
|
|
|
|
def generic_update():
|
|
return [gr.update(interactive=True), gr.update(interactive=True)]
|
|
|
|
with gr.Blocks() as demo:
|
|
run = gr.Button(value="Make interactive")
|
|
image = gr.Image()
|
|
textbox = gr.Text()
|
|
run.click(specific_update, None, [image, textbox])
|
|
run.click(generic_update, None, [image, textbox])
|
|
|
|
for fn_index in range(2):
|
|
output = await demo.process_api(fn_index, [], state={})
|
|
assert output["data"][0] == {
|
|
"interactive": True,
|
|
"__type__": "update",
|
|
"mode": "dynamic",
|
|
}
|
|
assert output["data"][1] == {"__type__": "update", "mode": "dynamic"}
|
|
|
|
|
|
class TestCallFunction:
|
|
@pytest.mark.asyncio
|
|
async def test_call_regular_function(self):
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(
|
|
lambda x: "Hello, " + x,
|
|
inputs=text,
|
|
outputs=text,
|
|
)
|
|
|
|
output = await demo.call_function(0, ["World"])
|
|
assert output["prediction"] == "Hello, World"
|
|
output = demo("World")
|
|
assert output == "Hello, World"
|
|
|
|
output = await demo.call_function(0, ["Abubakar"])
|
|
assert output["prediction"] == "Hello, Abubakar"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_multiple_functions(self):
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
text2 = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(
|
|
lambda x: "Hello, " + x,
|
|
inputs=text,
|
|
outputs=text,
|
|
)
|
|
text.change(
|
|
lambda x: "Hi, " + x,
|
|
inputs=text,
|
|
outputs=text2,
|
|
)
|
|
|
|
output = await demo.call_function(0, ["World"])
|
|
assert output["prediction"] == "Hello, World"
|
|
output = demo("World")
|
|
assert output == "Hello, World"
|
|
|
|
output = await demo.call_function(1, ["World"])
|
|
assert output["prediction"] == "Hi, World"
|
|
output = demo("World", fn_index=1) # fn_index must be a keyword argument
|
|
assert output == "Hi, World"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_generator(self):
|
|
def generator(x):
|
|
for i in range(x):
|
|
yield i
|
|
|
|
with gr.Blocks() as demo:
|
|
inp = gr.Number()
|
|
out = gr.Number()
|
|
btn = gr.Button()
|
|
btn.click(
|
|
generator,
|
|
inputs=inp,
|
|
outputs=out,
|
|
)
|
|
|
|
demo.queue()
|
|
assert demo.config["enable_queue"]
|
|
|
|
output = await demo.call_function(0, [3])
|
|
assert output["prediction"] == 0
|
|
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == 1
|
|
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == 2
|
|
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == gr.components._Keywords.FINISHED_ITERATING
|
|
assert output["iterator"] is None
|
|
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_both_generator_and_function(self):
|
|
def generator(x):
|
|
for i in range(x):
|
|
yield i, x
|
|
|
|
with gr.Blocks() as demo:
|
|
inp = gr.Number()
|
|
out1 = gr.Number()
|
|
out2 = gr.Number()
|
|
btn = gr.Button()
|
|
inp.change(lambda x: x + x, inp, out1)
|
|
btn.click(
|
|
generator,
|
|
inputs=inp,
|
|
outputs=[out1, out2],
|
|
)
|
|
|
|
demo.queue()
|
|
|
|
output = await demo.call_function(0, [2])
|
|
assert output["prediction"] == 4
|
|
output = await demo.call_function(0, [-1])
|
|
assert output["prediction"] == -2
|
|
|
|
output = await demo.call_function(1, [3])
|
|
assert output["prediction"] == (0, 3)
|
|
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == (1, 3)
|
|
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == (2, 3)
|
|
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == (gr.components._Keywords.FINISHED_ITERATING,) * 2
|
|
assert output["iterator"] is None
|
|
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
|
assert output["prediction"] == (0, 3)
|
|
|
|
|
|
class TestBatchProcessing:
|
|
def test_raise_exception_if_batching_an_event_thats_not_queued(self):
|
|
def trim(words, lens):
|
|
trimmed_words = []
|
|
for w, l in zip(words, lens):
|
|
trimmed_words.append(w[: int(l)])
|
|
return [trimmed_words]
|
|
|
|
msg = "In order to use batching, the queue must be enabled."
|
|
|
|
with pytest.raises(ValueError, match=msg):
|
|
demo = gr.Interface(
|
|
trim, ["textbox", "number"], ["textbox"], batch=True, max_batch_size=16
|
|
)
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
with pytest.raises(ValueError, match=msg):
|
|
with gr.Blocks() as demo:
|
|
with gr.Row():
|
|
word = gr.Textbox(label="word")
|
|
leng = gr.Number(label="leng")
|
|
output = gr.Textbox(label="Output")
|
|
with gr.Row():
|
|
run = gr.Button()
|
|
|
|
run.click(trim, [word, leng], output, batch=True, max_batch_size=16)
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
with pytest.raises(ValueError, match=msg):
|
|
with gr.Blocks() as demo:
|
|
with gr.Row():
|
|
word = gr.Textbox(label="word")
|
|
leng = gr.Number(label="leng")
|
|
output = gr.Textbox(label="Output")
|
|
with gr.Row():
|
|
run = gr.Button()
|
|
|
|
run.click(
|
|
trim,
|
|
[word, leng],
|
|
output,
|
|
batch=True,
|
|
max_batch_size=16,
|
|
queue=False,
|
|
)
|
|
demo.queue()
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_regular_function(self):
|
|
def batch_fn(x):
|
|
results = []
|
|
for word in x:
|
|
results.append("Hello " + word)
|
|
return (results,)
|
|
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
|
|
|
|
output = await demo.call_function(0, [["Adam", "Yahya"]])
|
|
assert output["prediction"][0] == ["Hello Adam", "Hello Yahya"]
|
|
output = demo("Abubakar")
|
|
assert output == "Hello Abubakar"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_functions_multiple_parameters(self):
|
|
def regular_fn(word1, word2):
|
|
return len(word1) > len(word2)
|
|
|
|
def batch_fn(words, lengths):
|
|
comparisons = []
|
|
trim_words = []
|
|
for word, length in zip(words, lengths):
|
|
trim_words.append(word[:length])
|
|
comparisons.append(len(word) > length)
|
|
return trim_words, comparisons
|
|
|
|
with gr.Blocks() as demo:
|
|
text1 = gr.Textbox()
|
|
text2 = gr.Textbox()
|
|
leng = gr.Number(precision=0)
|
|
bigger = gr.Checkbox()
|
|
btn1 = gr.Button("Check")
|
|
btn2 = gr.Button("Trim")
|
|
btn1.click(regular_fn, inputs=[text1, text2], outputs=bigger)
|
|
btn2.click(
|
|
batch_fn,
|
|
inputs=[text1, leng],
|
|
outputs=[text1, bigger],
|
|
batch=True,
|
|
)
|
|
|
|
output = await demo.call_function(0, ["Adam", "Yahya"])
|
|
assert output["prediction"] is False
|
|
output = demo("Abubakar", "Abid")
|
|
assert output
|
|
|
|
output = await demo.call_function(1, [["Adam", "Mary"], [3, 5]])
|
|
assert output["prediction"] == (
|
|
["Ada", "Mary"],
|
|
[True, False],
|
|
)
|
|
output = demo("Abubakar", 3, fn_index=1)
|
|
assert output == ["Abu", True]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_batch_generator(self):
|
|
with pytest.raises(ValueError):
|
|
|
|
def batch_fn(x):
|
|
results = []
|
|
for word in x:
|
|
results.append("Hello " + word)
|
|
yield (results,)
|
|
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
|
|
|
|
await demo.process_api(0, [["Adam", "Yahya"]], state={})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exceeds_max_batch_size(self):
|
|
with pytest.raises(ValueError):
|
|
|
|
def batch_fn(x):
|
|
results = []
|
|
for word in x:
|
|
results.append("Hello " + word)
|
|
return (results,)
|
|
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(
|
|
batch_fn, inputs=text, outputs=text, batch=True, max_batch_size=2
|
|
)
|
|
|
|
await demo.process_api(0, [["A", "B", "C"]], state={})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unequal_batch_sizes(self):
|
|
with pytest.raises(ValueError):
|
|
|
|
def batch_fn(x, y):
|
|
results = []
|
|
for word1, word2 in zip(x, y):
|
|
results.append("Hello " + word1 + word2)
|
|
return (results,)
|
|
|
|
with gr.Blocks() as demo:
|
|
t1 = gr.Textbox()
|
|
t2 = gr.Textbox()
|
|
btn = gr.Button()
|
|
btn.click(batch_fn, inputs=[t1, t2], outputs=t1, batch=True)
|
|
|
|
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]], state={})
|
|
|
|
|
|
class TestSpecificUpdate:
|
|
def test_without_update(self):
|
|
with pytest.raises(KeyError):
|
|
gr.Textbox.get_specific_update({"lines": 4})
|
|
|
|
def test_with_update(self):
|
|
specific_update = gr.Textbox.get_specific_update(
|
|
{"lines": 4, "__type__": "update", "interactive": False}
|
|
)
|
|
assert specific_update == {
|
|
"lines": 4,
|
|
"max_lines": None,
|
|
"placeholder": None,
|
|
"label": None,
|
|
"show_label": None,
|
|
"type": None,
|
|
"type": None,
|
|
"visible": None,
|
|
"value": gr.components._Keywords.NO_VALUE,
|
|
"__type__": "update",
|
|
"mode": "static",
|
|
}
|
|
|
|
specific_update = gr.Textbox.get_specific_update(
|
|
{"lines": 4, "__type__": "update", "interactive": True}
|
|
)
|
|
assert specific_update == {
|
|
"lines": 4,
|
|
"max_lines": None,
|
|
"placeholder": None,
|
|
"label": None,
|
|
"show_label": None,
|
|
"type": None,
|
|
"visible": None,
|
|
"value": gr.components._Keywords.NO_VALUE,
|
|
"__type__": "update",
|
|
"mode": "dynamic",
|
|
}
|
|
|
|
def test_with_generic_update(self):
|
|
specific_update = gr.Video.get_specific_update(
|
|
{
|
|
"visible": True,
|
|
"value": "test.mp4",
|
|
"__type__": "generic_update",
|
|
"interactive": True,
|
|
}
|
|
)
|
|
assert specific_update == {
|
|
"source": None,
|
|
"label": None,
|
|
"show_label": None,
|
|
"visible": True,
|
|
"value": "test.mp4",
|
|
"mode": "dynamic",
|
|
"interactive": True,
|
|
"__type__": "update",
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_accordion_update(self):
|
|
with gr.Blocks() as demo:
|
|
with gr.Accordion(label="Open for greeting", open=False) as accordion:
|
|
gr.Textbox("Hello!")
|
|
open_btn = gr.Button(label="Open Accordion")
|
|
close_btn = gr.Button(label="Close Accordion")
|
|
open_btn.click(
|
|
lambda: gr.Accordion.update(open=True, label="Open Accordion"),
|
|
inputs=None,
|
|
outputs=[accordion],
|
|
)
|
|
close_btn.click(
|
|
lambda: gr.Accordion.update(open=False, label="Closed Accordion"),
|
|
inputs=None,
|
|
outputs=[accordion],
|
|
)
|
|
result = await demo.process_api(
|
|
fn_index=0, inputs=[None], request=None, state={}
|
|
)
|
|
assert result["data"][0] == {
|
|
"open": True,
|
|
"label": "Open Accordion",
|
|
"__type__": "update",
|
|
}
|
|
result = await demo.process_api(
|
|
fn_index=1, inputs=[None], request=None, state={}
|
|
)
|
|
assert result["data"][0] == {
|
|
"open": False,
|
|
"label": "Closed Accordion",
|
|
"__type__": "update",
|
|
}
|
|
|
|
|
|
class TestRender:
|
|
def test_duplicate_error(self):
|
|
with pytest.raises(DuplicateBlockError):
|
|
t = gr.Textbox()
|
|
with gr.Blocks():
|
|
t.render()
|
|
gr.Number()
|
|
t.render()
|
|
|
|
with pytest.raises(DuplicateBlockError):
|
|
with gr.Blocks():
|
|
t = gr.Textbox()
|
|
t.render()
|
|
|
|
with pytest.raises(DuplicateBlockError):
|
|
io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())
|
|
with gr.Blocks():
|
|
io.render()
|
|
io.render()
|
|
|
|
with pytest.raises(DuplicateBlockError):
|
|
t = gr.Textbox()
|
|
io = gr.Interface(lambda x: x, t, gr.Textbox())
|
|
with gr.Blocks():
|
|
io.render()
|
|
t.render()
|
|
|
|
def test_no_error(self):
|
|
t = gr.Textbox()
|
|
t2 = gr.Textbox()
|
|
with gr.Blocks():
|
|
t.render()
|
|
t3 = t2.render()
|
|
assert t2 == t3
|
|
|
|
t = gr.Textbox()
|
|
io = gr.Interface(lambda x: x, t, gr.Textbox())
|
|
with gr.Blocks():
|
|
io.render()
|
|
gr.Textbox()
|
|
|
|
io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())
|
|
io2 = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())
|
|
with gr.Blocks():
|
|
io.render()
|
|
io3 = io2.render()
|
|
assert io2 == io3
|
|
|
|
|
|
class TestCancel:
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 8),
|
|
reason="Tasks dont have names in 3.7",
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_function(self, capsys):
|
|
async def long_job():
|
|
await asyncio.sleep(10)
|
|
print("HELLO FROM LONG JOB")
|
|
|
|
with gr.Blocks() as demo:
|
|
button = gr.Button(value="Start")
|
|
click = button.click(long_job, None, None)
|
|
cancel = gr.Button(value="Cancel")
|
|
cancel.click(None, None, None, cancels=[click])
|
|
|
|
cancel_fun = demo.fns[-1].fn
|
|
task = asyncio.create_task(long_job())
|
|
task.set_name("foo_0")
|
|
# If cancel_fun didn't cancel long_job the message would be printed to the console
|
|
# The test would also take 10 seconds
|
|
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
|
captured = capsys.readouterr()
|
|
assert "HELLO FROM LONG JOB" not in captured.out
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 8),
|
|
reason="Tasks dont have names in 3.7",
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_function_with_multiple_blocks(self, capsys):
|
|
async def long_job():
|
|
await asyncio.sleep(10)
|
|
print("HELLO FROM LONG JOB")
|
|
|
|
with gr.Blocks() as demo1:
|
|
textbox = gr.Textbox()
|
|
button1 = gr.Button(value="Start")
|
|
button1.click(lambda x: x, textbox, textbox)
|
|
with gr.Blocks() as demo2:
|
|
button2 = gr.Button(value="Start")
|
|
click = button2.click(long_job, None, None)
|
|
cancel = gr.Button(value="Cancel")
|
|
cancel.click(None, None, None, cancels=[click])
|
|
|
|
with gr.Blocks() as demo:
|
|
with gr.Tab("Demo 1"):
|
|
demo1.render()
|
|
with gr.Tab("Demo 2"):
|
|
demo2.render()
|
|
|
|
cancel_fun = demo.fns[-1].fn
|
|
|
|
task = asyncio.create_task(long_job())
|
|
task.set_name("foo_1")
|
|
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
|
captured = capsys.readouterr()
|
|
assert "HELLO FROM LONG JOB" not in captured.out
|
|
|
|
def test_raise_exception_if_cancelling_an_event_thats_not_queued(self):
|
|
def iteration(a):
|
|
yield a
|
|
|
|
msg = "In order to cancel an event, the queue for that event must be enabled!"
|
|
with pytest.raises(ValueError, match=msg):
|
|
gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch(
|
|
prevent_thread_lock=True
|
|
)
|
|
|
|
with pytest.raises(ValueError, match=msg):
|
|
with gr.Blocks() as demo:
|
|
button = gr.Button(value="Predict")
|
|
click = button.click(None, None, None)
|
|
cancel = gr.Button(value="Cancel")
|
|
cancel.click(None, None, None, cancels=[click])
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
with pytest.raises(ValueError, match=msg):
|
|
with gr.Blocks() as demo:
|
|
button = gr.Button(value="Predict")
|
|
click = button.click(None, None, None, queue=False)
|
|
cancel = gr.Button(value="Cancel")
|
|
cancel.click(None, None, None, cancels=[click])
|
|
demo.queue().launch(prevent_thread_lock=True)
|
|
|
|
|
|
class TestEvery:
|
|
def test_raise_exception_if_parameters_invalid(self):
|
|
with pytest.raises(
|
|
ValueError, match="Cannot run change event in a batch and every 0.5 seconds"
|
|
):
|
|
with gr.Blocks():
|
|
num = gr.Number()
|
|
num.change(
|
|
lambda s: s + 1, inputs=[num], outputs=[num], every=0.5, batch=True
|
|
)
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Parameter every must be positive or None"
|
|
):
|
|
with gr.Blocks():
|
|
num = gr.Number()
|
|
num.change(lambda s: s + 1, inputs=[num], outputs=[num], every=-0.1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_every_does_not_block_queue(self):
|
|
|
|
with gr.Blocks() as demo:
|
|
num = gr.Number(value=0)
|
|
name = gr.Textbox()
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
name.change(lambda n: n + random.random(), num, num, every=0.5)
|
|
button.click(lambda s: f"Hello, {s}!", name, greeting)
|
|
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
client = TestClient(app)
|
|
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
status = client.get("/queue/status")
|
|
# If the continuous event got pushed to the queue, the size would be nonzero
|
|
# asserting false will terminate the test
|
|
if status.json()["queue_size"] != 0:
|
|
assert False
|
|
else:
|
|
break
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generating_event_cancelled_if_ws_closed(self, capsys):
|
|
def generation():
|
|
for i in range(10):
|
|
time.sleep(0.1)
|
|
print(f"At step {i}")
|
|
yield i
|
|
return "Hello!"
|
|
|
|
with gr.Blocks() as demo:
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
button.click(generation, None, greeting)
|
|
|
|
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
n_steps = 0
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
|
elif msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
elif msg["msg"] == "process_generating":
|
|
if n_steps == 2:
|
|
# Close the websocket
|
|
break
|
|
n_steps += 1
|
|
else:
|
|
continue
|
|
await asyncio.sleep(1)
|
|
# If the generation function did not get cancelled
|
|
# it would have finished running and `At step 9` would
|
|
# have been printed
|
|
captured = capsys.readouterr()
|
|
assert "At step 9" not in captured.out
|
|
|
|
|
|
class TestProgressBar:
|
|
@pytest.mark.asyncio
|
|
async def test_progress_bar(self):
|
|
from tqdm import tqdm
|
|
|
|
with gr.Blocks() as demo:
|
|
name = gr.Textbox()
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
|
|
def greet(s, prog=gr.Progress()):
|
|
prog(0, desc="start")
|
|
time.sleep(0.25)
|
|
for _ in prog.tqdm(range(4), unit="iter"):
|
|
time.sleep(0.25)
|
|
time.sleep(1)
|
|
for i in tqdm(["a", "b", "c"], desc="alphabet"):
|
|
time.sleep(0.25)
|
|
return f"Hello, {s}!"
|
|
|
|
button.click(greet, name, greeting)
|
|
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
progress_updates = []
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
if msg["msg"] == "progress":
|
|
progress_updates.append(msg["progress_data"])
|
|
if msg["msg"] == "process_completed":
|
|
completed = True
|
|
break
|
|
print(progress_updates)
|
|
assert progress_updates == [
|
|
[
|
|
{
|
|
"index": None,
|
|
"length": None,
|
|
"unit": "steps",
|
|
"progress": 0.0,
|
|
"desc": "start",
|
|
}
|
|
],
|
|
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_bar_track_tqdm(self):
|
|
from tqdm import tqdm
|
|
|
|
with gr.Blocks() as demo:
|
|
name = gr.Textbox()
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
|
|
def greet(s, prog=gr.Progress(track_tqdm=True)):
|
|
prog(0, desc="start")
|
|
time.sleep(0.25)
|
|
for _ in prog.tqdm(range(4), unit="iter"):
|
|
time.sleep(0.25)
|
|
time.sleep(1)
|
|
for i in tqdm(["a", "b", "c"], desc="alphabet"):
|
|
time.sleep(0.25)
|
|
return f"Hello, {s}!"
|
|
|
|
button.click(greet, name, greeting)
|
|
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
|
) as ws:
|
|
completed = False
|
|
progress_updates = []
|
|
while not completed:
|
|
msg = json.loads(await ws.recv())
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
|
if msg["msg"] == "progress":
|
|
progress_updates.append(msg["progress_data"])
|
|
if msg["msg"] == "process_completed":
|
|
completed = True
|
|
break
|
|
assert progress_updates == [
|
|
[
|
|
{
|
|
"index": None,
|
|
"length": None,
|
|
"unit": "steps",
|
|
"progress": 0.0,
|
|
"desc": "start",
|
|
}
|
|
],
|
|
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
|
[
|
|
{
|
|
"index": 0,
|
|
"length": 3,
|
|
"unit": "steps",
|
|
"progress": None,
|
|
"desc": "alphabet",
|
|
}
|
|
],
|
|
[
|
|
{
|
|
"index": 1,
|
|
"length": 3,
|
|
"unit": "steps",
|
|
"progress": None,
|
|
"desc": "alphabet",
|
|
}
|
|
],
|
|
[
|
|
{
|
|
"index": 2,
|
|
"length": 3,
|
|
"unit": "steps",
|
|
"progress": None,
|
|
"desc": "alphabet",
|
|
}
|
|
],
|
|
]
|
|
|
|
|
|
class TestAddRequests:
|
|
def test_no_type_hints(self):
|
|
def moo(a, b):
|
|
return a + b
|
|
|
|
inputs = [1, 2]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs
|
|
|
|
boo = partial(moo, a=1)
|
|
inputs = [2]
|
|
inputs_ = gr.helpers.special_args(boo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs
|
|
|
|
def test_no_type_hints_with_request(self):
|
|
def moo(a: str, b: int):
|
|
return a + str(b)
|
|
|
|
inputs = ["abc", 2]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs
|
|
|
|
boo = partial(moo, a="def")
|
|
inputs = [2]
|
|
inputs_ = gr.helpers.special_args(boo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs
|
|
|
|
def test_type_hints_with_request(self):
|
|
def moo(a: str, b: gr.Request):
|
|
return a
|
|
|
|
inputs = ["abc"]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs + [request]
|
|
|
|
def moo(a: gr.Request, b, c: int):
|
|
return c
|
|
|
|
inputs = ["abc", 5]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == [request] + inputs
|
|
|
|
def test_type_hints_with_multiple_requests(self):
|
|
def moo(a: str, b: gr.Request, c: gr.Request):
|
|
return a
|
|
|
|
inputs = ["abc"]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == inputs + [request, request]
|
|
|
|
def moo(a: gr.Request, b, c: int, d: gr.Request):
|
|
return c
|
|
|
|
inputs = ["abc", 5]
|
|
request = gr.Request()
|
|
inputs_ = gr.helpers.special_args(moo, copy.deepcopy(inputs), request)[0]
|
|
assert inputs_ == [request] + inputs + [request]
|
|
|
|
|
|
def test_queue_enabled_for_fn():
|
|
with gr.Blocks() as demo:
|
|
input = gr.Textbox()
|
|
output = gr.Textbox()
|
|
number = gr.Number()
|
|
button = gr.Button()
|
|
button.click(lambda x: f"Hello, {x}!", input, output)
|
|
button.click(lambda: 42, None, number, queue=True)
|
|
|
|
assert not demo.queue_enabled_for_fn(0)
|
|
assert demo.queue_enabled_for_fn(1)
|
|
demo.queue()
|
|
assert demo.queue_enabled_for_fn(0)
|
|
assert demo.queue_enabled_for_fn(1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_queue_when_using_auth():
|
|
sleep_time = 1
|
|
|
|
async def say_hello(name):
|
|
await asyncio.sleep(sleep_time)
|
|
return f"Hello {name}!"
|
|
|
|
with gr.Blocks() as demo:
|
|
_input = gr.Textbox()
|
|
_output = gr.Textbox()
|
|
button = gr.Button()
|
|
button.click(say_hello, _input, _output)
|
|
demo.queue()
|
|
app, _, _ = demo.launch(auth=("abc", "123"), prevent_thread_lock=True)
|
|
client = TestClient(app)
|
|
|
|
resp = client.post(
|
|
f"{demo.local_url}login",
|
|
data={"username": "abc", "password": "123"},
|
|
follow_redirects=False,
|
|
)
|
|
assert resp.status_code == 200
|
|
token = resp.cookies.get("access-token")
|
|
assert token
|
|
|
|
with pytest.raises(Exception) as e:
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join",
|
|
) as ws:
|
|
await ws.recv()
|
|
assert e.type == websockets.InvalidStatusCode
|
|
|
|
async def run_ws(_loop, _time, i):
|
|
async with websockets.connect(
|
|
f"{demo.local_url.replace('http', 'ws')}queue/join",
|
|
extra_headers={"Cookie": f"access-token={token}"},
|
|
) as ws:
|
|
while True:
|
|
try:
|
|
msg = json.loads(await ws.recv())
|
|
except websockets.ConnectionClosedOK:
|
|
break
|
|
if msg["msg"] == "send_hash":
|
|
await ws.send(
|
|
json.dumps({"fn_index": 0, "session_hash": "enwpitpex2q"})
|
|
)
|
|
if msg["msg"] == "send_data":
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
"data": [str(i)],
|
|
"fn_index": 0,
|
|
"session_hash": "enwpitpex2q",
|
|
}
|
|
)
|
|
)
|
|
msg = json.loads(await ws.recv())
|
|
assert msg["msg"] == "process_starts"
|
|
if msg["msg"] == "process_completed":
|
|
assert msg["success"]
|
|
assert msg["output"]["data"] == [f"Hello {i}!"]
|
|
assert _loop.time() > _time
|
|
break
|
|
|
|
loop = asyncio.get_event_loop()
|
|
tm = loop.time()
|
|
group = asyncio.gather(
|
|
*[run_ws(loop, tm + sleep_time * (i + 1) - 1, i) for i in range(3)]
|
|
)
|
|
await group
|
|
|
|
|
|
def test_temp_file_sets_get_extended():
|
|
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
|
|
|
with gr.Blocks() as demo1:
|
|
gr.Video(str(test_file_dir / "video_sample.mp4"))
|
|
|
|
with gr.Blocks() as demo2:
|
|
gr.Audio(str(test_file_dir / "audio_sample.wav"))
|
|
|
|
with gr.Blocks() as demo3:
|
|
demo1.render()
|
|
demo2.render()
|
|
|
|
assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets
|