mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Allow caching generators and async generators (#4927)
* helpers * helpers * async and tests * docstring * changelog * type * typing * helpers
This commit is contained in:
parent
9137f1caa0
commit
e90ad010ad
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
## New Features:
|
## New Features:
|
||||||
- Chatbot messages now show hyperlinks to download files uploaded to `gr.Chatbot()` by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 4848](https://github.com/gradio-app/gradio/pull/4848)
|
- Chatbot messages now show hyperlinks to download files uploaded to `gr.Chatbot()` by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 4848](https://github.com/gradio-app/gradio/pull/4848)
|
||||||
|
- Cached examples now work with generators and async generators by [@abidlabs](https://github.com/abidlabs) in [PR 4927](https://github.com/gradio-app/gradio/pull/4927)
|
||||||
|
|
||||||
## Bug Fixes:
|
## Bug Fixes:
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class Examples:
|
|||||||
inputs: the component or list of components corresponding to the examples
|
inputs: the component or list of components corresponding to the examples
|
||||||
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
||||||
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
||||||
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
|
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` must be provided. If `fn` is a generator function, then the last yielded value will be used as the output.
|
||||||
examples_per_page: how many examples to show per page.
|
examples_per_page: how many examples to show per page.
|
||||||
label: the label to use for the examples component (by default, "Examples")
|
label: the label to use for the examples component (by default, "Examples")
|
||||||
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
|
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
|
||||||
@ -289,7 +289,7 @@ class Examples:
|
|||||||
"""
|
"""
|
||||||
if Path(self.cached_file).exists():
|
if Path(self.cached_file).exists():
|
||||||
print(
|
print(
|
||||||
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if Context.root_block is None:
|
if Context.root_block is None:
|
||||||
@ -298,10 +298,31 @@ class Examples:
|
|||||||
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
||||||
cache_logger = CSVLogger()
|
cache_logger = CSVLogger()
|
||||||
|
|
||||||
|
if inspect.isgeneratorfunction(self.fn):
|
||||||
|
|
||||||
|
def get_final_item(args): # type: ignore
|
||||||
|
x = None
|
||||||
|
for x in self.fn(args): # noqa: B007 # type: ignore
|
||||||
|
pass
|
||||||
|
return x
|
||||||
|
|
||||||
|
fn = get_final_item
|
||||||
|
elif inspect.isasyncgenfunction(self.fn):
|
||||||
|
|
||||||
|
async def get_final_item(args):
|
||||||
|
x = None
|
||||||
|
async for x in self.fn(args): # noqa: B007 # type: ignore
|
||||||
|
pass
|
||||||
|
return x
|
||||||
|
|
||||||
|
fn = get_final_item
|
||||||
|
else:
|
||||||
|
fn = self.fn
|
||||||
|
|
||||||
# create a fake dependency to process the examples and get the predictions
|
# create a fake dependency to process the examples and get the predictions
|
||||||
dependency, fn_index = Context.root_block.set_event_trigger(
|
dependency, fn_index = Context.root_block.set_event_trigger(
|
||||||
event_name="fake_event",
|
event_name="fake_event",
|
||||||
fn=self.fn,
|
fn=fn,
|
||||||
inputs=self.inputs_with_examples, # type: ignore
|
inputs=self.inputs_with_examples, # type: ignore
|
||||||
outputs=self.outputs, # type: ignore
|
outputs=self.outputs, # type: ignore
|
||||||
preprocess=self.preprocess and not self._api_mode,
|
preprocess=self.preprocess and not self._api_mode,
|
||||||
@ -312,6 +333,7 @@ class Examples:
|
|||||||
assert self.outputs is not None
|
assert self.outputs is not None
|
||||||
cache_logger.setup(self.outputs, self.cached_folder)
|
cache_logger.setup(self.outputs, self.cached_folder)
|
||||||
for example_id, _ in enumerate(self.examples):
|
for example_id, _ in enumerate(self.examples):
|
||||||
|
print(f"Caching example {example_id + 1}/{len(self.examples)}")
|
||||||
processed_input = self.processed_examples[example_id]
|
processed_input = self.processed_examples[example_id]
|
||||||
if self.batch:
|
if self.batch:
|
||||||
processed_input = [[value] for value in processed_input]
|
processed_input = [[value] for value in processed_input]
|
||||||
@ -329,6 +351,7 @@ class Examples:
|
|||||||
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
||||||
Context.root_block.dependencies.remove(dependency)
|
Context.root_block.dependencies.remove(dependency)
|
||||||
Context.root_block.fns.pop(fn_index)
|
Context.root_block.fns.pop(fn_index)
|
||||||
|
print("Caching complete\n")
|
||||||
|
|
||||||
async def load_from_cache(self, example_id: int) -> list[Any]:
|
async def load_from_cache(self, example_id: int) -> list[Any]:
|
||||||
"""Loads a particular cached example for the interface.
|
"""Loads a particular cached example for the interface.
|
||||||
|
@ -154,7 +154,7 @@ class Interface(Blocks):
|
|||||||
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
|
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
|
||||||
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
|
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
|
||||||
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
||||||
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
cache_examples: If True, caches examples in the server for fast runtime in examples. If `fn` is a generator function, then the last yielded value will be used as the output. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
||||||
examples_per_page: If examples are provided, how many to display per page.
|
examples_per_page: If examples are provided, how many to display per page.
|
||||||
live: whether the interface should automatically rerun if any of the inputs change.
|
live: whether the interface should automatically rerun if any of the inputs change.
|
||||||
interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
|
interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
|
||||||
|
@ -204,6 +204,38 @@ class TestProcessExamples:
|
|||||||
{"label": "lion"},
|
{"label": "lion"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caching_with_generators(self):
|
||||||
|
def test_generator(x):
|
||||||
|
for y in range(len(x)):
|
||||||
|
yield "Your output: " + x[: y + 1]
|
||||||
|
|
||||||
|
io = gr.Interface(
|
||||||
|
test_generator,
|
||||||
|
"textbox",
|
||||||
|
"textbox",
|
||||||
|
examples=["abcdef"],
|
||||||
|
cache_examples=True,
|
||||||
|
)
|
||||||
|
prediction = await io.examples_handler.load_from_cache(0)
|
||||||
|
assert prediction[0] == "Your output: abcdef"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caching_with_async_generators(self):
|
||||||
|
async def test_generator(x):
|
||||||
|
for y in range(len(x)):
|
||||||
|
yield "Your output: " + x[: y + 1]
|
||||||
|
|
||||||
|
io = gr.Interface(
|
||||||
|
test_generator,
|
||||||
|
"textbox",
|
||||||
|
"textbox",
|
||||||
|
examples=["abcdef"],
|
||||||
|
cache_examples=True,
|
||||||
|
)
|
||||||
|
prediction = await io.examples_handler.load_from_cache(0)
|
||||||
|
assert prediction[0] == "Your output: abcdef"
|
||||||
|
|
||||||
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
|
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
|
||||||
def foo(a, b):
|
def foo(a, b):
|
||||||
return a + b
|
return a + b
|
||||||
|
Loading…
Reference in New Issue
Block a user