Use ContextVar instead of threading.local()

This commit is contained in:
cbensimon 2023-09-20 09:32:52 +00:00
parent eebf9d71f9
commit 2a76eb46e3
4 changed files with 20 additions and 16 deletions

View File

@ -91,9 +91,9 @@ BUILT_IN_THEMES: dict[str, Theme] = {
def in_event_listener():
from gradio import context
from gradio.context import LocalContext
return getattr(context.thread_data, "in_event_listener", False)
return LocalContext.in_event_listener.get()
def updateable(fn):

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import threading
from contextvars import ContextVar
from typing import TYPE_CHECKING
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
@ -17,4 +17,7 @@ class Context:
hf_token: str | None = None # The token provided when loading private HF repos
thread_data = threading.local()
class LocalContext:
blocks: ContextVar[Blocks | None] = ContextVar("blocks", default=None)
in_event_listener: ContextVar[bool] = ContextVar("in_event_listener", default=False)
event_id: ContextVar[str | None] = ContextVar("event_id", default=None)

View File

@ -1098,22 +1098,23 @@ class EventData:
def log_message(message: str, level: Literal["info", "warning"] = "info"):
from gradio import context
from gradio.context import LocalContext
if not hasattr(context.thread_data, "blocks"): # Function called outside of Gradio
blocks = LocalContext.blocks.get()
if blocks is None: # Function called outside of Gradio
if level == "info":
print(message)
elif level == "warning":
warnings.warn(message)
return
if not context.thread_data.blocks.enable_queue:
if not blocks.enable_queue:
warnings.warn(
f"Queueing must be enabled to issue {level.capitalize()}: '{message}'."
)
return
context.thread_data.blocks._queue.log_message(
event_id=context.thread_data.event_id, log=message, level=level
)
event_id = LocalContext.event_id.get()
assert event_id
blocks._queue.log_message(event_id=event_id, log=message, level=level)
@document()

View File

@ -663,16 +663,16 @@ def get_function_with_locals(
fn: Callable, blocks: Blocks, event_id: str | None, in_event_listener: bool
):
def before_fn(blocks, event_id):
from gradio.context import thread_data
from gradio.context import LocalContext
thread_data.blocks = blocks
thread_data.in_event_listener = in_event_listener
thread_data.event_id = event_id
LocalContext.blocks.set(blocks)
LocalContext.in_event_listener.set(in_event_listener)
LocalContext.event_id.set(event_id)
def after_fn():
from gradio.context import thread_data
from gradio.context import LocalContext
thread_data.in_event_listener = False
LocalContext.in_event_listener.set(False)
return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn