Fixes issue 5781: Enables specifying a caching directory for Examples (#6803)

* issue 5781 first commit

* second commit

* unnecessary str removed

* backend formatted

* Update gradio/helpers.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update guides/02_building-interfaces/03_more-on-examples.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* tests added

* add changeset

* format

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
cswamy 2023-12-19 01:07:38 +00:00 committed by GitHub
parent 50496f967f
commit 77c900311e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 161 additions and 108 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Fixes issue 5781: Enables specifying a caching directory for Examples

View File

@ -33,7 +33,6 @@ from gradio.flagging import CSVLogger
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component
CACHED_FOLDER = "gradio_cached_examples"
LOG_FILE = "log.csv"
set_documentation_group("helpers")
@ -248,7 +247,7 @@ class Examples:
elem_id=elem_id,
)
self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cache_examples = cache_examples
self.run_on_click = run_on_click

View File

@ -54,7 +54,6 @@ from gradio import route_utils, utils, wasm_utils
from gradio.context import Context
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.exceptions import Error
from gradio.helpers import CACHED_FOLDER
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
@ -455,7 +454,7 @@ class App(FastAPI):
)
was_uploaded = utils.is_in_or_equal(abs_path, app.uploaded_file_dir)
is_cached_example = utils.is_in_or_equal(
abs_path, utils.abspath(CACHED_FOLDER)
abs_path, utils.abspath(utils.get_cache_folder())
)
if not (

View File

@ -1016,3 +1016,7 @@ class LRUCache(OrderedDict, Generic[K, V]):
elif len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))

View File

@ -34,7 +34,7 @@ Sometimes your app has many input components, but you would only like to provide
## Caching examples
You may wish to provide some cached examples of your model for users to quickly try out, in case your model takes a while to run normally.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples`.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples` in your working directory by default. You can also set this directory with the `GRADIO_EXAMPLES_CACHE` environment variable, which can be either an absolute path or a relative path to your working directory.
Whenever a user clicks on an example, the output will automatically be populated in the app now, using data from this cached directory instead of actually running the function. This is useful so users can quickly try out your model without adding any load!

View File

@ -1,10 +1,11 @@
import tempfile
from concurrent.futures import wait
from pathlib import Path
from unittest.mock import patch
import pytest
import gradio as gr
from gradio import helpers
def invalid_fn(message):
@ -79,44 +80,52 @@ class TestInit:
)
def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")
def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")
def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
def test_default_accordion_params(self):
chatbot = gr.ChatInterface(
@ -146,34 +155,38 @@ class TestInit:
assert accordion.get_config().get("label") == "MOAR"
def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
class TestAPI:

View File

@ -1,4 +1,5 @@
import os
import tempfile
import textwrap
import warnings
from pathlib import Path
@ -356,7 +357,7 @@ class TestLoadInterface:
class TestLoadInterfaceWithExamples:
def test_interface_load_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
@ -365,7 +366,9 @@ class TestLoadInterfaceWithExamples:
def test_interface_load_cache_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],

View File

@ -10,17 +10,19 @@ from unittest.mock import patch
import gradio_client as grc
import pytest
from gradio_client import media_data, utils
from gradio_client import media_data
from gradio_client import utils as client_utils
from pydub import AudioSegment
from starlette.testclient import TestClient
from tqdm import tqdm
import gradio as gr
from gradio import utils
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestExamples:
def test_handle_single_input(self):
def test_handle_single_input(self, patched_cache_folder):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
@ -29,37 +31,41 @@ class TestExamples:
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert (
utils.encode_file_to_base64(examples.processed_examples[0][0]["path"])
client_utils.encode_file_to_base64(
examples.processed_examples[0][0]["path"]
)
== media_data.BASE64_IMAGE
)
def test_handle_multiple_inputs(self):
def test_handle_multiple_inputs(self, patched_cache_folder):
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples[0][0] == "hello"
assert (
utils.encode_file_to_base64(examples.processed_examples[0][1]["path"])
client_utils.encode_file_to_base64(
examples.processed_examples[0][1]["path"]
)
== media_data.BASE64_IMAGE
)
def test_handle_directory(self):
def test_handle_directory(self, patched_cache_folder):
examples = gr.Examples("test/test_files/images", gr.Image())
assert len(examples.processed_examples) == 2
for row in examples.processed_examples:
for output in row:
assert (
utils.encode_file_to_base64(output["path"])
client_utils.encode_file_to_base64(output["path"])
== media_data.BASE64_IMAGE
)
def test_handle_directory_with_log_file(self):
def test_handle_directory_with_log_file(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
ex = utils.traverse(
ex = client_utils.traverse(
examples.processed_examples,
lambda s: utils.encode_file_to_base64(s["path"]),
lambda s: client_utils.encode_file_to_base64(s["path"]),
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
)
assert ex == [
@ -69,11 +75,11 @@ class TestExamples:
for sample in examples.dataset.samples:
assert os.path.isabs(sample[0])
def test_examples_per_page(self):
def test_examples_per_page(self, patched_cache_folder):
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
assert examples.dataset.get_config()["samples_per_page"] == 2
def test_no_preprocessing(self):
def test_no_preprocessing(self, patched_cache_folder):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
@ -88,9 +94,11 @@ class TestExamples:
)
prediction = examples.load_from_cache(0)
assert utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
assert (
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
)
def test_no_postprocessing(self):
def test_no_postprocessing(self, patched_cache_folder):
def im(x):
return [
{
@ -116,25 +124,45 @@ class TestExamples:
prediction = examples.load_from_cache(0)
file = prediction[0].root[0].image.path
assert utils.encode_url_or_file_to_base64(
assert client_utils.encode_url_or_file_to_base64(
file
) == utils.encode_url_or_file_to_base64("test/test_files/bus.png")
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png")
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
def test_setting_cache_dir_env_variable(monkeypatch):
temp_dir = tempfile.mkdtemp()
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir)
with gr.Blocks():
image = gr.Image()
image2 = gr.Image()
examples = gr.Examples(
examples=["test/test_files/bus.png"],
inputs=image,
outputs=image2,
fn=lambda x: x,
cache_examples=True,
)
prediction = examples.load_from_cache(0)
path_to_cached_file = Path(prediction[0].path)
assert utils.is_in_or_equal(path_to_cached_file, temp_dir)
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False)
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestExamplesDataset:
def test_no_headers(self):
def test_no_headers(self, patched_cache_folder):
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
assert examples.dataset.headers == []
def test_all_headers(self):
def test_all_headers(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log",
[gr.Image(label="im"), gr.Text(label="your text")],
)
assert examples.dataset.headers == ["im", "your text"]
def test_some_headers(self):
def test_some_headers(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
@ -178,9 +206,9 @@ def test_example_caching_relaunch(connect):
)
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestProcessExamples:
def test_caching(self):
def test_caching(self, patched_cache_folder):
io = gr.Interface(
lambda x: f"Hello {x}",
"text",
@ -191,7 +219,7 @@ class TestProcessExamples:
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
def test_example_caching_relaunch(self, connect):
def test_example_caching_relaunch(self, patched_cache_folder, connect):
def combine(a, b):
return a + " " + b
@ -224,7 +252,7 @@ class TestProcessExamples:
"hello Eve",
)
def test_caching_image(self):
def test_caching_image(self, patched_cache_folder):
io = gr.Interface(
lambda x: x,
"image",
@ -233,11 +261,11 @@ class TestProcessExamples:
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert utils.encode_url_or_file_to_base64(prediction[0].path).startswith(
assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith(
""
)
def test_caching_audio(self):
def test_caching_audio(self, patched_cache_folder):
io = gr.Interface(
lambda x: x,
"audio",
@ -247,11 +275,11 @@ class TestProcessExamples:
)
prediction = io.examples_handler.load_from_cache(0)
file = prediction[0].path
assert utils.encode_url_or_file_to_base64(file).startswith(
assert client_utils.encode_url_or_file_to_base64(file).startswith(
"data:audio/wav;base64,UklGRgA/"
)
def test_caching_with_update(self):
def test_caching_with_update(self, patched_cache_folder):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
@ -265,7 +293,7 @@ class TestProcessExamples:
"__type__": "update",
}
def test_caching_with_mix_update(self):
def test_caching_with_mix_update(self, patched_cache_folder):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
@ -280,7 +308,7 @@ class TestProcessExamples:
"__type__": "update",
}
def test_caching_with_dict(self):
def test_caching_with_dict(self, patched_cache_folder):
text = gr.Textbox()
out = gr.Label()
@ -297,7 +325,7 @@ class TestProcessExamples:
gr.Label.data_model(**{"label": "lion", "confidences": None}),
]
def test_caching_with_generators(self):
def test_caching_with_generators(self, patched_cache_folder):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
@ -312,7 +340,7 @@ class TestProcessExamples:
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_caching_with_generators_and_streamed_output(self):
def test_caching_with_generators_and_streamed_output(self, patched_cache_folder):
file_dir = Path(Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")
@ -334,7 +362,7 @@ class TestProcessExamples:
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
assert float(prediction[1]) == 10.0
def test_caching_with_async_generators(self):
def test_caching_with_async_generators(self, patched_cache_folder):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
@ -349,7 +377,9 @@ class TestProcessExamples:
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def test_raise_helpful_error_message_if_providing_partial_examples(
self, patched_cache_folder, tmp_path
):
def foo(a, b):
return a + b
@ -406,7 +436,7 @@ class TestProcessExamples:
cache_examples=True,
)
def test_caching_with_batch(self):
def test_caching_with_batch(self, patched_cache_folder):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return [trimmed_words]
@ -423,7 +453,7 @@ class TestProcessExamples:
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
def test_caching_with_batch_multiple_outputs(self):
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return trimmed_words, lens
@ -440,7 +470,7 @@ class TestProcessExamples:
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]
def test_caching_with_non_io_component(self):
def test_caching_with_non_io_component(self, patched_cache_folder):
def predict(name):
return name, gr.update(visible=True)
@ -460,7 +490,7 @@ class TestProcessExamples:
prediction = examples.load_from_cache(0)
assert prediction == ["John", {"visible": True, "__type__": "update"}]
def test_end_to_end(self):
def test_end_to_end(self, patched_cache_folder):
def concatenate(str1, str2):
return str1 + str2
@ -518,7 +548,7 @@ class TestProcessExamples:
}
]
def test_end_to_end_cache_examples(self):
def test_end_to_end_cache_examples(self, patched_cache_folder):
def concatenate(str1, str2):
return f"{str1} {str2}"
@ -547,7 +577,7 @@ class TestProcessExamples:
def test_multiple_file_flagging(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
io = gr.Interface(
fn=lambda *x: list(x),
inputs=[
@ -565,7 +595,7 @@ def test_multiple_file_flagging(tmp_path):
def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
file_2 = tmp_path / "file_2"