mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Supports gr.update()
in example caching (#2309)
* support for update * formatting * testing * fixes to tests * fixes to tests * fix tests * fix review comments * Update blocks.py
This commit is contained in:
parent
eaae71fdd1
commit
c977ef1fa8
@ -44,7 +44,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import wandb
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component, StatusTracker
|
||||
from gradio.components import Component, IOComponent
|
||||
|
||||
|
||||
class Block:
|
||||
@ -183,6 +183,12 @@ class Block:
|
||||
"style": self._style,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_specific_update(cls, generic_update):
|
||||
del generic_update["__type__"]
|
||||
generic_update = cls.update(**generic_update)
|
||||
return generic_update
|
||||
|
||||
|
||||
class BlockContext(Block):
|
||||
def __init__(
|
||||
@ -304,14 +310,39 @@ def update(**kwargs) -> dict:
|
||||
return kwargs
|
||||
|
||||
|
||||
def is_update(val):
|
||||
return type(val) is dict and "update" in val.get("__type__", "")
|
||||
|
||||
|
||||
def skip() -> dict:
|
||||
return update()
|
||||
|
||||
|
||||
def postprocess_update_dict(block: Block, update_dict: Dict):
|
||||
prediction_value = block.get_specific_update(update_dict)
|
||||
if prediction_value.get("value") is components._Keywords.NO_VALUE:
|
||||
prediction_value.pop("value")
|
||||
prediction_value = delete_none(prediction_value, skip_value=True)
|
||||
if "value" in prediction_value:
|
||||
prediction_value["value"] = block.postprocess(prediction_value["value"])
|
||||
return prediction_value
|
||||
|
||||
|
||||
def convert_update_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
|
||||
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
|
||||
if all(keys_are_blocks):
|
||||
reordered_predictions = [skip() for _ in outputs_ids]
|
||||
for component, value in predictions.items():
|
||||
if component._id not in outputs_ids:
|
||||
return ValueError(
|
||||
f"Returned component {component} not specified as output of function."
|
||||
)
|
||||
output_index = outputs_ids.index(component._id)
|
||||
reordered_predictions[output_index] = value
|
||||
predictions = utils.resolve_singleton(reordered_predictions)
|
||||
elif any(keys_are_blocks):
|
||||
raise ValueError(
|
||||
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
|
||||
)
|
||||
return predictions
|
||||
|
||||
|
||||
@document("load")
|
||||
class Blocks(BlockContext):
|
||||
"""
|
||||
@ -656,21 +687,10 @@ class Blocks(BlockContext):
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
if type(predictions) is dict and len(predictions) > 0:
|
||||
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
|
||||
if all(keys_are_blocks):
|
||||
reordered_predictions = [skip() for _ in dependency["outputs"]]
|
||||
for component, value in predictions.items():
|
||||
if component._id not in dependency["outputs"]:
|
||||
return ValueError(
|
||||
f"Returned component {component} not specified as output of function."
|
||||
)
|
||||
output_index = dependency["outputs"].index(component._id)
|
||||
reordered_predictions[output_index] = value
|
||||
predictions = utils.resolve_singleton(reordered_predictions)
|
||||
elif any(keys_are_blocks):
|
||||
raise ValueError(
|
||||
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
|
||||
)
|
||||
predictions = convert_update_dict_to_list(
|
||||
dependency["outputs"], predictions
|
||||
)
|
||||
|
||||
if len(dependency["outputs"]) == 1:
|
||||
predictions = (predictions,)
|
||||
|
||||
@ -682,38 +702,15 @@ class Blocks(BlockContext):
|
||||
break
|
||||
block = self.blocks[output_id]
|
||||
if getattr(block, "stateful", False):
|
||||
if not is_update(predictions[i]):
|
||||
if not utils.is_update(predictions[i]):
|
||||
state[output_id] = predictions[i]
|
||||
output.append(None)
|
||||
else:
|
||||
prediction_value = predictions[i]
|
||||
if is_update(prediction_value):
|
||||
if prediction_value["__type__"] == "generic_update":
|
||||
del prediction_value["__type__"]
|
||||
prediction_value = block.__class__.update(
|
||||
**prediction_value
|
||||
)
|
||||
# If the prediction is the default (NO_VALUE) enum then the user did
|
||||
# not specify a value for the 'value' key and we can get rid of it
|
||||
if (
|
||||
prediction_value.get("value")
|
||||
== components._Keywords.NO_VALUE
|
||||
):
|
||||
prediction_value.pop("value")
|
||||
prediction_value = delete_none(
|
||||
prediction_value, skip_value=True
|
||||
)
|
||||
if "value" in prediction_value:
|
||||
prediction_value["value"] = block.postprocess(
|
||||
prediction_value["value"]
|
||||
)
|
||||
output_value = prediction_value
|
||||
if utils.is_update(prediction_value):
|
||||
output_value = postprocess_update_dict(block, prediction_value)
|
||||
else:
|
||||
output_value = (
|
||||
block.postprocess(prediction_value)
|
||||
if prediction_value is not None
|
||||
else None
|
||||
)
|
||||
output_value = block.postprocess(prediction_value)
|
||||
output.append(output_value)
|
||||
|
||||
else:
|
||||
|
@ -3,6 +3,7 @@ Defines helper methods useful for loading and caching Interface examples.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import csv
|
||||
import inspect
|
||||
import os
|
||||
@ -13,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional
|
||||
import anyio
|
||||
|
||||
from gradio import utils
|
||||
from gradio.blocks import convert_update_dict_to_list, postprocess_update_dict
|
||||
from gradio.components import Dataset
|
||||
from gradio.context import Context
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
@ -188,11 +190,8 @@ class Examples:
|
||||
for example in self.processed_examples
|
||||
]
|
||||
if cache_examples:
|
||||
for ex in non_none_examples:
|
||||
if (
|
||||
len([sample for sample in ex if sample is not None])
|
||||
!= self.inputs_with_examples
|
||||
):
|
||||
for example in self.examples:
|
||||
if len([ex for ex in example if ex is not None]) != len(self.inputs):
|
||||
warnings.warn(
|
||||
"Examples are being cached but not all input components have "
|
||||
"example values. This may result in an exception being thrown by "
|
||||
@ -275,13 +274,23 @@ class Examples:
|
||||
predictions = await self.fn(*processed_input)
|
||||
else:
|
||||
predictions = await anyio.to_thread.run_sync(self.fn, *processed_input)
|
||||
|
||||
output_ids = [output._id for output in self.outputs]
|
||||
if type(predictions) is dict and len(predictions) > 0:
|
||||
predictions = convert_update_dict_to_list(output_ids, predictions)
|
||||
|
||||
if len(self.outputs) == 1:
|
||||
predictions = [predictions]
|
||||
if not self._api_mode:
|
||||
predictions = [
|
||||
output_component.postprocess(predictions[i])
|
||||
for i, output_component in enumerate(self.outputs)
|
||||
]
|
||||
predictions_ = []
|
||||
for i, output_component in enumerate(self.outputs):
|
||||
if utils.is_update(predictions[i]):
|
||||
predictions_.append(
|
||||
postprocess_update_dict(output_component, predictions[i])
|
||||
)
|
||||
else:
|
||||
predictions_.append(output_component.postprocess(predictions[i]))
|
||||
predictions = predictions_
|
||||
return predictions
|
||||
|
||||
async def load_from_cache(self, example_id: int) -> List[Any]:
|
||||
@ -294,5 +303,10 @@ class Examples:
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
for component, value in zip(self.outputs, example):
|
||||
output.append(component.serialize(value, self.cached_folder))
|
||||
try:
|
||||
value_as_dict = ast.literal_eval(value)
|
||||
assert utils.is_update(value_as_dict)
|
||||
output.append(value_as_dict)
|
||||
except (ValueError, TypeError, SyntaxError, AssertionError):
|
||||
output.append(component.serialize(value, self.cached_folder))
|
||||
return output
|
||||
|
@ -206,15 +206,18 @@ class CSVLogger(FlaggingCallback):
|
||||
component.label or f"component {idx}"
|
||||
),
|
||||
)
|
||||
csv_data.append(
|
||||
component.deserialize(
|
||||
sample,
|
||||
save_dir=save_dir,
|
||||
encryption_key=self.encryption_key,
|
||||
if utils.is_update(sample):
|
||||
csv_data.append(str(sample))
|
||||
else:
|
||||
csv_data.append(
|
||||
component.deserialize(
|
||||
sample,
|
||||
save_dir=save_dir,
|
||||
encryption_key=self.encryption_key,
|
||||
)
|
||||
if sample is not None
|
||||
else ""
|
||||
)
|
||||
if sample is not None
|
||||
else ""
|
||||
)
|
||||
csv_data.append(flag_option if flag_option is not None else "")
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
@ -675,3 +676,7 @@ def validate_url(possible_url: str) -> bool:
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_update(val):
|
||||
return type(val) is dict and "update" in val.get("__type__", "")
|
||||
|
@ -408,5 +408,40 @@ class TestCallFunction:
|
||||
assert output["prediction"] == (0, 3)
|
||||
|
||||
|
||||
class TestSpecificUpdate:
|
||||
def test_without_update(self):
|
||||
with pytest.raises(KeyError):
|
||||
gr.Textbox.get_specific_update({"lines": 4})
|
||||
|
||||
def test_with_update(self):
|
||||
specific_update = gr.Textbox.get_specific_update(
|
||||
{"lines": 4, "__type__": "update"}
|
||||
)
|
||||
assert specific_update == {
|
||||
"lines": 4,
|
||||
"max_lines": None,
|
||||
"placeholder": None,
|
||||
"label": None,
|
||||
"show_label": None,
|
||||
"visible": None,
|
||||
"value": gr.components._Keywords.NO_VALUE,
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
def test_with_generic_update(self):
|
||||
specific_update = gr.Video.get_specific_update(
|
||||
{"visible": True, "value": "test.mp4", "__type__": "generic_update"}
|
||||
)
|
||||
assert specific_update == {
|
||||
"source": None,
|
||||
"label": None,
|
||||
"show_label": None,
|
||||
"interactive": None,
|
||||
"visible": True,
|
||||
"value": "test.mp4",
|
||||
"__type__": "update",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -61,6 +62,7 @@ class TestExamplesDataset:
|
||||
assert examples.dataset.headers == ["im", ""]
|
||||
|
||||
|
||||
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
|
||||
class TestProcessExamples:
|
||||
@pytest.mark.asyncio
|
||||
async def test_predict_example(self):
|
||||
@ -85,18 +87,54 @@ class TestProcessExamples:
|
||||
"text",
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
await io.examples_handler.cache_interface_examples()
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
io.close()
|
||||
assert prediction[0] == "Hello Dunya"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_update(self):
|
||||
io = gr.Interface(
|
||||
lambda x: gr.update(visible=False),
|
||||
"text",
|
||||
"image",
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
)
|
||||
await io.examples_handler.cache_interface_examples()
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
assert prediction[0] == {"visible": False, "__type__": "update"}
|
||||
|
||||
def test_raise_helpful_error_message_if_providing_partial_examples(tmp_path):
|
||||
def foo(a, b):
|
||||
return a + b
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_mix_update(self):
|
||||
io = gr.Interface(
|
||||
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
|
||||
"text",
|
||||
["text", "image"],
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
)
|
||||
await io.examples_handler.cache_interface_examples()
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
assert prediction[0] == {"lines": 4, "value": "hello", "__type__": "update"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_dict(self):
|
||||
text = gr.Textbox()
|
||||
out = gr.Label()
|
||||
|
||||
io = gr.Interface(
|
||||
lambda _: {text: gr.update(lines=4), out: "lion"},
|
||||
"textbox",
|
||||
[text, out],
|
||||
examples=["abc"],
|
||||
cache_examples=True,
|
||||
)
|
||||
await io.examples_handler.cache_interface_examples()
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
assert prediction == [{"lines": 4, "__type__": "update"}, {"label": "lion"}]
|
||||
|
||||
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
|
||||
def foo(a, b):
|
||||
return a + b
|
||||
|
||||
with patch("gradio.examples.CACHED_FOLDER", tmp_path):
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="^Examples are being cached but not all input components have",
|
||||
|
Loading…
Reference in New Issue
Block a user