Interface types: handle input-only, output-only, and unified interfaces (#1108)

* optional labels

* added prop

* Added IOComponent class

* get component fix

* fixed component function

* fixed test components

* formatting

* fixed output tests

* working on blocks tests

* fixed test blocks

* cleanup

* merged

* unrender

* add article

* formatting

* fixed render()

* added demo

* formatting

* merge main

* add interface types

* added output only

* added input only

* formatting

* added demos

* formatting

* removed unnecessary import

* updated demos

* fixed duplication

* fix for state
This commit is contained in:
Abubakar Abid 2022-04-28 03:06:16 -07:00 committed by GitHub
parent f47e65b14d
commit 93e5a82ff2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 197 additions and 67 deletions

View File

@ -0,0 +1,32 @@
# This demo needs to be run from the repo folder.
# python demo/fake_gan/run.py
import random
import time
import gradio as gr
def fake_gan():
time.sleep(1)
image = random.choice(
[
"https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
"https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
"https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
"https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
"https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
]
)
return image
demo = gr.Interface(
fn=fake_gan,
inputs=None,
outputs=gr.Image(label="Generated Image"),
title="FD-GAN",
description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
)
if __name__ == "__main__":
demo.launch()

14
demo/gpt_j_unified/run.py Normal file
View File

@ -0,0 +1,14 @@
import gradio as gr
component = gr.Textbox(lines=5, label="Text")
api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B")
demo = gr.Interface(
fn=lambda x: x[:-50] + api(x[-50:]),
inputs=component,
outputs=component,
title="GPT-J-6B",
)
if __name__ == "__main__":
demo.launch()

View File

@ -129,7 +129,7 @@ def get_huggingface_interface(model_name, api_key, alias):
},
"text-classification": {
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(type="confidences", label="Classification"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()[0]
@ -159,7 +159,7 @@ def get_huggingface_interface(model_name, api_key, alias):
components.Textbox(label="Possible class names (" "comma-separated)"),
components.Checkbox(label="Allow multiple true classes"),
],
"outputs": components.Label(type="confidences", label="Classification"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda i, c, m: {
"inputs": i,
"parameters": {"candidate_labels": c, "multi_class": m},

View File

@ -13,6 +13,7 @@ import random
import re
import warnings
import weakref
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from markdown_it import MarkdownIt
@ -48,6 +49,12 @@ class Interface(Blocks):
# stores references to all currently existing Interface instances
instances: weakref.WeakSet = weakref.WeakSet()
class InterfaceTypes(Enum):
STANDARD = auto()
INPUT_ONLY = auto()
OUTPUT_ONLY = auto()
UNIFIED = auto()
@classmethod
def get_instances(cls) -> List[Interface]:
"""
@ -103,8 +110,8 @@ class Interface(Blocks):
def __init__(
self,
fn: Callable | List[Callable],
inputs: str | Component | List[str | Component] = None,
outputs: str | Component | List[str | Component] = None,
inputs: Optional[str | Component | List[str | Component]] = None,
outputs: Optional[str | Component | List[str | Component]] = None,
examples: Optional[List[Any] | List[List[Any]] | str] = None,
cache_examples: Optional[bool] = None,
examples_per_page: int = 10,
@ -157,10 +164,15 @@ class Interface(Blocks):
analytics_enabled=analytics_enabled, mode="interface", **kwargs
)
if inputs is None:
inputs = []
if outputs is None:
self.interface_type = self.InterfaceTypes.STANDARD
if inputs is None and outputs is None:
raise ValueError("Must provide at least one of `inputs` or `outputs`")
elif outputs is None:
outputs = []
self.interface_type = self.InterfaceTypes.INPUT_ONLY
elif inputs is None:
inputs = []
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
if not isinstance(fn, list):
fn = [fn]
@ -195,10 +207,20 @@ class Interface(Blocks):
self.input_components = [get_component_instance(i).unrender() for i in inputs]
self.output_components = [get_component_instance(o).unrender() for o in outputs]
for o in self.output_components:
o.interactive = (
False # Force output components to be treated as non-interactive
)
if len(self.input_components) == len(self.output_components):
same_components = [
i is o for i, o in zip(self.input_components, self.output_components)
]
if all(same_components):
self.interface_type = self.InterfaceTypes.UNIFIED
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
]:
for o in self.output_components:
o.interactive = False # Force output components to be non-interactive
if repeat_outputs_per_model:
self.output_components *= len(fn)
@ -441,49 +463,85 @@ class Interface(Blocks):
if self.description:
Markdown(self.description)
with Row():
with Column(
css={
"background-color": "rgb(249,250,251)",
"padding": "0.5rem",
"border-radius": "0.5rem",
}
):
input_component_column = Column()
with input_component_column:
for component in self.input_components:
component.render()
if self.interpretation:
interpret_component_column = Column(visible=False)
interpretation_set = []
with interpret_component_column:
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]:
with Column(
css={
"background-color": "rgb(249,250,251)",
"padding": "0.5rem",
"border-radius": "0.5rem",
}
):
input_component_column = Column()
if self.interface_type in [
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]:
status_tracker = StatusTracker(cover_container=True)
with input_component_column:
for component in self.input_components:
interpretation_set.append(Interpretation(component))
with Row():
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit")
with Column(
css={
"background-color": "rgb(249,250,251)",
"padding": "0.5rem",
"border-radius": "0.5rem",
}
):
status_tracker = StatusTracker(cover_container=True)
for component in self.output_components:
component.render()
with Row():
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components + self.output_components,
outputs=[],
)
component.render()
if self.interpretation:
interpretation_btn = Button("Interpret")
interpret_component_column = Column(visible=False)
interpretation_set = []
with interpret_component_column:
for component in self.input_components:
interpretation_set.append(Interpretation(component))
with Row():
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
]:
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit")
elif self.interface_type == self.InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components,
outputs=[],
)
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
]:
with Column(
css={
"background-color": "rgb(249,250,251)",
"padding": "0.5rem",
"border-radius": "0.5rem",
}
):
status_tracker = StatusTracker(cover_container=True)
for component in self.output_components:
component.render()
with Row():
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components
+ self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
lambda *args: self.run_prediction(args)[0]
if len(self.output_components) == 1
@ -503,20 +561,41 @@ class Interface(Blocks):
)
clear_btn.click(
(
lambda: [
component.default_value
if hasattr(component, "default_value")
else None
for component in self.input_components + self.output_components
]
+ [True]
+ ([False] if self.interpretation else [])
lambda: utils.resolve_singleton(
[
component.default_value
if hasattr(component, "default_value")
else None
for component in self.input_components
+ self.output_components
]
+ (
[True]
if self.interface_type
in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]
else []
)
+ ([False] if self.interpretation else [])
)
),
[],
(
self.input_components
+ self.output_components
+ [input_component_column]
+ (
[input_component_column]
if self.interface_type
in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]
else []
)
+ ([interpret_component_column] if self.interpretation else [])
),
)
@ -614,17 +693,15 @@ class Interface(Blocks):
for predict_fn in self.predict:
prediction = predict_fn(*processed_input)
if len(self.output_components) == len(self.predict):
if len(self.output_components) == len(self.predict) or prediction is None:
prediction = [prediction]
if self.api_mode: # Serialize the input
prediction_ = copy.deepcopy(prediction)
prediction = []
for (
pred
) in (
prediction_
): # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
# Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
for pred in prediction_:
prediction.append(
self.output_components[output_component_counter].deserialize(
pred

View File

@ -369,3 +369,10 @@ def delete_none(_dict):
_dict = type(_dict)(delete_none(item) for item in _dict if item is not None)
return _dict
def resolve_singleton(_list):
if len(_list) == 1:
return _list[0]
else:
return _list