Allow caching generators and async generators (#4927)

* helpers

* helpers

* async and tests

* docstring

* changelog

* type

* typing

* helpers
This commit is contained in:
Abubakar Abid 2023-07-14 19:38:22 -04:00 committed by GitHub
parent 9137f1caa0
commit e90ad010ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 4 deletions

View File

@ -2,6 +2,7 @@
## New Features:
- Chatbot messages now show hyperlinks to download files uploaded to `gr.Chatbot()` by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 4848](https://github.com/gradio-app/gradio/pull/4848)
- Cached examples now work with generators and async generators by [@abidlabs](https://github.com/abidlabs) in [PR 4927](https://github.com/gradio-app/gradio/pull/4927)
## Bug Fixes:

View File

@ -110,7 +110,7 @@ class Examples:
inputs: the component or list of components corresponding to the examples
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` must be provided. If `fn` is a generator function, then the last yielded value will be used as the output.
examples_per_page: how many examples to show per page.
label: the label to use for the examples component (by default, "Examples")
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
@ -289,7 +289,7 @@ class Examples:
"""
if Path(self.cached_file).exists():
print(
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
)
else:
if Context.root_block is None:
@ -298,10 +298,31 @@ class Examples:
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
cache_logger = CSVLogger()
if inspect.isgeneratorfunction(self.fn):
def get_final_item(args): # type: ignore
x = None
for x in self.fn(args): # noqa: B007 # type: ignore
pass
return x
fn = get_final_item
elif inspect.isasyncgenfunction(self.fn):
async def get_final_item(args):
x = None
async for x in self.fn(args): # noqa: B007 # type: ignore
pass
return x
fn = get_final_item
else:
fn = self.fn
# create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event",
fn=self.fn,
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
outputs=self.outputs, # type: ignore
preprocess=self.preprocess and not self._api_mode,
@ -312,6 +333,7 @@ class Examples:
assert self.outputs is not None
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
print(f"Caching example {example_id + 1}/{len(self.examples)}")
processed_input = self.processed_examples[example_id]
if self.batch:
processed_input = [[value] for value in processed_input]
@ -329,6 +351,7 @@ class Examples:
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)
print("Caching complete\n")
async def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.

View File

@ -154,7 +154,7 @@ class Interface(Blocks):
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
cache_examples: If True, caches examples in the server for fast runtime in examples. If `fn` is a generator function, then the last yielded value will be used as the output. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
examples_per_page: If examples are provided, how many to display per page.
live: whether the interface should automatically rerun if any of the inputs change.
interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.

View File

@ -204,6 +204,38 @@ class TestProcessExamples:
{"label": "lion"},
]
@pytest.mark.asyncio
async def test_caching_with_generators(self):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
@pytest.mark.asyncio
async def test_caching_with_async_generators(self):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def foo(a, b):
return a + b