mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
Handle special arguments when extracting parameter names for view API page (#8400)
* fix special args * add changeset * format * ignore typecheck --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
d393a4a224
commit
33c8081aa9
5
.changeset/fair-items-sort.md
Normal file
5
.changeset/fair-items-sort.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Handle special arguments when extracting parameter names for view API page
|
@ -4,6 +4,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
||||
|
||||
@ -448,7 +449,36 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
|
||||
def _setup_api(self) -> None:
|
||||
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
||||
if self.is_generator:
|
||||
|
||||
@functools.wraps(self.fn)
|
||||
async def api_fn(message, history, *args, **kwargs): # type: ignore
|
||||
if self.is_async:
|
||||
generator = self.fn(message, history, *args, **kwargs)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, **kwargs, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
yield first_response, history + [[message, first_response]]
|
||||
except StopIteration:
|
||||
yield None, history + [[message, None]]
|
||||
async for response in generator:
|
||||
yield response, history + [[message, response]]
|
||||
else:
|
||||
|
||||
@functools.wraps(self.fn)
|
||||
async def api_fn(message, history, *args, **kwargs):
|
||||
if self.is_async:
|
||||
response = await self.fn(message, history, *args, **kwargs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, message, history, *args, **kwargs, limiter=self.limiter
|
||||
)
|
||||
history.append([message, response])
|
||||
return response, history
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
@ -575,44 +605,6 @@ class ChatInterface(Blocks):
|
||||
update = history + [[message, response]]
|
||||
yield update, update
|
||||
|
||||
async def _api_submit_fn(
|
||||
self, message: str, history: list[list[str | None]], request: Request, *args
|
||||
) -> tuple[str, list[list[str | None]]]:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
response = await self.fn(*inputs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
history.append([message, response])
|
||||
return response, history
|
||||
|
||||
async def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]], request: Request, *args
|
||||
) -> AsyncGenerator:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
generator = self.fn(*inputs)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
yield first_response, history + [[message, first_response]]
|
||||
except StopIteration:
|
||||
yield None, history + [[message, None]]
|
||||
async for response in generator:
|
||||
yield response, history + [[message, response]]
|
||||
|
||||
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
||||
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
||||
|
||||
|
@ -1302,14 +1302,21 @@ def get_upload_folder() -> str:
|
||||
|
||||
|
||||
def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
|
||||
"""
|
||||
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value).
|
||||
Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken.
|
||||
"""
|
||||
params_info = []
|
||||
signature = inspect.signature(func)
|
||||
type_hints = get_type_hints(func)
|
||||
for name, parameter in signature.parameters.items():
|
||||
if parameter.kind in (
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
):
|
||||
break
|
||||
if is_special_typed_parameter(name, type_hints):
|
||||
continue
|
||||
if parameter.default is inspect.Parameter.empty:
|
||||
params_info.append((name, False, None))
|
||||
else:
|
||||
|
@ -495,6 +495,12 @@ class TestFunctionParams:
|
||||
|
||||
assert get_function_params(func) == [("a", False, None)]
|
||||
|
||||
def test_function_with_special_args(self):
|
||||
def func(a, r: Request, b=10):
|
||||
pass
|
||||
|
||||
assert get_function_params(func) == [("a", False, None), ("b", True, 10)]
|
||||
|
||||
def test_class_method_skip_first_param(self):
|
||||
class MyClass:
|
||||
def method(self, arg1, arg2=42):
|
||||
|
Loading…
Reference in New Issue
Block a user