mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Support for iterative outputs (#2189)
* Support for iterative outputs (#2162) (#2188) * added generator demo * fixed demo structure * fixes * fix failing tests due to refactor * test components * adding generators * fixes * iterative * formatting * add all * added demo * demo * formatting * fixed frontend * 3.2.1b release * removed test queue * iterative * formatting * formatting * Support for iterative outputs (#2149) * added generator demo * fixed demo structure * fixes * fix failing tests due to refactor * test components * adding generators * fixes * iterative * formatting * add all * added demo * demo * formatting * fixed frontend * 3.2.1b release * iterative * formatting * formatting * reverted queue everywhere * added queue to demos * added fake diffusion with gif * add to demos * more complex counter * fixes * image gif * fixes * version * merged * added support for state * formatting * generating animation * fix * tests, iterator * tests * formatting * tests for queuing * version * generating orange border animation * testings * added to documentation Co-authored-by: Ali Abid <aabid94@gmail.com>
This commit is contained in:
parent
f8523868d0
commit
bf1510165d
27
demo/count_generator/run.py
Normal file
27
demo/count_generator/run.py
Normal file
@ -0,0 +1,27 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
def count(n):
|
||||
for i in range(int(n)):
|
||||
time.sleep(0.5)
|
||||
yield i
|
||||
|
||||
def show(n):
|
||||
return str(list(range(int(n))))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
num = gr.Number(value=10)
|
||||
with gr.Row():
|
||||
count_btn = gr.Button("Count")
|
||||
list_btn = gr.Button("List")
|
||||
with gr.Column():
|
||||
out = gr.Textbox()
|
||||
|
||||
count_btn.click(count, num, out)
|
||||
list_btn.click(show, num, out)
|
||||
|
||||
demo.queue()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
22
demo/fake_diffusion/run.py
Normal file
22
demo/fake_diffusion/run.py
Normal file
@ -0,0 +1,22 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def fake_diffusion(steps):
|
||||
for _ in range(steps):
|
||||
time.sleep(1)
|
||||
image = np.random.random((600, 600, 3))
|
||||
yield image
|
||||
|
||||
image = "https://i.picsum.photos/id/867/600/600.jpg?hmac=qE7QFJwLmlE_WKI7zMH6SgH5iY5fx8ec6ZJQBwKRT44"
|
||||
yield image
|
||||
|
||||
|
||||
demo = gr.Interface(fake_diffusion,
|
||||
inputs=gr.Slider(1, 10, 3),
|
||||
outputs="image")
|
||||
demo.queue()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
BIN
demo/fake_diffusion_with_gif/image.gif
Normal file
BIN
demo/fake_diffusion_with_gif/image.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 MiB |
48
demo/fake_diffusion_with_gif/run.py
Normal file
48
demo/fake_diffusion_with_gif/run.py
Normal file
@ -0,0 +1,48 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def create_gif(images):
|
||||
pil_images = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
response = requests.get(image)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
else:
|
||||
image = Image.fromarray((image * 255).astype(np.uint8))
|
||||
pil_images.append(image)
|
||||
fp_out = os.path.join(os.path.dirname(__file__), "image.gif")
|
||||
img = pil_images.pop(0)
|
||||
img.save(fp=fp_out, format='GIF', append_images=pil_images,
|
||||
save_all=True, duration=400, loop=0)
|
||||
return fp_out
|
||||
|
||||
|
||||
def fake_diffusion(steps):
|
||||
images = []
|
||||
for _ in range(steps):
|
||||
time.sleep(1)
|
||||
image = np.random.random((600, 600, 3))
|
||||
images.append(image)
|
||||
yield image, gr.Image.update(visible=False)
|
||||
|
||||
time.sleep(1)
|
||||
image = "https://i.picsum.photos/id/867/600/600.jpg?hmac=qE7QFJwLmlE_WKI7zMH6SgH5iY5fx8ec6ZJQBwKRT44"
|
||||
images.append(image)
|
||||
gif_path = create_gif(images)
|
||||
|
||||
yield image, gr.Image.update(value=gif_path, visible=True)
|
||||
|
||||
|
||||
demo = gr.Interface(fake_diffusion,
|
||||
inputs=gr.Slider(1, 10, 3),
|
||||
outputs=["image", gr.Image(label="All Images", visible=False)])
|
||||
demo.queue()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -614,19 +614,47 @@ class Blocks(BlockContext):
|
||||
processed_input = raw_input
|
||||
return processed_input
|
||||
|
||||
async def call_function(self, fn_index, processed_input):
|
||||
async def call_function(self, fn_index, processed_input, iterator=None):
|
||||
"""Calls and times function with given index and preprocessed input."""
|
||||
block_fn = self.fns[fn_index]
|
||||
|
||||
is_generating = False
|
||||
start = time.time()
|
||||
if inspect.iscoroutinefunction(block_fn.fn):
|
||||
prediction = await block_fn.fn(*processed_input)
|
||||
else:
|
||||
prediction = await anyio.to_thread.run_sync(
|
||||
block_fn.fn, *processed_input, limiter=self.limiter
|
||||
)
|
||||
|
||||
if iterator is None: # If not a generator function that has already run
|
||||
if inspect.iscoroutinefunction(block_fn.fn):
|
||||
prediction = await block_fn.fn(*processed_input)
|
||||
else:
|
||||
prediction = await anyio.to_thread.run_sync(
|
||||
block_fn.fn, *processed_input, limiter=self.limiter
|
||||
)
|
||||
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
raise ValueError("Gradio does not support async generators.")
|
||||
if inspect.isgeneratorfunction(block_fn.fn):
|
||||
if not self.enable_queue:
|
||||
raise ValueError("Need to enable queue to use generators.")
|
||||
try:
|
||||
if iterator is None:
|
||||
iterator = prediction
|
||||
prediction = next(iterator)
|
||||
is_generating = True
|
||||
except StopIteration:
|
||||
n_outputs = len(self.dependencies[fn_index].get("outputs"))
|
||||
prediction = (
|
||||
components._Keywords.FINISHED_ITERATING
|
||||
if n_outputs == 1
|
||||
else (components._Keywords.FINISHED_ITERATING,) * n_outputs
|
||||
)
|
||||
iterator = None
|
||||
|
||||
duration = time.time() - start
|
||||
return prediction, duration
|
||||
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"duration": duration,
|
||||
"is_generating": is_generating,
|
||||
"iterator": iterator,
|
||||
}
|
||||
|
||||
def postprocess_data(self, fn_index, predictions, state):
|
||||
block_fn = self.fns[fn_index]
|
||||
@ -654,6 +682,9 @@ class Blocks(BlockContext):
|
||||
if block_fn.postprocess:
|
||||
output = []
|
||||
for i, output_id in enumerate(dependency["outputs"]):
|
||||
if predictions[i] is components._Keywords.FINISHED_ITERATING:
|
||||
output.append(None)
|
||||
break
|
||||
block = self.blocks[output_id]
|
||||
if getattr(block, "stateful", False):
|
||||
if not is_update(predictions[i]):
|
||||
@ -697,7 +728,8 @@ class Blocks(BlockContext):
|
||||
fn_index: int,
|
||||
inputs: List[Any],
|
||||
username: str = None,
|
||||
state: Optional[Dict[int, any]] = None,
|
||||
state: Optional[Dict[int, Any]] = None,
|
||||
iterators: Dict[int, Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Processes API calls from the frontend. First preprocesses the data,
|
||||
@ -711,16 +743,19 @@ class Blocks(BlockContext):
|
||||
block_fn = self.fns[fn_index]
|
||||
|
||||
inputs = self.preprocess_data(fn_index, inputs, state)
|
||||
iterator = iterators.get(fn_index, None)
|
||||
|
||||
predictions, duration = await self.call_function(fn_index, inputs)
|
||||
block_fn.total_runtime += duration
|
||||
result = await self.call_function(fn_index, inputs, iterator)
|
||||
block_fn.total_runtime += result["duration"]
|
||||
block_fn.total_runs += 1
|
||||
|
||||
predictions = self.postprocess_data(fn_index, predictions, state)
|
||||
predictions = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
|
||||
return {
|
||||
"data": predictions,
|
||||
"duration": duration,
|
||||
"is_generating": result["is_generating"],
|
||||
"iterator": result["iterator"],
|
||||
"duration": result["duration"],
|
||||
"average_duration": block_fn.total_runtime / block_fn.total_runs,
|
||||
}
|
||||
|
||||
@ -967,6 +1002,7 @@ class Blocks(BlockContext):
|
||||
"The `enable_queue` parameter has been deprecated. Please use the `.queue()` method instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.is_space:
|
||||
self.enable_queue = self.enable_queue is not False
|
||||
else:
|
||||
|
@ -62,6 +62,7 @@ set_documentation_group("component")
|
||||
|
||||
class _Keywords(Enum):
|
||||
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
|
||||
FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
|
||||
|
||||
|
||||
class Component(Block):
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from pydantic import BaseModel
|
||||
@ -210,6 +210,14 @@ class Queue:
|
||||
queue_eta=self.queue_duration,
|
||||
)
|
||||
|
||||
async def call_prediction(self, event: Event):
|
||||
response = await Request(
|
||||
method=Request.Method.POST,
|
||||
url=f"{self.server_path}api/predict",
|
||||
json=event.data,
|
||||
)
|
||||
return response
|
||||
|
||||
async def process_event(self, event: Event) -> None:
|
||||
client_awake = await self.gather_event_data(event)
|
||||
if not client_awake:
|
||||
@ -218,27 +226,44 @@ class Queue:
|
||||
if not client_awake:
|
||||
return
|
||||
begin_time = time.time()
|
||||
response = await Request(
|
||||
method=Request.Method.POST,
|
||||
url=f"{self.server_path}api/predict",
|
||||
json=event.data,
|
||||
)
|
||||
response = await self.call_prediction(event)
|
||||
if response.json.get("is_generating", False):
|
||||
while response.json.get("is_generating", False):
|
||||
old_response = response
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_generating",
|
||||
"output": old_response.json,
|
||||
"success": old_response.status == 200,
|
||||
},
|
||||
)
|
||||
response = await self.call_prediction(event)
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": old_response.json,
|
||||
"success": old_response.status == 200,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": response.json,
|
||||
"success": response.status == 200,
|
||||
},
|
||||
)
|
||||
end_time = time.time()
|
||||
success = response.status == 200
|
||||
if success:
|
||||
if response.status == 200:
|
||||
self.update_estimation(end_time - begin_time)
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": response.json,
|
||||
"success": success,
|
||||
},
|
||||
)
|
||||
|
||||
await event.disconnect()
|
||||
await self.clean_event(event)
|
||||
|
||||
async def send_message(self, event, data: json) -> bool:
|
||||
async def send_message(self, event, data: Dict) -> bool:
|
||||
try:
|
||||
await event.websocket.send_json(data=data)
|
||||
return True
|
||||
@ -246,7 +271,7 @@ class Queue:
|
||||
await self.clean_event(event)
|
||||
return False
|
||||
|
||||
async def get_message(self, event) -> Optional[json]:
|
||||
async def get_message(self, event) -> Optional[Dict]:
|
||||
try:
|
||||
data = await event.websocket.receive_json()
|
||||
return data
|
||||
|
@ -20,7 +20,7 @@ async def run_interpret(interface, raw_input):
|
||||
for i, input_component in enumerate(interface.input_components)
|
||||
]
|
||||
original_output = await interface.call_function(0, processed_input)
|
||||
original_output = original_output[0]
|
||||
original_output = original_output["prediction"]
|
||||
|
||||
if len(interface.output_components) == 1:
|
||||
original_output = [original_output]
|
||||
@ -47,7 +47,7 @@ async def run_interpret(interface, raw_input):
|
||||
neighbor_output = await interface.call_function(
|
||||
0, processed_neighbor_input
|
||||
)
|
||||
neighbor_output = neighbor_output[0]
|
||||
neighbor_output = neighbor_output["prediction"]
|
||||
if len(interface.output_components) == 1:
|
||||
neighbor_output = [neighbor_output]
|
||||
processed_neighbor_output = [
|
||||
@ -91,7 +91,7 @@ async def run_interpret(interface, raw_input):
|
||||
neighbor_output = await interface.call_function(
|
||||
0, processed_neighbor_input
|
||||
)
|
||||
neighbor_output = neighbor_output[0]
|
||||
neighbor_output = neighbor_output["prediction"]
|
||||
if len(interface.output_components) == 1:
|
||||
neighbor_output = [neighbor_output]
|
||||
processed_neighbor_output = [
|
||||
@ -144,7 +144,7 @@ async def run_interpret(interface, raw_input):
|
||||
new_output = utils.synchronize_async(
|
||||
interface.call_function, 0, processed_masked_input
|
||||
)
|
||||
new_output = new_output[0]
|
||||
new_output = new_output["prediction"]
|
||||
if len(interface.output_components) == 1:
|
||||
new_output = [new_output]
|
||||
pred = get_regression_or_classification_value(
|
||||
|
@ -40,7 +40,7 @@ class Parallel(gradio.Interface):
|
||||
return_values_with_durations = await asyncio.gather(
|
||||
*[interface.call_function(0, args) for interface in interfaces]
|
||||
)
|
||||
return_values = [rv[0] for rv in return_values_with_durations]
|
||||
return_values = [rv["prediction"] for rv in return_values_with_durations]
|
||||
combined_list = []
|
||||
for interface, return_value in zip(interfaces, return_values):
|
||||
if len(interface.output_components) == 1:
|
||||
@ -91,7 +91,7 @@ class Series(gradio.Interface):
|
||||
]
|
||||
|
||||
# run all of predictions sequentially
|
||||
data = (await interface.call_function(0, data))[0]
|
||||
data = (await interface.call_function(0, data))["prediction"]
|
||||
if len(interface.output_components) == 1:
|
||||
data = [data]
|
||||
|
||||
|
@ -10,6 +10,7 @@ import os
|
||||
import posixpath
|
||||
import secrets
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Type
|
||||
@ -87,6 +88,9 @@ class App(FastAPI):
|
||||
self.tokens = None
|
||||
self.auth = None
|
||||
self.blocks: Optional[gradio.Blocks] = None
|
||||
self.state_holder = {}
|
||||
self.iterators = defaultdict(dict)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def configure_app(self, blocks: gradio.Blocks) -> None:
|
||||
@ -115,7 +119,6 @@ class App(FastAPI):
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.state_holder = {}
|
||||
|
||||
@app.get("/user")
|
||||
@app.get("/user/")
|
||||
@ -255,14 +258,19 @@ class App(FastAPI):
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
iterators = app.iterators[body.session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
iterator = {}
|
||||
raw_input = body.data
|
||||
fn_index = body.fn_index
|
||||
try:
|
||||
output = await app.blocks.process_api(
|
||||
fn_index, raw_input, username, session_state
|
||||
fn_index, raw_input, username, session_state, iterators
|
||||
)
|
||||
iterator = output.pop("iterator", None)
|
||||
if hasattr(body, "session_hash"):
|
||||
app.iterators[body.session_hash][fn_index] = iterator
|
||||
if isinstance(output, Error):
|
||||
raise output
|
||||
except BaseException as error:
|
||||
|
@ -1 +1 @@
|
||||
3.2
|
||||
3.2.1b2
|
||||
|
@ -160,4 +160,24 @@ with gr.Blocks() as demo2:
|
||||
demo2.launch()
|
||||
```
|
||||
|
||||
Docs: Examples
|
||||
## Iterative Outputs
|
||||
|
||||
In some cases, you may want to show a sequence of outputs rather than a single output. For example, you might have an image generation model and you want to show the image that is generated at each step, leading up to the final image.
|
||||
|
||||
In such cases, you can supply a **generator** function into Gradio instead of a regular function. Creating generators in Python is very simple: instead of a single `return` value, a function should `yield` a series of values instead. Usually the `yield` statement is put in some kind of loop. Here's an example of an generator that simply counts up to a given number:
|
||||
|
||||
```python
|
||||
def my_generator(x):
|
||||
for i in range(x):
|
||||
yield i
|
||||
```
|
||||
|
||||
You supply a generator into Gradio the same way as you would a regular function. For example, here's a a (fake) image generation model that generates noise for several steps before outputting an image:
|
||||
|
||||
$code_fake_diffusion
|
||||
$demo_fake_diffusion
|
||||
|
||||
Note that we've added a `time.sleep(1)` in the iterator to create an artificial pause between steps so that you are able to observe the steps of the iterator (in a real image generation model, this probably wouldn't be necessary).
|
||||
|
||||
Supplying a generator into Gradio **requires** you to enable queuing in the underlying Interface or Blocks (see the queuing section above).
|
||||
|
||||
|
@ -16,6 +16,7 @@ def copy_all_demos(source_dir: str, dest_dir: str):
|
||||
"blocks_update",
|
||||
"calculator",
|
||||
"fake_gan",
|
||||
"fake_diffusion_with_gif",
|
||||
"gender_sentence_default_interpretation",
|
||||
"image_mod_default_image",
|
||||
"interface_parallel_load",
|
||||
|
@ -246,7 +246,6 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components):
|
||||
|
||||
|
||||
def test_blocks_do_not_filter_none_values_from_updates(io_components):
|
||||
|
||||
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
@ -267,7 +266,6 @@ def test_blocks_do_not_filter_none_values_from_updates(io_components):
|
||||
|
||||
|
||||
def test_blocks_does_not_replace_keyword_literal():
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.Textbox()
|
||||
btn = gr.Button(value="Reset")
|
||||
@ -281,5 +279,90 @@ def test_blocks_does_not_replace_keyword_literal():
|
||||
assert output[0]["value"] == "NO_VALUE"
|
||||
|
||||
|
||||
class TestCallFunction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_regular_function(self):
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.Textbox()
|
||||
btn = gr.Button()
|
||||
btn.click(
|
||||
lambda x: "Hello, " + x,
|
||||
inputs=text,
|
||||
outputs=text,
|
||||
)
|
||||
|
||||
output = await demo.call_function(0, ["World"])
|
||||
assert output["prediction"] == "Hello, World"
|
||||
output = await demo.call_function(0, ["Abubakar"])
|
||||
assert output["prediction"] == "Hello, Abubakar"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_generator(self):
|
||||
def generator(x):
|
||||
for i in range(x):
|
||||
yield i
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
inp = gr.Number()
|
||||
out = gr.Number()
|
||||
btn = gr.Button()
|
||||
btn.click(
|
||||
generator,
|
||||
inputs=inp,
|
||||
outputs=out,
|
||||
)
|
||||
|
||||
demo.queue()
|
||||
|
||||
output = await demo.call_function(0, [3])
|
||||
assert output["prediction"] == 0
|
||||
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == 1
|
||||
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == 2
|
||||
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == gr.components._Keywords.FINISHED_ITERATING
|
||||
assert output["iterator"] is None
|
||||
output = await demo.call_function(0, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_both_generator_and_function(self):
|
||||
def generator(x):
|
||||
for i in range(x):
|
||||
yield i, x
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
inp = gr.Number()
|
||||
out1 = gr.Number()
|
||||
out2 = gr.Number()
|
||||
btn = gr.Button()
|
||||
inp.change(lambda x: x + x, inp, out1)
|
||||
btn.click(
|
||||
generator,
|
||||
inputs=inp,
|
||||
outputs=[out1, out2],
|
||||
)
|
||||
|
||||
demo.queue()
|
||||
|
||||
output = await demo.call_function(0, [2])
|
||||
assert output["prediction"] == 4
|
||||
output = await demo.call_function(0, [-1])
|
||||
assert output["prediction"] == -2
|
||||
|
||||
output = await demo.call_function(1, [3])
|
||||
assert output["prediction"] == (0, 3)
|
||||
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == (1, 3)
|
||||
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == (2, 3)
|
||||
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == (gr.components._Keywords.FINISHED_ITERATING,) * 2
|
||||
assert output["iterator"] is None
|
||||
output = await demo.call_function(1, [3], iterator=output["iterator"])
|
||||
assert output["prediction"] == (0, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1753,9 +1753,9 @@ class TestState:
|
||||
|
||||
io = gr.Interface(test, ["text", "state"], ["text", "state"])
|
||||
result = await io.call_function(0, ["abc"])
|
||||
assert result[0][0] == "abc def"
|
||||
result = await io.call_function(0, ["abc", result[0][0]])
|
||||
assert result[0][0] == "abcabc def"
|
||||
assert result["prediction"][0] == "abc def"
|
||||
result = await io.call_function(0, ["abc", result["prediction"][0]])
|
||||
assert result["prediction"][0] == "abcabc def"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_blocks(self):
|
||||
@ -1765,9 +1765,9 @@ class TestState:
|
||||
btn.click(lambda x: x + 1, score, score)
|
||||
|
||||
result = await demo.call_function(0, [0])
|
||||
assert result[0] == 1
|
||||
result = await demo.call_function(0, [result[0]])
|
||||
assert result[0] == 2
|
||||
assert result["prediction"] == 1
|
||||
result = await demo.call_function(0, [result["prediction"]])
|
||||
assert result["prediction"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_for_backwards_compatibility(self):
|
||||
@ -1777,9 +1777,9 @@ class TestState:
|
||||
btn.click(lambda x: x + 1, score, score)
|
||||
|
||||
result = await demo.call_function(0, [0])
|
||||
assert result[0] == 1
|
||||
result = await demo.call_function(0, [result[0]])
|
||||
assert result[0] == 2
|
||||
assert result["prediction"] == 1
|
||||
result = await demo.call_function(0, [result["prediction"]])
|
||||
assert result["prediction"] == 2
|
||||
|
||||
|
||||
def test_dataframe_as_example_converts_dataframes():
|
||||
|
@ -10,42 +10,4 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestQueue:
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_with_single_event(self):
|
||||
async def wait(data):
|
||||
await asyncio.sleep(0.1)
|
||||
return data
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.Textbox()
|
||||
button = gr.Button()
|
||||
button.click(wait, [text], [text])
|
||||
app, local_url, _ = demo.launch(prevent_thread_lock=True, enable_queue=True)
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/queue/join") as _: # websocket
|
||||
"""#Unable to make this part work, seems like there is an issue with thread acquire and exiting the scope
|
||||
websocket.send_json({"hash": "0001"})
|
||||
assert {
|
||||
"avg_event_concurrent_process_time": 1.0,
|
||||
"avg_event_process_time": 1.0,
|
||||
"msg": "estimation",
|
||||
"queue_eta": 1,
|
||||
"queue_size": 0,
|
||||
"rank": -1,
|
||||
"rank_eta": -1,
|
||||
} == websocket.receive_json()
|
||||
|
||||
while True:
|
||||
message = websocket.receive_json()
|
||||
if "estimation" == message["msg"]:
|
||||
continue
|
||||
elif "send_data" == message["msg"]:
|
||||
websocket.send_json({"data": [1], "fn": 0})
|
||||
elif "process_starts" == message["msg"]:
|
||||
continue
|
||||
elif "process_completed" == message["msg"]:
|
||||
assert message["output"]["data"] == ["1"]
|
||||
break
|
||||
"""
|
||||
|
||||
demo.close()
|
||||
pass # TODO
|
||||
|
@ -142,6 +142,52 @@ class TestRoutes(unittest.TestCase):
|
||||
close_all()
|
||||
|
||||
|
||||
class TestGeneratorRoutes:
|
||||
def test_generator(self):
|
||||
def generator(string):
|
||||
for char in string:
|
||||
yield char
|
||||
|
||||
io = Interface(generator, "text", "text")
|
||||
app, _, _ = io.queue().launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == ["a"]
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == ["b"]
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == ["c"]
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == [None]
|
||||
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["abc"], "fn_index": 0, "session_hash": "11"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
assert output["data"] == ["a"]
|
||||
|
||||
|
||||
class TestApp:
|
||||
def test_create_app(self):
|
||||
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
|
||||
@ -152,7 +198,9 @@ class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
auth=("test", "correct_password"),
|
||||
prevent_thread_lock=True,
|
||||
enable_queue=False,
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
|
@ -26,6 +26,7 @@ type StatusResponse =
|
||||
interface Payload {
|
||||
data: Array<unknown>;
|
||||
fn_index: number;
|
||||
session_hash?: string;
|
||||
}
|
||||
|
||||
declare let BUILD_MODE: string;
|
||||
@ -89,6 +90,7 @@ export const fn =
|
||||
}): Promise<unknown> => {
|
||||
const fn_index = payload.fn_index;
|
||||
|
||||
payload.session_hash = session_hash;
|
||||
if (frontend_fn !== undefined) {
|
||||
payload.data = await frontend_fn(payload.data.concat(output_data));
|
||||
}
|
||||
@ -172,6 +174,20 @@ export const fn =
|
||||
null
|
||||
);
|
||||
break;
|
||||
case "process_generating":
|
||||
loading_status.update(
|
||||
fn_index,
|
||||
data.success ? "generating" : "error",
|
||||
queue,
|
||||
null,
|
||||
null,
|
||||
data.output.average_duration,
|
||||
!data.success ? data.output.error : null
|
||||
);
|
||||
if (data.success) {
|
||||
queue_callback(data.output);
|
||||
}
|
||||
break;
|
||||
case "process_completed":
|
||||
loading_status.update(
|
||||
fn_index,
|
||||
|
@ -52,7 +52,7 @@
|
||||
export let queue: boolean = false;
|
||||
export let queue_position: number | null;
|
||||
export let queue_size: number | null;
|
||||
export let status: "complete" | "pending" | "error";
|
||||
export let status: "complete" | "pending" | "error" | "generating";
|
||||
export let scroll_to_output: boolean = false;
|
||||
export let timer: boolean = true;
|
||||
export let visible: boolean = true;
|
||||
@ -141,6 +141,8 @@
|
||||
<div
|
||||
class="wrap"
|
||||
class:opacity-0={!status || status === "complete"}
|
||||
class:cover-bg={status === "pending" || status === "error"}
|
||||
class:generating={status === "generating"}
|
||||
class:!hidden={!visible}
|
||||
bind:this={el}
|
||||
>
|
||||
@ -187,13 +189,21 @@
|
||||
|
||||
<style lang="postcss">
|
||||
.wrap {
|
||||
@apply absolute inset-0 z-50 flex flex-col justify-center items-center bg-white dark:bg-gray-800 pointer-events-none transition-opacity max-h-screen;
|
||||
@apply absolute inset-0 z-50 flex flex-col justify-center items-center dark:bg-gray-800 pointer-events-none transition-opacity max-h-screen;
|
||||
}
|
||||
|
||||
:global(.dark) .wrap {
|
||||
:global(.dark) .cover-bg {
|
||||
@apply bg-gray-800;
|
||||
}
|
||||
|
||||
.cover-bg {
|
||||
@apply bg-white;
|
||||
}
|
||||
|
||||
.generating {
|
||||
@apply border-2 border-orange-500 animate-pulse;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
@apply absolute inset-0 origin-left bg-slate-100 dark:bg-gray-700 top-0 left-0 z-10 opacity-80;
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ import { writable } from "svelte/store";
|
||||
|
||||
export interface LoadingStatus {
|
||||
eta: number | null;
|
||||
status: "pending" | "error" | "complete";
|
||||
status: "pending" | "error" | "complete" | "generating";
|
||||
queue: boolean;
|
||||
queue_position: number | null;
|
||||
queue_size: number | null;
|
||||
|
Loading…
x
Reference in New Issue
Block a user