mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Decrease latency: do not run pre and postprocess in threadpool (#7796)
* revert * add changeset * lint * explicit call --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
d831040032
commit
aad209f0c0
5
.changeset/old-dolls-pump.md
Normal file
5
.changeset/old-dolls-pump.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Decrease latency: do not run pre and postprocess in threadpool
|
@ -1603,15 +1603,6 @@ Received outputs:
|
||||
|
||||
return data
|
||||
|
||||
def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None):
|
||||
output = []
|
||||
for i in zip(*batch):
|
||||
args = [fn_index, list(i), state]
|
||||
if explicit_call is not None:
|
||||
args.append(explicit_call)
|
||||
output.append(fn(*args))
|
||||
return output
|
||||
|
||||
async def process_api(
|
||||
self,
|
||||
fn_index: int,
|
||||
@ -1662,15 +1653,10 @@ Received outputs:
|
||||
raise ValueError(
|
||||
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
|
||||
)
|
||||
inputs = await anyio.to_thread.run_sync(
|
||||
self.run_fn_batch,
|
||||
self.preprocess_data,
|
||||
inputs,
|
||||
fn_index,
|
||||
state,
|
||||
explicit_call,
|
||||
limiter=self.limiter,
|
||||
)
|
||||
inputs = [
|
||||
self.preprocess_data(fn_index, list(i), state, explicit_call)
|
||||
for i in zip(*inputs)
|
||||
]
|
||||
result = await self.call_function(
|
||||
fn_index,
|
||||
list(zip(*inputs)),
|
||||
@ -1681,14 +1667,9 @@ Received outputs:
|
||||
in_event_listener,
|
||||
)
|
||||
preds = result["prediction"]
|
||||
data = await anyio.to_thread.run_sync(
|
||||
self.run_fn_batch,
|
||||
self.postprocess_data,
|
||||
preds,
|
||||
fn_index,
|
||||
state,
|
||||
limiter=self.limiter,
|
||||
)
|
||||
data = [
|
||||
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
|
||||
]
|
||||
if root_path is not None:
|
||||
data = processing_utils.add_root_url(data, root_path, None)
|
||||
data = list(zip(*data))
|
||||
@ -1698,14 +1679,7 @@ Received outputs:
|
||||
if old_iterator:
|
||||
inputs = []
|
||||
else:
|
||||
inputs = await anyio.to_thread.run_sync(
|
||||
self.preprocess_data,
|
||||
fn_index,
|
||||
inputs,
|
||||
state,
|
||||
explicit_call,
|
||||
limiter=self.limiter,
|
||||
)
|
||||
inputs = self.preprocess_data(fn_index, inputs, state, explicit_call)
|
||||
was_generating = old_iterator is not None
|
||||
result = await self.call_function(
|
||||
fn_index,
|
||||
@ -1716,13 +1690,7 @@ Received outputs:
|
||||
event_data,
|
||||
in_event_listener,
|
||||
)
|
||||
data = await anyio.to_thread.run_sync(
|
||||
self.postprocess_data,
|
||||
fn_index, # type: ignore
|
||||
result["prediction"],
|
||||
state,
|
||||
limiter=self.limiter,
|
||||
)
|
||||
data = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
if root_path is not None:
|
||||
data = processing_utils.add_root_url(data, root_path, None)
|
||||
is_generating, iterator = result["is_generating"], result["iterator"]
|
||||
|
Loading…
Reference in New Issue
Block a user