Decrease latency: do not run pre and postprocess in threadpool (#7796)

* revert

* add changeset

* lint

* explicit call

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-03-22 12:38:52 -07:00 committed by GitHub
parent d831040032
commit aad209f0c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 41 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Decrease latency: do not run pre and postprocess in threadpool

View File

@ -1603,15 +1603,6 @@ Received outputs:
return data
def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None):
output = []
for i in zip(*batch):
args = [fn_index, list(i), state]
if explicit_call is not None:
args.append(explicit_call)
output.append(fn(*args))
return output
async def process_api(
self,
fn_index: int,
@ -1662,15 +1653,10 @@ Received outputs:
raise ValueError(
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)
inputs = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.preprocess_data,
inputs,
fn_index,
state,
explicit_call,
limiter=self.limiter,
)
inputs = [
self.preprocess_data(fn_index, list(i), state, explicit_call)
for i in zip(*inputs)
]
result = await self.call_function(
fn_index,
list(zip(*inputs)),
@ -1681,14 +1667,9 @@ Received outputs:
in_event_listener,
)
preds = result["prediction"]
data = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.postprocess_data,
preds,
fn_index,
state,
limiter=self.limiter,
)
data = [
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
]
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
data = list(zip(*data))
@ -1698,14 +1679,7 @@ Received outputs:
if old_iterator:
inputs = []
else:
inputs = await anyio.to_thread.run_sync(
self.preprocess_data,
fn_index,
inputs,
state,
explicit_call,
limiter=self.limiter,
)
inputs = self.preprocess_data(fn_index, inputs, state, explicit_call)
was_generating = old_iterator is not None
result = await self.call_function(
fn_index,
@ -1716,13 +1690,7 @@ Received outputs:
event_data,
in_event_listener,
)
data = await anyio.to_thread.run_sync(
self.postprocess_data,
fn_index, # type: ignore
result["prediction"],
state,
limiter=self.limiter,
)
data = self.postprocess_data(fn_index, result["prediction"], state)
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
is_generating, iterator = result["is_generating"], result["iterator"]