mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Fix bug where you cannot cache examples with Interface.load (#1949)
* Possible fix * Remove breakpoint * Implementation * Add unit test * Fix test * Lint * Add _api_mode to interface signature
This commit is contained in:
parent
2903b74160
commit
299ba1bd1a
BIN
demo/image_classifier_interface_load/cheetah1.jpeg
Normal file
BIN
demo/image_classifier_interface_load/cheetah1.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
BIN
demo/image_classifier_interface_load/cheetah1.jpg
Normal file
BIN
demo/image_classifier_interface_load/cheetah1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
BIN
demo/image_classifier_interface_load/lion.jpg
Normal file
BIN
demo/image_classifier_interface_load/lion.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
29
demo/image_classifier_interface_load/run.py
Normal file
29
demo/image_classifier_interface_load/run.py
Normal file
@ -0,0 +1,29 @@
|
||||
import gradio as gr
|
||||
|
||||
images = ["cheetah1.jpeg", "cheetah1.jpg", "lion.jpg"]
|
||||
|
||||
|
||||
img_classifier = gr.Interface.load(
|
||||
"models/google/vit-base-patch16-224", examples=images, cache_examples=True
|
||||
)
|
||||
|
||||
|
||||
def func(img, text):
|
||||
return img_classifier(img), text
|
||||
|
||||
|
||||
using_img_classifier_as_function = gr.Interface(
|
||||
func,
|
||||
[gr.Image(type="filepath"), "text"],
|
||||
["label", "text"],
|
||||
examples=[
|
||||
["cheetah1.jpeg", None],
|
||||
["cheetah1.jpg", "cheetah"],
|
||||
["lion.jpg", "lion"],
|
||||
],
|
||||
cache_examples=True,
|
||||
)
|
||||
demo = gr.TabbedInterface([using_img_classifier_as_function, img_classifier])
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -291,8 +291,8 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
}
|
||||
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
kwargs["_api_mode"] = True # So interface doesn't run pre/postprocess.
|
||||
interface = gradio.Interface(**kwargs)
|
||||
interface.api_mode = True # So interface doesn't run pre/postprocess.
|
||||
return interface
|
||||
|
||||
|
||||
@ -414,8 +414,8 @@ def get_spaces_interface(model_name, config, alias, **kwargs):
|
||||
config["fn"] = fn
|
||||
|
||||
kwargs = dict(config, **kwargs)
|
||||
kwargs["_api_mode"] = True
|
||||
interface = gradio.Interface(**kwargs)
|
||||
interface.api_mode = True # So interface doesn't run pre/postprocess.
|
||||
return interface
|
||||
|
||||
|
||||
|
@ -149,6 +149,7 @@ class Interface(Blocks):
|
||||
flagging_dir: str = "flagged",
|
||||
flagging_callback: FlaggingCallback = CSVLogger(),
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
_api_mode: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -274,7 +275,7 @@ class Interface(Blocks):
|
||||
else:
|
||||
raise ValueError("Invalid value for parameter: interpretation")
|
||||
|
||||
self.api_mode = False
|
||||
self.api_mode = _api_mode
|
||||
self.fn = fn
|
||||
self.fn_durations = [0, 0]
|
||||
self.__name__ = fn.__name__
|
||||
@ -600,7 +601,7 @@ class Interface(Blocks):
|
||||
examples=examples,
|
||||
inputs=non_state_inputs,
|
||||
outputs=non_state_outputs,
|
||||
fn=self.fn,
|
||||
fn=submit_fn,
|
||||
cache_examples=self.cache_examples,
|
||||
examples_per_page=examples_per_page,
|
||||
)
|
||||
@ -668,7 +669,7 @@ class Interface(Blocks):
|
||||
if prediction is None or len(self.output_components) == 1:
|
||||
prediction = [prediction]
|
||||
|
||||
if self.api_mode: # Deerialize the input
|
||||
if self.api_mode: # Deserialize the input
|
||||
prediction = [
|
||||
output_component.deserialize(prediction[i])
|
||||
for i, output_component in enumerate(self.output_components)
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import pathlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
@ -229,5 +231,15 @@ class TestLoadFromPipeline(unittest.TestCase):
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
|
||||
def test_interface_load_cache_examples(tmp_path):
|
||||
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
with patch("gradio.examples.CACHED_FOLDER", tmp_path):
|
||||
gr.Interface.load(
|
||||
name="models/google/vit-base-patch16-224",
|
||||
examples=[pathlib.Path(test_file_dir, "cheetah1.jpg")],
|
||||
cache_examples=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user