diff --git a/.changeset/old-dolls-pump.md b/.changeset/old-dolls-pump.md new file mode 100644 index 0000000000..aaef531db0 --- /dev/null +++ b/.changeset/old-dolls-pump.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Decrease latency: do not run pre and postprocess in threadpool diff --git a/gradio/blocks.py b/gradio/blocks.py index 1d04be05cd..897641fdb5 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1603,15 +1603,6 @@ Received outputs: return data - def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None): - output = [] - for i in zip(*batch): - args = [fn_index, list(i), state] - if explicit_call is not None: - args.append(explicit_call) - output.append(fn(*args)) - return output - async def process_api( self, fn_index: int, @@ -1662,15 +1653,10 @@ Received outputs: raise ValueError( f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})" ) - inputs = await anyio.to_thread.run_sync( - self.run_fn_batch, - self.preprocess_data, - inputs, - fn_index, - state, - explicit_call, - limiter=self.limiter, - ) + inputs = [ + self.preprocess_data(fn_index, list(i), state, explicit_call) + for i in zip(*inputs) + ] result = await self.call_function( fn_index, list(zip(*inputs)), @@ -1681,14 +1667,9 @@ Received outputs: in_event_listener, ) preds = result["prediction"] - data = await anyio.to_thread.run_sync( - self.run_fn_batch, - self.postprocess_data, - preds, - fn_index, - state, - limiter=self.limiter, - ) + data = [ + self.postprocess_data(fn_index, list(o), state) for o in zip(*preds) + ] if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) data = list(zip(*data)) @@ -1698,14 +1679,7 @@ Received outputs: if old_iterator: inputs = [] else: - inputs = await anyio.to_thread.run_sync( - self.preprocess_data, - fn_index, - inputs, - state, - explicit_call, - limiter=self.limiter, - ) + inputs = self.preprocess_data(fn_index, inputs, state, explicit_call) was_generating = old_iterator is not None result = await self.call_function( fn_index, @@ -1716,13 +1690,7 @@ Received outputs: event_data, in_event_listener, ) - data = await anyio.to_thread.run_sync( - self.postprocess_data, - fn_index, # type: ignore - result["prediction"], - state, - limiter=self.limiter, - ) + data = self.postprocess_data(fn_index, result["prediction"], state) if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) is_generating, iterator = result["is_generating"], result["iterator"]