Support for iterative outputs (#2189)

* Support for iterative outputs (#2162) (#2188)

* added generator demo

* fixed demo structure

* fixes

* fix failing tests due to refactor

* test components

* adding generators

* fixes

* iterative

* formatting

* add all

* added demo

* demo

* formatting

* fixed frontend

* 3.2.1b release

* removed test queue

* iterative

* formatting

* formatting

* Support for iterative outputs (#2149)

* added generator demo

* fixed demo structure

* fixes

* fix failing tests due to refactor

* test components

* adding generators

* fixes

* iterative

* formatting

* add all

* added demo

* demo

* formatting

* fixed frontend

* 3.2.1b release

* iterative

* formatting

* formatting

* reverted queue everywhere

* added queue to demos

* added fake diffusion with gif

* add to demos

* more complex counter

* fixes

* image gif

* fixes

* version

* merged

* added support for state

* formatting

* generating animation

* fix

* tests, iterator

* tests

* formatting

* tests for queuing

* version

* generating orange border animation

* testings

* added to documentation

Co-authored-by: Ali Abid <aabid94@gmail.com>
This commit is contained in:
Abubakar Abid 2022-09-08 07:35:31 -07:00 committed by GitHub
parent f8523868d0
commit bf1510165d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 404 additions and 97 deletions

View File

@ -0,0 +1,27 @@
import gradio as gr
import time
def count(n):
for i in range(int(n)):
time.sleep(0.5)
yield i
def show(n):
return str(list(range(int(n))))
with gr.Blocks() as demo:
with gr.Column():
num = gr.Number(value=10)
with gr.Row():
count_btn = gr.Button("Count")
list_btn = gr.Button("List")
with gr.Column():
out = gr.Textbox()
count_btn.click(count, num, out)
list_btn.click(show, num, out)
demo.queue()
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,22 @@
import gradio as gr
import numpy as np
import time
def fake_diffusion(steps):
for _ in range(steps):
time.sleep(1)
image = np.random.random((600, 600, 3))
yield image
image = "https://i.picsum.photos/id/867/600/600.jpg?hmac=qE7QFJwLmlE_WKI7zMH6SgH5iY5fx8ec6ZJQBwKRT44"
yield image
demo = gr.Interface(fake_diffusion,
inputs=gr.Slider(1, 10, 3),
outputs="image")
demo.queue()
if __name__ == "__main__":
demo.launch()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@ -0,0 +1,48 @@
import gradio as gr
import numpy as np
import time
import os
from PIL import Image
import requests
from io import BytesIO
def create_gif(images):
pil_images = []
for image in images:
if isinstance(image, str):
response = requests.get(image)
image = Image.open(BytesIO(response.content))
else:
image = Image.fromarray((image * 255).astype(np.uint8))
pil_images.append(image)
fp_out = os.path.join(os.path.dirname(__file__), "image.gif")
img = pil_images.pop(0)
img.save(fp=fp_out, format='GIF', append_images=pil_images,
save_all=True, duration=400, loop=0)
return fp_out
def fake_diffusion(steps):
images = []
for _ in range(steps):
time.sleep(1)
image = np.random.random((600, 600, 3))
images.append(image)
yield image, gr.Image.update(visible=False)
time.sleep(1)
image = "https://i.picsum.photos/id/867/600/600.jpg?hmac=qE7QFJwLmlE_WKI7zMH6SgH5iY5fx8ec6ZJQBwKRT44"
images.append(image)
gif_path = create_gif(images)
yield image, gr.Image.update(value=gif_path, visible=True)
demo = gr.Interface(fake_diffusion,
inputs=gr.Slider(1, 10, 3),
outputs=["image", gr.Image(label="All Images", visible=False)])
demo.queue()
if __name__ == "__main__":
demo.launch()

View File

@ -614,19 +614,47 @@ class Blocks(BlockContext):
processed_input = raw_input
return processed_input
async def call_function(self, fn_index, processed_input):
async def call_function(self, fn_index, processed_input, iterator=None):
"""Calls and times function with given index and preprocessed input."""
block_fn = self.fns[fn_index]
is_generating = False
start = time.time()
if inspect.iscoroutinefunction(block_fn.fn):
prediction = await block_fn.fn(*processed_input)
else:
prediction = await anyio.to_thread.run_sync(
block_fn.fn, *processed_input, limiter=self.limiter
)
if iterator is None: # If not a generator function that has already run
if inspect.iscoroutinefunction(block_fn.fn):
prediction = await block_fn.fn(*processed_input)
else:
prediction = await anyio.to_thread.run_sync(
block_fn.fn, *processed_input, limiter=self.limiter
)
if inspect.isasyncgenfunction(block_fn.fn):
raise ValueError("Gradio does not support async generators.")
if inspect.isgeneratorfunction(block_fn.fn):
if not self.enable_queue:
raise ValueError("Need to enable queue to use generators.")
try:
if iterator is None:
iterator = prediction
prediction = next(iterator)
is_generating = True
except StopIteration:
n_outputs = len(self.dependencies[fn_index].get("outputs"))
prediction = (
components._Keywords.FINISHED_ITERATING
if n_outputs == 1
else (components._Keywords.FINISHED_ITERATING,) * n_outputs
)
iterator = None
duration = time.time() - start
return prediction, duration
return {
"prediction": prediction,
"duration": duration,
"is_generating": is_generating,
"iterator": iterator,
}
def postprocess_data(self, fn_index, predictions, state):
block_fn = self.fns[fn_index]
@ -654,6 +682,9 @@ class Blocks(BlockContext):
if block_fn.postprocess:
output = []
for i, output_id in enumerate(dependency["outputs"]):
if predictions[i] is components._Keywords.FINISHED_ITERATING:
output.append(None)
break
block = self.blocks[output_id]
if getattr(block, "stateful", False):
if not is_update(predictions[i]):
@ -697,7 +728,8 @@ class Blocks(BlockContext):
fn_index: int,
inputs: List[Any],
username: str = None,
state: Optional[Dict[int, any]] = None,
state: Optional[Dict[int, Any]] = None,
iterators: Dict[int, Any] = None,
) -> Dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
@ -711,16 +743,19 @@ class Blocks(BlockContext):
block_fn = self.fns[fn_index]
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None)
predictions, duration = await self.call_function(fn_index, inputs)
block_fn.total_runtime += duration
result = await self.call_function(fn_index, inputs, iterator)
block_fn.total_runtime += result["duration"]
block_fn.total_runs += 1
predictions = self.postprocess_data(fn_index, predictions, state)
predictions = self.postprocess_data(fn_index, result["prediction"], state)
return {
"data": predictions,
"duration": duration,
"is_generating": result["is_generating"],
"iterator": result["iterator"],
"duration": result["duration"],
"average_duration": block_fn.total_runtime / block_fn.total_runs,
}
@ -967,6 +1002,7 @@ class Blocks(BlockContext):
"The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
DeprecationWarning,
)
if self.is_space:
self.enable_queue = self.enable_queue is not False
else:

View File

@ -62,6 +62,7 @@ set_documentation_group("component")
class _Keywords(Enum):
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
class Component(Block):

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import json
import time
from typing import List, Optional
from typing import Dict, List, Optional
import fastapi
from pydantic import BaseModel
@ -210,6 +210,14 @@ class Queue:
queue_eta=self.queue_duration,
)
async def call_prediction(self, event: Event):
response = await Request(
method=Request.Method.POST,
url=f"{self.server_path}api/predict",
json=event.data,
)
return response
async def process_event(self, event: Event) -> None:
client_awake = await self.gather_event_data(event)
if not client_awake:
@ -218,27 +226,44 @@ class Queue:
if not client_awake:
return
begin_time = time.time()
response = await Request(
method=Request.Method.POST,
url=f"{self.server_path}api/predict",
json=event.data,
)
response = await self.call_prediction(event)
if response.json.get("is_generating", False):
while response.json.get("is_generating", False):
old_response = response
await self.send_message(
event,
{
"msg": "process_generating",
"output": old_response.json,
"success": old_response.status == 200,
},
)
response = await self.call_prediction(event)
await self.send_message(
event,
{
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
},
)
else:
await self.send_message(
event,
{
"msg": "process_completed",
"output": response.json,
"success": response.status == 200,
},
)
end_time = time.time()
success = response.status == 200
if success:
if response.status == 200:
self.update_estimation(end_time - begin_time)
await self.send_message(
event,
{
"msg": "process_completed",
"output": response.json,
"success": success,
},
)
await event.disconnect()
await self.clean_event(event)
async def send_message(self, event, data: json) -> bool:
async def send_message(self, event, data: Dict) -> bool:
try:
await event.websocket.send_json(data=data)
return True
@ -246,7 +271,7 @@ class Queue:
await self.clean_event(event)
return False
async def get_message(self, event) -> Optional[json]:
async def get_message(self, event) -> Optional[Dict]:
try:
data = await event.websocket.receive_json()
return data

View File

@ -20,7 +20,7 @@ async def run_interpret(interface, raw_input):
for i, input_component in enumerate(interface.input_components)
]
original_output = await interface.call_function(0, processed_input)
original_output = original_output[0]
original_output = original_output["prediction"]
if len(interface.output_components) == 1:
original_output = [original_output]
@ -47,7 +47,7 @@ async def run_interpret(interface, raw_input):
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output[0]
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
@ -91,7 +91,7 @@ async def run_interpret(interface, raw_input):
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output[0]
neighbor_output = neighbor_output["prediction"]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
@ -144,7 +144,7 @@ async def run_interpret(interface, raw_input):
new_output = utils.synchronize_async(
interface.call_function, 0, processed_masked_input
)
new_output = new_output[0]
new_output = new_output["prediction"]
if len(interface.output_components) == 1:
new_output = [new_output]
pred = get_regression_or_classification_value(

View File

@ -40,7 +40,7 @@ class Parallel(gradio.Interface):
return_values_with_durations = await asyncio.gather(
*[interface.call_function(0, args) for interface in interfaces]
)
return_values = [rv[0] for rv in return_values_with_durations]
return_values = [rv["prediction"] for rv in return_values_with_durations]
combined_list = []
for interface, return_value in zip(interfaces, return_values):
if len(interface.output_components) == 1:
@ -91,7 +91,7 @@ class Series(gradio.Interface):
]
# run all of predictions sequentially
data = (await interface.call_function(0, data))[0]
data = (await interface.call_function(0, data))["prediction"]
if len(interface.output_components) == 1:
data = [data]

View File

@ -10,6 +10,7 @@ import os
import posixpath
import secrets
import traceback
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Type
@ -87,6 +88,9 @@ class App(FastAPI):
self.tokens = None
self.auth = None
self.blocks: Optional[gradio.Blocks] = None
self.state_holder = {}
self.iterators = defaultdict(dict)
super().__init__(**kwargs)
def configure_app(self, blocks: gradio.Blocks) -> None:
@ -115,7 +119,6 @@ class App(FastAPI):
allow_methods=["*"],
allow_headers=["*"],
)
app.state_holder = {}
@app.get("/user")
@app.get("/user/")
@ -255,14 +258,19 @@ class App(FastAPI):
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
iterators = app.iterators[body.session_hash]
else:
session_state = {}
iterator = {}
raw_input = body.data
fn_index = body.fn_index
try:
output = await app.blocks.process_api(
fn_index, raw_input, username, session_state
fn_index, raw_input, username, session_state, iterators
)
iterator = output.pop("iterator", None)
if hasattr(body, "session_hash"):
app.iterators[body.session_hash][fn_index] = iterator
if isinstance(output, Error):
raise output
except BaseException as error:

View File

@ -1 +1 @@
3.2
3.2.1b2

View File

@ -160,4 +160,24 @@ with gr.Blocks() as demo2:
demo2.launch()
```
Docs: Examples
## Iterative Outputs
In some cases, you may want to show a sequence of outputs rather than a single output. For example, you might have an image generation model and you want to show the image that is generated at each step, leading up to the final image.
In such cases, you can supply a **generator** function into Gradio instead of a regular function. Creating generators in Python is very simple: instead of a single `return` value, a function should `yield` a series of values instead. Usually the `yield` statement is put in some kind of loop. Here's an example of an generator that simply counts up to a given number:
```python
def my_generator(x):
for i in range(x):
yield i
```
You supply a generator into Gradio the same way as you would a regular function. For example, here's a a (fake) image generation model that generates noise for several steps before outputting an image:
$code_fake_diffusion
$demo_fake_diffusion
Note that we've added a `time.sleep(1)` in the iterator to create an artificial pause between steps so that you are able to observe the steps of the iterator (in a real image generation model, this probably wouldn't be necessary).
Supplying a generator into Gradio **requires** you to enable queuing in the underlying Interface or Blocks (see the queuing section above).

View File

@ -16,6 +16,7 @@ def copy_all_demos(source_dir: str, dest_dir: str):
"blocks_update",
"calculator",
"fake_gan",
"fake_diffusion_with_gif",
"gender_sentence_default_interpretation",
"image_mod_default_image",
"interface_parallel_load",

View File

@ -246,7 +246,6 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components):
def test_blocks_do_not_filter_none_values_from_updates(io_components):
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
with gr.Blocks() as demo:
for component in io_components:
@ -267,7 +266,6 @@ def test_blocks_do_not_filter_none_values_from_updates(io_components):
def test_blocks_does_not_replace_keyword_literal():
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button(value="Reset")
@ -281,5 +279,90 @@ def test_blocks_does_not_replace_keyword_literal():
assert output[0]["value"] == "NO_VALUE"
class TestCallFunction:
@pytest.mark.asyncio
async def test_call_regular_function(self):
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button()
btn.click(
lambda x: "Hello, " + x,
inputs=text,
outputs=text,
)
output = await demo.call_function(0, ["World"])
assert output["prediction"] == "Hello, World"
output = await demo.call_function(0, ["Abubakar"])
assert output["prediction"] == "Hello, Abubakar"
@pytest.mark.asyncio
async def test_call_generator(self):
def generator(x):
for i in range(x):
yield i
with gr.Blocks() as demo:
inp = gr.Number()
out = gr.Number()
btn = gr.Button()
btn.click(
generator,
inputs=inp,
outputs=out,
)
demo.queue()
output = await demo.call_function(0, [3])
assert output["prediction"] == 0
output = await demo.call_function(0, [3], iterator=output["iterator"])
assert output["prediction"] == 1
output = await demo.call_function(0, [3], iterator=output["iterator"])
assert output["prediction"] == 2
output = await demo.call_function(0, [3], iterator=output["iterator"])
assert output["prediction"] == gr.components._Keywords.FINISHED_ITERATING
assert output["iterator"] is None
output = await demo.call_function(0, [3], iterator=output["iterator"])
assert output["prediction"] == 0
@pytest.mark.asyncio
async def test_call_both_generator_and_function(self):
def generator(x):
for i in range(x):
yield i, x
with gr.Blocks() as demo:
inp = gr.Number()
out1 = gr.Number()
out2 = gr.Number()
btn = gr.Button()
inp.change(lambda x: x + x, inp, out1)
btn.click(
generator,
inputs=inp,
outputs=[out1, out2],
)
demo.queue()
output = await demo.call_function(0, [2])
assert output["prediction"] == 4
output = await demo.call_function(0, [-1])
assert output["prediction"] == -2
output = await demo.call_function(1, [3])
assert output["prediction"] == (0, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
assert output["prediction"] == (1, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
assert output["prediction"] == (2, 3)
output = await demo.call_function(1, [3], iterator=output["iterator"])
assert output["prediction"] == (gr.components._Keywords.FINISHED_ITERATING,) * 2
assert output["iterator"] is None
output = await demo.call_function(1, [3], iterator=output["iterator"])
assert output["prediction"] == (0, 3)
if __name__ == "__main__":
unittest.main()

View File

@ -1753,9 +1753,9 @@ class TestState:
io = gr.Interface(test, ["text", "state"], ["text", "state"])
result = await io.call_function(0, ["abc"])
assert result[0][0] == "abc def"
result = await io.call_function(0, ["abc", result[0][0]])
assert result[0][0] == "abcabc def"
assert result["prediction"][0] == "abc def"
result = await io.call_function(0, ["abc", result["prediction"][0]])
assert result["prediction"][0] == "abcabc def"
@pytest.mark.asyncio
async def test_in_blocks(self):
@ -1765,9 +1765,9 @@ class TestState:
btn.click(lambda x: x + 1, score, score)
result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2
assert result["prediction"] == 1
result = await demo.call_function(0, [result["prediction"]])
assert result["prediction"] == 2
@pytest.mark.asyncio
async def test_variable_for_backwards_compatibility(self):
@ -1777,9 +1777,9 @@ class TestState:
btn.click(lambda x: x + 1, score, score)
result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2
assert result["prediction"] == 1
result = await demo.call_function(0, [result["prediction"]])
assert result["prediction"] == 2
def test_dataframe_as_example_converts_dataframes():

View File

@ -10,42 +10,4 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestQueue:
@pytest.mark.asyncio
async def test_queue_with_single_event(self):
async def wait(data):
await asyncio.sleep(0.1)
return data
with gr.Blocks() as demo:
text = gr.Textbox()
button = gr.Button()
button.click(wait, [text], [text])
app, local_url, _ = demo.launch(prevent_thread_lock=True, enable_queue=True)
client = TestClient(app)
with client.websocket_connect("/queue/join") as _: # websocket
"""#Unable to make this part work, seems like there is an issue with thread acquire and exiting the scope
websocket.send_json({"hash": "0001"})
assert {
"avg_event_concurrent_process_time": 1.0,
"avg_event_process_time": 1.0,
"msg": "estimation",
"queue_eta": 1,
"queue_size": 0,
"rank": -1,
"rank_eta": -1,
} == websocket.receive_json()
while True:
message = websocket.receive_json()
if "estimation" == message["msg"]:
continue
elif "send_data" == message["msg"]:
websocket.send_json({"data": [1], "fn": 0})
elif "process_starts" == message["msg"]:
continue
elif "process_completed" == message["msg"]:
assert message["output"]["data"] == ["1"]
break
"""
demo.close()
pass # TODO

View File

@ -142,6 +142,52 @@ class TestRoutes(unittest.TestCase):
close_all()
class TestGeneratorRoutes:
def test_generator(self):
def generator(string):
for char in string:
yield char
io = Interface(generator, "text", "text")
app, _, _ = io.queue().launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
)
output = dict(response.json())
assert output["data"] == ["a"]
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
)
output = dict(response.json())
assert output["data"] == ["b"]
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
)
output = dict(response.json())
assert output["data"] == ["c"]
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
)
output = dict(response.json())
assert output["data"] == [None]
response = client.post(
"/api/predict/",
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
)
output = dict(response.json())
assert output["data"] == ["a"]
class TestApp:
def test_create_app(self):
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
@ -152,7 +198,9 @@ class TestAuthenticatedRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(
auth=("test", "correct_password"), prevent_thread_lock=True
auth=("test", "correct_password"),
prevent_thread_lock=True,
enable_queue=False,
)
self.client = TestClient(self.app)

View File

@ -26,6 +26,7 @@ type StatusResponse =
interface Payload {
data: Array<unknown>;
fn_index: number;
session_hash?: string;
}
declare let BUILD_MODE: string;
@ -89,6 +90,7 @@ export const fn =
}): Promise<unknown> => {
const fn_index = payload.fn_index;
payload.session_hash = session_hash;
if (frontend_fn !== undefined) {
payload.data = await frontend_fn(payload.data.concat(output_data));
}
@ -172,6 +174,20 @@ export const fn =
null
);
break;
case "process_generating":
loading_status.update(
fn_index,
data.success ? "generating" : "error",
queue,
null,
null,
data.output.average_duration,
!data.success ? data.output.error : null
);
if (data.success) {
queue_callback(data.output);
}
break;
case "process_completed":
loading_status.update(
fn_index,

View File

@ -52,7 +52,7 @@
export let queue: boolean = false;
export let queue_position: number | null;
export let queue_size: number | null;
export let status: "complete" | "pending" | "error";
export let status: "complete" | "pending" | "error" | "generating";
export let scroll_to_output: boolean = false;
export let timer: boolean = true;
export let visible: boolean = true;
@ -141,6 +141,8 @@
<div
class="wrap"
class:opacity-0={!status || status === "complete"}
class:cover-bg={status === "pending" || status === "error"}
class:generating={status === "generating"}
class:!hidden={!visible}
bind:this={el}
>
@ -187,13 +189,21 @@
<style lang="postcss">
.wrap {
@apply absolute inset-0 z-50 flex flex-col justify-center items-center bg-white dark:bg-gray-800 pointer-events-none transition-opacity max-h-screen;
@apply absolute inset-0 z-50 flex flex-col justify-center items-center dark:bg-gray-800 pointer-events-none transition-opacity max-h-screen;
}
:global(.dark) .wrap {
:global(.dark) .cover-bg {
@apply bg-gray-800;
}
.cover-bg {
@apply bg-white;
}
.generating {
@apply border-2 border-orange-500 animate-pulse;
}
.progress-bar {
@apply absolute inset-0 origin-left bg-slate-100 dark:bg-gray-700 top-0 left-0 z-10 opacity-80;
}

View File

@ -2,7 +2,7 @@ import { writable } from "svelte/store";
export interface LoadingStatus {
eta: number | null;
status: "pending" | "error" | "complete";
status: "pending" | "error" | "complete" | "generating";
queue: boolean;
queue_position: number | null;
queue_size: number | null;