Convert async methods in the Examples class into normal sync methods (#5822)

* Convert async methods in the Examples class into normal sync methods

* add changeset

* Fix test/test_chat_interface.py

* Fix test/test_helpers.py

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-10-07 03:05:11 +09:00 committed by GitHub
parent 1aa186220d
commit 7b63db2716
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 78 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Convert async methods in the Examples class into normal sync methods

View File

@ -369,10 +369,10 @@ class Examples:
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)
async def load_example(example_id):
def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
] + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
self.load_input_event = self.dataset.click(
@ -385,7 +385,7 @@ class Examples:
api_name=self.api_name, # type: ignore
)
async def load_from_cache(self, example_id: int) -> list[Any]:
def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.
Parameters:
example_id: The id of the example to process (zero-indexed).

View File

@ -72,52 +72,47 @@ class TestInit:
None,
)
@pytest.mark.asyncio
async def test_example_caching(self, monkeypatch):
def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello hello"]
assert prediction_hi[0][0] == ["hi", "hi hi"]
@pytest.mark.asyncio
async def test_example_caching_async(self, monkeypatch):
def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
assert prediction_hi[0][0] == ["tom", "hi, tom"]
@pytest.mark.asyncio
async def test_example_caching_with_streaming(self, monkeypatch):
def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]
@pytest.mark.asyncio
async def test_example_caching_with_streaming_async(self, monkeypatch):
def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]
@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs(self, monkeypatch):
def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
@ -125,15 +120,12 @@ class TestInit:
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]
@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs_already_rendered(
self, monkeypatch
):
def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
@ -145,8 +137,8 @@ class TestInit:
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]

View File

@ -58,8 +58,7 @@ class TestExamples:
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
assert examples.dataset.get_config()["samples_per_page"] == 2
@pytest.mark.asyncio
async def test_no_preprocessing(self):
def test_no_preprocessing(self):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
@ -73,11 +72,10 @@ class TestExamples:
preprocess=False,
)
prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
assert prediction == [media_data.BASE64_IMAGE]
@pytest.mark.asyncio
async def test_no_postprocessing(self):
def test_no_postprocessing(self):
def im(x):
return [media_data.BASE64_IMAGE]
@ -94,7 +92,7 @@ class TestExamples:
postprocess=False,
)
prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
file = prediction[0][0][0]["name"]
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE
@ -158,8 +156,7 @@ def test_example_caching_relaunch(connect):
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_caching(self):
def test_caching(self):
io = gr.Interface(
lambda x: f"Hello {x}",
"text",
@ -167,7 +164,7 @@ class TestProcessExamples:
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
def test_example_caching_relaunch(self, connect):
@ -203,8 +200,7 @@ class TestProcessExamples:
"hello Eve",
)
@pytest.mark.asyncio
async def test_caching_image(self):
def test_caching_image(self):
io = gr.Interface(
lambda x: x,
"image",
@ -212,11 +208,10 @@ class TestProcessExamples:
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0].startswith("data:image/png;base64,iVBORw0KGgoAAA")
@pytest.mark.asyncio
async def test_caching_audio(self):
def test_caching_audio(self):
io = gr.Interface(
lambda x: x,
"audio",
@ -224,14 +219,13 @@ class TestProcessExamples:
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
file = prediction[0]["name"]
assert utils.encode_url_or_file_to_base64(file).startswith(
"data:audio/wav;base64,UklGRgA/"
)
@pytest.mark.asyncio
async def test_caching_with_update(self):
def test_caching_with_update(self):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
@ -239,14 +233,13 @@ class TestProcessExamples:
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"visible": False,
"__type__": "update",
}
@pytest.mark.asyncio
async def test_caching_with_mix_update(self):
def test_caching_with_mix_update(self):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
@ -254,15 +247,14 @@ class TestProcessExamples:
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"lines": 4,
"value": "hello",
"__type__": "update",
}
@pytest.mark.asyncio
async def test_caching_with_dict(self):
def test_caching_with_dict(self):
text = gr.Textbox()
out = gr.Label()
@ -273,14 +265,13 @@ class TestProcessExamples:
examples=["abc"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == [
{"lines": 4, "__type__": "update", "mode": "static"},
{"label": "lion"},
]
@pytest.mark.asyncio
async def test_caching_with_generators(self):
def test_caching_with_generators(self):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
@ -292,11 +283,10 @@ class TestProcessExamples:
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
@pytest.mark.asyncio
async def test_caching_with_generators_and_streamed_output(self):
def test_caching_with_generators_and_streamed_output(self):
file_dir = Path(Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")
@ -311,15 +301,14 @@ class TestProcessExamples:
examples=[3],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
len_input_audio = len(AudioSegment.from_wav(audio))
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
length_ratio = len_output_audio / len_input_audio
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
assert float(prediction[1]) == 10.0
@pytest.mark.asyncio
async def test_caching_with_async_generators(self):
def test_caching_with_async_generators(self):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
@ -331,7 +320,7 @@ class TestProcessExamples:
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
@ -391,8 +380,7 @@ class TestProcessExamples:
cache_examples=True,
)
@pytest.mark.asyncio
async def test_caching_with_batch(self):
def test_caching_with_batch(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return [trimmed_words]
@ -406,11 +394,10 @@ class TestProcessExamples:
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
@pytest.mark.asyncio
async def test_caching_with_batch_multiple_outputs(self):
def test_caching_with_batch_multiple_outputs(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return trimmed_words, lens
@ -424,11 +411,10 @@ class TestProcessExamples:
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]
@pytest.mark.asyncio
async def test_caching_with_non_io_component(self):
def test_caching_with_non_io_component(self):
def predict(name):
return name, gr.update(visible=True)
@ -445,7 +431,7 @@ class TestProcessExamples:
cache_examples=True,
)
prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
assert prediction == ["John", {"visible": True, "__type__": "update"}]
def test_end_to_end(self):
@ -500,8 +486,7 @@ class TestProcessExamples:
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
@pytest.mark.asyncio
async def test_multiple_file_flagging(tmp_path):
def test_multiple_file_flagging(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
io = gr.Interface(
fn=lambda *x: list(x),
@ -513,14 +498,13 @@ async def test_multiple_file_flagging(tmp_path):
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert len(prediction[0]) == 2
assert all(isinstance(d, dict) for d in prediction[0])
@pytest.mark.asyncio
async def test_examples_keep_all_suffixes(tmp_path):
def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
@ -535,10 +519,10 @@ async def test_examples_keep_all_suffixes(tmp_path):
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert Path(prediction[0]["name"]).read_text() == "file 1"
assert prediction[0]["orig_name"] == "foo.bar.txt"
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert Path(prediction[0]["name"]).read_text() == "file 2"
assert prediction[0]["orig_name"] == "foo.bar.txt"