mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Use ContextVar instead of threading.local()
This commit is contained in:
parent
eebf9d71f9
commit
2a76eb46e3
@ -91,9 +91,9 @@ BUILT_IN_THEMES: dict[str, Theme] = {
|
||||
|
||||
|
||||
def in_event_listener():
|
||||
from gradio import context
|
||||
from gradio.context import LocalContext
|
||||
|
||||
return getattr(context.thread_data, "in_event_listener", False)
|
||||
return LocalContext.in_event_listener.get()
|
||||
|
||||
|
||||
def updateable(fn):
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
@ -17,4 +17,7 @@ class Context:
|
||||
hf_token: str | None = None # The token provided when loading private HF repos
|
||||
|
||||
|
||||
thread_data = threading.local()
|
||||
class LocalContext:
|
||||
blocks: ContextVar[Blocks | None] = ContextVar("blocks", default=None)
|
||||
in_event_listener: ContextVar[bool] = ContextVar("in_event_listener", default=False)
|
||||
event_id: ContextVar[str | None] = ContextVar("event_id", default=None)
|
||||
|
@ -1098,22 +1098,23 @@ class EventData:
|
||||
|
||||
|
||||
def log_message(message: str, level: Literal["info", "warning"] = "info"):
|
||||
from gradio import context
|
||||
from gradio.context import LocalContext
|
||||
|
||||
if not hasattr(context.thread_data, "blocks"): # Function called outside of Gradio
|
||||
blocks = LocalContext.blocks.get()
|
||||
if blocks is None: # Function called outside of Gradio
|
||||
if level == "info":
|
||||
print(message)
|
||||
elif level == "warning":
|
||||
warnings.warn(message)
|
||||
return
|
||||
if not context.thread_data.blocks.enable_queue:
|
||||
if not blocks.enable_queue:
|
||||
warnings.warn(
|
||||
f"Queueing must be enabled to issue {level.capitalize()}: '{message}'."
|
||||
)
|
||||
return
|
||||
context.thread_data.blocks._queue.log_message(
|
||||
event_id=context.thread_data.event_id, log=message, level=level
|
||||
)
|
||||
event_id = LocalContext.event_id.get()
|
||||
assert event_id
|
||||
blocks._queue.log_message(event_id=event_id, log=message, level=level)
|
||||
|
||||
|
||||
@document()
|
||||
|
@ -663,16 +663,16 @@ def get_function_with_locals(
|
||||
fn: Callable, blocks: Blocks, event_id: str | None, in_event_listener: bool
|
||||
):
|
||||
def before_fn(blocks, event_id):
|
||||
from gradio.context import thread_data
|
||||
from gradio.context import LocalContext
|
||||
|
||||
thread_data.blocks = blocks
|
||||
thread_data.in_event_listener = in_event_listener
|
||||
thread_data.event_id = event_id
|
||||
LocalContext.blocks.set(blocks)
|
||||
LocalContext.in_event_listener.set(in_event_listener)
|
||||
LocalContext.event_id.set(event_id)
|
||||
|
||||
def after_fn():
|
||||
from gradio.context import thread_data
|
||||
from gradio.context import LocalContext
|
||||
|
||||
thread_data.in_event_listener = False
|
||||
LocalContext.in_event_listener.set(False)
|
||||
|
||||
return function_wrapper(
|
||||
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
|
||||
|
Loading…
Reference in New Issue
Block a user