Refactor some python tests (#4834)

* Refactor

* Remove tests

* Fix image file

* trigger ci

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-07-10 12:40:46 -05:00 committed by GitHub
parent cd551f70dd
commit 6436e4ea5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 111 deletions

View File

@ -11,6 +11,7 @@ import time
import unittest.mock as mock
import uuid
import warnings
from concurrent.futures import wait
from contextlib import contextmanager
from functools import partial
from pathlib import Path
@ -257,43 +258,17 @@ class TestBlocksMethods:
assert block.css == css
@pytest.mark.asyncio
async def test_restart_after_close(self):
async def test_restart_after_close(self, connect):
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox()).queue()
io.launch(prevent_thread_lock=True)
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "process_completed":
completed = True
assert msg["output"]["data"][0] == "freddy"
io.close()
io.launch(prevent_thread_lock=True)
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["Victor"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "process_completed":
completed = True
assert msg["output"]["data"][0] == "Victor"
with connect(io) as client:
assert client.predict("freddy", api_name="/predict") == "freddy"
# connect launches the interface which is what we need to test
with connect(io) as client:
assert client.predict("Victor", api_name="/predict") == "Victor"
@pytest.mark.asyncio
async def test_async_generators(self):
async def test_async_generators(self, connect):
async def async_iteration(count: int):
for i in range(count):
yield i
@ -317,36 +292,15 @@ class TestBlocksMethods:
iterate = gr.Button(value="Iterate")
iterate.click(iteration, num2, o2)
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)
demo.queue(concurrency_count=2)
def _get_ws_pred(data, fn_index):
async def wrapped():
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(
json.dumps({"data": [data], "fn_index": fn_index})
)
if msg["msg"] == "send_hash":
await ws.send(
json.dumps(
{"fn_index": fn_index, "session_hash": "shdce"}
)
)
if msg["msg"] == "process_completed":
completed = True
assert msg["output"]["data"][0] == data - 1
with connect(demo) as client:
job_1 = client.submit(3, fn_index=0)
job_2 = client.submit(4, fn_index=1)
wait([job_1, job_2])
return wrapped
try:
await asyncio.gather(_get_ws_pred(3, 0)(), _get_ws_pred(4, 1)())
finally:
demo.close()
assert job_1.outputs()[-1] == 2
assert job_2.outputs()[-1] == 3
def test_async_generators_interface(self, connect):
async def async_iteration(count: int):
@ -1373,7 +1327,7 @@ class TestCancel:
demo.queue().launch(prevent_thread_lock=True)
@pytest.mark.asyncio
async def test_cancel_button_for_interfaces(self):
async def test_cancel_button_for_interfaces(self, connect):
def generate(x):
for i in range(4):
yield i
@ -1385,23 +1339,10 @@ class TestCancel:
)
assert not io.blocks[stop_btn_id].visible
io.launch(prevent_thread_lock=True)
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 1}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "process_completed":
assert msg["output"]["data"] == ["3"]
completed = True
io.close()
with connect(io) as client:
job = client.submit("freddy", fn_index=1)
wait([job])
assert job.outputs()[-1] == "3"
class TestEvery:
@ -1453,7 +1394,7 @@ class TestEvery:
break
@pytest.mark.asyncio
async def test_generating_event_cancelled_if_ws_closed(self, capsys):
async def test_generating_event_cancelled_if_ws_closed(self, connect, capsys):
def generation():
for i in range(10):
time.sleep(0.1)
@ -1466,26 +1407,12 @@ class TestEvery:
button = gr.Button(value="Greet")
button.click(generation, None, greeting)
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
with connect(demo) as client:
job = client.submit(0, fn_index=0)
for i, _ in enumerate(job):
if i == 2:
job.cancel()
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
n_steps = 0
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
elif msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
elif msg["msg"] == "process_generating":
if n_steps == 2:
# Close the websocket
break
n_steps += 1
else:
continue
await asyncio.sleep(1)
# If the generation function did not get cancelled
# it would have finished running and `At step 9` would

View File

@ -709,7 +709,7 @@ class TestImage:
@pytest.mark.flaky
def test_serialize_url(self):
img = "https://gradio.app/assets/img/header-image.jpg"
img = "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg"
expected = client_utils.encode_url_or_file_to_base64(img)
assert gr.Image().serialize(img) == expected

View File

@ -183,7 +183,7 @@ class TestLoadInterface:
def test_english_to_spanish(self):
with pytest.warns(UserWarning):
io = gr.load("spaces/abidlabs/english_to_spanish", title="hi")
io = gr.load("spaces/gradio-tests/english_to_spanish", title="hi")
assert isinstance(io.input_components[0], gr.Textbox)
assert isinstance(io.output_components[0], gr.Textbox)
@ -212,7 +212,7 @@ class TestLoadInterface:
pass
def test_numerical_to_label_space(self):
io = gr.load("spaces/abidlabs/titanic-survival")
io = gr.load("spaces/gradio-tests/titanic-survival")
try:
assert io.theme.name == "soft"
with open(io("male", 77, 10)) as f:
@ -366,11 +366,11 @@ class TestLoadInterfaceWithExamples:
def test_interface_with_examples(self):
# This demo has the "fake_event" correctly removed
demo = gr.load("spaces/freddyaboulton/calculator")
demo = gr.load("spaces/gradio-tests/test-calculator-1")
assert demo(2, "add", 3) == 5
# This demo still has the "fake_event". both should work
demo = gr.load("spaces/abidlabs/test-calculator-2")
demo = gr.load("spaces/gradio-tests/test-calculator-2")
assert demo(2, "add", 4) == 6
@ -441,13 +441,13 @@ def check_dataset(config, readme_examples):
def test_load_blocks_with_default_values():
io = gr.load("spaces/abidlabs/min-dalle")
io = gr.load("spaces/gradio-tests/min-dalle")
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)
io = gr.load("spaces/abidlabs/min-dalle-later")
io = gr.load("spaces/gradio-tests/min-dalle-later")
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)
io = gr.load("spaces/freddyaboulton/dataframe_load")
io = gr.load("spaces/gradio-tests/dataframe_load")
assert io.get_config_file()["components"][0]["props"]["value"] == {
"headers": ["a", "b"],
"data": [[1, 4], [2, 5], [3, 6]],

View File

@ -24,8 +24,8 @@ class TestSeries:
@pytest.mark.flaky
def test_with_external(self):
io1 = gr.load("spaces/abidlabs/image-identity")
io2 = gr.load("spaces/abidlabs/image-classifier")
io1 = gr.load("spaces/gradio-tests/image-identity")
io2 = gr.load("spaces/gradio-tests/image-classifier")
series = mix.Series(io1, io2)
try:
with open(series("gradio/test_data/lion.jpg")) as f:
@ -55,8 +55,8 @@ class TestParallel:
@pytest.mark.flaky
def test_with_external(self):
io1 = gr.load("spaces/abidlabs/english_to_spanish")
io2 = gr.load("spaces/abidlabs/english2german")
io1 = gr.load("spaces/gradio-tests/english_to_spanish")
io2 = gr.load("spaces/gradio-tests/english2german")
parallel = mix.Parallel(io1, io2)
try:
hello_es, hello_de = parallel("Hello")