mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Convert async methods in the Examples class into normal sync methods (#5822)
* Convert async methods in the Examples class into normal sync methods * add changeset * Fix test/test_chat_interface.py * Fix test/test_helpers.py * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
1aa186220d
commit
7b63db2716
5
.changeset/evil-berries-teach.md
Normal file
5
.changeset/evil-berries-teach.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Convert async methods in the Examples class into normal sync methods
|
@ -369,10 +369,10 @@ class Examples:
|
||||
Context.root_block.dependencies.pop(index)
|
||||
Context.root_block.fns.pop(index)
|
||||
|
||||
async def load_example(example_id):
|
||||
def load_example(example_id):
|
||||
processed_example = self.non_none_processed_examples[
|
||||
example_id
|
||||
] + await self.load_from_cache(example_id)
|
||||
] + self.load_from_cache(example_id)
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
self.load_input_event = self.dataset.click(
|
||||
@ -385,7 +385,7 @@ class Examples:
|
||||
api_name=self.api_name, # type: ignore
|
||||
)
|
||||
|
||||
async def load_from_cache(self, example_id: int) -> list[Any]:
|
||||
def load_from_cache(self, example_id: int) -> list[Any]:
|
||||
"""Loads a particular cached example for the interface.
|
||||
Parameters:
|
||||
example_id: The id of the example to process (zero-indexed).
|
||||
|
@ -72,52 +72,47 @@ class TestInit:
|
||||
None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching(self, monkeypatch):
|
||||
def test_example_caching(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
double, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "hello hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_async(self, monkeypatch):
|
||||
def test_example_caching_async(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
async_greet, examples=["abubakar", "tom"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
|
||||
assert prediction_hi[0][0] == ["tom", "hi, tom"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_streaming(self, monkeypatch):
|
||||
def test_example_caching_with_streaming(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
stream, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_streaming_async(self, monkeypatch):
|
||||
def test_example_caching_with_streaming_async(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
async_stream, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_additional_inputs(self, monkeypatch):
|
||||
def test_example_caching_with_additional_inputs(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
@ -125,15 +120,12 @@ class TestInit:
|
||||
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "robot hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "ro"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_additional_inputs_already_rendered(
|
||||
self, monkeypatch
|
||||
):
|
||||
def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
with gr.Blocks():
|
||||
with gr.Accordion("Inputs"):
|
||||
@ -145,8 +137,8 @@ class TestInit:
|
||||
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
prediction_hello = chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "robot hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "ro"]
|
||||
|
||||
|
@ -58,8 +58,7 @@ class TestExamples:
|
||||
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
|
||||
assert examples.dataset.get_config()["samples_per_page"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_preprocessing(self):
|
||||
def test_no_preprocessing(self):
|
||||
with gr.Blocks():
|
||||
image = gr.Image()
|
||||
textbox = gr.Textbox()
|
||||
@ -73,11 +72,10 @@ class TestExamples:
|
||||
preprocess=False,
|
||||
)
|
||||
|
||||
prediction = await examples.load_from_cache(0)
|
||||
prediction = examples.load_from_cache(0)
|
||||
assert prediction == [media_data.BASE64_IMAGE]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_postprocessing(self):
|
||||
def test_no_postprocessing(self):
|
||||
def im(x):
|
||||
return [media_data.BASE64_IMAGE]
|
||||
|
||||
@ -94,7 +92,7 @@ class TestExamples:
|
||||
postprocess=False,
|
||||
)
|
||||
|
||||
prediction = await examples.load_from_cache(0)
|
||||
prediction = examples.load_from_cache(0)
|
||||
file = prediction[0][0][0]["name"]
|
||||
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE
|
||||
|
||||
@ -158,8 +156,7 @@ def test_example_caching_relaunch(connect):
|
||||
|
||||
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
|
||||
class TestProcessExamples:
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching(self):
|
||||
def test_caching(self):
|
||||
io = gr.Interface(
|
||||
lambda x: f"Hello {x}",
|
||||
"text",
|
||||
@ -167,7 +164,7 @@ class TestProcessExamples:
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
assert prediction[0] == "Hello Dunya"
|
||||
|
||||
def test_example_caching_relaunch(self, connect):
|
||||
@ -203,8 +200,7 @@ class TestProcessExamples:
|
||||
"hello Eve",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_image(self):
|
||||
def test_caching_image(self):
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"image",
|
||||
@ -212,11 +208,10 @@ class TestProcessExamples:
|
||||
examples=[["test/test_files/bus.png"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction[0].startswith("")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_audio(self):
|
||||
def test_caching_audio(self):
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"audio",
|
||||
@ -224,14 +219,13 @@ class TestProcessExamples:
|
||||
examples=[["test/test_files/audio_sample.wav"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
file = prediction[0]["name"]
|
||||
assert utils.encode_url_or_file_to_base64(file).startswith(
|
||||
"data:audio/wav;base64,UklGRgA/"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_update(self):
|
||||
def test_caching_with_update(self):
|
||||
io = gr.Interface(
|
||||
lambda x: gr.update(visible=False),
|
||||
"text",
|
||||
@ -239,14 +233,13 @@ class TestProcessExamples:
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
assert prediction[0] == {
|
||||
"visible": False,
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_mix_update(self):
|
||||
def test_caching_with_mix_update(self):
|
||||
io = gr.Interface(
|
||||
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
|
||||
"text",
|
||||
@ -254,15 +247,14 @@ class TestProcessExamples:
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
assert prediction[0] == {
|
||||
"lines": 4,
|
||||
"value": "hello",
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_dict(self):
|
||||
def test_caching_with_dict(self):
|
||||
text = gr.Textbox()
|
||||
out = gr.Label()
|
||||
|
||||
@ -273,14 +265,13 @@ class TestProcessExamples:
|
||||
examples=["abc"],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction == [
|
||||
{"lines": 4, "__type__": "update", "mode": "static"},
|
||||
{"label": "lion"},
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_generators(self):
|
||||
def test_caching_with_generators(self):
|
||||
def test_generator(x):
|
||||
for y in range(len(x)):
|
||||
yield "Your output: " + x[: y + 1]
|
||||
@ -292,11 +283,10 @@ class TestProcessExamples:
|
||||
examples=["abcdef"],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction[0] == "Your output: abcdef"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_generators_and_streamed_output(self):
|
||||
def test_caching_with_generators_and_streamed_output(self):
|
||||
file_dir = Path(Path(__file__).parent, "test_files")
|
||||
audio = str(file_dir / "audio_sample.wav")
|
||||
|
||||
@ -311,15 +301,14 @@ class TestProcessExamples:
|
||||
examples=[3],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
len_input_audio = len(AudioSegment.from_wav(audio))
|
||||
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
|
||||
length_ratio = len_output_audio / len_input_audio
|
||||
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
|
||||
assert float(prediction[1]) == 10.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_async_generators(self):
|
||||
def test_caching_with_async_generators(self):
|
||||
async def test_generator(x):
|
||||
for y in range(len(x)):
|
||||
yield "Your output: " + x[: y + 1]
|
||||
@ -331,7 +320,7 @@ class TestProcessExamples:
|
||||
examples=["abcdef"],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction[0] == "Your output: abcdef"
|
||||
|
||||
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
|
||||
@ -391,8 +380,7 @@ class TestProcessExamples:
|
||||
cache_examples=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_batch(self):
|
||||
def test_caching_with_batch(self):
|
||||
def trim_words(words, lens):
|
||||
trimmed_words = [word[:length] for word, length in zip(words, lens)]
|
||||
return [trimmed_words]
|
||||
@ -406,11 +394,10 @@ class TestProcessExamples:
|
||||
examples=[["hello", 3], ["hi", 4]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction == ["hel"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_batch_multiple_outputs(self):
|
||||
def test_caching_with_batch_multiple_outputs(self):
|
||||
def trim_words(words, lens):
|
||||
trimmed_words = [word[:length] for word, length in zip(words, lens)]
|
||||
return trimmed_words, lens
|
||||
@ -424,11 +411,10 @@ class TestProcessExamples:
|
||||
examples=[["hello", 3], ["hi", 4]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert prediction == ["hel", "3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_non_io_component(self):
|
||||
def test_caching_with_non_io_component(self):
|
||||
def predict(name):
|
||||
return name, gr.update(visible=True)
|
||||
|
||||
@ -445,7 +431,7 @@ class TestProcessExamples:
|
||||
cache_examples=True,
|
||||
)
|
||||
|
||||
prediction = await examples.load_from_cache(0)
|
||||
prediction = examples.load_from_cache(0)
|
||||
assert prediction == ["John", {"visible": True, "__type__": "update"}]
|
||||
|
||||
def test_end_to_end(self):
|
||||
@ -500,8 +486,7 @@ class TestProcessExamples:
|
||||
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_file_flagging(tmp_path):
|
||||
def test_multiple_file_flagging(tmp_path):
|
||||
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
|
||||
io = gr.Interface(
|
||||
fn=lambda *x: list(x),
|
||||
@ -513,14 +498,13 @@ async def test_multiple_file_flagging(tmp_path):
|
||||
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
|
||||
assert len(prediction[0]) == 2
|
||||
assert all(isinstance(d, dict) for d in prediction[0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_examples_keep_all_suffixes(tmp_path):
|
||||
def test_examples_keep_all_suffixes(tmp_path):
|
||||
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
|
||||
file_1 = tmp_path / "foo.bar.txt"
|
||||
file_1.write_text("file 1")
|
||||
@ -535,10 +519,10 @@ async def test_examples_keep_all_suffixes(tmp_path):
|
||||
examples=[[str(file_1)], [str(file_2)]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
prediction = io.examples_handler.load_from_cache(0)
|
||||
assert Path(prediction[0]["name"]).read_text() == "file 1"
|
||||
assert prediction[0]["orig_name"] == "foo.bar.txt"
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
assert Path(prediction[0]["name"]).read_text() == "file 2"
|
||||
assert prediction[0]["orig_name"] == "foo.bar.txt"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user