Use local demos for client tests (#3975)

* Fix tests

* Fix tests

* Address comments
This commit is contained in:
Freddy Boulton 2023-04-26 11:11:28 -04:00 committed by GitHub
parent ee78458c64
commit f886045535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 315 additions and 158 deletions

View File

@ -1,4 +1,159 @@
import random
import time
import gradio as gr
import pytest
def pytest_configure(config):
config.addinivalue_line(
"markers", "flaky: mark test as flaky. Failure will not cause te"
)
@pytest.fixture
def calculator_demo():
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
if num2 == 0:
raise gr.Error("Cannot divide by zero!")
return num1 / num2
demo = gr.Interface(
calculator,
["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
examples=[
[5, "add", 3],
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
],
)
return demo.queue()
@pytest.fixture
def increment_demo():
with gr.Blocks() as demo:
btn1 = gr.Button("Increment")
btn2 = gr.Button("Increment")
numb = gr.Number()
state = gr.State(0)
btn1.click(
lambda x: (x + 1, x + 1),
state,
[state, numb],
api_name="increment_with_queue",
)
btn2.click(
lambda x: (x + 1, x + 1),
state,
[state, numb],
queue=False,
api_name="increment_without_queue",
)
return demo.queue()
@pytest.fixture
def progress_demo():
def my_function(x, progress=gr.Progress()):
progress(0, desc="Starting...")
for _ in progress.tqdm(range(20)):
time.sleep(0.1)
return x
return gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue()
@pytest.fixture
def yield_demo():
def spell(x):
for i in range(len(x)):
time.sleep(0.5)
yield x[:i]
return gr.Interface(spell, "textbox", "textbox").queue()
@pytest.fixture
def cancel_from_client_demo():
def iteration():
for i in range(20):
print(f"i: {i}")
yield i
time.sleep(0.5)
def long_process():
time.sleep(10)
print("DONE!")
return 10
with gr.Blocks() as demo:
num = gr.Number()
btn = gr.Button(value="Iterate")
btn.click(iteration, None, num, api_name="iterate")
btn2 = gr.Button(value="Long Process")
btn2.click(long_process, None, num, api_name="long")
return demo.queue(concurrency_count=40)
@pytest.fixture
def sentiment_classification_demo():
def classifier(text):
return {label: random.random() for label in ["POSITIVE", "NEGATIVE", "NEUTRAL"]}
def sleep_for_test():
time.sleep(10)
return 2
with gr.Blocks(theme="gstaff/xkcd") as demo:
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Text")
with gr.Row():
classify = gr.Button("Classify Sentiment")
with gr.Column():
label = gr.Label(label="Predicted Sentiment")
number = gr.Number()
btn = gr.Button("Sleep then print")
classify.click(classifier, input_text, label, api_name="classify")
btn.click(sleep_for_test, None, number, api_name="sleep")
return demo
@pytest.fixture
def count_generator_demo():
def count(n):
for i in range(int(n)):
time.sleep(0.5)
yield i
def show(n):
return str(list(range(int(n))))
with gr.Blocks() as demo:
with gr.Column():
num = gr.Number(value=10)
with gr.Row():
count_btn = gr.Button("Count")
list_btn = gr.Button("List")
with gr.Column():
out = gr.Textbox()
count_btn.click(count, num, out)
list_btn.click(show, num, out)
return demo.queue()

View File

@ -5,6 +5,7 @@ import tempfile
import time
import uuid
from concurrent.futures import CancelledError, TimeoutError
from contextlib import contextmanager
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
@ -21,6 +22,23 @@ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_TOKEN = "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" # Intentionally revealing this key for testing purposes
@contextmanager
def connect(demo: gr.Blocks):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
try:
yield Client(local_url)
finally:
# A more verbose version of .close()
# because we should set a timeout
# the tests that call .cancel() can get stuck
# waiting for the thread to join
if demo.enable_queue:
demo._queue.close()
demo.is_running = False
demo.server.should_exit = True
demo.server.thread.join(timeout=1)
class TestPredictionsFromSpaces:
@pytest.mark.flaky
def test_raise_error_invalid_state(self):
@ -49,108 +67,102 @@ class TestPredictionsFromSpaces:
output = client.predict("abc", api_name="/predict")
assert output == "abc"
@pytest.mark.flaky
def test_state(self):
client = Client("gradio-tests/increment")
output = client.predict(api_name="/increment_without_queue")
assert output == 1
output = client.predict(api_name="/increment_without_queue")
assert output == 2
output = client.predict(api_name="/increment_without_queue")
assert output == 3
client.reset_session()
output = client.predict(api_name="/increment_without_queue")
assert output == 1
output = client.predict(api_name="/increment_with_queue")
assert output == 2
client.reset_session()
output = client.predict(api_name="/increment_with_queue")
assert output == 1
output = client.predict(api_name="/increment_with_queue")
assert output == 2
def test_state(self, increment_demo):
with connect(increment_demo) as client:
output = client.predict(api_name="/increment_without_queue")
assert output == 1
output = client.predict(api_name="/increment_without_queue")
assert output == 2
output = client.predict(api_name="/increment_without_queue")
assert output == 3
client.reset_session()
output = client.predict(api_name="/increment_without_queue")
assert output == 1
output = client.predict(api_name="/increment_with_queue")
assert output == 2
client.reset_session()
output = client.predict(api_name="/increment_with_queue")
assert output == 1
output = client.predict(api_name="/increment_with_queue")
assert output == 2
def test_job_status(self, calculator_demo):
with connect(calculator_demo) as client:
statuses = []
job = client.submit(5, "add", 4)
while not job.done():
time.sleep(0.1)
statuses.append(job.status())
assert statuses
# Messages are sorted by time
assert sorted([s.time for s in statuses if s]) == [
s.time for s in statuses if s
]
assert sorted([s.code for s in statuses if s]) == [
s.code for s in statuses if s
]
@pytest.mark.flaky
def test_job_status(self):
statuses = []
client = Client(src="gradio/calculator")
job = client.submit(5, "add", 4)
while not job.done():
time.sleep(0.1)
def test_job_status_queue_disabled(self, sentiment_classification_demo):
with connect(sentiment_classification_demo) as client:
statuses = []
job = client.submit("I love the gradio python client", api_name="/classify")
while not job.done():
time.sleep(0.02)
statuses.append(job.status())
statuses.append(job.status())
assert statuses
# Messages are sorted by time
assert sorted([s.time for s in statuses if s]) == [
s.time for s in statuses if s
]
assert sorted([s.code for s in statuses if s]) == [
s.code for s in statuses if s
]
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
assert not any(s.progress_data for s in statuses)
@pytest.mark.flaky
def test_job_status_queue_disabled(self):
statuses = []
client = Client(src="freddyaboulton/sentiment-classification")
job = client.submit("I love the gradio python client", api_name="/classify")
while not job.done():
time.sleep(0.02)
statuses.append(job.status())
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
assert not any(s.progress_data for s in statuses)
def test_intermediate_outputs(self, count_generator_demo):
with connect(count_generator_demo) as client:
job = client.submit(3, fn_index=0)
while not job.done():
time.sleep(0.1)
assert job.outputs() == [str(i) for i in range(3)]
outputs = []
for o in client.submit(3, fn_index=0):
outputs.append(o)
assert outputs == [str(i) for i in range(3)]
def test_break_in_loop_if_error(self, calculator_demo):
with connect(calculator_demo) as client:
job = client.submit("foo", "add", 4, fn_index=0)
output = [o for o in job]
assert output == []
@pytest.mark.flaky
def test_intermediate_outputs(
self,
):
client = Client(src="gradio/count_generator")
job = client.submit(3, fn_index=0)
while not job.done():
time.sleep(0.1)
assert job.outputs() == [str(i) for i in range(3)]
outputs = []
for o in client.submit(3, fn_index=0):
outputs.append(o)
assert outputs == [str(i) for i in range(3)]
@pytest.mark.flaky
def test_break_in_loop_if_error(self):
calculator = Client(src="gradio/calculator")
job = calculator.submit("foo", "add", 4, fn_index=0)
output = [o for o in job]
assert output == []
@pytest.mark.flaky
def test_timeout(self):
def test_timeout(self, sentiment_classification_demo):
with pytest.raises(TimeoutError):
client = Client(src="gradio-tests/sleep")
job = client.submit("ping", api_name="/predict")
job.result(timeout=0.05)
with connect(sentiment_classification_demo.queue()) as client:
job = client.submit(api_name="/sleep")
job.result(timeout=0.05)
@pytest.mark.flaky
def test_timeout_no_queue(self):
def test_timeout_no_queue(self, sentiment_classification_demo):
with pytest.raises(TimeoutError):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.submit(api_name="/sleep")
job.result(timeout=0.1)
with connect(sentiment_classification_demo) as client:
job = client.submit(api_name="/sleep")
job.result(timeout=0.1)
def test_raises_exception(self, calculator_demo):
with pytest.raises(Exception):
with connect(calculator_demo) as client:
job = client.submit("foo", "add", 9, fn_index=0)
job.result()
def test_raises_exception_no_queue(self, sentiment_classification_demo):
with pytest.raises(Exception):
with connect(sentiment_classification_demo) as client:
job = client.submit([5], api_name="/sleep")
job.result()
@pytest.mark.flaky
def test_raises_exception(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/calculator")
job = client.submit("foo", "add", 9, fn_index=0)
job.result()
@pytest.mark.flaky
def test_raises_exception_no_queue(self):
with pytest.raises(Exception):
client = Client(src="freddyaboulton/sentiment-classification")
job = client.submit([5], api_name="/sleep")
job.result()
def test_job_output_video(self):
client = Client(src="gradio/video_component")
job = client.submit(
@ -159,20 +171,9 @@ class TestPredictionsFromSpaces:
)
assert pathlib.Path(job.result()).exists()
def test_progress_updates(self):
def my_function(x, progress=gr.Progress()):
progress(0, desc="Starting...")
for i in progress.tqdm(range(20)):
time.sleep(0.1)
return x
def test_progress_updates(self, progress_demo):
demo = gr.Interface(my_function, gr.Textbox(), gr.Textbox()).queue(
concurrency_count=20
)
_, local_url, _ = demo.launch(prevent_thread_lock=True)
try:
client = Client(src=local_url)
with connect(progress_demo) as client:
job = client.submit("hello", api_name="/predict")
statuses = []
while not job.done():
@ -190,63 +191,60 @@ class TestPredictionsFromSpaces:
)
count += unit in all_progress_data
assert count
finally:
demo.close()
@pytest.mark.flaky
def test_cancel_from_client_queued(self):
client = Client(src="gradio-tests/test-cancel-from-client")
start = time.time()
job = client.submit(api_name="/long")
while not job.done():
if job.status().code == Status.STARTING:
job.cancel()
break
with pytest.raises(CancelledError):
job.result()
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
assert job.status().code == Status.CANCELLED
job = client.submit(api_name="/iterate")
iteration_count = 0
while not job.done():
if job.status().code == Status.ITERATING:
iteration_count += 1
if iteration_count == 3:
def test_cancel_from_client_queued(self, cancel_from_client_demo):
with connect(cancel_from_client_demo) as client:
start = time.time()
job = client.submit(api_name="/long")
while not job.done():
if job.status().code == Status.STARTING:
job.cancel()
break
time.sleep(0.5)
# Result for iterative jobs is always the first result
assert job.result() == 0
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
with pytest.raises(CancelledError):
job.result()
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
assert job.status().code == Status.CANCELLED
# Test that we did not iterate all the way to the end
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
assert job.status().code == Status.CANCELLED
job = client.submit(api_name="/iterate")
iteration_count = 0
while not job.done():
if job.status().code == Status.ITERATING:
iteration_count += 1
if iteration_count == 3:
job.cancel()
break
time.sleep(0.5)
# Result for iterative jobs is always the first result
assert job.result() == 0
# The whole prediction takes 10 seconds to run
# and does not iterate. So this tests that we can cancel
# halfway through a prediction
assert time.time() - start < 10
# Test that we did not iterate all the way to the end
assert all(o in [0, 1, 2, 3, 4, 5] for o in job.outputs())
assert job.status().code == Status.CANCELLED
def test_cancel_subsequent_jobs_state_reset(self, yield_demo):
with connect(yield_demo) as client:
job1 = client.submit("abcdefefadsadfs")
time.sleep(3)
job1.cancel()
assert len(job1.outputs()) < len("abcdefefadsadfs")
assert job1.status().code == Status.CANCELLED
job2 = client.submit("abcd")
while not job2.done():
time.sleep(0.1)
# Ran all iterations from scratch
assert job2.status().code == Status.FINISHED
assert len(job2.outputs()) == 5
@pytest.mark.flaky
def test_cancel_subsequent_jobs_state_reset(self):
client = Client("abidlabs/test-yield")
job1 = client.submit("abcdefefadsadfs")
time.sleep(5)
job1.cancel()
assert len(job1.outputs()) < len("abcdefefadsadfs")
assert job1.status().code == Status.CANCELLED
job2 = client.submit("abcd")
while not job2.done():
time.sleep(0.1)
# Ran all iterations from scratch
assert job2.status().code == Status.FINISHED
assert len(job2.outputs()) == 5
def test_upload_file_private_space(self):
client = Client(
@ -289,6 +287,7 @@ class TestPredictionsFromSpaces:
assert open(output[1]).read() == "File2"
upload.assert_called_once()
@pytest.mark.flaky
def test_upload_file_upload_route_does_not_exist(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload-old-version",
@ -587,11 +586,14 @@ class TestAPIInfo:
}
@pytest.mark.flaky
def test_serializable_in_mapping(self):
client = Client("freddyaboulton/calculator")
assert all(
[c.__class__ == SimpleSerializable for c in client.endpoints[0].serializers]
)
def test_serializable_in_mapping(self, calculator_demo):
with connect(calculator_demo) as client:
assert all(
[
isinstance(c, SimpleSerializable)
for c in client.endpoints[0].serializers
]
)
@pytest.mark.flaky
def test_private_space(self):