diff --git a/.changeset/fair-items-sort.md b/.changeset/fair-items-sort.md new file mode 100644 index 0000000000..45b7c0ad29 --- /dev/null +++ b/.changeset/fair-items-sort.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Handle special arguments when extracting parameter names for view API page diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index c57140e60c..ff45f82643 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -4,6 +4,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat from __future__ import annotations +import functools import inspect from typing import AsyncGenerator, Callable, Literal, Union, cast @@ -448,7 +449,36 @@ class ChatInterface(Blocks): ) def _setup_api(self) -> None: - api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn + if self.is_generator: + + @functools.wraps(self.fn) + async def api_fn(message, history, *args, **kwargs): # type: ignore + if self.is_async: + generator = self.fn(message, history, *args, **kwargs) + else: + generator = await anyio.to_thread.run_sync( + self.fn, message, history, *args, **kwargs, limiter=self.limiter + ) + generator = SyncToAsyncIterator(generator, self.limiter) + try: + first_response = await async_iteration(generator) + yield first_response, history + [[message, first_response]] + except StopIteration: + yield None, history + [[message, None]] + async for response in generator: + yield response, history + [[message, response]] + else: + + @functools.wraps(self.fn) + async def api_fn(message, history, *args, **kwargs): + if self.is_async: + response = await self.fn(message, history, *args, **kwargs) + else: + response = await anyio.to_thread.run_sync( + self.fn, message, history, *args, **kwargs, limiter=self.limiter + ) + history.append([message, response]) + return response, history self.fake_api_btn.click( api_fn, @@ -575,44 +605,6 @@ class ChatInterface(Blocks): update = history + [[message, response]] yield update, update - async def _api_submit_fn( - self, message: str, history: list[list[str | None]], request: Request, *args - ) -> tuple[str, list[list[str | None]]]: - inputs, _, _ = special_args( - self.fn, inputs=[message, history, *args], request=request - ) - - if self.is_async: - response = await self.fn(*inputs) - else: - response = await anyio.to_thread.run_sync( - self.fn, *inputs, limiter=self.limiter - ) - history.append([message, response]) - return response, history - - async def _api_stream_fn( - self, message: str, history: list[list[str | None]], request: Request, *args - ) -> AsyncGenerator: - inputs, _, _ = special_args( - self.fn, inputs=[message, history, *args], request=request - ) - - if self.is_async: - generator = self.fn(*inputs) - else: - generator = await anyio.to_thread.run_sync( - self.fn, *inputs, limiter=self.limiter - ) - generator = SyncToAsyncIterator(generator, self.limiter) - try: - first_response = await async_iteration(generator) - yield first_response, history + [[message, first_response]] - except StopIteration: - yield None, history + [[message, None]] - async for response in generator: - yield response, history + [[message, response]] - async def _examples_fn(self, message: str, *args) -> list[list[str | None]]: inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None) diff --git a/gradio/utils.py b/gradio/utils.py index df6cc313cf..6553bd08ab 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1302,14 +1302,21 @@ def get_upload_folder() -> str: def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]: + """ + Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value). + Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken. + """ params_info = [] signature = inspect.signature(func) + type_hints = get_type_hints(func) for name, parameter in signature.parameters.items(): if parameter.kind in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, ): break + if is_special_typed_parameter(name, type_hints): + continue if parameter.default is inspect.Parameter.empty: params_info.append((name, False, None)) else: diff --git a/test/test_utils.py b/test/test_utils.py index 2e7b4d91ed..f297818e91 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -495,6 +495,12 @@ class TestFunctionParams: assert get_function_params(func) == [("a", False, None)] + def test_function_with_special_args(self): + def func(a, r: Request, b=10): + pass + + assert get_function_params(func) == [("a", False, None), ("b", True, 10)] + def test_class_method_skip_first_param(self): class MyClass: def method(self, arg1, arg2=42):