mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Use local demos for client tests (#3975)
* Fix tests * Fix tests * Address comments
This commit is contained in:
parent
ee78458c64
commit
f886045535
@ -1,4 +1,159 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line(
|
||||
"markers", "flaky: mark test as flaky. Failure will not cause te"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def calculator_demo():
|
||||
def calculator(num1, operation, num2):
|
||||
if operation == "add":
|
||||
return num1 + num2
|
||||
elif operation == "subtract":
|
||||
return num1 - num2
|
||||
elif operation == "multiply":
|
||||
return num1 * num2
|
||||
elif operation == "divide":
|
||||
if num2 == 0:
|
||||
raise gr.Error("Cannot divide by zero!")
|
||||
return num1 / num2
|
||||
|
||||
demo = gr.Interface(
|
||||
calculator,
|
||||
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
|
||||
"number",
|
||||
examples=[
|
||||
[5, "add", 3],
|
||||
[4, "divide", 2],
|
||||
[-4, "multiply", 2.5],
|
||||
[0, "subtract", 1.2],
|
||||
],
|
||||
)
|
||||
return demo.queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def increment_demo():
|
||||
with gr.Blocks() as demo:
|
||||
btn1 = gr.Button("Increment")
|
||||
btn2 = gr.Button("Increment")
|
||||
numb = gr.Number()
|
||||
|
||||
state = gr.State(0)
|
||||
|
||||
btn1.click(
|
||||
lambda x: (x + 1, x + 1),
|
||||
state,
|
||||
[state, numb],
|
||||
api_name="increment_with_queue",
|
||||
)
|
||||
btn2.click(
|
||||
lambda x: (x + 1, x + 1),
|
||||
state,
|
||||
[state, numb],
|
||||
queue=False,
|
||||
api_name="increment_without_queue",
|
||||
)
|
||||
return demo.queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def progress_demo():
|
||||
def my_function(x, progress=gr.Progress()):
|
||||
progress(0, desc="Starting...")
|
||||
for _ in progress.tqdm(range(20)):
|
||||
time.sleep(0.1)
|
||||
return x
|
||||
|
||||
return gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def yield_demo():
|
||||
def spell(x):
|
||||
for i in range(len(x)):
|
||||
time.sleep(0.5)
|
||||
yield x[:i]
|
||||
|
||||
return gr.Interface(spell, "textbox", "textbox").queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cancel_from_client_demo():
|
||||
def iteration():
|
||||
for i in range(20):
|
||||
print(f"i: {i}")
|
||||
yield i
|
||||
time.sleep(0.5)
|
||||
|
||||
def long_process():
|
||||
time.sleep(10)
|
||||
print("DONE!")
|
||||
return 10
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
num = gr.Number()
|
||||
|
||||
btn = gr.Button(value="Iterate")
|
||||
btn.click(iteration, None, num, api_name="iterate")
|
||||
btn2 = gr.Button(value="Long Process")
|
||||
btn2.click(long_process, None, num, api_name="long")
|
||||
|
||||
return demo.queue(concurrency_count=40)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sentiment_classification_demo():
|
||||
def classifier(text):
|
||||
return {label: random.random() for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]}
|
||||
|
||||
def sleep_for_test():
|
||||
time.sleep(10)
|
||||
return 2
|
||||
|
||||
with gr.Blocks(theme="gstaff/xkcd") as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_text = gr.Textbox(label="Input Text")
|
||||
with gr.Row():
|
||||
classify = gr.Button("Classify Sentiment")
|
||||
with gr.Column():
|
||||
label = gr.Label(label="Predicted Sentiment")
|
||||
number = gr.Number()
|
||||
btn = gr.Button("Sleep then print")
|
||||
classify.click(classifier, input_text, label, api_name="classify")
|
||||
btn.click(sleep_for_test, None, number, api_name="sleep")
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def count_generator_demo():
|
||||
def count(n):
|
||||
for i in range(int(n)):
|
||||
time.sleep(0.5)
|
||||
yield i
|
||||
|
||||
def show(n):
|
||||
return str(list(range(int(n))))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
num = gr.Number(value=10)
|
||||
with gr.Row():
|
||||
count_btn = gr.Button("Count")
|
||||
list_btn = gr.Button("List")
|
||||
with gr.Column():
|
||||
out = gr.Textbox()
|
||||
|
||||
count_btn.click(count, num, out)
|
||||
list_btn.click(show, num, out)
|
||||
|
||||
return demo.queue()
|
||||
|
@ -5,6 +5,7 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import CancelledError, TimeoutError
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -21,6 +22,23 @@ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
HF_TOKEN = "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" # Intentionally revealing this key for testing purposes
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connect(demo: gr.Blocks):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
try:
|
||||
yield Client(local_url)
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
# the tests that call .cancel() can get stuck
|
||||
# waiting for the thread to join
|
||||
if demo.enable_queue:
|
||||
demo._queue.close()
|
||||
demo.is_running = False
|
||||
demo.server.should_exit = True
|
||||
demo.server.thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestPredictionsFromSpaces:
|
||||
@pytest.mark.flaky
|
||||
def test_raise_error_invalid_state(self):
|
||||
@ -49,108 +67,102 @@ class TestPredictionsFromSpaces:
|
||||
output = client.predict("abc", api_name="/predict")
|
||||
assert output == "abc"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_state(self):
|
||||
client = Client("gradio-tests/increment")
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 2
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 3
|
||||
client.reset_session()
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 2
|
||||
client.reset_session()
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 2
|
||||
def test_state(self, increment_demo):
|
||||
with connect(increment_demo) as client:
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 2
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 3
|
||||
client.reset_session()
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 2
|
||||
client.reset_session()
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 1
|
||||
output = client.predict(api_name="/increment_with_queue")
|
||||
assert output == 2
|
||||
|
||||
def test_job_status(self, calculator_demo):
|
||||
with connect(calculator_demo) as client:
|
||||
statuses = []
|
||||
job = client.submit(5, "add", 4)
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
statuses.append(job.status())
|
||||
|
||||
assert statuses
|
||||
# Messages are sorted by time
|
||||
assert sorted([s.time for s in statuses if s]) == [
|
||||
s.time for s in statuses if s
|
||||
]
|
||||
assert sorted([s.code for s in statuses if s]) == [
|
||||
s.code for s in statuses if s
|
||||
]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_job_status(self):
|
||||
statuses = []
|
||||
client = Client(src="gradio/calculator")
|
||||
job = client.submit(5, "add", 4)
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
def test_job_status_queue_disabled(self, sentiment_classification_demo):
|
||||
with connect(sentiment_classification_demo) as client:
|
||||
statuses = []
|
||||
job = client.submit("I love the gradio python client", api_name="/classify")
|
||||
while not job.done():
|
||||
time.sleep(0.02)
|
||||
statuses.append(job.status())
|
||||
statuses.append(job.status())
|
||||
|
||||
assert statuses
|
||||
# Messages are sorted by time
|
||||
assert sorted([s.time for s in statuses if s]) == [
|
||||
s.time for s in statuses if s
|
||||
]
|
||||
assert sorted([s.code for s in statuses if s]) == [
|
||||
s.code for s in statuses if s
|
||||
]
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
assert not any(s.progress_data for s in statuses)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_job_status_queue_disabled(self):
|
||||
statuses = []
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.submit("I love the gradio python client", api_name="/classify")
|
||||
while not job.done():
|
||||
time.sleep(0.02)
|
||||
statuses.append(job.status())
|
||||
statuses.append(job.status())
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
assert not any(s.progress_data for s in statuses)
|
||||
def test_intermediate_outputs(self, count_generator_demo):
|
||||
with connect(count_generator_demo) as client:
|
||||
job = client.submit(3, fn_index=0)
|
||||
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
|
||||
assert job.outputs() == [str(i) for i in range(3)]
|
||||
|
||||
outputs = []
|
||||
for o in client.submit(3, fn_index=0):
|
||||
outputs.append(o)
|
||||
assert outputs == [str(i) for i in range(3)]
|
||||
|
||||
def test_break_in_loop_if_error(self, calculator_demo):
|
||||
with connect(calculator_demo) as client:
|
||||
job = client.submit("foo", "add", 4, fn_index=0)
|
||||
output = [o for o in job]
|
||||
assert output == []
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_intermediate_outputs(
|
||||
self,
|
||||
):
|
||||
client = Client(src="gradio/count_generator")
|
||||
job = client.submit(3, fn_index=0)
|
||||
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
|
||||
assert job.outputs() == [str(i) for i in range(3)]
|
||||
|
||||
outputs = []
|
||||
for o in client.submit(3, fn_index=0):
|
||||
outputs.append(o)
|
||||
assert outputs == [str(i) for i in range(3)]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_break_in_loop_if_error(self):
|
||||
calculator = Client(src="gradio/calculator")
|
||||
job = calculator.submit("foo", "add", 4, fn_index=0)
|
||||
output = [o for o in job]
|
||||
assert output == []
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_timeout(self):
|
||||
def test_timeout(self, sentiment_classification_demo):
|
||||
with pytest.raises(TimeoutError):
|
||||
client = Client(src="gradio-tests/sleep")
|
||||
job = client.submit("ping", api_name="/predict")
|
||||
job.result(timeout=0.05)
|
||||
with connect(sentiment_classification_demo.queue()) as client:
|
||||
job = client.submit(api_name="/sleep")
|
||||
job.result(timeout=0.05)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_timeout_no_queue(self):
|
||||
def test_timeout_no_queue(self, sentiment_classification_demo):
|
||||
with pytest.raises(TimeoutError):
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.submit(api_name="/sleep")
|
||||
job.result(timeout=0.1)
|
||||
with connect(sentiment_classification_demo) as client:
|
||||
job = client.submit(api_name="/sleep")
|
||||
job.result(timeout=0.1)
|
||||
|
||||
def test_raises_exception(self, calculator_demo):
|
||||
with pytest.raises(Exception):
|
||||
with connect(calculator_demo) as client:
|
||||
job = client.submit("foo", "add", 9, fn_index=0)
|
||||
job.result()
|
||||
|
||||
def test_raises_exception_no_queue(self, sentiment_classification_demo):
|
||||
with pytest.raises(Exception):
|
||||
with connect(sentiment_classification_demo) as client:
|
||||
job = client.submit([5], api_name="/sleep")
|
||||
job.result()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_raises_exception(self):
|
||||
with pytest.raises(Exception):
|
||||
client = Client(src="freddyaboulton/calculator")
|
||||
job = client.submit("foo", "add", 9, fn_index=0)
|
||||
job.result()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_raises_exception_no_queue(self):
|
||||
with pytest.raises(Exception):
|
||||
client = Client(src="freddyaboulton/sentiment-classification")
|
||||
job = client.submit([5], api_name="/sleep")
|
||||
job.result()
|
||||
|
||||
def test_job_output_video(self):
|
||||
client = Client(src="gradio/video_component")
|
||||
job = client.submit(
|
||||
@ -159,20 +171,9 @@ class TestPredictionsFromSpaces:
|
||||
)
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
|
||||
def test_progress_updates(self):
|
||||
def my_function(x, progress=gr.Progress()):
|
||||
progress(0, desc="Starting...")
|
||||
for i in progress.tqdm(range(20)):
|
||||
time.sleep(0.1)
|
||||
return x
|
||||
def test_progress_updates(self, progress_demo):
|
||||
|
||||
demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue(
|
||||
concurrency_count=20
|
||||
)
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
|
||||
try:
|
||||
client = Client(src=local_url)
|
||||
with connect(progress_demo) as client:
|
||||
job = client.submit("hello", api_name="/predict")
|
||||
statuses = []
|
||||
while not job.done():
|
||||
@ -190,63 +191,60 @@ class TestPredictionsFromSpaces:
|
||||
)
|
||||
count += unit in all_progress_data
|
||||
assert count
|
||||
finally:
|
||||
demo.close()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_cancel_from_client_queued(self):
|
||||
client = Client(src="gradio-tests/test-cancel-from-client")
|
||||
start = time.time()
|
||||
job = client.submit(api_name="/long")
|
||||
while not job.done():
|
||||
if job.status().code == Status.STARTING:
|
||||
job.cancel()
|
||||
break
|
||||
with pytest.raises(CancelledError):
|
||||
job.result()
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
assert job.status().code == Status.CANCELLED
|
||||
|
||||
job = client.submit(api_name="/iterate")
|
||||
iteration_count = 0
|
||||
while not job.done():
|
||||
if job.status().code == Status.ITERATING:
|
||||
iteration_count += 1
|
||||
if iteration_count == 3:
|
||||
def test_cancel_from_client_queued(self, cancel_from_client_demo):
|
||||
with connect(cancel_from_client_demo) as client:
|
||||
start = time.time()
|
||||
job = client.submit(api_name="/long")
|
||||
while not job.done():
|
||||
if job.status().code == Status.STARTING:
|
||||
job.cancel()
|
||||
break
|
||||
time.sleep(0.5)
|
||||
# Result for iterative jobs is always the first result
|
||||
assert job.result() == 0
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
with pytest.raises(CancelledError):
|
||||
job.result()
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
assert job.status().code == Status.CANCELLED
|
||||
|
||||
# Test that we did not iterate all the way to the end
|
||||
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
|
||||
assert job.status().code == Status.CANCELLED
|
||||
job = client.submit(api_name="/iterate")
|
||||
iteration_count = 0
|
||||
while not job.done():
|
||||
if job.status().code == Status.ITERATING:
|
||||
iteration_count += 1
|
||||
if iteration_count == 3:
|
||||
job.cancel()
|
||||
break
|
||||
time.sleep(0.5)
|
||||
# Result for iterative jobs is always the first result
|
||||
assert job.result() == 0
|
||||
# The whole prediction takes 10 seconds to run
|
||||
# and does not iterate. So this tests that we can cancel
|
||||
# halfway through a prediction
|
||||
assert time.time() - start < 10
|
||||
|
||||
# Test that we did not iterate all the way to the end
|
||||
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
|
||||
assert job.status().code == Status.CANCELLED
|
||||
|
||||
def test_cancel_subsequent_jobs_state_reset(self, yield_demo):
|
||||
with connect(yield_demo) as client:
|
||||
job1 = client.submit("abcdefefadsadfs")
|
||||
time.sleep(3)
|
||||
job1.cancel()
|
||||
|
||||
assert len(job1.outputs()) < len("abcdefefadsadfs")
|
||||
assert job1.status().code == Status.CANCELLED
|
||||
|
||||
job2 = client.submit("abcd")
|
||||
while not job2.done():
|
||||
time.sleep(0.1)
|
||||
# Ran all iterations from scratch
|
||||
assert job2.status().code == Status.FINISHED
|
||||
assert len(job2.outputs()) == 5
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_cancel_subsequent_jobs_state_reset(self):
|
||||
client = Client("abidlabs/test-yield")
|
||||
job1 = client.submit("abcdefefadsadfs")
|
||||
time.sleep(5)
|
||||
job1.cancel()
|
||||
|
||||
assert len(job1.outputs()) < len("abcdefefadsadfs")
|
||||
assert job1.status().code == Status.CANCELLED
|
||||
|
||||
job2 = client.submit("abcd")
|
||||
while not job2.done():
|
||||
time.sleep(0.1)
|
||||
# Ran all iterations from scratch
|
||||
assert job2.status().code == Status.FINISHED
|
||||
assert len(job2.outputs()) == 5
|
||||
|
||||
def test_upload_file_private_space(self):
|
||||
|
||||
client = Client(
|
||||
@ -289,6 +287,7 @@ class TestPredictionsFromSpaces:
|
||||
assert open(output[1]).read() == "File2"
|
||||
upload.assert_called_once()
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_upload_file_upload_route_does_not_exist(self):
|
||||
client = Client(
|
||||
src="gradio-tests/not-actually-private-file-upload-old-version",
|
||||
@ -587,11 +586,14 @@ class TestAPIInfo:
|
||||
}
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_serializable_in_mapping(self):
|
||||
client = Client("freddyaboulton/calculator")
|
||||
assert all(
|
||||
[c.__class__ == SimpleSerializable for c in client.endpoints[0].serializers]
|
||||
)
|
||||
def test_serializable_in_mapping(self, calculator_demo):
|
||||
with connect(calculator_demo) as client:
|
||||
assert all(
|
||||
[
|
||||
isinstance(c, SimpleSerializable)
|
||||
for c in client.endpoints[0].serializers
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_private_space(self):
|
||||
|
Loading…
Reference in New Issue
Block a user