From a773eaf7504abb53b99885b3454dc1e027adbb42 Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Fri, 18 Aug 2023 07:54:58 -0700 Subject: [PATCH] Stop passing inputs and preprocessing on iterators (#5260) * changes * add changeset * changes * add changeset * Update blocks.py * changes --------- Co-authored-by: gradio-pr-bot --- .changeset/hip-queens-grin.md | 5 +++++ gradio/blocks.py | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 13 deletions(-) create mode 100644 .changeset/hip-queens-grin.md diff --git a/.changeset/hip-queens-grin.md b/.changeset/hip-queens-grin.md new file mode 100644 index 0000000000..4a8418cc96 --- /dev/null +++ b/.changeset/hip-queens-grin.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Stop passing inputs and preprocessing on iterators diff --git a/gradio/blocks.py b/gradio/blocks.py index 46be3864ec..98f390d480 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1080,23 +1080,21 @@ class Blocks(BlockContext): block_fn = self.fns[fn_index] assert block_fn.fn, f"function with index {fn_index} not defined." is_generating = False - - if block_fn.inputs_as_dict: - processed_input = [dict(zip(block_fn.inputs, processed_input))] - request = requests[0] if isinstance(requests, list) else requests - processed_input, progress_index, _ = special_args( - block_fn.fn, processed_input, request, event_data - ) - progress_tracker = ( - processed_input[progress_index] if progress_index is not None else None - ) - start = time.time() - fn = utils.get_function_with_locals(block_fn.fn, self, event_id) if iterator is None: # If not a generator function that has already run + if block_fn.inputs_as_dict: + processed_input = [dict(zip(block_fn.inputs, processed_input))] + + processed_input, progress_index, _ = special_args( + block_fn.fn, processed_input, request, event_data + ) + progress_tracker = ( + processed_input[progress_index] if progress_index is not None else None + ) + if progress_tracker is not None and progress_index is not None: progress_tracker, fn = create_tracker( self, event_id, fn, progress_tracker.track_tqdm @@ -1425,8 +1423,11 @@ Received outputs: data = list(zip(*data)) is_generating, iterator = None, None else: - inputs = self.preprocess_data(fn_index, inputs, state) old_iterator = iterators.get(fn_index, None) if iterators else None + if old_iterator: + inputs = [] + else: + inputs = self.preprocess_data(fn_index, inputs, state) was_generating = old_iterator is not None result = await self.call_function( fn_index, inputs, old_iterator, request, event_id, event_data