Run before_fn and after_fn for each generator iteration (#7029)

* Run after_fn for async generators

* add changeset

* Add code

* add changeset

* Add back before/after for generators in function_wrapper

* Rework function_wrapper

* move back to closure

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-01-22 09:57:57 -08:00 committed by GitHub
parent c17533c6b0
commit ac735551bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 11 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Run before_fn and after_fn for each generator iteration

View File

@ -635,12 +635,19 @@ def function_wrapper(
@functools.wraps(f)
async def asyncgen_wrapper(*args, **kwargs):
if before_fn:
before_fn(*before_args)
async for response in f(*args, **kwargs):
iterator = f(*args, **kwargs)
while True:
if before_fn:
before_fn(*before_args)
try:
response = await iterator.__anext__()
except StopAsyncIteration:
if after_fn:
after_fn(*after_args)
break
if after_fn:
after_fn(*after_args)
yield response
if after_fn:
after_fn(*after_args)
return asyncgen_wrapper
@ -661,11 +668,19 @@ def function_wrapper(
@functools.wraps(f)
def gen_wrapper(*args, **kwargs):
if before_fn:
before_fn(*before_args)
yield from f(*args, **kwargs)
if after_fn:
after_fn(*after_args)
iterator = f(*args, **kwargs)
while True:
if before_fn:
before_fn(*before_args)
try:
response = next(iterator)
except StopIteration:
if after_fn:
after_fn(*after_args)
break
if after_fn:
after_fn(*after_args)
yield response
return gen_wrapper
@ -705,7 +720,10 @@ def get_function_with_locals(
LocalContext.request.set(None)
return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
fn,
before_fn=before_fn,
before_args=(blocks, event_id),
after_fn=after_fn,
)