* batch

* formatting

* added parameter

* batch

* added docstrings

* correct arguments

* docstring

* adapt process_api for batch

* backend

* __call__

* more regular tests

* formatting

* cleaning up blocks.py

* __call__ works

* api route works

* first attempt at queue

* fixing tests

* fix some tests

* formatting

* removed print

* merge

* queue works!

* removed batch timeout

* removed batch timeout

* updated documentation

* fixing tests

* fixing tests

* fixing queue tests

* fixing queue tests

* formatting

* fix blocks config

* fix tests

* update documentation

* updated tests

* blocks

* blocks

* blocks

* tests

* test fixes

* more tests

* faster

* foramtting

* test fixes

* dataclasses

* fix

* revert to fix test

* fix

* fix test

* formatting

* fix tests

* refactoring examples

* formatting

* changelog

* fix examples

* formatting

* fix tests

* formatting

* catch error

* formatting

* fix tests

* fix cancel with batch

* final tests and docs

* test routes

* formatting
This commit is contained in:
Abubakar Abid 2022-10-24 16:32:37 -07:00 committed by GitHub
parent 834d945b1a
commit 0e168c4dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1174 additions and 554 deletions

View File

@ -1,7 +1,52 @@
# Upcoming Release
## New Features:
No changes to highlight.
### Batched Functions
Gradio now supports the ability to pass *batched* functions. Batched functions are just
functions which take in a list of inputs and return a list of predictions.
For example, here is a batched function that takes in two lists of inputs (a list of
words and a list of ints), and returns a list of trimmed words as output:
```py
import time
def trim_words(words, lens):
trimmed_words = []
time.sleep(5)
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
```
The advantage of using batched functions is that if you enable queuing, the Gradio
server can automatically *batch* incoming requests and process them in parallel,
potentially speeding up your demo. Here's what the Gradio code looks like (notice
the `batch=True` and `max_batch_size=16` -- both of these parameters can be passed
into event triggers or into the `Interface` class)
```py
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
word = gr.Textbox(label="word", value="abc")
leng = gr.Number(label="leng", precision=0, value=1)
output = gr.Textbox(label="Output")
with gr.Row():
run = gr.Button()
event = run.click(trim_words, [word, leng], output, batch=True, max_batch_size=16)
demo.queue()
demo.launch()
```
In the example above, 16 requests could be processed in parallel (for a total inference
time of 5 seconds), instead of each request being processed separately (for a total
inference time of 80 seconds).
## Bug Fixes:
* Fixes issue where plotly animations, interactivity, titles, legends, were not working properly. [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 2486](https://github.com/gradio-app/gradio/pull/2486)
@ -19,6 +64,7 @@ No changes to highlight.
* Fixes the error message if a user builds Gradio locally and tries to use `share=True` by [@abidlabs](https://github.com/abidlabs) in [PR 2502](https://github.com/gradio-app/gradio/pull/2502)
* Allows the render() function to return self by [@Raul9595](https://github.com/Raul9595) in [PR 2514](https://github.com/gradio-app/gradio/pull/2514)
* Fixes issue where plotly animations, interactivity, titles, legends, were not working properly. [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 2486](https://github.com/gradio-app/gradio/pull/2486)
* Gradio now supports batched functions by [@abidlabs](https://github.com/abidlabs) in [PR 2218](https://github.com/gradio-app/gradio/pull/2218)
## Contributors Shoutout:
No changes to highlight.

View File

@ -11,7 +11,17 @@ import time
import warnings
import webbrowser
from types import ModuleType
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
)
import anyio
import requests
@ -109,17 +119,19 @@ class Block:
def set_event_trigger(
self,
event_name: str,
fn: Optional[Callable],
inputs: Optional[Component | List[Component]],
outputs: Optional[Component | List[Component]],
fn: Callable | None,
inputs: Component | List[Component] | None,
outputs: Component | List[Component] | None,
preprocess: bool = True,
postprocess: bool = True,
scroll_to_output: bool = False,
show_progress: bool = True,
api_name: Optional[AnyStr] = None,
js: Optional[str] = None,
api_name: AnyStr | None = None,
js: str | None = None,
no_target: bool = False,
queue: Optional[bool] = None,
queue: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
cancels: List[int] | None = None,
) -> Dict[str, Any]:
"""
@ -136,6 +148,9 @@ class Block:
api_name: Defining this parameter exposes the endpoint in the api docs
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
no_target: if True, sets "targets" to [], used for Blocks "load" event
batch: whether this function takes in a batch of inputs
max_batch_size: the maximum batch size to send to the function
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
Returns: None
"""
# Support for singular parameter
@ -168,7 +183,9 @@ class Block:
"api_name": api_name,
"scroll_to_output": scroll_to_output,
"show_progress": show_progress,
"cancels": cancels if cancels else [],
"batch": batch,
"max_batch_size": max_batch_size,
"cancels": cancels or [],
}
if api_name is not None:
dependency["documentation"] = [
@ -570,52 +587,6 @@ class Blocks(BlockContext):
return blocks
def __call__(self, *params, fn_index=0):
"""
Allows Blocks objects to be called as functions
Parameters:
*params: the parameters to pass to the function
fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
"""
dependency = self.dependencies[fn_index]
block_fn = self.fns[fn_index]
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the inputs are stateful."
)
else:
serialized_input = block.serialize(params[i])
processed_input.append(serialized_input)
processed_input = self.preprocess_data(fn_index, processed_input, None)
if inspect.iscoroutinefunction(block_fn.fn):
predictions = utils.synchronize_async(block_fn.fn, *processed_input)
else:
predictions = block_fn.fn(*processed_input)
predictions = self.postprocess_data(fn_index, predictions, None)
output_copy = copy.deepcopy(predictions)
predictions = []
for o, output_id in enumerate(dependency["outputs"]):
block = self.blocks[output_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the outputs are stateful."
)
else:
deserialized = block.deserialize(output_copy[o])
predictions.append(deserialized)
return utils.resolve_singleton(predictions)
def __str__(self):
return self.__repr__()
@ -668,26 +639,67 @@ class Blocks(BlockContext):
if Context.block is not None:
Context.block.children.extend(self.children)
return self
def preprocess_data(self, fn_index, raw_input, state):
def is_callable(self, fn_index: int = 0) -> bool:
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
block_fn = self.fns[fn_index]
if block_fn.preprocess:
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
if getattr(block, "stateful", False):
processed_input.append(state.get(input_id))
else:
processed_input.append(block.preprocess(raw_input[i]))
else:
processed_input = raw_input
return processed_input
if inspect.isasyncgenfunction(block_fn.fn):
return False
if inspect.isgeneratorfunction(block_fn.fn):
raise False
for input_id in dependency["inputs"]:
block = self.blocks[input_id]
if getattr(block, "stateful", False):
return False
for output_id in dependency["outputs"]:
block = self.blocks[output_id]
if getattr(block, "stateful", False):
return False
async def call_function(self, fn_index, processed_input, iterator=None):
return True
def __call__(self, *inputs, fn_index: int = 0):
"""
Allows Blocks objects to be called as functions. Supply the parameters to the
function as positional arguments. To choose which function to call, use the
fn_index parameter, which must be a keyword argument.
Parameters:
*inputs: the parameters to pass to the function
fn_index: the index of the function to call (defaults to 0, which for Interfaces, is the default prediction function)
"""
if not (self.is_callable(fn_index)):
raise ValueError(
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
)
inputs = list(inputs)
processed_inputs = self.serialize_data(fn_index, inputs)
batch = self.dependencies[fn_index]["batch"]
if batch:
processed_inputs = [[inp] for inp in processed_inputs]
outputs = utils.synchronize_async(self.process_api, fn_index, processed_inputs)
outputs = outputs["data"]
if batch:
outputs = [out[0] for out in outputs]
processed_outputs = self.deserialize_data(fn_index, outputs)
processed_outputs = utils.resolve_singleton(processed_outputs)
return processed_outputs
async def call_function(
self,
fn_index: int,
processed_input: List[Any],
iterator: Iterator[Any] | None = None,
):
"""Calls and times function with given index and preprocessed input."""
block_fn = self.fns[fn_index]
is_generating = False
@ -731,16 +743,57 @@ class Blocks(BlockContext):
"iterator": iterator,
}
def postprocess_data(self, fn_index, predictions, state):
def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
dependency = self.dependencies[fn_index]
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block: IOComponent = self.blocks[input_id]
serialized_input = block.serialize(inputs[i])
processed_input.append(serialized_input)
return processed_input
def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
dependency = self.dependencies[fn_index]
predictions = []
for o, output_id in enumerate(dependency["outputs"]):
block: IOComponent = self.blocks[output_id]
deserialized = block.deserialize(outputs[o])
predictions.append(deserialized)
return predictions
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
if block_fn.preprocess:
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block: IOComponent = self.blocks[input_id]
if getattr(block, "stateful", False):
processed_input.append(state.get(input_id))
else:
processed_input.append(block.preprocess(inputs[i]))
else:
processed_input = inputs
return processed_input
def postprocess_data(
self, fn_index: int, predictions: List[Any], state: Dict[int, Any]
):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
batch = dependency["batch"]
if type(predictions) is dict and len(predictions) > 0:
predictions = convert_component_dict_to_list(
dependency["outputs"], predictions
)
if len(dependency["outputs"]) == 1:
if len(dependency["outputs"]) == 1 and not (batch):
predictions = (predictions,)
output = []
@ -771,34 +824,61 @@ class Blocks(BlockContext):
fn_index: int,
inputs: List[Any],
username: str = None,
state: Dict[int, Any] | None = None,
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
iterators: Dict[int, Any] | None = None,
) -> Dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
then runs the relevant function, then postprocesses the output.
Parameters:
data: data recieved from the frontend
inputs: the list of raw inputs to pass to the function
fn_index: Index of function to run.
inputs: input data received from the frontend
username: name of user if authentication is set up (not used)
state: data stored from stateful components for session (key is input block id)
iterators: the in-progress iterators for each generator function (key is function index)
Returns: None
"""
block_fn = self.fns[fn_index]
batch = self.dependencies[fn_index]["batch"]
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
if batch:
max_batch_size = self.dependencies[fn_index]["max_batch_size"]
batch_sizes = [len(inp) for inp in inputs]
batch_size = batch_sizes[0]
if inspect.isasyncgenfunction(block_fn.fn) or inspect.isgeneratorfunction(
block_fn.fn
):
raise ValueError("Gradio does not support generators in batch mode.")
if not all(x == batch_size for x in batch_sizes):
raise ValueError(
f"All inputs to a batch function must have the same length but instead have sizes: {batch_sizes}."
)
if batch_size > max_batch_size:
raise ValueError(
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
result = await self.call_function(fn_index, zip(*inputs), None)
preds = result["prediction"]
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
data = list(zip(*data))
is_generating, iterator = None, None
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
result = await self.call_function(fn_index, inputs, iterator)
data = self.postprocess_data(fn_index, result["prediction"], state)
is_generating, iterator = result["is_generating"], result["iterator"]
result = await self.call_function(fn_index, inputs, iterator)
block_fn.total_runtime += result["duration"]
block_fn.total_runs += 1
predictions = self.postprocess_data(fn_index, result["prediction"], state)
return {
"data": predictions,
"is_generating": result["is_generating"],
"iterator": result["iterator"],
"data": data,
"is_generating": is_generating,
"iterator": iterator,
"duration": result["duration"],
"average_duration": block_fn.total_runtime / block_fn.total_runs,
}
@ -958,6 +1038,7 @@ class Blocks(BlockContext):
status_update_rate: If "auto", Queue will send status estimations to all clients whenever a job is finished. Otherwise Queue will send status at regular intervals set by this parameter as the number of seconds.
client_position_to_load_data: Once a client's position in Queue is less that this value, the Queue will collect the input data from the client. You may make this smaller if clients can send large volumes of data, such as video, since the queued data is stored in memory.
default_enabled: If True, all event listeners will use queueing by default.
max_size: Maximum number of jobs that can be queued at once. Jobs beyond this limit simply return an error message to the user asking them to try again. If None, there is no limit.
Example:
demo = gr.Interface(gr.Textbox(), gr.Image(), image_generator)
demo.queue(concurrency_count=3)
@ -970,6 +1051,7 @@ class Blocks(BlockContext):
data_gathering_start=client_position_to_load_data,
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
max_size=max_size,
blocks_dependencies=self.dependencies,
)
self.config = self.get_config_file()
return self
@ -1081,6 +1163,11 @@ class Blocks(BlockContext):
"another event without enabling the queue. Both can be solved by calling .queue() "
"before .launch()"
)
if dep["batch"] and (
dep["queue"] is False
or (dep["queue"] is None and not self.enable_queue)
):
raise ValueError("In order to use batching, the queue must be enabled.")
self.config = self.get_config_file()
self.share = share

17
gradio/dataclasses.py Normal file
View File

@ -0,0 +1,17 @@
from typing import Any, List, Optional
from pydantic import BaseModel
class PredictBody(BaseModel):
session_hash: Optional[str]
data: List[Any]
fn_index: Optional[int]
batched: Optional[
bool
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
class ResetBody(BaseModel):
session_hash: str
fn_index: int

View File

@ -40,8 +40,12 @@ def get_cancel_function(
)
def set_cancel_events(block: Block, event_name: str, cancels: List[Dict[str, Any]]):
def set_cancel_events(
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
):
if cancels:
if not isinstance(cancels, list):
cancels = [cancels]
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
block.set_event_trigger(
event_name,
@ -58,16 +62,18 @@ class Changeable(Block):
def change(
self,
fn: Callable,
inputs: List[Component],
outputs: List[Component],
inputs: Component | List[Component] | None,
outputs: Component | List[Component] | None,
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -75,13 +81,15 @@ class Changeable(Block):
or uploads an image). This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -104,6 +112,8 @@ class Changeable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "change", cancels)
return dep
@ -113,16 +123,18 @@ class Clickable(Block):
def click(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue=None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -130,13 +142,15 @@ class Clickable(Block):
This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -156,6 +170,8 @@ class Clickable(Block):
scroll_to_output=scroll_to_output,
show_progress=show_progress,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
js=_js,
preprocess=preprocess,
postprocess=postprocess,
@ -168,16 +184,18 @@ class Submittable(Block):
def submit(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -186,13 +204,15 @@ class Submittable(Block):
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -215,6 +235,8 @@ class Submittable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "submit", cancels)
return dep
@ -224,16 +246,18 @@ class Editable(Block):
def edit(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -241,13 +265,15 @@ class Editable(Block):
built-in editor. This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -270,6 +296,8 @@ class Editable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "edit", cancels)
return dep
@ -279,16 +307,18 @@ class Clearable(Block):
def clear(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -296,13 +326,15 @@ class Clearable(Block):
using the X button for the component. This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -325,6 +357,8 @@ class Clearable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "submit", cancels)
return dep
@ -334,16 +368,18 @@ class Playable(Block):
def play(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -351,13 +387,15 @@ class Playable(Block):
This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -380,6 +418,8 @@ class Playable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "play", cancels)
return dep
@ -387,16 +427,18 @@ class Playable(Block):
def pause(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: Optional[AnyStr] = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -404,13 +446,15 @@ class Playable(Block):
This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -433,6 +477,8 @@ class Playable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "pause", cancels)
return dep
@ -440,16 +486,18 @@ class Playable(Block):
def stop(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -457,13 +505,15 @@ class Playable(Block):
This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -486,6 +536,8 @@ class Playable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "stop", cancels)
return dep
@ -495,16 +547,18 @@ class Streamable(Block):
def stream(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
scroll_to_output: bool = False,
show_progress: bool = False,
queue: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -512,13 +566,15 @@ class Streamable(Block):
component). This method can be used when this component is in a Gradio Blocks.
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
queue: If True, will place the request on the queue, if the queue exists
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
@ -543,6 +599,8 @@ class Streamable(Block):
preprocess=preprocess,
postprocess=postprocess,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
)
set_cancel_events(self, "stream", cancels)
return dep
@ -552,7 +610,7 @@ class Blurrable(Block):
def blur(
self,
fn: Callable,
inputs: List[Component],
inputs: Component | List[Component] | None,
outputs: List[Component],
api_name: AnyStr = None,
status_tracker: Optional[StatusTracker] = None,
@ -561,7 +619,7 @@ class Blurrable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -569,8 +627,8 @@ class Blurrable(Block):
Parameters:
fn: Callable function
inputs: List of inputs
outputs: List of outputs
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as inputs. If the function returns no outputs, this should be an empty list.
api_name: Defining this parameter exposes the endpoint in the api docs
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending

View File

@ -5,16 +5,12 @@ from __future__ import annotations
import ast
import csv
import inspect
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional
import anyio
from gradio import utils
from gradio.blocks import convert_component_dict_to_list, postprocess_update_dict
from gradio.components import Dataset
from gradio.context import Context
from gradio.documentation import document, set_documentation_group
@ -42,6 +38,7 @@ def create_examples(
run_on_click: bool = False,
preprocess: bool = True,
postprocess: bool = True,
batch: bool = False,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
examples_obj = Examples(
@ -57,6 +54,7 @@ def create_examples(
run_on_click=run_on_click,
preprocess=preprocess,
postprocess=postprocess,
batch=batch,
_initiated_directly=False,
)
utils.synchronize_async(examples_obj.create)
@ -89,6 +87,7 @@ class Examples:
run_on_click: bool = False,
preprocess: bool = True,
postprocess: bool = True,
batch: bool = False,
_initiated_directly: bool = True,
):
"""
@ -104,6 +103,7 @@ class Examples:
run_on_click: if cache_examples is False, clicking on an example does not run the function when an example is clicked. Set this to True to run the function when an example is clicked. Has no effect if cache_examples is True.
preprocess: if True, preprocesses the example input before running the prediction function and caching the output. Only applies if cache_examples is True.
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if cache_examples is True.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is True.
"""
if _initiated_directly:
warnings.warn(
@ -186,6 +186,7 @@ class Examples:
self._api_mode = _api_mode
self.preprocess = preprocess
self.postprocess = postprocess
self.batch = batch
with utils.set_directory(working_directory):
self.processed_examples = [
@ -227,8 +228,6 @@ class Examples:
async def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
component to hold the examples"""
if self.cache_examples:
await self.cache_interface_examples()
async def load_example(example_id):
if self.cache_examples:
@ -255,55 +254,48 @@ class Examples:
outputs=self.outputs,
)
async def cache_interface_examples(self) -> None:
"""Caches all of the examples from an interface."""
if self.cache_examples:
await self.cache()
async def cache(self) -> None:
"""
Caches all of the examples so that their predictions can be shown immediately.
"""
if os.path.exists(self.cached_file):
print(
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
)
else:
if Context.root_block is None:
raise ValueError("Cannot cache examples if not in a Blocks context")
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
cache_logger = CSVLogger()
# create a fake dependency to process the examples and get the predictions
dependency = Context.root_block.set_event_trigger(
event_name="fake_event",
fn=self.fn,
inputs=self.inputs_with_examples,
outputs=self.outputs,
preprocess=self.preprocess and not self._api_mode,
postprocess=self.postprocess and not self._api_mode,
batch=self.batch,
)
fn_index = Context.root_block.dependencies.index(dependency)
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
prediction = await self.predict_example(example_id)
cache_logger.flag(prediction)
async def predict_example(self, example_id: int) -> List[Any]:
"""Loads an example from the interface and returns its prediction.
Parameters:
example_id: The id of the example to process (zero-indexed).
"""
processed_input = self.processed_examples[example_id]
if self.preprocess and not self._api_mode:
processed_input = [
input_component.preprocess(processed_input[i])
for i, input_component in enumerate(self.inputs_with_examples)
]
if inspect.iscoroutinefunction(self.fn):
predictions = await self.fn(*processed_input)
else:
predictions = await anyio.to_thread.run_sync(self.fn, *processed_input)
output_ids = [output._id for output in self.outputs]
if type(predictions) is dict and len(predictions) > 0:
predictions = convert_component_dict_to_list(output_ids, predictions)
if len(self.outputs) == 1:
predictions = [predictions]
if not self._api_mode:
predictions_ = []
for i, output_component in enumerate(self.outputs):
output = predictions[i]
if utils.is_update(predictions[i]):
output = postprocess_update_dict(
output_component, output, self.postprocess
)
elif self.postprocess:
output = output_component.postprocess(output)
predictions_.append(output)
predictions = predictions_
return predictions
processed_input = self.processed_examples[example_id]
if self.batch:
processed_input = [[value] for value in processed_input]
prediction = await Context.root_block.process_api(
fn_index, processed_input
)
output = prediction["data"]
if self.batch:
output = [value[0] for value in output]
cache_logger.flag(output)
async def load_from_cache(self, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface.

View File

@ -151,12 +151,14 @@ class Interface(Blocks):
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: Optional[bool] = None,
batch: bool = False,
max_batch_size: int = 4,
_api_mode: bool = False,
**kwargs,
):
"""
Parameters:
fn: the function to wrap an interface around. Often a machine learning model's prediction function.
fn: the function to wrap an interface around. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
@ -176,6 +178,8 @@ class Interface(Blocks):
flagging_dir: what to name the directory where flagged data is stored.
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -522,6 +526,8 @@ class Interface(Blocks):
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=batch,
max_batch_size=max_batch_size,
)
else:
for component in self.input_components:
@ -560,6 +566,8 @@ class Interface(Blocks):
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=batch,
max_batch_size=max_batch_size,
)
if inspect.isgeneratorfunction(fn):
stop_btn.click(
@ -648,6 +656,7 @@ class Interface(Blocks):
cache_examples=self.cache_examples,
examples_per_page=examples_per_page,
_api_mode=_api_mode,
batch=batch,
)
if self.interpretation:

View File

@ -1,13 +1,15 @@
from __future__ import annotations
import asyncio
import copy
import sys
import time
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import fastapi
from pydantic import BaseModel
from gradio.dataclasses import PredictBody
from gradio.utils import Request, run_coro_in_background
@ -21,6 +23,18 @@ class Estimation(BaseModel):
queue_eta: int
class Event:
def __init__(self, websocket: fastapi.WebSocket, fn_index: int | None = None):
self.websocket = websocket
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
self.fn_index: int | None = fn_index
self.session_hash: str = "foo"
async def disconnect(self, code=1000):
await self.websocket.close(code=code)
class Queue:
def __init__(
self,
@ -29,14 +43,15 @@ class Queue:
data_gathering_start: int,
update_intervals: int,
max_size: Optional[int],
blocks_dependencies: List,
):
self.event_queue = []
self.event_queue: List[Event] = []
self.events_pending_reconnection = []
self.stopped = False
self.max_thread_count = concurrency_count
self.data_gathering_start = data_gathering_start
self.update_intervals = update_intervals
self.active_jobs: List[None | Event] = [None] * concurrency_count
self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
self.delete_lock = asyncio.Lock()
self.server_path = None
self.duration_history_total = 0
@ -47,6 +62,7 @@ class Queue:
self.live_updates = live_updates
self.sleep_when_free = 0.05
self.max_size = max_size
self.blocks_dependencies = blocks_dependencies
async def start(self):
run_coro_in_background(self.start_processing)
@ -69,6 +85,26 @@ class Queue:
count += 1
return count
def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
if not (self.event_queue):
return None, False
first_event = self.event_queue.pop(0)
events = [first_event]
event_fn_index = first_event.fn_index
batch = self.blocks_dependencies[event_fn_index]["batch"]
if batch:
batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"]
rest_of_batch = [
event for event in self.event_queue if event.fn_index == event_fn_index
][: batch_size - 1]
events.extend(rest_of_batch)
[self.event_queue.remove(event) for event in rest_of_batch]
return events, batch
async def start_processing(self) -> None:
while not self.stopped:
if not self.event_queue:
@ -81,13 +117,16 @@ class Queue:
# Using mutex to avoid editing a list in use
async with self.delete_lock:
event = self.event_queue.pop(0)
events, batch = self.get_events_in_batch()
self.active_jobs[self.active_jobs.index(None)] = event
task = run_coro_in_background(self.process_event, event)
if sys.version_info >= (3, 8):
task.set_name(f"{event.session_hash}_{event.fn_index}")
run_coro_in_background(self.broadcast_live_estimations)
if events:
self.active_jobs[self.active_jobs.index(None)] = events
task = run_coro_in_background(self.process_events, events, batch)
run_coro_in_background(self.broadcast_live_estimations)
if sys.version_info >= (3, 8) and not (
batch
): # You shouldn't be able to cancel a task if it's part of a batch
task.set_name(f"{events[0].session_hash}_{events[0].fn_index}")
def push(self, event: Event) -> int | None:
"""
@ -106,8 +145,6 @@ class Queue:
if event in self.event_queue:
async with self.delete_lock:
self.event_queue.remove(event)
elif event in self.active_jobs:
self.active_jobs[self.active_jobs.index(event)] = None
async def broadcast_live_estimations(self) -> None:
"""
@ -128,7 +165,7 @@ class Queue:
]
)
async def gather_event_data(self, event: Event) -> None:
async def gather_event_data(self, event: Event) -> bool:
"""
Gather data for the event
@ -212,33 +249,43 @@ class Queue:
queue_eta=self.queue_duration,
)
async def call_prediction(self, event: Event):
async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
if batch:
data.data = list(zip(*[event.data.data for event in events if event.data]))
data.batched = True
response = await Request(
method=Request.Method.POST,
url=f"{self.server_path}api/predict",
json=event.data,
json=dict(data),
)
return response
async def process_event(self, event: Event) -> None:
async def process_events(self, events: List[Event], batch: bool) -> None:
awake_events: List[Event] = []
try:
client_awake = await self.gather_event_data(event)
if not client_awake:
return
client_awake = await self.send_message(event, {"msg": "process_starts"})
if not client_awake:
for event in events:
client_awake = await self.gather_event_data(event)
if client_awake:
client_awake = await self.send_message(
event, {"msg": "process_starts"}
)
if client_awake:
awake_events.append(event)
if not (awake_events):
return
begin_time = time.time()
response = await self.call_prediction(event)
response = await self.call_prediction(awake_events, batch)
if response.has_exception:
await self.send_message(
event,
{
"msg": "process_completed",
"output": {"error": str(response.exception)},
"success": False,
},
)
for event in awake_events:
await self.send_message(
event,
{
"msg": "process_completed",
"output": {"error": str(response.exception)},
"success": False,
},
)
elif response.json.get("is_generating", False):
while response.json.get("is_generating", False):
# Python 3.7 doesn't have named tasks.
@ -249,41 +296,49 @@ class Queue:
if not is_alive:
return
old_response = response
for event in awake_events:
await self.send_message(
event,
{
"msg": "process_generating",
"output": old_response.json,
"success": old_response.status == 200,
},
)
response = await self.call_prediction(awake_events, batch)
for event in awake_events:
await self.send_message(
event,
{
"msg": "process_generating",
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
},
)
response = await self.call_prediction(event)
await self.send_message(
event,
{
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
},
)
else:
await self.send_message(
event,
{
"msg": "process_completed",
"output": response.json,
"success": response.status == 200,
},
)
output = copy.deepcopy(response.json)
for e, event in enumerate(awake_events):
if batch and "data" in output:
output["data"] = list(zip(*response.json.get("data")))[e]
await self.send_message(
event,
{
"msg": "process_completed",
"output": output,
"success": response.status == 200,
},
)
end_time = time.time()
if response.status == 200:
self.update_estimation(end_time - begin_time)
finally:
try:
await event.disconnect()
except Exception:
pass
finally:
for event in awake_events:
try:
await event.disconnect()
except Exception:
pass
self.active_jobs[self.active_jobs.index(events)] = None
for event in awake_events:
await self.clean_event(event)
# Always reset the state of the iterator
# If the job finished successfully, this has no effect
@ -306,24 +361,10 @@ class Queue:
await self.clean_event(event)
return False
async def get_message(self, event) -> Optional[Dict]:
async def get_message(self, event) -> Optional[PredictBody]:
try:
data = await event.websocket.receive_json()
return data
return PredictBody(**data)
except:
await self.clean_event(event)
return None
class Event:
def __init__(self, websocket: fastapi.WebSocket):
from gradio.routes import PredictBody
self.websocket = websocket
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
self.fn_index = 0
self.session_hash = "foo"
async def disconnect(self, code=1000):
await self.websocket.close(code=code)

View File

@ -25,12 +25,12 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket, WebSocketState
import gradio
from gradio import encryptor, utils
from gradio.dataclasses import PredictBody, ResetBody
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
from gradio.queue import Estimation, Event
@ -55,33 +55,6 @@ class ORJSONResponse(JSONResponse):
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
# Data Models
###########
class QueueStatusBody(BaseModel):
hash: str
class QueuePushBody(BaseModel):
fn_index: int
action: str
session_hash: str
data: Any
class PredictBody(BaseModel):
session_hash: Optional[str]
data: Any
fn_index: Optional[int]
class ResetBody(BaseModel):
session_hash: str
fn_index: int
###########
# Auth
###########
@ -300,6 +273,9 @@ class App(FastAPI):
iterators = {}
raw_input = body.data
fn_index = body.fn_index
batch = app.blocks.dependencies[fn_index]["batch"]
if not (body.batched) and batch:
raw_input = [raw_input]
try:
output = await app.blocks.process_api(
fn_index, raw_input, username, session_state, iterators
@ -319,6 +295,9 @@ class App(FastAPI):
content={"error": str(error) if show_error else None},
status_code=500,
)
if not (body.batched) and batch:
output["data"] = output["data"][0]
return output
@app.post("/api/{api_name}", dependencies=[Depends(login_check)])
@ -348,9 +327,7 @@ class App(FastAPI):
@app.websocket("/queue/join")
async def join_queue(websocket: WebSocket):
if app.blocks._queue.server_path is None:
print(f"WS: {str(websocket.url)}")
app_url = get_server_url_from_ws_url(str(websocket.url))
print(f"Server URL: {app_url}")
app.blocks._queue.set_url(app_url)
await websocket.accept()
event = Event(websocket)

View File

@ -1,12 +1,10 @@
import gradio
XRAY_CONFIG = {
"version": "3.1.8b\n",
"version": "3.4b3\n",
"mode": "blocks",
"dev_mode": True,
"components": [
{
"id": 23,
"id": 27,
"type": "markdown",
"props": {
"value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
@ -16,7 +14,7 @@ XRAY_CONFIG = {
},
},
{
"id": 24,
"id": 28,
"type": "checkboxgroup",
"props": {
"choices": ["Covid", "Malaria", "Lung Cancer"],
@ -28,14 +26,14 @@ XRAY_CONFIG = {
"style": {},
},
},
{"id": 25, "type": "tabs", "props": {"visible": True, "style": {}}},
{"id": 29, "type": "tabs", "props": {"visible": True, "style": {}}},
{
"id": 26,
"id": 30,
"type": "tabitem",
"props": {"label": "X-ray", "visible": True, "style": {}},
},
{
"id": 27,
"id": 31,
"type": "row",
"props": {
"type": "row",
@ -45,7 +43,7 @@ XRAY_CONFIG = {
},
},
{
"id": 28,
"id": 32,
"type": "image",
"props": {
"image_mode": "RGB",
@ -59,44 +57,39 @@ XRAY_CONFIG = {
"style": {},
},
},
{
"id": 29,
"type": "json",
"props": {
"show_label": True,
"name": "json",
"visible": True,
"style": {},
},
},
{
"id": 30,
"type": "button",
"props": {
"value": "Run",
"variant": "secondary",
"name": "button",
"visible": True,
"style": {},
},
},
{
"id": 31,
"type": "tabitem",
"props": {"label": "CT Scan", "visible": True, "style": {}},
},
{
"id": 32,
"type": "row",
"props": {
"type": "row",
"variant": "default",
"visible": True,
"style": {},
},
},
{
"id": 33,
"type": "json",
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
},
{
"id": 34,
"type": "button",
"props": {
"value": "Run",
"variant": "secondary",
"name": "button",
"visible": True,
"style": {},
},
},
{
"id": 35,
"type": "tabitem",
"props": {"label": "CT Scan", "visible": True, "style": {}},
},
{
"id": 36,
"type": "row",
"props": {
"type": "row",
"variant": "default",
"visible": True,
"style": {},
},
},
{
"id": 37,
"type": "image",
"props": {
"image_mode": "RGB",
@ -111,17 +104,12 @@ XRAY_CONFIG = {
},
},
{
"id": 34,
"id": 38,
"type": "json",
"props": {
"show_label": True,
"name": "json",
"visible": True,
"style": {},
},
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
},
{
"id": 35,
"id": 39,
"type": "button",
"props": {
"value": "Run",
@ -132,7 +120,7 @@ XRAY_CONFIG = {
},
},
{
"id": 36,
"id": 40,
"type": "textbox",
"props": {
"lines": 1,
@ -145,12 +133,12 @@ XRAY_CONFIG = {
},
},
{
"id": 37,
"id": 41,
"type": "form",
"props": {"type": "form", "visible": True, "style": {}},
},
{
"id": 38,
"id": 42,
"type": "form",
"props": {"type": "form", "visible": True, "style": {}},
},
@ -161,82 +149,91 @@ XRAY_CONFIG = {
"is_space": False,
"enable_queue": None,
"show_error": False,
"show_api": True,
"layout": {
"id": 22,
"id": 26,
"children": [
{"id": 23},
{"id": 37, "children": [{"id": 24}]},
{"id": 27},
{"id": 41, "children": [{"id": 28}]},
{
"id": 25,
"id": 29,
"children": [
{
"id": 26,
"id": 30,
"children": [
{"id": 27, "children": [{"id": 28}, {"id": 29}]},
{"id": 30},
{"id": 31, "children": [{"id": 32}, {"id": 33}]},
{"id": 34},
],
},
{
"id": 31,
"id": 35,
"children": [
{"id": 32, "children": [{"id": 33}, {"id": 34}]},
{"id": 35},
{"id": 36, "children": [{"id": 37}, {"id": 38}]},
{"id": 39},
],
},
],
},
{"id": 38, "children": [{"id": 36}]},
{"id": 42, "children": [{"id": 40}]},
],
},
"dependencies": [
{
"targets": [30],
"targets": [34],
"trigger": "click",
"inputs": [24, 28],
"outputs": [29],
"inputs": [28, 32],
"outputs": [33],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
{
"targets": [35],
"targets": [39],
"trigger": "click",
"inputs": [24, 33],
"outputs": [34],
"inputs": [28, 37],
"outputs": [38],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
{
"targets": [],
"trigger": "load",
"inputs": [],
"outputs": [36],
"outputs": [40],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
],
}
XRAY_CONFIG_DIFF_IDS = {
"version": "3.4b3\n",
"mode": "blocks",
"dev_mode": True,
"components": [
{
"id": 1,
"id": 27,
"type": "markdown",
"props": {
"value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
@ -246,7 +243,7 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 22,
"id": 28,
"type": "checkboxgroup",
"props": {
"choices": ["Covid", "Malaria", "Lung Cancer"],
@ -258,25 +255,14 @@ XRAY_CONFIG_DIFF_IDS = {
"style": {},
},
},
{"id": 29, "type": "tabs", "props": {"visible": True, "style": {}}},
{
"id": 3,
"type": "tabs",
"props": {
"visible": True,
"style": {},
},
},
{
"id": 444,
"id": 30,
"type": "tabitem",
"props": {
"label": "X-ray",
"visible": True,
"style": {},
},
"props": {"label": "X-ray", "visible": True, "style": {}},
},
{
"id": 5,
"id": 31,
"type": "row",
"props": {
"type": "row",
@ -286,7 +272,7 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 6,
"id": 32,
"type": "image",
"props": {
"image_mode": "RGB",
@ -301,17 +287,12 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 7,
"id": 33,
"type": "json",
"props": {
"show_label": True,
"name": "json",
"visible": True,
"style": {},
},
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
},
{
"id": 8888,
"id": 34,
"type": "button",
"props": {
"value": "Run",
@ -322,16 +303,12 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 9,
"id": 35,
"type": "tabitem",
"props": {
"label": "CT Scan",
"visible": True,
"style": {},
},
"props": {"label": "CT Scan", "visible": True, "style": {}},
},
{
"id": 10,
"id": 36,
"type": "row",
"props": {
"type": "row",
@ -341,7 +318,7 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 11,
"id": 37,
"type": "image",
"props": {
"image_mode": "RGB",
@ -356,17 +333,12 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 12,
"id": 38,
"type": "json",
"props": {
"show_label": True,
"name": "json",
"visible": True,
"style": {},
},
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
},
{
"id": 13,
"id": 933,
"type": "button",
"props": {
"value": "Run",
@ -377,7 +349,7 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 141,
"id": 40,
"type": "textbox",
"props": {
"lines": 1,
@ -390,89 +362,100 @@ XRAY_CONFIG_DIFF_IDS = {
},
},
{
"id": 37,
"id": 41,
"type": "form",
"props": {"type": "form", "visible": True, "style": {}},
},
{
"id": 38,
"id": 42,
"type": "form",
"props": {"type": "form", "visible": True, "style": {}},
},
],
"theme": "default",
"css": None,
"enable_queue": False,
"title": "Gradio",
"is_space": False,
"enable_queue": None,
"show_error": False,
"show_api": True,
"layout": {
"id": 0,
"id": 26,
"children": [
{"id": 1},
{"id": 37, "children": [{"id": 22}]},
{"id": 27},
{"id": 41, "children": [{"id": 28}]},
{
"id": 3,
"id": 29,
"children": [
{
"id": 444,
"id": 30,
"children": [
{"id": 5, "children": [{"id": 6}, {"id": 7}]},
{"id": 8888},
{"id": 31, "children": [{"id": 32}, {"id": 33}]},
{"id": 34},
],
},
{
"id": 9,
"id": 35,
"children": [
{"id": 10, "children": [{"id": 11}, {"id": 12}]},
{"id": 13},
{"id": 36, "children": [{"id": 37}, {"id": 38}]},
{"id": 933},
],
},
],
},
{"id": 38, "children": [{"id": 141}]},
{"id": 42, "children": [{"id": 40}]},
],
},
"dependencies": [
{
"targets": [8888],
"targets": [34],
"trigger": "click",
"inputs": [22, 6],
"outputs": [7],
"inputs": [28, 32],
"outputs": [33],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
{
"targets": [13],
"targets": [933],
"trigger": "click",
"inputs": [22, 11],
"outputs": [12],
"inputs": [28, 37],
"outputs": [38],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
{
"targets": [],
"trigger": "load",
"inputs": [],
"outputs": [141],
"outputs": [40],
"backend_fn": True,
"js": None,
"queue": None,
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"batch": False,
"max_batch_size": 4,
"cancels": [],
},
],
}
XRAY_CONFIG_WITH_MISTAKE = {
"mode": "blocks",
"dev_mode": True,

View File

@ -182,3 +182,60 @@ Note that we've added a `time.sleep(1)` in the iterator to create an artificial
Supplying a generator into Gradio **requires** you to enable queuing in the underlying Interface or Blocks (see the queuing section above).
## Batch Functions
Gradio supports the ability to pass *batch* functions. Batch functions are just
functions which take in a list of inputs and return a list of predictions.
For example, here is a batched function that takes in two lists of inputs (a list of
words and a list of ints), and returns a list of trimmed words as output:
```py
import time
def trim_words(words, lens):
trimmed_words = []
time.sleep(5)
for w, l in zip(words, lens):
trimmed_words.append(w[:int(l)])
return [trimmed_words]
```
The advantage of using batched functions is that if you enable queuing, the Gradio
server can automatically *batch* incoming requests and process them in parallel,
potentially speeding up your demo. Here's what the Gradio code looks like (notice
the `batch=True` and `max_batch_size=16` -- both of these parameters can be passed
into event triggers or into the `Interface` class)
With `Interface`:
```python
demo = gr.Interface(trim_words, ["textbox", "number"], ["output"],
batch=True, max_batch_size=16)
demo.queue()
demo.launch()
```
With `Blocks`:
```py
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
word = gr.Textbox(label="word")
leng = gr.Number(label="leng")
output = gr.Textbox(label="Output")
with gr.Row():
run = gr.Button()
event = run.click(trim_words, [word, leng], output, batch=True, max_batch_size=16)
demo.queue()
demo.launch()
```
In the example above, 16 requests could be processed in parallel (for a total inference
time of 5 seconds), instead of each request being processed separately (for a total
inference time of 80 seconds).
Supplying a generator into Gradio **requires** you to enable queuing in the underlying Interface or Blocks (see the queuing section above).

View File

@ -387,9 +387,39 @@ class TestCallFunction:
output = await demo.call_function(0, ["World"])
assert output["prediction"] == "Hello, World"
output = demo("World")
assert output == "Hello, World"
output = await demo.call_function(0, ["Abubakar"])
assert output["prediction"] == "Hello, Abubakar"
@pytest.mark.asyncio
async def test_call_multiple_functions(self):
with gr.Blocks() as demo:
text = gr.Textbox()
text2 = gr.Textbox()
btn = gr.Button()
btn.click(
lambda x: "Hello, " + x,
inputs=text,
outputs=text,
)
text.change(
lambda x: "Hi, " + x,
inputs=text,
outputs=text2,
)
output = await demo.call_function(0, ["World"])
assert output["prediction"] == "Hello, World"
output = demo("World")
assert output == "Hello, World"
output = await demo.call_function(1, ["World"])
assert output["prediction"] == "Hi, World"
output = demo("World", fn_index=1) # fn_index must be a keyword argument
assert output == "Hi, World"
@pytest.mark.asyncio
async def test_call_generator(self):
def generator(x):
@ -459,6 +489,168 @@ class TestCallFunction:
assert output["prediction"] == (0, 3)
class TestBatchProcessing:
def test_raise_exception_if_batching_an_event_thats_not_queued(self):
def trim(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[: int(l)])
return [trimmed_words]
msg = "In order to use batching, the queue must be enabled."
with pytest.raises(ValueError, match=msg):
demo = gr.Interface(
trim, ["textbox", "number"], ["textbox"], batch=True, max_batch_size=16
)
demo.launch(prevent_thread_lock=True)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
with gr.Row():
word = gr.Textbox(label="word")
leng = gr.Number(label="leng")
output = gr.Textbox(label="Output")
with gr.Row():
run = gr.Button()
run.click(trim, [word, leng], output, batch=True, max_batch_size=16)
demo.launch(prevent_thread_lock=True)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
with gr.Row():
word = gr.Textbox(label="word")
leng = gr.Number(label="leng")
output = gr.Textbox(label="Output")
with gr.Row():
run = gr.Button()
run.click(
trim,
[word, leng],
output,
batch=True,
max_batch_size=16,
queue=False,
)
demo.queue()
demo.launch(prevent_thread_lock=True)
@pytest.mark.asyncio
async def test_call_regular_function(self):
def batch_fn(x):
results = []
for word in x:
results.append("Hello " + word)
return (results,)
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button()
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
output = await demo.call_function(0, [["Adam", "Yahya"]])
assert output["prediction"][0] == ["Hello Adam", "Hello Yahya"]
output = demo("Abubakar")
assert output == "Hello Abubakar"
@pytest.mark.asyncio
async def test_functions_multiple_parameters(self):
def regular_fn(word1, word2):
return len(word1) > len(word2)
def batch_fn(words, lengths):
comparisons = []
trim_words = []
for word, length in zip(words, lengths):
trim_words.append(word[:length])
comparisons.append(len(word) > length)
return trim_words, comparisons
with gr.Blocks() as demo:
text1 = gr.Textbox()
text2 = gr.Textbox()
leng = gr.Number(precision=0)
bigger = gr.Checkbox()
btn1 = gr.Button("Check")
btn2 = gr.Button("Trim")
btn1.click(regular_fn, inputs=[text1, text2], outputs=bigger)
btn2.click(
batch_fn,
inputs=[text1, leng],
outputs=[text1, bigger],
batch=True,
)
output = await demo.call_function(0, ["Adam", "Yahya"])
assert output["prediction"] is False
output = demo("Abubakar", "Abid")
assert output
output = await demo.call_function(1, [["Adam", "Mary"], [3, 5]])
assert output["prediction"] == (
["Ada", "Mary"],
[True, False],
)
output = demo("Abubakar", 3, fn_index=1)
assert output == ["Abu", True]
@pytest.mark.asyncio
async def test_invalid_batch_generator(self):
with pytest.raises(ValueError):
def batch_fn(x):
results = []
for word in x:
results.append("Hello " + word)
yield (results,)
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button()
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
await demo.process_api(0, [["Adam", "Yahya"]])
@pytest.mark.asyncio
async def test_exceeds_max_batch_size(self):
with pytest.raises(ValueError):
def batch_fn(x):
results = []
for word in x:
results.append("Hello " + word)
return (results,)
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button()
btn.click(
batch_fn, inputs=text, outputs=text, batch=True, max_batch_size=2
)
await demo.process_api(0, [["A", "B", "C"]])
@pytest.mark.asyncio
async def test_unequal_batch_sizes(self):
with pytest.raises(ValueError):
def batch_fn(x, y):
results = []
for word1, word2 in zip(x, y):
results.append("Hello " + word1 + word2)
return (results,)
with gr.Blocks() as demo:
t1 = gr.Textbox()
t2 = gr.Textbox()
btn = gr.Button()
btn.click(batch_fn, inputs=[t1, t2], outputs=t1, batch=True)
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]])
class TestSpecificUpdate:
def test_without_update(self):
with pytest.raises(KeyError):
@ -543,57 +735,57 @@ class TestRender:
assert io2 == io3
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Tasks dont have names in 3.7",
)
@pytest.mark.asyncio
async def test_cancel_function(capsys):
async def long_job():
await asyncio.sleep(10)
print("HELLO FROM LONG JOB")
class TestCancel:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Tasks dont have names in 3.7",
)
@pytest.mark.asyncio
async def test_cancel_function(self, capsys):
async def long_job():
await asyncio.sleep(10)
print("HELLO FROM LONG JOB")
with gr.Blocks():
button = gr.Button(value="Start")
click = button.click(long_job, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
task = asyncio.create_task(long_job())
task.set_name("foo_0")
# If cancel_fun didn't cancel long_job the message would be printed to the console
# The test would also take 10 seconds
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out
def test_raise_exception_if_cancelling_an_event_thats_not_queued():
def iteration(a):
yield a
msg = "In order to cancel an event, the queue for that event must be enabled!"
with pytest.raises(ValueError, match=msg):
gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch(
prevent_thread_lock=True
)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None)
with gr.Blocks():
button = gr.Button(value="Start")
click = button.click(long_job, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.launch(prevent_thread_lock=True)
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None, queue=False)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.queue().launch(prevent_thread_lock=True)
task = asyncio.create_task(long_job())
task.set_name("foo_0")
# If cancel_fun didn't cancel long_job the message would be printed to the console
# The test would also take 10 seconds
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out
def test_raise_exception_if_cancelling_an_event_thats_not_queued(self):
def iteration(a):
yield a
msg = "In order to cancel an event, the queue for that event must be enabled!"
with pytest.raises(ValueError, match=msg):
gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch(
prevent_thread_lock=True
)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.launch(prevent_thread_lock=True)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None, queue=False)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.queue().launch(prevent_thread_lock=True)
if __name__ == "__main__":

View File

@ -105,21 +105,6 @@ class TestExamplesDataset:
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_predict_example(self):
io = gr.Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction = await io.examples_handler.predict_example(0)
assert prediction[0] == "Hello World"
@pytest.mark.asyncio
async def test_coroutine_process_example(self):
async def coroutine(x):
return "Hello " + x
io = gr.Interface(coroutine, "text", "text", examples=[["World"]])
prediction = await io.examples_handler.predict_example(0)
assert prediction[0] == "Hello World"
@pytest.mark.asyncio
async def test_caching(self):
io = gr.Interface(
@ -127,8 +112,8 @@ class TestProcessExamples:
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
@ -139,11 +124,9 @@ class TestProcessExamples:
"image",
"image",
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
io.launch(prevent_thread_lock=True)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(0)
io.close()
assert prediction[0].startswith("")
@pytest.mark.asyncio
@ -153,11 +136,9 @@ class TestProcessExamples:
"audio",
"audio",
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
io.launch(prevent_thread_lock=True)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(0)
io.close()
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/")
@pytest.mark.asyncio
@ -167,8 +148,8 @@ class TestProcessExamples:
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"visible": False, "__type__": "update"}
@ -179,8 +160,8 @@ class TestProcessExamples:
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"lines": 4, "value": "hello", "__type__": "update"}
@ -196,7 +177,6 @@ class TestProcessExamples:
examples=["abc"],
cache_examples=True,
)
await io.examples_handler.cache_interface_examples()
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == [{"lines": 4, "__type__": "update"}, {"label": "lion"}]
@ -256,3 +236,43 @@ class TestProcessExamples:
examples=[["foo", None, None], ["bar", 2, 3]],
cache_examples=True,
)
@pytest.mark.asyncio
async def test_caching_with_batch(self):
def trim_words(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox"],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
@pytest.mark.asyncio
async def test_caching_with_batch_multiple_outputs(self):
def trim_words(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return trimmed_words, lens
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox", gr.Number(precision=0)],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]

View File

@ -23,6 +23,7 @@ def queue() -> Queue:
data_gathering_start=1,
update_intervals=1,
max_size=None,
blocks_dependencies=[],
)
yield queue_object
queue_object.close()
@ -31,7 +32,7 @@ def queue() -> Queue:
@pytest.fixture()
def mock_event() -> Event:
websocket = MagicMock()
event = Event(websocket=websocket)
event = Event(websocket=websocket, fn_index=0)
yield event
@ -108,7 +109,7 @@ class TestQueueMethods:
@pytest.mark.asyncio
async def test_gather_data_for_first_ranks(self, queue: Queue, mock_event: Event):
websocket = MagicMock()
mock_event2 = Event(websocket=websocket)
mock_event2 = Event(websocket=websocket, fn_index=0)
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
@ -180,7 +181,9 @@ class TestQueueProcessEvents:
queue.call_prediction.return_value.json = {"is_generating": False}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
@ -204,9 +207,11 @@ class TestQueueProcessEvents:
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
@ -219,9 +224,11 @@ class TestQueueProcessEvents:
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
@ -235,7 +242,10 @@ class TestQueueProcessEvents:
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=True, exception=ValueError("foo"))
)
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@ -256,7 +266,10 @@ class TestQueueProcessEvents:
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@ -278,7 +291,8 @@ class TestQueueProcessEvents:
mock_event.disconnect = AsyncMock(side_effect=ValueError("..."))
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
mock_request.assert_called_with(
method=Request.Method.POST,
url=f"{queue.server_path}reset",
@ -287,3 +301,102 @@ class TestQueueProcessEvents:
"fn_index": mock_event.fn_index,
},
)
class TestQueueBatch:
@pytest.mark.asyncio
async def test_process_event(self, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {
"is_generating": False,
"data": [[1, 2]],
}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
websocket = MagicMock()
mock_event2 = Event(websocket=websocket, fn_index=0)
mock_event2.disconnect = AsyncMock()
queue.active_jobs = [[mock_event, mock_event2]]
await queue.process_events([mock_event, mock_event2], batch=True)
queue.call_prediction.assert_called_once() # called once for both events
mock_event.disconnect.assert_called_once()
mock_event2.disconnect.assert_called_once()
queue.clean_event.call_count == 2
class TestGetEventsInBatch:
def test_empty_event_queue(self, queue: Queue):
queue.event_queue = []
events, _ = queue.get_events_in_batch()
assert events is None
def test_single_type_of_event(self, queue: Queue):
queue.blocks_dependencies = [{"batch": True, "max_batch_size": 3}]
queue.event_queue = [
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=0),
]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_multiple_batch_events(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": True, "max_batch_size": 2},
]
queue.event_queue = [
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=1),
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=1),
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=0),
]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [1, 1]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_both_types_of_event(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": False},
]
queue.event_queue = [
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=1),
Event(websocket=MagicMock(), fn_index=0),
Event(websocket=MagicMock(), fn_index=1),
Event(websocket=MagicMock(), fn_index=1),
]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0]
events, batch = queue.get_events_in_batch()
assert not (batch)
assert [e.fn_index for e in events] == [1]

View File

@ -11,7 +11,7 @@ import websockets
from fastapi import FastAPI
from fastapi.testclient import TestClient
import gradio
import gradio as gr
from gradio import Blocks, Interface, Textbox, close_all, routes
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -115,6 +115,33 @@ class TestRoutes(unittest.TestCase):
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
def test_predict_route_batching(self):
def batch_fn(x):
results = []
for word in x:
results.append("Hello " + word)
return (results,)
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button()
btn.click(batch_fn, inputs=text, outputs=text, batch=True, api_name="pred")
demo.queue()
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/pred/", json={"data": ["test"]})
output = dict(response.json())
self.assertEqual(output["data"], ["Hello test"])
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post(
"/api/pred/", json={"data": [["test", "test2"]], "batched": True}
)
output = dict(response.json())
self.assertEqual(output["data"], [["Hello test", "Hello test2"]])
def test_state(self):
def predict(input, history):
if history is None:
@ -220,63 +247,64 @@ class TestAuthenticatedRoutes(unittest.TestCase):
close_all()
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks don't work with async context managers in 3.7",
)
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
async def test_queue_join_routes_sets_url_if_none_set(mock_get_url):
io = Interface(lambda x: x, "text", "text").queue()
app, _, _ = io.launch(prevent_thread_lock=True)
io._queue.server_path = None
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
completed = msg["msg"] == "process_completed"
assert io._queue.server_path == "foo_url"
@pytest.mark.parametrize(
"ws_url,answer",
[
("ws://127.0.0.1:7861/queue/join", "http://127.0.0.1:7861/"),
(
"ws://127.0.0.1:7861/gradio/gradio/gradio/queue/join",
"http://127.0.0.1:7861/gradio/gradio/gradio/",
),
(
"wss://huggingface.co.tech/path/queue/join",
"https://huggingface.co.tech/path/",
),
],
)
def test_get_server_url_from_ws_url(ws_url, answer):
assert routes.get_server_url_from_ws_url(ws_url) == answer
def test_mount_gradio_app_set_dev_mode_false():
app = FastAPI()
@app.get("/")
def read_main():
return {"message": "Hello!"}
with gradio.Blocks() as blocks:
gradio.Textbox("Hello from gradio!")
app = routes.mount_gradio_app(app, blocks, path="/gradio")
gradio_fast_api = next(
route for route in app.routes if isinstance(route, starlette.routing.Mount)
class TestQueueRoutes:
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks don't work with async context managers in 3.7",
)
assert not gradio_fast_api.app.blocks.dev_mode
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
async def test_queue_join_routes_sets_url_if_none_set(self, mock_get_url):
io = Interface(lambda x: x, "text", "text").queue()
io.launch(prevent_thread_lock=True)
io._queue.server_path = None
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
completed = msg["msg"] == "process_completed"
assert io._queue.server_path == "foo_url"
@pytest.mark.parametrize(
"ws_url,answer",
[
("ws://127.0.0.1:7861/queue/join", "http://127.0.0.1:7861/"),
(
"ws://127.0.0.1:7861/gradio/gradio/gradio/queue/join",
"http://127.0.0.1:7861/gradio/gradio/gradio/",
),
(
"wss://huggingface.co.tech/path/queue/join",
"https://huggingface.co.tech/path/",
),
],
)
def test_get_server_url_from_ws_url(self, ws_url, answer):
assert routes.get_server_url_from_ws_url(ws_url) == answer
class TestDevMode:
def test_mount_gradio_app_set_dev_mode_false(self):
app = FastAPI()
@app.get("/")
def read_main():
return {"message": "Hello!"}
with gr.Blocks() as blocks:
gr.Textbox("Hello from gradio!")
app = routes.mount_gradio_app(app, blocks, path="/gradio")
gradio_fast_api = next(
route for route in app.routes if isinstance(route, starlette.routing.Mount)
)
assert not gradio_fast_api.app.blocks.dev_mode
if __name__ == "__main__":