diff --git a/demo/fake_gan_no_input/run.py b/demo/fake_gan_no_input/run.py new file mode 100644 index 0000000000..9b7f515293 --- /dev/null +++ b/demo/fake_gan_no_input/run.py @@ -0,0 +1,32 @@ +# This demo needs to be run from the repo folder. +# python demo/fake_gan/run.py +import random +import time + +import gradio as gr + + +def fake_gan(): + time.sleep(1) + image = random.choice( + [ + "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80", + "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80", + "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80", + "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80", + "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80", + ] + ) + return image + + +demo = gr.Interface( + fn=fake_gan, + inputs=None, + outputs=gr.Image(label="Generated Image"), + title="FD-GAN", + description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.", +) + +if __name__ == "__main__": + demo.launch() diff --git a/demo/gpt_j_unified/run.py b/demo/gpt_j_unified/run.py new file mode 100644 index 0000000000..b561f89509 --- /dev/null +++ b/demo/gpt_j_unified/run.py @@ -0,0 +1,14 @@ +import gradio as gr + +component = gr.Textbox(lines=5, label="Text") +api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B") + +demo = gr.Interface( + fn=lambda x: x[:-50] + api(x[-50:]), + inputs=component, + outputs=component, + title="GPT-J-6B", +) + +if __name__ == "__main__": + demo.launch() diff --git a/gradio/external.py b/gradio/external.py index 8f47b6cc65..642994a44b 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -129,7 +129,7 @@ def get_huggingface_interface(model_name, api_key, alias): }, "text-classification": { "inputs": components.Textbox(label="Input"), - "outputs": components.Label(type="confidences", label="Classification"), + "outputs": components.Label(label="Classification"), "preprocess": lambda x: {"inputs": x}, "postprocess": lambda r: { i["label"].split(", ")[0]: i["score"] for i in r.json()[0] @@ -159,7 +159,7 @@ def get_huggingface_interface(model_name, api_key, alias): components.Textbox(label="Possible class names (" "comma-separated)"), components.Checkbox(label="Allow multiple true classes"), ], - "outputs": components.Label(type="confidences", label="Classification"), + "outputs": components.Label(label="Classification"), "preprocess": lambda i, c, m: { "inputs": i, "parameters": {"candidate_labels": c, "multi_class": m}, diff --git a/gradio/interface.py b/gradio/interface.py index e50ae44192..865f045fec 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -13,6 +13,7 @@ import random import re import warnings import weakref +from enum import Enum, auto from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple from markdown_it import MarkdownIt @@ -48,6 +49,12 @@ class Interface(Blocks): # stores references to all currently existing Interface instances instances: weakref.WeakSet = weakref.WeakSet() + class InterfaceTypes(Enum): + STANDARD = auto() + INPUT_ONLY = auto() + OUTPUT_ONLY = auto() + UNIFIED = auto() + @classmethod def get_instances(cls) -> List[Interface]: """ @@ -103,8 +110,8 @@ class Interface(Blocks): def __init__( self, fn: Callable | List[Callable], - inputs: str | Component | List[str | Component] = None, - outputs: str | Component | List[str | Component] = None, + inputs: Optional[str | Component | List[str | Component]] = None, + outputs: Optional[str | Component | List[str | Component]] = None, examples: Optional[List[Any] | List[List[Any]] | str] = None, cache_examples: Optional[bool] = None, examples_per_page: int = 10, @@ -157,10 +164,15 @@ class Interface(Blocks): analytics_enabled=analytics_enabled, mode="interface", **kwargs ) - if inputs is None: - inputs = [] - if outputs is None: + self.interface_type = self.InterfaceTypes.STANDARD + if inputs is None and outputs is None: + raise ValueError("Must provide at least one of `inputs` or `outputs`") + elif outputs is None: outputs = [] + self.interface_type = self.InterfaceTypes.INPUT_ONLY + elif inputs is None: + inputs = [] + self.interface_type = self.InterfaceTypes.OUTPUT_ONLY if not isinstance(fn, list): fn = [fn] @@ -195,10 +207,20 @@ class Interface(Blocks): self.input_components = [get_component_instance(i).unrender() for i in inputs] self.output_components = [get_component_instance(o).unrender() for o in outputs] - for o in self.output_components: - o.interactive = ( - False # Force output components to be treated as non-interactive - ) + + if len(self.input_components) == len(self.output_components): + same_components = [ + i is o for i, o in zip(self.input_components, self.output_components) + ] + if all(same_components): + self.interface_type = self.InterfaceTypes.UNIFIED + + if self.interface_type in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.OUTPUT_ONLY, + ]: + for o in self.output_components: + o.interactive = False # Force output components to be non-interactive if repeat_outputs_per_model: self.output_components *= len(fn) @@ -441,49 +463,85 @@ class Interface(Blocks): if self.description: Markdown(self.description) with Row(): - with Column( - css={ - "background-color": "rgb(249,250,251)", - "padding": "0.5rem", - "border-radius": "0.5rem", - } - ): - input_component_column = Column() - with input_component_column: - for component in self.input_components: - component.render() - if self.interpretation: - interpret_component_column = Column(visible=False) - interpretation_set = [] - with interpret_component_column: + if self.interface_type in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.INPUT_ONLY, + self.InterfaceTypes.UNIFIED, + ]: + with Column( + css={ + "background-color": "rgb(249,250,251)", + "padding": "0.5rem", + "border-radius": "0.5rem", + } + ): + input_component_column = Column() + if self.interface_type in [ + self.InterfaceTypes.INPUT_ONLY, + self.InterfaceTypes.UNIFIED, + ]: + status_tracker = StatusTracker(cover_container=True) + with input_component_column: for component in self.input_components: - interpretation_set.append(Interpretation(component)) - with Row(): - clear_btn = Button("Clear") - if not self.live: - submit_btn = Button("Submit") - with Column( - css={ - "background-color": "rgb(249,250,251)", - "padding": "0.5rem", - "border-radius": "0.5rem", - } - ): - status_tracker = StatusTracker(cover_container=True) - for component in self.output_components: - component.render() - with Row(): - if self.allow_flagging == "manual": - flag_btn = Button("Flag") - flag_btn._click_no_preprocess( - lambda *flag_data: self.flagging_callback.flag( - flag_data - ), - inputs=self.input_components + self.output_components, - outputs=[], - ) + component.render() if self.interpretation: - interpretation_btn = Button("Interpret") + interpret_component_column = Column(visible=False) + interpretation_set = [] + with interpret_component_column: + for component in self.input_components: + interpretation_set.append(Interpretation(component)) + with Row(): + if self.interface_type in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.INPUT_ONLY, + ]: + clear_btn = Button("Clear") + if not self.live: + submit_btn = Button("Submit") + elif self.interface_type == self.InterfaceTypes.UNIFIED: + clear_btn = Button("Clear") + submit_btn = Button("Submit") + if self.allow_flagging == "manual": + flag_btn = Button("Flag") + flag_btn._click_no_preprocess( + lambda *flag_data: self.flagging_callback.flag( + flag_data + ), + inputs=self.input_components, + outputs=[], + ) + + if self.interface_type in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.OUTPUT_ONLY, + ]: + + with Column( + css={ + "background-color": "rgb(249,250,251)", + "padding": "0.5rem", + "border-radius": "0.5rem", + } + ): + status_tracker = StatusTracker(cover_container=True) + for component in self.output_components: + component.render() + with Row(): + if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY: + clear_btn = Button("Clear") + submit_btn = Button("Generate") + if self.allow_flagging == "manual": + flag_btn = Button("Flag") + flag_btn._click_no_preprocess( + lambda *flag_data: self.flagging_callback.flag( + flag_data + ), + inputs=self.input_components + + self.output_components, + outputs=[], + ) + if self.interpretation: + interpretation_btn = Button("Interpret") submit_fn = ( lambda *args: self.run_prediction(args)[0] if len(self.output_components) == 1 @@ -503,20 +561,41 @@ class Interface(Blocks): ) clear_btn.click( ( - lambda: [ - component.default_value - if hasattr(component, "default_value") - else None - for component in self.input_components + self.output_components - ] - + [True] - + ([False] if self.interpretation else []) + lambda: utils.resolve_singleton( + [ + component.default_value + if hasattr(component, "default_value") + else None + for component in self.input_components + + self.output_components + ] + + ( + [True] + if self.interface_type + in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.INPUT_ONLY, + self.InterfaceTypes.UNIFIED, + ] + else [] + ) + + ([False] if self.interpretation else []) + ) ), [], ( self.input_components + self.output_components - + [input_component_column] + + ( + [input_component_column] + if self.interface_type + in [ + self.InterfaceTypes.STANDARD, + self.InterfaceTypes.INPUT_ONLY, + self.InterfaceTypes.UNIFIED, + ] + else [] + ) + ([interpret_component_column] if self.interpretation else []) ), ) @@ -614,17 +693,15 @@ class Interface(Blocks): for predict_fn in self.predict: prediction = predict_fn(*processed_input) - if len(self.output_components) == len(self.predict): + if len(self.output_components) == len(self.predict) or prediction is None: prediction = [prediction] if self.api_mode: # Serialize the input prediction_ = copy.deepcopy(prediction) prediction = [] - for ( - pred - ) in ( - prediction_ - ): # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces + + # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces + for pred in prediction_: prediction.append( self.output_components[output_component_counter].deserialize( pred diff --git a/gradio/utils.py b/gradio/utils.py index 49b68b7a06..731dae7756 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -369,3 +369,10 @@ def delete_none(_dict): _dict = type(_dict)(delete_none(item) for item in _dict if item is not None) return _dict + + +def resolve_singleton(_list): + if len(_list) == 1: + return _list[0] + else: + return _list