mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Run before_fn and after_fn for each generator iteration (#7029)
* Run after_fn for async generators * add changeset * Add code * add changeset * Add back before/after for generators in function_wrapper * Rework function_wrapper * move back to closure --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
c17533c6b0
commit
ac735551bb
5
.changeset/big-bears-cover.md
Normal file
5
.changeset/big-bears-cover.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Run before_fn and after_fn for each generator iteration
|
@ -635,12 +635,19 @@ def function_wrapper(
|
||||
|
||||
@functools.wraps(f)
|
||||
async def asyncgen_wrapper(*args, **kwargs):
|
||||
if before_fn:
|
||||
before_fn(*before_args)
|
||||
async for response in f(*args, **kwargs):
|
||||
iterator = f(*args, **kwargs)
|
||||
while True:
|
||||
if before_fn:
|
||||
before_fn(*before_args)
|
||||
try:
|
||||
response = await iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
break
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
yield response
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
|
||||
return asyncgen_wrapper
|
||||
|
||||
@ -661,11 +668,19 @@ def function_wrapper(
|
||||
|
||||
@functools.wraps(f)
|
||||
def gen_wrapper(*args, **kwargs):
|
||||
if before_fn:
|
||||
before_fn(*before_args)
|
||||
yield from f(*args, **kwargs)
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
iterator = f(*args, **kwargs)
|
||||
while True:
|
||||
if before_fn:
|
||||
before_fn(*before_args)
|
||||
try:
|
||||
response = next(iterator)
|
||||
except StopIteration:
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
break
|
||||
if after_fn:
|
||||
after_fn(*after_args)
|
||||
yield response
|
||||
|
||||
return gen_wrapper
|
||||
|
||||
@ -705,7 +720,10 @@ def get_function_with_locals(
|
||||
LocalContext.request.set(None)
|
||||
|
||||
return function_wrapper(
|
||||
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
|
||||
fn,
|
||||
before_fn=before_fn,
|
||||
before_args=(blocks, event_id),
|
||||
after_fn=after_fn,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user