Exposing examples as a component for Blocks (#1639)

* examples as component

* renamed examples

* simplify internal logic

* fix tests

* cleanup

* fixed parallel and series

* cleaning up examples

* examples

* formatting

* fixes

* added unique ids

* added demo

* formatting

* fixed test_examples

* fixed test_interfaces

* fixed tests

* removed test from now

* raise ValueError for bad parameter values

* fixing series

* fixed series

* formatting

* speed up by preprocessing examples

* fixed parameter validation logic
This commit is contained in:
Abubakar Abid 2022-07-06 11:23:35 -07:00 committed by GitHub
parent 44e9a4d054
commit a1c391668a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 327 additions and 278 deletions

BIN
demo/blocks_inputs/lion.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -1,19 +1,36 @@
import gradio as gr
import os
str = """Hello friends
hello friends
Hello friends
"""
def combine(a, b):
return a + " " + b
def mirror(x):
return x
with gr.Blocks() as demo:
txt = gr.Textbox(label="Input", lines=5)
txt_2 = gr.Textbox(label="Output-Interactive")
txt_3 = gr.Textbox(str, label="Output", interactive=False)
txt = gr.Textbox(label="Input", lines=2)
txt_2 = gr.Textbox(label="Input 2")
txt_3 = gr.Textbox("", label="Output")
btn = gr.Button("Submit")
btn.click(lambda a: a, inputs=[txt], outputs=[txt_2])
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
with gr.Row():
im = gr.Image()
im_2 = gr.Image()
btn = gr.Button("Mirror Image")
btn.click(mirror, inputs=[im], outputs=[im_2])
gr.Markdown("## Text Examples")
gr.Examples([["hi", "Adam"], ["hello", "Eve"]], [txt, txt_2], txt_3, combine, cache_examples=True)
gr.Markdown("## Image Examples")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "lion.jpg")],
inputs=im,
outputs=im_2,
fn=mirror,
cache_examples=True)
if __name__ == "__main__":
demo.launch()

View File

@ -42,6 +42,7 @@ from gradio.components import (
Video,
component,
)
from gradio.examples import Examples
from gradio.flagging import (
CSVLogger,
FlaggingCallback,

200
gradio/examples.py Normal file
View File

@ -0,0 +1,200 @@
"""
Defines helper methods useful for loading and caching Interface examples.
"""
from __future__ import annotations
import csv
import os
import shutil
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from gradio.components import Dataset
from gradio.flagging import CSVLogger
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
from gradio.components import Component
CACHED_FOLDER = "gradio_cached_examples"
class Examples:
def __init__(
self,
examples: List[Any] | List[List[Any]] | str,
inputs: Component | List[Component],
outputs: Optional[Component | List[Component]] = None,
fn: Optional[Callable] = None,
cache_examples: bool = False,
examples_per_page: int = 10,
):
"""
This class is a wrapper over the Dataset component can be used to create Examples
for Blocks / Interfaces. Populates the Dataset component with examples and
assigns event listener so that clicking on an example populates the input/output
components. Optionally handles example caching for fast inference.
Parameters:
examples (List[Any] | List[List[Any]] | str): example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided.
inputs: (Component | List[Component]): the component or list of components corresponding to the examples
outputs: (Component | List[Component] | None): optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
fn: (Callable | None): optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
cache_examples (bool): if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
examples_per_page (int): how many examples to show per page (this parameter currently has no effect)
"""
if cache_examples and (fn is None or outputs is None):
raise ValueError("If caching examples, `fn` and `outputs` must be provided")
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
outputs = [outputs]
if examples is None:
raise ValueError("The parameter `examples` cannot be None")
elif isinstance(examples, list) and (
len(examples) == 0 or isinstance(examples[0], list)
):
pass
elif (
isinstance(examples, list) and len(inputs) == 1
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
examples = [[e] for e in examples]
elif isinstance(examples, str):
if not os.path.exists(examples):
raise FileNotFoundError(
"Could not find examples directory: " + examples
)
log_file = os.path.join(examples, "log.csv")
if not os.path.exists(log_file):
if len(inputs) == 1:
exampleset = [
[os.path.join(examples, item)] for item in os.listdir(examples)
]
else:
raise FileNotFoundError(
"Could not find log file (required for multiple inputs): "
+ log_file
)
else:
with open(log_file) as logs:
exampleset = list(csv.reader(logs))
exampleset = exampleset[1:] # remove header
for i, example in enumerate(exampleset):
for j, (component, cell) in enumerate(
zip(
inputs + outputs,
example,
)
):
exampleset[i][j] = component.restore_flagged(
examples,
cell,
None,
)
examples = exampleset
else:
raise ValueError(
"The parameter `examples` must either be a directory or a nested "
"list, where each sublist represents a set of inputs."
)
dataset = Dataset(
components=inputs,
samples=examples,
type="index",
)
self.examples = examples
self.inputs = inputs
self.outputs = outputs
self.fn = fn
self.cache_examples = cache_examples
self.examples_per_page = examples_per_page
self.processed_examples = [
[
component.preprocess_example(sample)
for component, sample in zip(inputs, example)
]
for example in examples
]
self.cached_folder = os.path.join(CACHED_FOLDER, str(dataset._id))
self.cached_file = os.path.join(self.cached_folder, "log.csv")
if cache_examples:
self.cache_interface_examples()
def load_example(example_id):
processed_example = self.processed_examples[example_id]
if cache_examples:
processed_example += self.load_from_cache(example_id)
if len(processed_example) == 1:
return processed_example[0]
else:
return processed_example
dataset.click(
load_example,
inputs=[dataset],
outputs=inputs + (outputs if cache_examples else []),
_postprocess=False,
queue=False,
)
def cache_interface_examples(self) -> None:
"""Caches all of the examples from an interface."""
if os.path.exists(self.cached_file):
print(
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
)
else:
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
cache_logger = CSVLogger()
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
try:
prediction = self.process_example(example_id)
cache_logger.flag(prediction)
except Exception as e:
shutil.rmtree(self.cached_folder)
raise e
def process_example(self, example_id: int) -> Tuple[List[Any], List[float]]:
"""Loads an example from the interface and returns its prediction."""
example_set = self.examples[example_id]
raw_input = [
self.inputs[i].preprocess_example(example)
for i, example in enumerate(example_set)
]
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.inputs)
]
predictions = self.fn(*processed_input)
if len(self.outputs) == 1:
predictions = [predictions]
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.outputs)
]
return processed_output
def load_from_cache(self, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(self.cached_file) as cache:
examples = list(csv.reader(cache, quotechar="'"))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(self.outputs, example):
output.append(
component.restore_flagged(
self.cached_folder,
cell,
None,
)
)
return output

View File

@ -125,11 +125,11 @@ class CSVLogger(FlaggingCallback):
if flag_index is None:
csv_data = []
for component, sample in zip(self.components, flag_data):
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
csv_data.append(
component.save_flagged(
flagging_dir,
component.label,
component.label or f"component {idx}",
sample,
self.encryption_key,
)
@ -140,7 +140,10 @@ class CSVLogger(FlaggingCallback):
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
if is_new:
headers = [component.label for component in self.components] + [
headers = [
component.label or f"component {idx}"
for idx, component in enumerate(self.components)
] + [
"flag",
"username",
"timestamp",

View File

@ -25,7 +25,6 @@ from gradio.blocks import Blocks
from gradio.components import (
Button,
Component,
Dataset,
Interpretation,
IOComponent,
Markdown,
@ -34,10 +33,10 @@ from gradio.components import (
get_component_instance,
)
from gradio.events import Changeable, Streamable
from gradio.examples import Examples
from gradio.external import load_from_pipeline # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.layouts import Column, Row, TabItem, Tabs
from gradio.process_examples import cache_interface_examples, load_from_cache
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
@ -126,7 +125,6 @@ class Interface(Blocks):
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: Optional[bool] = None,
_repeat_outputs_per_model: bool = True,
**kwargs,
):
"""
@ -174,14 +172,13 @@ class Interface(Blocks):
inputs = []
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
if not isinstance(fn, list):
fn = [fn]
else:
if isinstance(fn, list):
raise DeprecationWarning(
"The `fn` parameter only accepts a single function, support for a list "
"of functions has been deprecated. Please use gradio.mix.Parallel "
"instead."
)
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
@ -199,7 +196,7 @@ class Interface(Blocks):
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
default = utils.get_default_args(fn[0])[inputs.index("state")]
default = utils.get_default_args(fn)[inputs.index("state")]
state_variable = Variable(value=default)
inputs[inputs.index("state")] = state_variable
outputs[outputs.index("state")] = state_variable
@ -240,9 +237,6 @@ class Interface(Blocks):
for o in self.output_components:
o.interactive = False # Force output components to be non-interactive
if _repeat_outputs_per_model:
self.output_components *= len(fn)
if (
interpretation is None
or isinstance(interpretation, list)
@ -257,10 +251,9 @@ class Interface(Blocks):
raise ValueError("Invalid value for parameter: interpretation")
self.api_mode = False
self.predict = fn
self.predict_durations = [[0, 0]] * len(fn)
self.function_names = [func.__name__ for func in fn]
self.__name__ = ", ".join(self.function_names)
self.fn = fn
self.fn_durations = [0, 0]
self.__name__ = fn.__name__
self.live = live
self.title = title
@ -295,53 +288,7 @@ class Interface(Blocks):
if not (self.theme == "default"):
warnings.warn("Currently, only the 'default' theme is supported.")
if examples is None or (
isinstance(examples, list)
and (len(examples) == 0 or isinstance(examples[0], list))
):
self.examples = examples
elif (
isinstance(examples, list) and len(self.input_components) == 1
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
self.examples = [[e] for e in examples]
elif isinstance(examples, str):
if not os.path.exists(examples):
raise FileNotFoundError(
"Could not find examples directory: " + examples
)
log_file = os.path.join(examples, "log.csv")
if not os.path.exists(log_file):
if len(self.input_components) == 1:
exampleset = [
[os.path.join(examples, item)] for item in os.listdir(examples)
]
else:
raise FileNotFoundError(
"Could not find log file (required for multiple inputs): "
+ log_file
)
else:
with open(log_file) as logs:
exampleset = list(csv.reader(logs))
exampleset = exampleset[1:] # remove header
for i, example in enumerate(exampleset):
for j, (component, cell) in enumerate(
zip(
self.input_components + self.output_components,
example,
)
):
exampleset[i][j] = component.restore_flagged(
examples,
cell,
None,
)
self.examples = exampleset
else:
raise ValueError(
"Examples argument must either be a directory or a nested "
"list, where each sublist represents a set of inputs."
)
self.examples = examples
self.num_shap = num_shap
self.examples_per_page = examples_per_page
@ -415,7 +362,7 @@ class Interface(Blocks):
utils.version_check()
Interface.instances.add(self)
param_names = inspect.getfullargspec(self.predict[0])[0]
param_names = inspect.getfullargspec(self.fn)[0]
for component, param_name in zip(self.input_components, param_names):
if component.label is None:
component.label = param_name
@ -426,9 +373,6 @@ class Interface(Blocks):
else:
component.label = "output " + str(i)
if self.cache_examples and examples:
cache_interface_examples(self)
if self.allow_flagging != "never":
if self.interface_type == self.InterfaceTypes.UNIFIED:
self.flagging_callback.setup(self.input_components, self.flagging_dir)
@ -625,34 +569,16 @@ class Interface(Blocks):
non_state_inputs = [
c for c in self.input_components if not isinstance(c, Variable)
]
examples = Dataset(
components=non_state_inputs,
samples=self.examples,
type="index",
)
def load_example(example_id):
processed_examples = [
component.preprocess_example(sample)
for component, sample in zip(
self.input_components, self.examples[example_id]
)
]
if self.cache_examples:
processed_examples += load_from_cache(self, example_id)
if len(processed_examples) == 1:
return processed_examples[0]
else:
return processed_examples
examples.click(
load_example,
inputs=[examples],
outputs=non_state_inputs
+ (self.output_components if self.cache_examples else []),
_postprocess=False,
queue=False,
non_state_outputs = [
c for c in self.output_components if not isinstance(c, Variable)
]
self.examples_handler = Examples(
examples=examples,
inputs=non_state_inputs,
outputs=non_state_outputs,
fn=self.fn,
cache_examples=self.cache_examples,
examples_per_page=examples_per_page,
)
if self.interpretation:
@ -684,9 +610,7 @@ class Interface(Blocks):
return self.__repr__()
def __repr__(self):
repr = "Gradio Interface for: {}".format(
", ".join(fn.__name__ for fn in self.predict)
)
repr = f"Gradio Interface for: {self.__name__}"
repr += "\n" + "-" * len(repr)
repr += "\ninputs:"
for component in self.input_components:
@ -715,31 +639,19 @@ class Interface(Blocks):
input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)
]
predictions = []
output_component_counter = 0
for predict_fn in self.predict:
prediction = predict_fn(*processed_input)
prediction = self.fn(*processed_input)
if len(self.output_components) == len(self.predict) or prediction is None:
prediction = [prediction]
if prediction is None or len(self.output_components) == 1:
prediction = [prediction]
if self.api_mode: # Serialize the input
prediction_ = copy.deepcopy(prediction)
prediction = []
if self.api_mode: # Deerialize the input
prediction = [
output_component.deserialize(prediction[i])
for i, output_component in enumerate(self.output_components)
]
# Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
for pred in prediction_:
prediction.append(
self.output_components[output_component_counter].deserialize(
pred
)
)
output_component_counter += 1
predictions.extend(prediction)
return predictions
return prediction
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""
@ -777,19 +689,17 @@ class Interface(Blocks):
Passes a few samples through the function to test if the inputs/outputs
components are consistent with the function parameter and return values.
"""
for predict_fn in self.predict:
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
raw_input = []
for input_component in self.input_components:
if input_component.test_input is None:
print("SKIPPED")
break
else:
raw_input.append(input_component.test_input)
print("Test launch: {}()...".format(self.__name__), end=" ")
raw_input = []
for input_component in self.input_components:
if input_component.test_input is None:
print("SKIPPED")
break
else:
self.process(raw_input)
print("PASSED")
continue
raw_input.append(input_component.test_input)
else:
self.process(raw_input)
print("PASSED")
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
"""

View File

@ -1,10 +1,13 @@
"""
Ways to transform interfaces to produce new interfaces
"""
import warnings
from typing import TYPE_CHECKING, List
import gradio
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import IOComponent
class Parallel(gradio.Interface):
"""
@ -12,7 +15,7 @@ class Parallel(gradio.Interface):
The Interfaces to put in Parallel must share the same input components (but can have different output components).
"""
def __init__(self, *interfaces, **options):
def __init__(self, *interfaces: gradio.Interface, **options):
"""
Parameters:
*interfaces (Interface): any number of Interface objects that are to be compared in parallel
@ -20,38 +23,29 @@ class Parallel(gradio.Interface):
Returns:
(Interface): an Interface object comparing the given models
"""
fns = []
outputs = []
outputs: List[IOComponent] = []
for io in interfaces:
if not (isinstance(io, gradio.Interface)):
warnings.warn(
"Parallel may not work properly with non-Interface objects."
)
fns.extend(io.predict)
outputs.extend(io.output_components)
for interface in interfaces:
outputs.extend(interface.output_components)
def parallel_fn(*args):
return_values = []
for fn in fns:
value = fn(*args)
if isinstance(value, tuple):
return_values.extend(value)
else:
return_values.append(value)
for interface in interfaces:
value = interface.run_prediction(args)
return_values.extend(value)
if len(outputs) == 1:
return return_values[0]
return return_values
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])
kwargs = {
"fn": parallel_fn,
"inputs": interfaces[0].input_components,
"outputs": outputs,
"_repeat_outputs_per_model": False,
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[
0
].api_mode # TODO(abidlabs): make api_mode a per-function attribute
class Series(gradio.Interface):
@ -60,7 +54,7 @@ class Series(gradio.Interface):
and so the input and output components must agree between the interfaces).
"""
def __init__(self, *interfaces, **options):
def __init__(self, *interfaces: gradio.Interface, **options):
"""
Parameters:
*interfaces (Interface): any number of Interface objects that are to be connected in series
@ -68,41 +62,35 @@ class Series(gradio.Interface):
Returns:
(Interface): an Interface object connecting the given models
"""
fns = []
for io in interfaces:
if not (isinstance(io, gradio.Interface)):
warnings.warn(
"Series may not work properly with non-Interface objects."
)
fns.append(io.predict)
def connected_fn(
*data,
): # Run each function with the appropriate preprocessing and postprocessing
for idx, io in enumerate(interfaces):
def connected_fn(*data):
for idx, interface in enumerate(interfaces):
# skip preprocessing for first interface since the Series interface will include it
if idx > 0 and not (io.api_mode):
if idx > 0 and not (interface.api_mode):
data = [
input_component.preprocess(data[i])
for i, input_component in enumerate(io.input_components)
for i, input_component in enumerate(interface.input_components)
]
# run all of predictions sequentially
predictions = []
for predict_fn in io.predict:
prediction = predict_fn(*data)
predictions.append(prediction)
data = predictions
data = interface.fn(*data)
if len(interface.output_components) == 1:
data = [data]
# skip postprocessing for final interface since the Series interface will include it
if idx < len(interfaces) - 1 and not (io.api_mode):
if idx < len(interfaces) - 1 and not (interface.api_mode):
data = [
output_component.postprocess(data[i])
for i, output_component in enumerate(io.output_components)
for i, output_component in enumerate(
interface.output_components
)
]
return data[0]
if len(interface.output_components) == 1:
return data[0]
return data
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])
kwargs = {
"fn": connected_fn,
@ -111,6 +99,4 @@ class Series(gradio.Interface):
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[
0
].api_mode # TODO(abidlabs): make api_mode a per-function attribute
self.api_mode = interfaces[0].api_mode # TODO: set api_mode per-function

View File

@ -1,68 +0,0 @@
"""
Defines helper methods useful for loading and caching Interface examples.
"""
from __future__ import annotations
import csv
import os
import shutil
from typing import TYPE_CHECKING, Any, List, Tuple
from gradio.flagging import CSVLogger
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
CACHED_FOLDER = "gradio_cached_examples"
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
def process_example(
interface: Interface, example_id: int
) -> Tuple[List[Any], List[float]]:
"""Loads an example from the interface and returns its prediction."""
example_set = interface.examples[example_id]
raw_input = [
interface.input_components[i].preprocess_example(example)
for i, example in enumerate(example_set)
]
prediction = interface.process(raw_input)
return prediction
def cache_interface_examples(interface: Interface) -> None:
"""Caches all of the examples from an interface."""
if os.path.exists(CACHE_FILE):
print(
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
)
else:
print(
f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory."
)
cache_logger = CSVLogger()
cache_logger.setup(interface.output_components, CACHED_FOLDER)
for example_id, _ in enumerate(interface.examples):
try:
prediction = process_example(interface, example_id)
cache_logger.flag(prediction)
except Exception as e:
shutil.rmtree(CACHED_FOLDER)
raise e
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache, quotechar="'"))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(interface.output_components, example):
output.append(
component.restore_flagged(
CACHED_FOLDER,
cell,
interface.encryption_key if interface.encrypt else None,
)
)
return output

View File

@ -1,7 +1,7 @@
import os
import unittest
from gradio import Interface, process_examples
from gradio import Interface, examples
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -9,7 +9,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestProcessExamples(unittest.TestCase):
def test_process_example(self):
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction = process_examples.process_example(io, 0)
prediction = io.examples_handler.process_example(0)
self.assertEquals(prediction[0], "Hello World")
def test_caching(self):
@ -20,8 +20,8 @@ class TestProcessExamples(unittest.TestCase):
examples=[["World"], ["Dunya"], ["Monde"]],
)
io.launch(prevent_thread_lock=True)
process_examples.cache_interface_examples(io)
prediction = process_examples.load_from_cache(io, 1)
io.examples_handler.cache_interface_examples()
prediction = io.examples_handler.load_from_cache(1)
io.close()
self.assertEquals(prediction[0], "Hello Dunya")

View File

@ -27,7 +27,7 @@ class TestLoadInterface(unittest.TestCase):
src="models",
alias=model_type,
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
@ -36,14 +36,14 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Blocks.load(
name="lysandre/tiny-vit-random", src="models", alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Image)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
def test_text_generation(self):
model_type = "text_generation"
interface = gr.Interface.load("models/gpt2", alias=model_type)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
@ -52,7 +52,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
@ -61,7 +61,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
@ -70,7 +70,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/sshleifer/tiny-mbart", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
@ -81,7 +81,7 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
@ -90,7 +90,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/bert-base-uncased", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
@ -99,7 +99,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/facebook/bart-large-mnli", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.input_components[1], gr.components.Textbox)
self.assertIsInstance(interface.input_components[2], gr.components.Checkbox)
@ -110,7 +110,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
@ -119,7 +119,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/google/vit-base-patch16-224", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Image)
self.assertIsInstance(interface.output_components[0], gr.components.Label)
@ -130,7 +130,7 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Dataframe)
@ -141,7 +141,7 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
@ -152,7 +152,7 @@ class TestLoadInterface(unittest.TestCase):
api_key=None,
alias=model_type,
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
@ -161,7 +161,7 @@ class TestLoadInterface(unittest.TestCase):
interface = gr.Interface.load(
"models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
)
self.assertEqual(interface.predict[0].__name__, model_type)
self.assertEqual(interface.__name__, model_type)
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
self.assertIsInstance(interface.output_components[0], gr.components.Image)

View File

@ -120,8 +120,8 @@ class TestInterface(unittest.TestCase):
examples = ["test1", "test2"]
interface = Interface(lambda x: x, "textbox", "label", examples=examples)
interface.launch(prevent_thread_lock=True)
self.assertEqual(len(interface.examples), 2)
self.assertEqual(len(interface.examples[0]), 1)
self.assertEqual(len(interface.examples_handler.examples), 2)
self.assertEqual(len(interface.examples_handler.examples[0]), 1)
interface.close()
@mock.patch("IPython.display.display")