Fix bug where you cannot cache examples with Interface.load (#1949)

* Possible fix

* Remove breakpoint

* Implementation

* Add unit test

* Fix test

* Lint

* Add _api_mode to interface signature
This commit is contained in:
Freddy Boulton 2022-08-04 20:20:44 -04:00 committed by GitHub
parent 2903b74160
commit 299ba1bd1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 5 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -0,0 +1,29 @@
import gradio as gr
images = ["cheetah1.jpeg", "cheetah1.jpg", "lion.jpg"]
img_classifier = gr.Interface.load(
"models/google/vit-base-patch16-224", examples=images, cache_examples=True
)
def func(img, text):
return img_classifier(img), text
using_img_classifier_as_function = gr.Interface(
func,
[gr.Image(type="filepath"), "text"],
["label", "text"],
examples=[
["cheetah1.jpeg", None],
["cheetah1.jpg", "cheetah"],
["lion.jpg", "lion"],
],
cache_examples=True,
)
demo = gr.TabbedInterface([using_img_classifier_as_function, img_classifier])
if __name__ == "__main__":
demo.launch()

View File

@ -291,8 +291,8 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
}
kwargs = dict(interface_info, **kwargs)
kwargs["_api_mode"] = True # So interface doesn't run pre/postprocess.
interface = gradio.Interface(**kwargs)
interface.api_mode = True # So interface doesn't run pre/postprocess.
return interface
@ -414,8 +414,8 @@ def get_spaces_interface(model_name, config, alias, **kwargs):
config["fn"] = fn
kwargs = dict(config, **kwargs)
kwargs["_api_mode"] = True
interface = gradio.Interface(**kwargs)
interface.api_mode = True # So interface doesn't run pre/postprocess.
return interface

View File

@ -149,6 +149,7 @@ class Interface(Blocks):
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: Optional[bool] = None,
_api_mode: bool = False,
**kwargs,
):
"""
@ -274,7 +275,7 @@ class Interface(Blocks):
else:
raise ValueError("Invalid value for parameter: interpretation")
self.api_mode = False
self.api_mode = _api_mode
self.fn = fn
self.fn_durations = [0, 0]
self.__name__ = fn.__name__
@ -600,7 +601,7 @@ class Interface(Blocks):
examples=examples,
inputs=non_state_inputs,
outputs=non_state_outputs,
fn=self.fn,
fn=submit_fn,
cache_examples=self.cache_examples,
examples_per_page=examples_per_page,
)
@ -668,7 +669,7 @@ class Interface(Blocks):
if prediction is None or len(self.output_components) == 1:
prediction = [prediction]
if self.api_mode: # Deerialize the input
if self.api_mode: # Deserialize the input
prediction = [
output_component.deserialize(prediction[i])
for i, output_component in enumerate(self.output_components)

View File

@ -1,5 +1,7 @@
import os
import pathlib
import unittest
from unittest.mock import patch
import pytest
import transformers
@ -229,5 +231,15 @@ class TestLoadFromPipeline(unittest.TestCase):
self.assertIsNotNone(output)
def test_interface_load_cache_examples(tmp_path):
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
with patch("gradio.examples.CACHED_FOLDER", tmp_path):
gr.Interface.load(
name="models/google/vit-base-patch16-224",
examples=[pathlib.Path(test_file_dir, "cheetah1.jpg")],
cache_examples=True,
)
if __name__ == "__main__":
unittest.main()