Supports gr.update() in example caching (#2309)

* support for update

* formatting

* testing

* fixes to tests

* fixes to tests

* fix tests

* fix review comments

* Update blocks.py
This commit is contained in:
Abubakar Abid 2022-09-21 11:53:06 -05:00 committed by GitHub
parent eaae71fdd1
commit c977ef1fa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 163 additions and 71 deletions

View File

@ -44,7 +44,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import wandb
from fastapi.applications import FastAPI
from gradio.components import Component, StatusTracker
from gradio.components import Component, IOComponent
class Block:
@ -183,6 +183,12 @@ class Block:
"style": self._style,
}
@classmethod
def get_specific_update(cls, generic_update):
del generic_update["__type__"]
generic_update = cls.update(**generic_update)
return generic_update
class BlockContext(Block):
def __init__(
@ -304,14 +310,39 @@ def update(**kwargs) -> dict:
return kwargs
def is_update(val):
return type(val) is dict and "update" in val.get("__type__", "")
def skip() -> dict:
return update()
def postprocess_update_dict(block: Block, update_dict: Dict):
prediction_value = block.get_specific_update(update_dict)
if prediction_value.get("value") is components._Keywords.NO_VALUE:
prediction_value.pop("value")
prediction_value = delete_none(prediction_value, skip_value=True)
if "value" in prediction_value:
prediction_value["value"] = block.postprocess(prediction_value["value"])
return prediction_value
def convert_update_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
if all(keys_are_blocks):
reordered_predictions = [skip() for _ in outputs_ids]
for component, value in predictions.items():
if component._id not in outputs_ids:
return ValueError(
f"Returned component {component} not specified as output of function."
)
output_index = outputs_ids.index(component._id)
reordered_predictions[output_index] = value
predictions = utils.resolve_singleton(reordered_predictions)
elif any(keys_are_blocks):
raise ValueError(
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
)
return predictions
@document("load")
class Blocks(BlockContext):
"""
@ -656,21 +687,10 @@ class Blocks(BlockContext):
dependency = self.dependencies[fn_index]
if type(predictions) is dict and len(predictions) > 0:
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
if all(keys_are_blocks):
reordered_predictions = [skip() for _ in dependency["outputs"]]
for component, value in predictions.items():
if component._id not in dependency["outputs"]:
return ValueError(
f"Returned component {component} not specified as output of function."
)
output_index = dependency["outputs"].index(component._id)
reordered_predictions[output_index] = value
predictions = utils.resolve_singleton(reordered_predictions)
elif any(keys_are_blocks):
raise ValueError(
"Returned dictionary included some keys as Components. Either all keys must be Components to assign Component values, or return a List of values to assign output values in order."
)
predictions = convert_update_dict_to_list(
dependency["outputs"], predictions
)
if len(dependency["outputs"]) == 1:
predictions = (predictions,)
@ -682,38 +702,15 @@ class Blocks(BlockContext):
break
block = self.blocks[output_id]
if getattr(block, "stateful", False):
if not is_update(predictions[i]):
if not utils.is_update(predictions[i]):
state[output_id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if is_update(prediction_value):
if prediction_value["__type__"] == "generic_update":
del prediction_value["__type__"]
prediction_value = block.__class__.update(
**prediction_value
)
# If the prediction is the default (NO_VALUE) enum then the user did
# not specify a value for the 'value' key and we can get rid of it
if (
prediction_value.get("value")
== components._Keywords.NO_VALUE
):
prediction_value.pop("value")
prediction_value = delete_none(
prediction_value, skip_value=True
)
if "value" in prediction_value:
prediction_value["value"] = block.postprocess(
prediction_value["value"]
)
output_value = prediction_value
if utils.is_update(prediction_value):
output_value = postprocess_update_dict(block, prediction_value)
else:
output_value = (
block.postprocess(prediction_value)
if prediction_value is not None
else None
)
output_value = block.postprocess(prediction_value)
output.append(output_value)
else:

View File

@ -3,6 +3,7 @@ Defines helper methods useful for loading and caching Interface examples.
"""
from __future__ import annotations
import ast
import csv
import inspect
import os
@ -13,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional
import anyio
from gradio import utils
from gradio.blocks import convert_update_dict_to_list, postprocess_update_dict
from gradio.components import Dataset
from gradio.context import Context
from gradio.documentation import document, set_documentation_group
@ -188,11 +190,8 @@ class Examples:
for example in self.processed_examples
]
if cache_examples:
for ex in non_none_examples:
if (
len([sample for sample in ex if sample is not None])
!= self.inputs_with_examples
):
for example in self.examples:
if len([ex for ex in example if ex is not None]) != len(self.inputs):
warnings.warn(
"Examples are being cached but not all input components have "
"example values. This may result in an exception being thrown by "
@ -275,13 +274,23 @@ class Examples:
predictions = await self.fn(*processed_input)
else:
predictions = await anyio.to_thread.run_sync(self.fn, *processed_input)
output_ids = [output._id for output in self.outputs]
if type(predictions) is dict and len(predictions) > 0:
predictions = convert_update_dict_to_list(output_ids, predictions)
if len(self.outputs) == 1:
predictions = [predictions]
if not self._api_mode:
predictions = [
output_component.postprocess(predictions[i])
for i, output_component in enumerate(self.outputs)
]
predictions_ = []
for i, output_component in enumerate(self.outputs):
if utils.is_update(predictions[i]):
predictions_.append(
postprocess_update_dict(output_component, predictions[i])
)
else:
predictions_.append(output_component.postprocess(predictions[i]))
predictions = predictions_
return predictions
async def load_from_cache(self, example_id: int) -> List[Any]:
@ -294,5 +303,10 @@ class Examples:
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, value in zip(self.outputs, example):
output.append(component.serialize(value, self.cached_folder))
try:
value_as_dict = ast.literal_eval(value)
assert utils.is_update(value_as_dict)
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError, AssertionError):
output.append(component.serialize(value, self.cached_folder))
return output

View File

@ -206,15 +206,18 @@ class CSVLogger(FlaggingCallback):
component.label or f"component {idx}"
),
)
csv_data.append(
component.deserialize(
sample,
save_dir=save_dir,
encryption_key=self.encryption_key,
if utils.is_update(sample):
csv_data.append(str(sample))
else:
csv_data.append(
component.deserialize(
sample,
save_dir=save_dir,
encryption_key=self.encryption_key,
)
if sample is not None
else ""
)
if sample is not None
else ""
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import ast
import asyncio
import copy
import inspect
@ -675,3 +676,7 @@ def validate_url(possible_url: str) -> bool:
except Exception:
pass
return False
def is_update(val):
return type(val) is dict and "update" in val.get("__type__", "")

View File

@ -408,5 +408,40 @@ class TestCallFunction:
assert output["prediction"] == (0, 3)
class TestSpecificUpdate:
def test_without_update(self):
with pytest.raises(KeyError):
gr.Textbox.get_specific_update({"lines": 4})
def test_with_update(self):
specific_update = gr.Textbox.get_specific_update(
{"lines": 4, "__type__": "update"}
)
assert specific_update == {
"lines": 4,
"max_lines": None,
"placeholder": None,
"label": None,
"show_label": None,
"visible": None,
"value": gr.components._Keywords.NO_VALUE,
"__type__": "update",
}
def test_with_generic_update(self):
specific_update = gr.Video.get_specific_update(
{"visible": True, "value": "test.mp4", "__type__": "generic_update"}
)
assert specific_update == {
"source": None,
"label": None,
"show_label": None,
"interactive": None,
"visible": True,
"value": "test.mp4",
"__type__": "update",
}
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,5 @@
import os
import tempfile
from unittest.mock import patch
import pytest
@ -61,6 +62,7 @@ class TestExamplesDataset:
assert examples.dataset.headers == ["im", ""]
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_predict_example(self):
@ -85,18 +87,54 @@ class TestProcessExamples:
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
)
io.launch(prevent_thread_lock=True)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
io.close()
assert prediction[0] == "Hello Dunya"
@pytest.mark.asyncio
async def test_caching_with_update(self):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"visible": False, "__type__": "update"}
def test_raise_helpful_error_message_if_providing_partial_examples(tmp_path):
def foo(a, b):
return a + b
@pytest.mark.asyncio
async def test_caching_with_mix_update(self):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"lines": 4, "value": "hello", "__type__": "update"}
@pytest.mark.asyncio
async def test_caching_with_dict(self):
text = gr.Textbox()
out = gr.Label()
io = gr.Interface(
lambda _: {text: gr.update(lines=4), out: "lion"},
"textbox",
[text, out],
examples=["abc"],
cache_examples=True,
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == [{"lines": 4, "__type__": "update"}, {"label": "lion"}]
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def foo(a, b):
return a + b
with patch("gradio.examples.CACHED_FOLDER", tmp_path):
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",