mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Exposing examples as a component for Blocks (#1639)
* examples as component * renamed examples * simplify internal logic * fix tests * cleanup * fixed parallel and series * cleaning up examples * examples * formatting * fixes * added unique ids * added demo * formatting * fixed test_examples * fixed test_interfaces * fixed tests * removed test from now * raise ValueError for bad parameter values * fixing series * fixed series * formatting * speed up by preprocessing examples * fixed parameter validation logic
This commit is contained in:
parent
44e9a4d054
commit
a1c391668a
BIN
demo/blocks_inputs/lion.jpg
Normal file
BIN
demo/blocks_inputs/lion.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
@ -1,19 +1,36 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
str = """Hello friends
|
||||
hello friends
|
||||
|
||||
Hello friends
|
||||
|
||||
"""
|
||||
def combine(a, b):
|
||||
return a + " " + b
|
||||
|
||||
def mirror(x):
|
||||
return x
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
txt = gr.Textbox(label="Input", lines=5)
|
||||
txt_2 = gr.Textbox(label="Output-Interactive")
|
||||
txt_3 = gr.Textbox(str, label="Output", interactive=False)
|
||||
|
||||
txt = gr.Textbox(label="Input", lines=2)
|
||||
txt_2 = gr.Textbox(label="Input 2")
|
||||
txt_3 = gr.Textbox("", label="Output")
|
||||
btn = gr.Button("Submit")
|
||||
btn.click(lambda a: a, inputs=[txt], outputs=[txt_2])
|
||||
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
|
||||
|
||||
with gr.Row():
|
||||
im = gr.Image()
|
||||
im_2 = gr.Image()
|
||||
|
||||
btn = gr.Button("Mirror Image")
|
||||
btn.click(mirror, inputs=[im], outputs=[im_2])
|
||||
|
||||
gr.Markdown("## Text Examples")
|
||||
gr.Examples([["hi", "Adam"], ["hello", "Eve"]], [txt, txt_2], txt_3, combine, cache_examples=True)
|
||||
gr.Markdown("## Image Examples")
|
||||
gr.Examples(
|
||||
examples=[os.path.join(os.path.dirname(__file__), "lion.jpg")],
|
||||
inputs=im,
|
||||
outputs=im_2,
|
||||
fn=mirror,
|
||||
cache_examples=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -42,6 +42,7 @@ from gradio.components import (
|
||||
Video,
|
||||
component,
|
||||
)
|
||||
from gradio.examples import Examples
|
||||
from gradio.flagging import (
|
||||
CSVLogger,
|
||||
FlaggingCallback,
|
||||
|
200
gradio/examples.py
Normal file
200
gradio/examples.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
Defines helper methods useful for loading and caching Interface examples.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
from gradio.components import Dataset
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio import Interface
|
||||
from gradio.components import Component
|
||||
|
||||
CACHED_FOLDER = "gradio_cached_examples"
|
||||
|
||||
|
||||
class Examples:
|
||||
def __init__(
|
||||
self,
|
||||
examples: List[Any] | List[List[Any]] | str,
|
||||
inputs: Component | List[Component],
|
||||
outputs: Optional[Component | List[Component]] = None,
|
||||
fn: Optional[Callable] = None,
|
||||
cache_examples: bool = False,
|
||||
examples_per_page: int = 10,
|
||||
):
|
||||
"""
|
||||
This class is a wrapper over the Dataset component can be used to create Examples
|
||||
for Blocks / Interfaces. Populates the Dataset component with examples and
|
||||
assigns event listener so that clicking on an example populates the input/output
|
||||
components. Optionally handles example caching for fast inference.
|
||||
|
||||
Parameters:
|
||||
examples (List[Any] | List[List[Any]] | str): example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided.
|
||||
inputs: (Component | List[Component]): the component or list of components corresponding to the examples
|
||||
outputs: (Component | List[Component] | None): optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
||||
fn: (Callable | None): optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
||||
cache_examples (bool): if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
|
||||
examples_per_page (int): how many examples to show per page (this parameter currently has no effect)
|
||||
"""
|
||||
if cache_examples and (fn is None or outputs is None):
|
||||
raise ValueError("If caching examples, `fn` and `outputs` must be provided")
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if examples is None:
|
||||
raise ValueError("The parameter `examples` cannot be None")
|
||||
elif isinstance(examples, list) and (
|
||||
len(examples) == 0 or isinstance(examples[0], list)
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
isinstance(examples, list) and len(inputs) == 1
|
||||
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
||||
examples = [[e] for e in examples]
|
||||
elif isinstance(examples, str):
|
||||
if not os.path.exists(examples):
|
||||
raise FileNotFoundError(
|
||||
"Could not find examples directory: " + examples
|
||||
)
|
||||
log_file = os.path.join(examples, "log.csv")
|
||||
if not os.path.exists(log_file):
|
||||
if len(inputs) == 1:
|
||||
exampleset = [
|
||||
[os.path.join(examples, item)] for item in os.listdir(examples)
|
||||
]
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Could not find log file (required for multiple inputs): "
|
||||
+ log_file
|
||||
)
|
||||
else:
|
||||
with open(log_file) as logs:
|
||||
exampleset = list(csv.reader(logs))
|
||||
exampleset = exampleset[1:] # remove header
|
||||
for i, example in enumerate(exampleset):
|
||||
for j, (component, cell) in enumerate(
|
||||
zip(
|
||||
inputs + outputs,
|
||||
example,
|
||||
)
|
||||
):
|
||||
exampleset[i][j] = component.restore_flagged(
|
||||
examples,
|
||||
cell,
|
||||
None,
|
||||
)
|
||||
examples = exampleset
|
||||
else:
|
||||
raise ValueError(
|
||||
"The parameter `examples` must either be a directory or a nested "
|
||||
"list, where each sublist represents a set of inputs."
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
components=inputs,
|
||||
samples=examples,
|
||||
type="index",
|
||||
)
|
||||
|
||||
self.examples = examples
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
self.fn = fn
|
||||
self.cache_examples = cache_examples
|
||||
self.examples_per_page = examples_per_page
|
||||
|
||||
self.processed_examples = [
|
||||
[
|
||||
component.preprocess_example(sample)
|
||||
for component, sample in zip(inputs, example)
|
||||
]
|
||||
for example in examples
|
||||
]
|
||||
|
||||
self.cached_folder = os.path.join(CACHED_FOLDER, str(dataset._id))
|
||||
self.cached_file = os.path.join(self.cached_folder, "log.csv")
|
||||
if cache_examples:
|
||||
self.cache_interface_examples()
|
||||
|
||||
def load_example(example_id):
|
||||
processed_example = self.processed_examples[example_id]
|
||||
if cache_examples:
|
||||
processed_example += self.load_from_cache(example_id)
|
||||
if len(processed_example) == 1:
|
||||
return processed_example[0]
|
||||
else:
|
||||
return processed_example
|
||||
|
||||
dataset.click(
|
||||
load_example,
|
||||
inputs=[dataset],
|
||||
outputs=inputs + (outputs if cache_examples else []),
|
||||
_postprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def cache_interface_examples(self) -> None:
|
||||
"""Caches all of the examples from an interface."""
|
||||
if os.path.exists(self.cached_file):
|
||||
print(
|
||||
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
)
|
||||
else:
|
||||
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
|
||||
cache_logger = CSVLogger()
|
||||
cache_logger.setup(self.outputs, self.cached_folder)
|
||||
for example_id, _ in enumerate(self.examples):
|
||||
try:
|
||||
prediction = self.process_example(example_id)
|
||||
cache_logger.flag(prediction)
|
||||
except Exception as e:
|
||||
shutil.rmtree(self.cached_folder)
|
||||
raise e
|
||||
|
||||
def process_example(self, example_id: int) -> Tuple[List[Any], List[float]]:
|
||||
"""Loads an example from the interface and returns its prediction."""
|
||||
example_set = self.examples[example_id]
|
||||
raw_input = [
|
||||
self.inputs[i].preprocess_example(example)
|
||||
for i, example in enumerate(example_set)
|
||||
]
|
||||
processed_input = [
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.inputs)
|
||||
]
|
||||
predictions = self.fn(*processed_input)
|
||||
if len(self.outputs) == 1:
|
||||
predictions = [predictions]
|
||||
processed_output = [
|
||||
output_component.postprocess(predictions[i])
|
||||
if predictions[i] is not None
|
||||
else None
|
||||
for i, output_component in enumerate(self.outputs)
|
||||
]
|
||||
|
||||
return processed_output
|
||||
|
||||
def load_from_cache(self, example_id: int) -> List[Any]:
|
||||
"""Loads a particular cached example for the interface."""
|
||||
with open(self.cached_file) as cache:
|
||||
examples = list(csv.reader(cache, quotechar="'"))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
for component, cell in zip(self.outputs, example):
|
||||
output.append(
|
||||
component.restore_flagged(
|
||||
self.cached_folder,
|
||||
cell,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return output
|
@ -125,11 +125,11 @@ class CSVLogger(FlaggingCallback):
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
||||
csv_data.append(
|
||||
component.save_flagged(
|
||||
flagging_dir,
|
||||
component.label,
|
||||
component.label or f"component {idx}",
|
||||
sample,
|
||||
self.encryption_key,
|
||||
)
|
||||
@ -140,7 +140,10 @@ class CSVLogger(FlaggingCallback):
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
if is_new:
|
||||
headers = [component.label for component in self.components] + [
|
||||
headers = [
|
||||
component.label or f"component {idx}"
|
||||
for idx, component in enumerate(self.components)
|
||||
] + [
|
||||
"flag",
|
||||
"username",
|
||||
"timestamp",
|
||||
|
@ -25,7 +25,6 @@ from gradio.blocks import Blocks
|
||||
from gradio.components import (
|
||||
Button,
|
||||
Component,
|
||||
Dataset,
|
||||
Interpretation,
|
||||
IOComponent,
|
||||
Markdown,
|
||||
@ -34,10 +33,10 @@ from gradio.components import (
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.events import Changeable, Streamable
|
||||
from gradio.examples import Examples
|
||||
from gradio.external import load_from_pipeline # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.layouts import Column, Row, TabItem, Tabs
|
||||
from gradio.process_examples import cache_interface_examples, load_from_cache
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
@ -126,7 +125,6 @@ class Interface(Blocks):
|
||||
flagging_dir: str = "flagged",
|
||||
flagging_callback: FlaggingCallback = CSVLogger(),
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
_repeat_outputs_per_model: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -174,14 +172,13 @@ class Interface(Blocks):
|
||||
inputs = []
|
||||
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
|
||||
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
else:
|
||||
if isinstance(fn, list):
|
||||
raise DeprecationWarning(
|
||||
"The `fn` parameter only accepts a single function, support for a list "
|
||||
"of functions has been deprecated. Please use gradio.mix.Parallel "
|
||||
"instead."
|
||||
)
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
@ -199,7 +196,7 @@ class Interface(Blocks):
|
||||
raise ValueError(
|
||||
"If using 'state', there must be exactly one state input and one state output."
|
||||
)
|
||||
default = utils.get_default_args(fn[0])[inputs.index("state")]
|
||||
default = utils.get_default_args(fn)[inputs.index("state")]
|
||||
state_variable = Variable(value=default)
|
||||
inputs[inputs.index("state")] = state_variable
|
||||
outputs[outputs.index("state")] = state_variable
|
||||
@ -240,9 +237,6 @@ class Interface(Blocks):
|
||||
for o in self.output_components:
|
||||
o.interactive = False # Force output components to be non-interactive
|
||||
|
||||
if _repeat_outputs_per_model:
|
||||
self.output_components *= len(fn)
|
||||
|
||||
if (
|
||||
interpretation is None
|
||||
or isinstance(interpretation, list)
|
||||
@ -257,10 +251,9 @@ class Interface(Blocks):
|
||||
raise ValueError("Invalid value for parameter: interpretation")
|
||||
|
||||
self.api_mode = False
|
||||
self.predict = fn
|
||||
self.predict_durations = [[0, 0]] * len(fn)
|
||||
self.function_names = [func.__name__ for func in fn]
|
||||
self.__name__ = ", ".join(self.function_names)
|
||||
self.fn = fn
|
||||
self.fn_durations = [0, 0]
|
||||
self.__name__ = fn.__name__
|
||||
self.live = live
|
||||
self.title = title
|
||||
|
||||
@ -295,53 +288,7 @@ class Interface(Blocks):
|
||||
if not (self.theme == "default"):
|
||||
warnings.warn("Currently, only the 'default' theme is supported.")
|
||||
|
||||
if examples is None or (
|
||||
isinstance(examples, list)
|
||||
and (len(examples) == 0 or isinstance(examples[0], list))
|
||||
):
|
||||
self.examples = examples
|
||||
elif (
|
||||
isinstance(examples, list) and len(self.input_components) == 1
|
||||
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
||||
self.examples = [[e] for e in examples]
|
||||
elif isinstance(examples, str):
|
||||
if not os.path.exists(examples):
|
||||
raise FileNotFoundError(
|
||||
"Could not find examples directory: " + examples
|
||||
)
|
||||
log_file = os.path.join(examples, "log.csv")
|
||||
if not os.path.exists(log_file):
|
||||
if len(self.input_components) == 1:
|
||||
exampleset = [
|
||||
[os.path.join(examples, item)] for item in os.listdir(examples)
|
||||
]
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Could not find log file (required for multiple inputs): "
|
||||
+ log_file
|
||||
)
|
||||
else:
|
||||
with open(log_file) as logs:
|
||||
exampleset = list(csv.reader(logs))
|
||||
exampleset = exampleset[1:] # remove header
|
||||
for i, example in enumerate(exampleset):
|
||||
for j, (component, cell) in enumerate(
|
||||
zip(
|
||||
self.input_components + self.output_components,
|
||||
example,
|
||||
)
|
||||
):
|
||||
exampleset[i][j] = component.restore_flagged(
|
||||
examples,
|
||||
cell,
|
||||
None,
|
||||
)
|
||||
self.examples = exampleset
|
||||
else:
|
||||
raise ValueError(
|
||||
"Examples argument must either be a directory or a nested "
|
||||
"list, where each sublist represents a set of inputs."
|
||||
)
|
||||
self.examples = examples
|
||||
self.num_shap = num_shap
|
||||
self.examples_per_page = examples_per_page
|
||||
|
||||
@ -415,7 +362,7 @@ class Interface(Blocks):
|
||||
utils.version_check()
|
||||
Interface.instances.add(self)
|
||||
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
param_names = inspect.getfullargspec(self.fn)[0]
|
||||
for component, param_name in zip(self.input_components, param_names):
|
||||
if component.label is None:
|
||||
component.label = param_name
|
||||
@ -426,9 +373,6 @@ class Interface(Blocks):
|
||||
else:
|
||||
component.label = "output " + str(i)
|
||||
|
||||
if self.cache_examples and examples:
|
||||
cache_interface_examples(self)
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
if self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir)
|
||||
@ -625,34 +569,16 @@ class Interface(Blocks):
|
||||
non_state_inputs = [
|
||||
c for c in self.input_components if not isinstance(c, Variable)
|
||||
]
|
||||
|
||||
examples = Dataset(
|
||||
components=non_state_inputs,
|
||||
samples=self.examples,
|
||||
type="index",
|
||||
)
|
||||
|
||||
def load_example(example_id):
|
||||
processed_examples = [
|
||||
component.preprocess_example(sample)
|
||||
for component, sample in zip(
|
||||
self.input_components, self.examples[example_id]
|
||||
)
|
||||
]
|
||||
if self.cache_examples:
|
||||
processed_examples += load_from_cache(self, example_id)
|
||||
if len(processed_examples) == 1:
|
||||
return processed_examples[0]
|
||||
else:
|
||||
return processed_examples
|
||||
|
||||
examples.click(
|
||||
load_example,
|
||||
inputs=[examples],
|
||||
outputs=non_state_inputs
|
||||
+ (self.output_components if self.cache_examples else []),
|
||||
_postprocess=False,
|
||||
queue=False,
|
||||
non_state_outputs = [
|
||||
c for c in self.output_components if not isinstance(c, Variable)
|
||||
]
|
||||
self.examples_handler = Examples(
|
||||
examples=examples,
|
||||
inputs=non_state_inputs,
|
||||
outputs=non_state_outputs,
|
||||
fn=self.fn,
|
||||
cache_examples=self.cache_examples,
|
||||
examples_per_page=examples_per_page,
|
||||
)
|
||||
|
||||
if self.interpretation:
|
||||
@ -684,9 +610,7 @@ class Interface(Blocks):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
repr = "Gradio Interface for: {}".format(
|
||||
", ".join(fn.__name__ for fn in self.predict)
|
||||
)
|
||||
repr = f"Gradio Interface for: {self.__name__}"
|
||||
repr += "\n" + "-" * len(repr)
|
||||
repr += "\ninputs:"
|
||||
for component in self.input_components:
|
||||
@ -715,31 +639,19 @@ class Interface(Blocks):
|
||||
input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions = []
|
||||
output_component_counter = 0
|
||||
|
||||
for predict_fn in self.predict:
|
||||
prediction = predict_fn(*processed_input)
|
||||
prediction = self.fn(*processed_input)
|
||||
|
||||
if len(self.output_components) == len(self.predict) or prediction is None:
|
||||
prediction = [prediction]
|
||||
if prediction is None or len(self.output_components) == 1:
|
||||
prediction = [prediction]
|
||||
|
||||
if self.api_mode: # Serialize the input
|
||||
prediction_ = copy.deepcopy(prediction)
|
||||
prediction = []
|
||||
if self.api_mode: # Deerialize the input
|
||||
prediction = [
|
||||
output_component.deserialize(prediction[i])
|
||||
for i, output_component in enumerate(self.output_components)
|
||||
]
|
||||
|
||||
# Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
|
||||
for pred in prediction_:
|
||||
prediction.append(
|
||||
self.output_components[output_component_counter].deserialize(
|
||||
pred
|
||||
)
|
||||
)
|
||||
output_component_counter += 1
|
||||
|
||||
predictions.extend(prediction)
|
||||
|
||||
return predictions
|
||||
return prediction
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
@ -777,19 +689,17 @@ class Interface(Blocks):
|
||||
Passes a few samples through the function to test if the inputs/outputs
|
||||
components are consistent with the function parameter and return values.
|
||||
"""
|
||||
for predict_fn in self.predict:
|
||||
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
|
||||
raw_input = []
|
||||
for input_component in self.input_components:
|
||||
if input_component.test_input is None:
|
||||
print("SKIPPED")
|
||||
break
|
||||
else:
|
||||
raw_input.append(input_component.test_input)
|
||||
print("Test launch: {}()...".format(self.__name__), end=" ")
|
||||
raw_input = []
|
||||
for input_component in self.input_components:
|
||||
if input_component.test_input is None:
|
||||
print("SKIPPED")
|
||||
break
|
||||
else:
|
||||
self.process(raw_input)
|
||||
print("PASSED")
|
||||
continue
|
||||
raw_input.append(input_component.test_input)
|
||||
else:
|
||||
self.process(raw_input)
|
||||
print("PASSED")
|
||||
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
|
||||
"""
|
||||
|
@ -1,10 +1,13 @@
|
||||
"""
|
||||
Ways to transform interfaces to produce new interfaces
|
||||
"""
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import gradio
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.components import IOComponent
|
||||
|
||||
|
||||
class Parallel(gradio.Interface):
|
||||
"""
|
||||
@ -12,7 +15,7 @@ class Parallel(gradio.Interface):
|
||||
The Interfaces to put in Parallel must share the same input components (but can have different output components).
|
||||
"""
|
||||
|
||||
def __init__(self, *interfaces, **options):
|
||||
def __init__(self, *interfaces: gradio.Interface, **options):
|
||||
"""
|
||||
Parameters:
|
||||
*interfaces (Interface): any number of Interface objects that are to be compared in parallel
|
||||
@ -20,38 +23,29 @@ class Parallel(gradio.Interface):
|
||||
Returns:
|
||||
(Interface): an Interface object comparing the given models
|
||||
"""
|
||||
fns = []
|
||||
outputs = []
|
||||
outputs: List[IOComponent] = []
|
||||
|
||||
for io in interfaces:
|
||||
if not (isinstance(io, gradio.Interface)):
|
||||
warnings.warn(
|
||||
"Parallel may not work properly with non-Interface objects."
|
||||
)
|
||||
fns.extend(io.predict)
|
||||
outputs.extend(io.output_components)
|
||||
for interface in interfaces:
|
||||
outputs.extend(interface.output_components)
|
||||
|
||||
def parallel_fn(*args):
|
||||
return_values = []
|
||||
for fn in fns:
|
||||
value = fn(*args)
|
||||
if isinstance(value, tuple):
|
||||
return_values.extend(value)
|
||||
else:
|
||||
return_values.append(value)
|
||||
for interface in interfaces:
|
||||
value = interface.run_prediction(args)
|
||||
return_values.extend(value)
|
||||
if len(outputs) == 1:
|
||||
return return_values[0]
|
||||
return return_values
|
||||
|
||||
parallel_fn.__name__ = " | ".join([io.__name__ for io in interfaces])
|
||||
|
||||
kwargs = {
|
||||
"fn": parallel_fn,
|
||||
"inputs": interfaces[0].input_components,
|
||||
"outputs": outputs,
|
||||
"_repeat_outputs_per_model": False,
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[
|
||||
0
|
||||
].api_mode # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
||||
|
||||
class Series(gradio.Interface):
|
||||
@ -60,7 +54,7 @@ class Series(gradio.Interface):
|
||||
and so the input and output components must agree between the interfaces).
|
||||
"""
|
||||
|
||||
def __init__(self, *interfaces, **options):
|
||||
def __init__(self, *interfaces: gradio.Interface, **options):
|
||||
"""
|
||||
Parameters:
|
||||
*interfaces (Interface): any number of Interface objects that are to be connected in series
|
||||
@ -68,41 +62,35 @@ class Series(gradio.Interface):
|
||||
Returns:
|
||||
(Interface): an Interface object connecting the given models
|
||||
"""
|
||||
fns = []
|
||||
for io in interfaces:
|
||||
if not (isinstance(io, gradio.Interface)):
|
||||
warnings.warn(
|
||||
"Series may not work properly with non-Interface objects."
|
||||
)
|
||||
fns.append(io.predict)
|
||||
|
||||
def connected_fn(
|
||||
*data,
|
||||
): # Run each function with the appropriate preprocessing and postprocessing
|
||||
for idx, io in enumerate(interfaces):
|
||||
def connected_fn(*data):
|
||||
for idx, interface in enumerate(interfaces):
|
||||
# skip preprocessing for first interface since the Series interface will include it
|
||||
if idx > 0 and not (io.api_mode):
|
||||
if idx > 0 and not (interface.api_mode):
|
||||
data = [
|
||||
input_component.preprocess(data[i])
|
||||
for i, input_component in enumerate(io.input_components)
|
||||
for i, input_component in enumerate(interface.input_components)
|
||||
]
|
||||
|
||||
# run all of predictions sequentially
|
||||
predictions = []
|
||||
for predict_fn in io.predict:
|
||||
prediction = predict_fn(*data)
|
||||
predictions.append(prediction)
|
||||
data = predictions
|
||||
data = interface.fn(*data)
|
||||
if len(interface.output_components) == 1:
|
||||
data = [data]
|
||||
|
||||
# skip postprocessing for final interface since the Series interface will include it
|
||||
if idx < len(interfaces) - 1 and not (io.api_mode):
|
||||
if idx < len(interfaces) - 1 and not (interface.api_mode):
|
||||
data = [
|
||||
output_component.postprocess(data[i])
|
||||
for i, output_component in enumerate(io.output_components)
|
||||
for i, output_component in enumerate(
|
||||
interface.output_components
|
||||
)
|
||||
]
|
||||
|
||||
return data[0]
|
||||
if len(interface.output_components) == 1:
|
||||
return data[0]
|
||||
return data
|
||||
|
||||
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
|
||||
connected_fn.__name__ = " => ".join([io.__name__ for io in interfaces])
|
||||
|
||||
kwargs = {
|
||||
"fn": connected_fn,
|
||||
@ -111,6 +99,4 @@ class Series(gradio.Interface):
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[
|
||||
0
|
||||
].api_mode # TODO(abidlabs): make api_mode a per-function attribute
|
||||
self.api_mode = interfaces[0].api_mode # TODO: set api_mode per-function
|
||||
|
@ -1,68 +0,0 @@
|
||||
"""
|
||||
Defines helper methods useful for loading and caching Interface examples.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Any, List, Tuple
|
||||
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio import Interface
|
||||
|
||||
CACHED_FOLDER = "gradio_cached_examples"
|
||||
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
|
||||
|
||||
|
||||
def process_example(
|
||||
interface: Interface, example_id: int
|
||||
) -> Tuple[List[Any], List[float]]:
|
||||
"""Loads an example from the interface and returns its prediction."""
|
||||
example_set = interface.examples[example_id]
|
||||
raw_input = [
|
||||
interface.input_components[i].preprocess_example(example)
|
||||
for i, example in enumerate(example_set)
|
||||
]
|
||||
prediction = interface.process(raw_input)
|
||||
return prediction
|
||||
|
||||
|
||||
def cache_interface_examples(interface: Interface) -> None:
|
||||
"""Caches all of the examples from an interface."""
|
||||
if os.path.exists(CACHE_FILE):
|
||||
print(
|
||||
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory."
|
||||
)
|
||||
cache_logger = CSVLogger()
|
||||
cache_logger.setup(interface.output_components, CACHED_FOLDER)
|
||||
for example_id, _ in enumerate(interface.examples):
|
||||
try:
|
||||
prediction = process_example(interface, example_id)
|
||||
cache_logger.flag(prediction)
|
||||
except Exception as e:
|
||||
shutil.rmtree(CACHED_FOLDER)
|
||||
raise e
|
||||
|
||||
|
||||
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
|
||||
"""Loads a particular cached example for the interface."""
|
||||
with open(CACHE_FILE) as cache:
|
||||
examples = list(csv.reader(cache, quotechar="'"))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
for component, cell in zip(interface.output_components, example):
|
||||
output.append(
|
||||
component.restore_flagged(
|
||||
CACHED_FOLDER,
|
||||
cell,
|
||||
interface.encryption_key if interface.encrypt else None,
|
||||
)
|
||||
)
|
||||
return output
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from gradio import Interface, process_examples
|
||||
from gradio import Interface, examples
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -9,7 +9,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
class TestProcessExamples(unittest.TestCase):
|
||||
def test_process_example(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
|
||||
prediction = process_examples.process_example(io, 0)
|
||||
prediction = io.examples_handler.process_example(0)
|
||||
self.assertEquals(prediction[0], "Hello World")
|
||||
|
||||
def test_caching(self):
|
||||
@ -20,8 +20,8 @@ class TestProcessExamples(unittest.TestCase):
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
process_examples.cache_interface_examples(io)
|
||||
prediction = process_examples.load_from_cache(io, 1)
|
||||
io.examples_handler.cache_interface_examples()
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
io.close()
|
||||
self.assertEquals(prediction[0], "Hello Dunya")
|
||||
|
@ -27,7 +27,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
src="models",
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
@ -36,14 +36,14 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Blocks.load(
|
||||
name="lysandre/tiny-vit-random", src="models", alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
interface = gr.Interface.load("models/gpt2", alias=model_type)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
@ -52,7 +52,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
@ -61,7 +61,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
@ -70,7 +70,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/sshleifer/tiny-mbart", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
@ -81,7 +81,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
@ -90,7 +90,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/bert-base-uncased", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
@ -99,7 +99,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/bart-large-mnli", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.input_components[1], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.input_components[2], gr.components.Checkbox)
|
||||
@ -110,7 +110,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
||||
|
||||
@ -119,7 +119,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/google/vit-base-patch16-224", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
||||
|
||||
@ -130,7 +130,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Dataframe)
|
||||
|
||||
@ -141,7 +141,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
@ -152,7 +152,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
||||
|
||||
@ -161,7 +161,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
interface = gr.Interface.load(
|
||||
"models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface.predict[0].__name__, model_type)
|
||||
self.assertEqual(interface.__name__, model_type)
|
||||
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
||||
self.assertIsInstance(interface.output_components[0], gr.components.Image)
|
||||
|
||||
|
@ -120,8 +120,8 @@ class TestInterface(unittest.TestCase):
|
||||
examples = ["test1", "test2"]
|
||||
interface = Interface(lambda x: x, "textbox", "label", examples=examples)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
self.assertEqual(len(interface.examples), 2)
|
||||
self.assertEqual(len(interface.examples[0]), 1)
|
||||
self.assertEqual(len(interface.examples_handler.examples), 2)
|
||||
self.assertEqual(len(interface.examples_handler.examples[0]), 1)
|
||||
interface.close()
|
||||
|
||||
@mock.patch("IPython.display.display")
|
||||
|
Loading…
x
Reference in New Issue
Block a user