From 2a93225952091bf06b4ae652cfceae1bd4a20362 Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Mon, 16 May 2022 11:51:09 -0700 Subject: [PATCH] Create Streamables (#1279) * changes * fix * fix for vars too * changes * fix tests Co-authored-by: Abubakar Abid --- demo/calculator/run.py | 1 - demo/live_with_vars/run.py | 6 ++ demo/stream_audio/run.py | 20 ++++++ demo/stream_frames/run.py | 13 ++++ demo/streaming_stt/run.py | 12 +++- gradio/components.py | 62 ++++++++++++++++++- gradio/events.py | 20 ++++++ gradio/interface.py | 20 +++++- gradio/test_data/blocks_configs.py | 6 ++ test/test_components.py | 3 + ui/packages/app/src/Blocks.svelte | 8 ++- ui/packages/app/src/Render.svelte | 3 +- .../app/src/components/Audio/Audio.svelte | 18 +++++- .../app/src/components/Image/Image.svelte | 5 ++ ui/packages/app/src/stores.ts | 31 +++++++++- ui/packages/audio/src/Audio.svelte | 31 ++++++++-- ui/packages/image/src/Image.svelte | 14 ++++- ui/packages/image/src/Webcam.svelte | 46 +++++++++----- 18 files changed, 276 insertions(+), 43 deletions(-) create mode 100644 demo/live_with_vars/run.py create mode 100644 demo/stream_audio/run.py create mode 100644 demo/stream_frames/run.py diff --git a/demo/calculator/run.py b/demo/calculator/run.py index 1339a629c5..96aafc1638 100644 --- a/demo/calculator/run.py +++ b/demo/calculator/run.py @@ -1,6 +1,5 @@ import gradio as gr - def calculator(num1, operation, num2): if operation == "add": return num1 + num2 diff --git a/demo/live_with_vars/run.py b/demo/live_with_vars/run.py new file mode 100644 index 0000000000..acff9d7fb8 --- /dev/null +++ b/demo/live_with_vars/run.py @@ -0,0 +1,6 @@ +import gradio as gr + +gr.Interface( + lambda x, y: (x + y if y is not None else x, x + y if y is not None else x), + ["textbox", "state"], + ["textbox", "state"], live=True).launch() diff --git a/demo/stream_audio/run.py b/demo/stream_audio/run.py new file mode 100644 index 0000000000..8fcd3c2aff --- /dev/null +++ b/demo/stream_audio/run.py @@ -0,0 +1,20 @@ +import gradio as gr +import numpy as np + +with gr.Blocks() as demo: + inp = gr.Audio(source="microphone") + out = gr.Audio() + stream = gr.Variable() + + def add_to_stream(audio, instream): + if audio is None: + return gr.update(), instream + if instream is None: + ret = audio + else: + ret = (audio[0], np.concatenate((instream[1], audio[1]))) + return ret, ret + inp.stream(add_to_stream, [inp, stream], [out, stream]) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/demo/stream_frames/run.py b/demo/stream_frames/run.py new file mode 100644 index 0000000000..294617429d --- /dev/null +++ b/demo/stream_frames/run.py @@ -0,0 +1,13 @@ +import gradio as gr +import numpy as np + +with gr.Blocks() as demo: + inp = gr.Image(source="webcam") + out = gr.Image() + + def flip(im): + return np.flipud(im) + inp.stream(flip, [inp], [out]) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/demo/streaming_stt/run.py b/demo/streaming_stt/run.py index 04f5e20969..ec68e6b1d6 100644 --- a/demo/streaming_stt/run.py +++ b/demo/streaming_stt/run.py @@ -1,7 +1,7 @@ from deepspeech import Model import gradio as gr import numpy as np -import urllib.request +import urllib.request model_file_path = "deepspeech-0.9.3-models.pbmm" lm_file_path = "deepspeech-0.9.3-models.scorer" @@ -45,7 +45,13 @@ def transcribe(speech, stream): text = stream.intermediateDecode() return text, stream -demo = gr.Interface(transcribe, ["microphone", "state"], ["text", "state"], live=True) + +demo = gr.Interface( + transcribe, + [gr.Audio(source="microphone", streaming=True), "state"], + ["text", "state"], + live=True, +) if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch() diff --git a/gradio/components.py b/gradio/components.py index cca422d564..44e7d21e17 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -16,7 +16,7 @@ import tempfile import warnings from copy import deepcopy from types import ModuleType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import matplotlib.figure import numpy @@ -34,6 +34,7 @@ from gradio.events import ( Clickable, Editable, Playable, + Streamable, Submittable, ) @@ -1306,7 +1307,7 @@ class Dropdown(Radio): ) -class Image(Editable, Clearable, Changeable, IOComponent): +class Image(Editable, Clearable, Changeable, Streamable, IOComponent): """ Creates an image component that can be used to upload/draw images (as an input) or display images (as an output). Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type`. @@ -1330,6 +1331,7 @@ class Image(Editable, Clearable, Changeable, IOComponent): interactive: Optional[bool] = None, visible: bool = True, elem_id: Optional[str] = None, + streaming: bool = False, **kwargs, ): """ @@ -1344,6 +1346,7 @@ class Image(Editable, Clearable, Changeable, IOComponent): label (Optional[str]): component name in interface. show_label (bool): if True, will display label. visible (bool): If False, component will be hidden. + streaming (bool): If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'. """ self.type = type self.value = ( @@ -1359,6 +1362,10 @@ class Image(Editable, Clearable, Changeable, IOComponent): self.invert_colors = invert_colors self.test_input = deepcopy(media_data.BASE64_IMAGE) self.interpret_by_tokens = True + self.streaming = streaming + if streaming and source != "webcam": + raise ValueError("Image streaming only available if source is 'webcam'.") + IOComponent.__init__( self, label=label, @@ -1377,6 +1384,7 @@ class Image(Editable, Clearable, Changeable, IOComponent): "source": self.source, "tool": self.tool, "value": self.value, + "streaming": self.streaming, **IOComponent.get_config(self), } @@ -1635,6 +1643,25 @@ class Image(Editable, Clearable, Changeable, IOComponent): container_bg_color=container_bg_color, ) + def stream( + self, + fn: Callable, + inputs: List[Component], + outputs: List[Component], + _js: Optional[str] = None, + ): + """ + Parameters: + fn: Callable function + inputs: List of inputs + outputs: List of outputs + _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. + Returns: None + """ + if self.source != "webcam": + raise ValueError("Image streaming only available if source is 'webcam'.") + Streamable.stream(self, fn, inputs, outputs, _js) + class Video(Changeable, Clearable, Playable, IOComponent): """ @@ -1777,7 +1804,7 @@ class Video(Changeable, Clearable, Playable, IOComponent): return processing_utils.decode_base64_to_file(x).name -class Audio(Changeable, Clearable, Playable, IOComponent): +class Audio(Changeable, Clearable, Playable, Streamable, IOComponent): """ Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output). Preprocessing: passes the uploaded audio as a {Tuple(int, numpy.array)} corresponding to (sample rate, data) or as a {str} filepath, depending on `type` @@ -1797,6 +1824,7 @@ class Audio(Changeable, Clearable, Playable, IOComponent): interactive: Optional[bool] = None, visible: bool = True, elem_id: Optional[str] = None, + streaming: bool = False, **kwargs, ): """ @@ -1807,6 +1835,7 @@ class Audio(Changeable, Clearable, Playable, IOComponent): label (Optional[str]): component name in interface. show_label (bool): if True, will display label. visible (bool): If False, component will be hidden. + streaming (bool): If set to true when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'. """ self.value = ( processing_utils.encode_url_or_file_to_base64(value) if value else None @@ -1817,6 +1846,11 @@ class Audio(Changeable, Clearable, Playable, IOComponent): self.output_type = "auto" self.test_input = deepcopy(media_data.BASE64_AUDIO) self.interpret_by_tokens = True + self.streaming = streaming + if streaming and source != "microphone": + raise ValueError( + "Audio streaming only available if source is 'microphone'." + ) IOComponent.__init__( self, label=label, @@ -1832,6 +1866,7 @@ class Audio(Changeable, Clearable, Playable, IOComponent): return { "source": self.source, # TODO: This did not exist in output template, careful here if an error arrives "value": self.value, + "streaming": self.streaming, **IOComponent.get_config(self), } @@ -2056,6 +2091,27 @@ class Audio(Changeable, Clearable, Playable, IOComponent): def deserialize(self, x): return processing_utils.decode_base64_to_file(x).name + def stream( + self, + fn: Callable, + inputs: List[Component], + outputs: List[Component], + _js: Optional[str] = None, + ): + """ + Parameters: + fn: Callable function + inputs: List of inputs + outputs: List of outputs + _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. + Returns: None + """ + if self.source != "microphone": + raise ValueError( + "Audio streaming only available if source is 'microphone'." + ) + Streamable.stream(self, fn, inputs, outputs, _js) + class File(Changeable, Clearable, IOComponent): """ diff --git a/gradio/events.py b/gradio/events.py index 76deaa8cc4..5982eab3ec 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -179,3 +179,23 @@ class Playable(Block): Returns: None """ self.set_event_trigger("stop", fn, inputs, outputs, js=_js) + + +class Streamable(Block): + def stream( + self, + fn: Callable, + inputs: List[Component], + outputs: List[Component], + _js: Optional[str] = None, + ): + """ + Parameters: + fn: Callable function + inputs: List of inputs + outputs: List of outputs + _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. + Returns: None + """ + self.streaming = True + self.set_event_trigger("stream", fn, inputs, outputs, js=_js) diff --git a/gradio/interface.py b/gradio/interface.py index 5f47ef74c4..4dfef7fead 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -33,6 +33,7 @@ from gradio.components import ( Variable, get_component_instance, ) +from gradio.events import Changeable, Streamable from gradio.external import load_from_pipeline, load_interface # type: ignore from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore from gradio.layouts import Column, Row, TabItem, Tabs @@ -502,9 +503,22 @@ class Interface(Blocks): ) if self.live: for component in self.input_components: - component.change( - submit_fn, self.input_components, self.output_components - ) + if isinstance(component, Streamable): + if component.streaming: + component.stream( + submit_fn, self.input_components, self.output_components + ) + continue + else: + print( + "Hint: Set streaming=True for " + + component.__class__.__name__ + + " component to use live streaming." + ) + if isinstance(component, Changeable): + component.change( + submit_fn, self.input_components, self.output_components + ) else: submit_btn.click( submit_fn, diff --git a/gradio/test_data/blocks_configs.py b/gradio/test_data/blocks_configs.py index c97849ee8c..d7bf5b6f7a 100644 --- a/gradio/test_data/blocks_configs.py +++ b/gradio/test_data/blocks_configs.py @@ -44,6 +44,7 @@ XRAY_CONFIG = { "tool": "editor", "show_label": True, "name": "image", + "streaming": False, "visible": True, "style": {}, }, @@ -89,6 +90,7 @@ XRAY_CONFIG = { "tool": "editor", "show_label": True, "name": "image", + "streaming": False, "visible": True, "style": {}, }, @@ -239,6 +241,7 @@ XRAY_CONFIG_DIFF_IDS = { "tool": "editor", "show_label": True, "name": "image", + "streaming": False, "visible": True, "style": {}, }, @@ -284,6 +287,7 @@ XRAY_CONFIG_DIFF_IDS = { "tool": "editor", "show_label": True, "name": "image", + "streaming": False, "visible": True, "style": {}, }, @@ -439,6 +443,7 @@ XRAY_CONFIG_WITH_MISTAKE = { "source": "upload", "tool": "editor", "name": "image", + "streaming": False, "style": {}, }, }, @@ -484,6 +489,7 @@ XRAY_CONFIG_WITH_MISTAKE = { "source": "upload", "tool": "editor", "name": "image", + "streaming": False, "style": {}, }, }, diff --git a/test/test_components.py b/test/test_components.py index 9c8b164464..d7adfb5112 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -577,6 +577,7 @@ class TestImage(unittest.TestCase): "source": "upload", "tool": "editor", "name": "image", + "streaming": False, "show_label": True, "label": "Upload Your Image", "style": {}, @@ -734,6 +735,7 @@ class TestAudio(unittest.TestCase): { "source": "upload", "name": "audio", + "streaming": False, "show_label": True, "label": "Upload Your Audio", "style": {}, @@ -776,6 +778,7 @@ class TestAudio(unittest.TestCase): audio_output.get_config(), { "name": "audio", + "streaming": False, "show_label": True, "label": None, "source": "upload", diff --git a/ui/packages/app/src/Blocks.svelte b/ui/packages/app/src/Blocks.svelte index 31bfb22308..6c14d8a381 100644 --- a/ui/packages/app/src/Blocks.svelte +++ b/ui/packages/app/src/Blocks.svelte @@ -178,7 +178,6 @@ } let handled_dependencies: Array = []; - let status_tracker_values: Record = {}; async function handle_mount() { await tick(); @@ -311,7 +310,7 @@ $: set_status($loading_status); dependencies.forEach((v, i) => { - loading_status.register(i, v.outputs); + loading_status.register(i, v.inputs, v.outputs); }); function set_status( @@ -320,6 +319,10 @@ for (const id in statuses) { set_prop(instance_map[id], "loading_status", statuses[id]); } + const inputs_to_update = loading_status.get_inputs_to_update(); + for (const [id, pending_status] of inputs_to_update) { + set_prop(instance_map[id], "pending", pending_status === "pending"); + } } let mode = ""; @@ -387,7 +390,6 @@ {instance_map} {theme} {root} - {status_tracker_values} on:mount={handle_mount} on:destroy={({ detail }) => handle_destroy(detail)} /> diff --git a/ui/packages/app/src/Render.svelte b/ui/packages/app/src/Render.svelte index f36c8307f5..5efc3b3aa8 100644 --- a/ui/packages/app/src/Render.svelte +++ b/ui/packages/app/src/Render.svelte @@ -1,5 +1,6 @@ -{#if value === null} +{#if value === null || streaming} {#if source === "microphone"}
{#if recording} diff --git a/ui/packages/image/src/Image.svelte b/ui/packages/image/src/Image.svelte index 80fb0d16b8..706e947ff5 100644 --- a/ui/packages/image/src/Image.svelte +++ b/ui/packages/image/src/Image.svelte @@ -21,6 +21,8 @@ export let drop_text: string = "Drop an image file"; export let or_text: string = "or"; export let upload_text: string = "click to upload"; + export let streaming: boolean = false; + export let pending: boolean = false; let mode: "edit" | "view" = "view"; let sketch: Sketch; @@ -37,11 +39,12 @@ function handle_save({ detail }: { detail: string }) { value = detail; mode = "view"; - dispatch("edit"); + dispatch(streaming ? "stream" : "edit"); } const dispatch = createEventDispatcher<{ change: string | null; + stream: string | null; edit: undefined; clear: undefined; drag: boolean; @@ -67,7 +70,7 @@ on:clear={() => sketch.clear()} /> - {:else if value === null} + {:else if value === null || streaming} {#if source === "upload"} {:else if source === "webcam"} - + {/if} {:else if tool === "select"} diff --git a/ui/packages/image/src/Webcam.svelte b/ui/packages/image/src/Webcam.svelte index dcbf2ca2c2..a7bb966d55 100644 --- a/ui/packages/image/src/Webcam.svelte +++ b/ui/packages/image/src/Webcam.svelte @@ -4,6 +4,8 @@ let video_source: HTMLVideoElement; let canvas: HTMLCanvasElement; + export let streaming: boolean = false; + export let pending: boolean = false; export let mode: "image" | "video" = "image"; @@ -38,7 +40,7 @@ ); var data = canvas.toDataURL("image/png"); - dispatch("capture", data); + dispatch(streaming ? "stream" : "capture", data); } } @@ -88,29 +90,39 @@ } access_webcam(); + + if (streaming && mode === "image") { + window.setInterval(() => { + if (video_source && !pending) { + take_picture(); + } + }, 500); + }