mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Interface types: handle input-only, output-only, and unified interfaces (#1108)
* optional labels * added prop * Added IOComponent class * get component fix * fixed component function * fixed test components * formatting * fixed output tests * working on blocks tests * fixed test blocks * cleanup * merged * unrender * add article * formatting * fixed render() * added demo * formatting * merge main * add interface types * added output only * added input only * formatting * added demos * formatting * removed unnecessary import * updated demos * fixed duplication * fix for state
This commit is contained in:
parent
f47e65b14d
commit
93e5a82ff2
32
demo/fake_gan_no_input/run.py
Normal file
32
demo/fake_gan_no_input/run.py
Normal file
@ -0,0 +1,32 @@
|
||||
# This demo needs to be run from the repo folder.
|
||||
# python demo/fake_gan/run.py
|
||||
import random
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def fake_gan():
|
||||
time.sleep(1)
|
||||
image = random.choice(
|
||||
[
|
||||
"https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
|
||||
"https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
|
||||
"https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
|
||||
"https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
|
||||
"https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
|
||||
]
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=fake_gan,
|
||||
inputs=None,
|
||||
outputs=gr.Image(label="Generated Image"),
|
||||
title="FD-GAN",
|
||||
description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
14
demo/gpt_j_unified/run.py
Normal file
14
demo/gpt_j_unified/run.py
Normal file
@ -0,0 +1,14 @@
|
||||
import gradio as gr
|
||||
|
||||
component = gr.Textbox(lines=5, label="Text")
|
||||
api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B")
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=lambda x: x[:-50] + api(x[-50:]),
|
||||
inputs=component,
|
||||
outputs=component,
|
||||
title="GPT-J-6B",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -129,7 +129,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
},
|
||||
"text-classification": {
|
||||
"inputs": components.Textbox(label="Input"),
|
||||
"outputs": components.Label(type="confidences", label="Classification"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {
|
||||
i["label"].split(", ")[0]: i["score"] for i in r.json()[0]
|
||||
@ -159,7 +159,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
components.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
components.Checkbox(label="Allow multiple true classes"),
|
||||
],
|
||||
"outputs": components.Label(type="confidences", label="Classification"),
|
||||
"outputs": components.Label(label="Classification"),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"inputs": i,
|
||||
"parameters": {"candidate_labels": c, "multi_class": m},
|
||||
|
@ -13,6 +13,7 @@ import random
|
||||
import re
|
||||
import warnings
|
||||
import weakref
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
@ -48,6 +49,12 @@ class Interface(Blocks):
|
||||
# stores references to all currently existing Interface instances
|
||||
instances: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
class InterfaceTypes(Enum):
|
||||
STANDARD = auto()
|
||||
INPUT_ONLY = auto()
|
||||
OUTPUT_ONLY = auto()
|
||||
UNIFIED = auto()
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> List[Interface]:
|
||||
"""
|
||||
@ -103,8 +110,8 @@ class Interface(Blocks):
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable | List[Callable],
|
||||
inputs: str | Component | List[str | Component] = None,
|
||||
outputs: str | Component | List[str | Component] = None,
|
||||
inputs: Optional[str | Component | List[str | Component]] = None,
|
||||
outputs: Optional[str | Component | List[str | Component]] = None,
|
||||
examples: Optional[List[Any] | List[List[Any]] | str] = None,
|
||||
cache_examples: Optional[bool] = None,
|
||||
examples_per_page: int = 10,
|
||||
@ -157,10 +164,15 @@ class Interface(Blocks):
|
||||
analytics_enabled=analytics_enabled, mode="interface", **kwargs
|
||||
)
|
||||
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
if outputs is None:
|
||||
self.interface_type = self.InterfaceTypes.STANDARD
|
||||
if inputs is None and outputs is None:
|
||||
raise ValueError("Must provide at least one of `inputs` or `outputs`")
|
||||
elif outputs is None:
|
||||
outputs = []
|
||||
self.interface_type = self.InterfaceTypes.INPUT_ONLY
|
||||
elif inputs is None:
|
||||
inputs = []
|
||||
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
|
||||
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
@ -195,10 +207,20 @@ class Interface(Blocks):
|
||||
|
||||
self.input_components = [get_component_instance(i).unrender() for i in inputs]
|
||||
self.output_components = [get_component_instance(o).unrender() for o in outputs]
|
||||
for o in self.output_components:
|
||||
o.interactive = (
|
||||
False # Force output components to be treated as non-interactive
|
||||
)
|
||||
|
||||
if len(self.input_components) == len(self.output_components):
|
||||
same_components = [
|
||||
i is o for i, o in zip(self.input_components, self.output_components)
|
||||
]
|
||||
if all(same_components):
|
||||
self.interface_type = self.InterfaceTypes.UNIFIED
|
||||
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
]:
|
||||
for o in self.output_components:
|
||||
o.interactive = False # Force output components to be non-interactive
|
||||
|
||||
if repeat_outputs_per_model:
|
||||
self.output_components *= len(fn)
|
||||
@ -441,49 +463,85 @@ class Interface(Blocks):
|
||||
if self.description:
|
||||
Markdown(self.description)
|
||||
with Row():
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
input_component_column = Column()
|
||||
with input_component_column:
|
||||
for component in self.input_components:
|
||||
component.render()
|
||||
if self.interpretation:
|
||||
interpret_component_column = Column(visible=False)
|
||||
interpretation_set = []
|
||||
with interpret_component_column:
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
input_component_column = Column()
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
status_tracker = StatusTracker(cover_container=True)
|
||||
with input_component_column:
|
||||
for component in self.input_components:
|
||||
interpretation_set.append(Interpretation(component))
|
||||
with Row():
|
||||
clear_btn = Button("Clear")
|
||||
if not self.live:
|
||||
submit_btn = Button("Submit")
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
status_tracker = StatusTracker(cover_container=True)
|
||||
for component in self.output_components:
|
||||
component.render()
|
||||
with Row():
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag")
|
||||
flag_btn._click_no_preprocess(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=[],
|
||||
)
|
||||
component.render()
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
interpret_component_column = Column(visible=False)
|
||||
interpretation_set = []
|
||||
with interpret_component_column:
|
||||
for component in self.input_components:
|
||||
interpretation_set.append(Interpretation(component))
|
||||
with Row():
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
]:
|
||||
clear_btn = Button("Clear")
|
||||
if not self.live:
|
||||
submit_btn = Button("Submit")
|
||||
elif self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag")
|
||||
flag_btn._click_no_preprocess(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components,
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
]:
|
||||
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
"padding": "0.5rem",
|
||||
"border-radius": "0.5rem",
|
||||
}
|
||||
):
|
||||
status_tracker = StatusTracker(cover_container=True)
|
||||
for component in self.output_components:
|
||||
component.render()
|
||||
with Row():
|
||||
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Generate")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag")
|
||||
flag_btn._click_no_preprocess(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components
|
||||
+ self.output_components,
|
||||
outputs=[],
|
||||
)
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
submit_fn = (
|
||||
lambda *args: self.run_prediction(args)[0]
|
||||
if len(self.output_components) == 1
|
||||
@ -503,20 +561,41 @@ class Interface(Blocks):
|
||||
)
|
||||
clear_btn.click(
|
||||
(
|
||||
lambda: [
|
||||
component.default_value
|
||||
if hasattr(component, "default_value")
|
||||
else None
|
||||
for component in self.input_components + self.output_components
|
||||
]
|
||||
+ [True]
|
||||
+ ([False] if self.interpretation else [])
|
||||
lambda: utils.resolve_singleton(
|
||||
[
|
||||
component.default_value
|
||||
if hasattr(component, "default_value")
|
||||
else None
|
||||
for component in self.input_components
|
||||
+ self.output_components
|
||||
]
|
||||
+ (
|
||||
[True]
|
||||
if self.interface_type
|
||||
in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
)
|
||||
+ ([False] if self.interpretation else [])
|
||||
)
|
||||
),
|
||||
[],
|
||||
(
|
||||
self.input_components
|
||||
+ self.output_components
|
||||
+ [input_component_column]
|
||||
+ (
|
||||
[input_component_column]
|
||||
if self.interface_type
|
||||
in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
)
|
||||
+ ([interpret_component_column] if self.interpretation else [])
|
||||
),
|
||||
)
|
||||
@ -614,17 +693,15 @@ class Interface(Blocks):
|
||||
for predict_fn in self.predict:
|
||||
prediction = predict_fn(*processed_input)
|
||||
|
||||
if len(self.output_components) == len(self.predict):
|
||||
if len(self.output_components) == len(self.predict) or prediction is None:
|
||||
prediction = [prediction]
|
||||
|
||||
if self.api_mode: # Serialize the input
|
||||
prediction_ = copy.deepcopy(prediction)
|
||||
prediction = []
|
||||
for (
|
||||
pred
|
||||
) in (
|
||||
prediction_
|
||||
): # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
|
||||
|
||||
# Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
|
||||
for pred in prediction_:
|
||||
prediction.append(
|
||||
self.output_components[output_component_counter].deserialize(
|
||||
pred
|
||||
|
@ -369,3 +369,10 @@ def delete_none(_dict):
|
||||
_dict = type(_dict)(delete_none(item) for item in _dict if item is not None)
|
||||
|
||||
return _dict
|
||||
|
||||
|
||||
def resolve_singleton(_list):
|
||||
if len(_list) == 1:
|
||||
return _list[0]
|
||||
else:
|
||||
return _list
|
||||
|
Loading…
x
Reference in New Issue
Block a user