Handle special arguments when extracting parameter names for view API page (#8400)

* fix special args

* add changeset

* format

* ignore typecheck

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-06-04 08:12:10 -07:00 committed by GitHub
parent d393a4a224
commit 33c8081aa9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 39 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Handle special arguments when extracting parameter names for view API page

View File

@ -4,6 +4,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import functools
import inspect
from typing import AsyncGenerator, Callable, Literal, Union, cast
@ -448,7 +449,36 @@ class ChatInterface(Blocks):
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
if self.is_generator:
@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs): # type: ignore
if self.is_async:
generator = self.fn(message, history, *args, **kwargs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
yield response, history + [[message, response]]
else:
@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs):
if self.is_async:
response = await self.fn(message, history, *args, **kwargs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
history.append([message, response])
return response, history
self.fake_api_btn.click(
api_fn,
@ -575,44 +605,6 @@ class ChatInterface(Blocks):
update = history + [[message, response]]
yield update, update
async def _api_submit_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> tuple[str, list[list[str | None]]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
history.append([message, response])
return response, history
async def _api_stream_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
yield response, history + [[message, response]]
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)

View File

@ -1302,14 +1302,21 @@ def get_upload_folder() -> str:
def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
"""
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value).
Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken.
"""
params_info = []
signature = inspect.signature(func)
type_hints = get_type_hints(func)
for name, parameter in signature.parameters.items():
if parameter.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
break
if is_special_typed_parameter(name, type_hints):
continue
if parameter.default is inspect.Parameter.empty:
params_info.append((name, False, None))
else:

View File

@ -495,6 +495,12 @@ class TestFunctionParams:
assert get_function_params(func) == [("a", False, None)]
def test_function_with_special_args(self):
def func(a, r: Request, b=10):
pass
assert get_function_params(func) == [("a", False, None), ("b", True, 10)]
def test_class_method_skip_first_param(self):
class MyClass:
def method(self, arg1, arg2=42):