mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Refactor some python tests (#4834)
* Refactor * Remove tests * Fix image file * trigger ci --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
cd551f70dd
commit
6436e4ea5b
@ -11,6 +11,7 @@ import time
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import wait
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -257,43 +258,17 @@ class TestBlocksMethods:
|
||||
assert block.css == css
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_after_close(self):
|
||||
async def test_restart_after_close(self, connect):
|
||||
io = gr.Interface(lambda s: s, gr.Textbox(), gr.Textbox()).queue()
|
||||
io.launch(prevent_thread_lock=True)
|
||||
|
||||
async with websockets.connect(
|
||||
f"{io.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 0}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
assert msg["output"]["data"][0] == "freddy"
|
||||
|
||||
io.close()
|
||||
io.launch(prevent_thread_lock=True)
|
||||
|
||||
async with websockets.connect(
|
||||
f"{io.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["Victor"], "fn_index": 0}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
assert msg["output"]["data"][0] == "Victor"
|
||||
with connect(io) as client:
|
||||
assert client.predict("freddy", api_name="/predict") == "freddy"
|
||||
# connect launches the interface which is what we need to test
|
||||
with connect(io) as client:
|
||||
assert client.predict("Victor", api_name="/predict") == "Victor"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generators(self):
|
||||
async def test_async_generators(self, connect):
|
||||
async def async_iteration(count: int):
|
||||
for i in range(count):
|
||||
yield i
|
||||
@ -317,36 +292,15 @@ class TestBlocksMethods:
|
||||
iterate = gr.Button(value="Iterate")
|
||||
iterate.click(iteration, num2, o2)
|
||||
|
||||
demo.queue(concurrency_count=2).launch(prevent_thread_lock=True)
|
||||
demo.queue(concurrency_count=2)
|
||||
|
||||
def _get_ws_pred(data, fn_index):
|
||||
async def wrapped():
|
||||
async with websockets.connect(
|
||||
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(
|
||||
json.dumps({"data": [data], "fn_index": fn_index})
|
||||
)
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{"fn_index": fn_index, "session_hash": "shdce"}
|
||||
)
|
||||
)
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
assert msg["output"]["data"][0] == data - 1
|
||||
with connect(demo) as client:
|
||||
job_1 = client.submit(3, fn_index=0)
|
||||
job_2 = client.submit(4, fn_index=1)
|
||||
wait([job_1, job_2])
|
||||
|
||||
return wrapped
|
||||
|
||||
try:
|
||||
await asyncio.gather(_get_ws_pred(3, 0)(), _get_ws_pred(4, 1)())
|
||||
finally:
|
||||
demo.close()
|
||||
assert job_1.outputs()[-1] == 2
|
||||
assert job_2.outputs()[-1] == 3
|
||||
|
||||
def test_async_generators_interface(self, connect):
|
||||
async def async_iteration(count: int):
|
||||
@ -1373,7 +1327,7 @@ class TestCancel:
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_button_for_interfaces(self):
|
||||
async def test_cancel_button_for_interfaces(self, connect):
|
||||
def generate(x):
|
||||
for i in range(4):
|
||||
yield i
|
||||
@ -1385,23 +1339,10 @@ class TestCancel:
|
||||
)
|
||||
assert not io.blocks[stop_btn_id].visible
|
||||
|
||||
io.launch(prevent_thread_lock=True)
|
||||
|
||||
async with websockets.connect(
|
||||
f"{io.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["freddy"], "fn_index": 1}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "process_completed":
|
||||
assert msg["output"]["data"] == ["3"]
|
||||
completed = True
|
||||
|
||||
io.close()
|
||||
with connect(io) as client:
|
||||
job = client.submit("freddy", fn_index=1)
|
||||
wait([job])
|
||||
assert job.outputs()[-1] == "3"
|
||||
|
||||
|
||||
class TestEvery:
|
||||
@ -1453,7 +1394,7 @@ class TestEvery:
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generating_event_cancelled_if_ws_closed(self, capsys):
|
||||
async def test_generating_event_cancelled_if_ws_closed(self, connect, capsys):
|
||||
def generation():
|
||||
for i in range(10):
|
||||
time.sleep(0.1)
|
||||
@ -1466,26 +1407,12 @@ class TestEvery:
|
||||
button = gr.Button(value="Greet")
|
||||
button.click(generation, None, greeting)
|
||||
|
||||
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
||||
with connect(demo) as client:
|
||||
job = client.submit(0, fn_index=0)
|
||||
for i, _ in enumerate(job):
|
||||
if i == 2:
|
||||
job.cancel()
|
||||
|
||||
async with websockets.connect(
|
||||
f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
n_steps = 0
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
||||
elif msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
elif msg["msg"] == "process_generating":
|
||||
if n_steps == 2:
|
||||
# Close the websocket
|
||||
break
|
||||
n_steps += 1
|
||||
else:
|
||||
continue
|
||||
await asyncio.sleep(1)
|
||||
# If the generation function did not get cancelled
|
||||
# it would have finished running and `At step 9` would
|
||||
|
@ -709,7 +709,7 @@ class TestImage:
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_serialize_url(self):
|
||||
img = "https://gradio.app/assets/img/header-image.jpg"
|
||||
img = "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg"
|
||||
expected = client_utils.encode_url_or_file_to_base64(img)
|
||||
assert gr.Image().serialize(img) == expected
|
||||
|
||||
|
@ -183,7 +183,7 @@ class TestLoadInterface:
|
||||
|
||||
def test_english_to_spanish(self):
|
||||
with pytest.warns(UserWarning):
|
||||
io = gr.load("spaces/abidlabs/english_to_spanish", title="hi")
|
||||
io = gr.load("spaces/gradio-tests/english_to_spanish", title="hi")
|
||||
assert isinstance(io.input_components[0], gr.Textbox)
|
||||
assert isinstance(io.output_components[0], gr.Textbox)
|
||||
|
||||
@ -212,7 +212,7 @@ class TestLoadInterface:
|
||||
pass
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
io = gr.load("spaces/abidlabs/titanic-survival")
|
||||
io = gr.load("spaces/gradio-tests/titanic-survival")
|
||||
try:
|
||||
assert io.theme.name == "soft"
|
||||
with open(io("male", 77, 10)) as f:
|
||||
@ -366,11 +366,11 @@ class TestLoadInterfaceWithExamples:
|
||||
|
||||
def test_interface_with_examples(self):
|
||||
# This demo has the "fake_event" correctly removed
|
||||
demo = gr.load("spaces/freddyaboulton/calculator")
|
||||
demo = gr.load("spaces/gradio-tests/test-calculator-1")
|
||||
assert demo(2, "add", 3) == 5
|
||||
|
||||
# This demo still has the "fake_event". both should work
|
||||
demo = gr.load("spaces/abidlabs/test-calculator-2")
|
||||
demo = gr.load("spaces/gradio-tests/test-calculator-2")
|
||||
assert demo(2, "add", 4) == 6
|
||||
|
||||
|
||||
@ -441,13 +441,13 @@ def check_dataset(config, readme_examples):
|
||||
|
||||
|
||||
def test_load_blocks_with_default_values():
|
||||
io = gr.load("spaces/abidlabs/min-dalle")
|
||||
io = gr.load("spaces/gradio-tests/min-dalle")
|
||||
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)
|
||||
|
||||
io = gr.load("spaces/abidlabs/min-dalle-later")
|
||||
io = gr.load("spaces/gradio-tests/min-dalle-later")
|
||||
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)
|
||||
|
||||
io = gr.load("spaces/freddyaboulton/dataframe_load")
|
||||
io = gr.load("spaces/gradio-tests/dataframe_load")
|
||||
assert io.get_config_file()["components"][0]["props"]["value"] == {
|
||||
"headers": ["a", "b"],
|
||||
"data": [[1, 4], [2, 5], [3, 6]],
|
||||
|
@ -24,8 +24,8 @@ class TestSeries:
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_with_external(self):
|
||||
io1 = gr.load("spaces/abidlabs/image-identity")
|
||||
io2 = gr.load("spaces/abidlabs/image-classifier")
|
||||
io1 = gr.load("spaces/gradio-tests/image-identity")
|
||||
io2 = gr.load("spaces/gradio-tests/image-classifier")
|
||||
series = mix.Series(io1, io2)
|
||||
try:
|
||||
with open(series("gradio/test_data/lion.jpg")) as f:
|
||||
@ -55,8 +55,8 @@ class TestParallel:
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_with_external(self):
|
||||
io1 = gr.load("spaces/abidlabs/english_to_spanish")
|
||||
io2 = gr.load("spaces/abidlabs/english2german")
|
||||
io1 = gr.load("spaces/gradio-tests/english_to_spanish")
|
||||
io2 = gr.load("spaces/gradio-tests/english2german")
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
try:
|
||||
hello_es, hello_de = parallel("Hello")
|
||||
|
Loading…
x
Reference in New Issue
Block a user