mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Allow users to add a custom API route (#10332)
* changes * add changeset * changes * add changeset * changes * changes * changes * chagnes * changes --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
decb594455
commit
e742dcccb3
8
.changeset/sour-apples-boil.md
Normal file
8
.changeset/sour-apples-boil.md
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
"@gradio/client": patch
|
||||
"@gradio/core": patch
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Allow users to add a custom API route
|
@ -181,6 +181,7 @@ export function get_type(
|
||||
serializer: string,
|
||||
signature_type: "return" | "parameter"
|
||||
): string | undefined {
|
||||
if (component === "Api") return type.type;
|
||||
switch (type?.type) {
|
||||
case "string":
|
||||
return "string";
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@ -19,7 +20,17 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
@ -994,6 +1005,93 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
|
||||
raise APIInfoParseError(f"Cannot parse schema {schema}")
|
||||
|
||||
|
||||
def python_type_to_json_schema(type_hint: Any) -> dict:
|
||||
try:
|
||||
return _python_type_to_json_schema(type_hint)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _python_type_to_json_schema(type_hint: Any) -> dict:
|
||||
"""Convert a Python type hint to a JSON schema."""
|
||||
if type_hint is type(None):
|
||||
return {"type": "null"}
|
||||
if type_hint is str:
|
||||
return {"type": "string"}
|
||||
if type_hint is int:
|
||||
return {"type": "integer"}
|
||||
if type_hint is float:
|
||||
return {"type": "number"}
|
||||
if type_hint is bool:
|
||||
return {"type": "boolean"}
|
||||
|
||||
origin = get_origin(type_hint)
|
||||
|
||||
if origin is Literal:
|
||||
literal_values = get_args(type_hint)
|
||||
if len(literal_values) == 1:
|
||||
return {"const": literal_values[0]}
|
||||
return {"enum": list(literal_values)}
|
||||
|
||||
if origin is Union or str(origin) == "|":
|
||||
types = get_args(type_hint)
|
||||
if len(types) == 2 and type(None) in types:
|
||||
other_type = next(t for t in types if t is not type(None))
|
||||
schema = _python_type_to_json_schema(other_type)
|
||||
if "type" in schema:
|
||||
schema["type"] = [schema["type"], "null"]
|
||||
else:
|
||||
schema["oneOf"] = [{"type": "null"}, schema]
|
||||
return schema
|
||||
return {"anyOf": [_python_type_to_json_schema(t) for t in types]}
|
||||
|
||||
if origin is list:
|
||||
item_type = get_args(type_hint)[0]
|
||||
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
|
||||
if origin is tuple:
|
||||
types = get_args(type_hint)
|
||||
return {
|
||||
"type": "array",
|
||||
"prefixItems": [_python_type_to_json_schema(t) for t in types],
|
||||
"minItems": len(types),
|
||||
"maxItems": len(types),
|
||||
}
|
||||
|
||||
if origin is dict:
|
||||
key_type, value_type = get_args(type_hint)
|
||||
if key_type is not str:
|
||||
raise ValueError("JSON Schema only supports string keys in objects")
|
||||
schema = {
|
||||
"type": "object",
|
||||
"additionalProperties": _python_type_to_json_schema(value_type),
|
||||
}
|
||||
return schema
|
||||
|
||||
if inspect.isclass(type_hint) and hasattr(type_hint, "__annotations__"):
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
hints = get_type_hints(type_hint)
|
||||
for field_name, field_type in hints.items():
|
||||
properties[field_name] = _python_type_to_json_schema(field_type)
|
||||
if hasattr(type_hint, "__total__"):
|
||||
if type_hint.__total__:
|
||||
required.append(field_name)
|
||||
elif (
|
||||
not hasattr(type_hint, "__dataclass_fields__")
|
||||
or not type_hint.__dataclass_fields__[field_name].default
|
||||
):
|
||||
required.append(field_name)
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
return schema
|
||||
|
||||
if type_hint is Any:
|
||||
return {}
|
||||
|
||||
|
||||
def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
|
||||
"""
|
||||
Traverse a JSON object and apply a function to each element that satisfies the is_root condition.
|
||||
|
@ -75,6 +75,7 @@ from gradio.events import (
|
||||
RetryData,
|
||||
SelectData,
|
||||
UndoData,
|
||||
api,
|
||||
on,
|
||||
)
|
||||
from gradio.exceptions import Error
|
||||
|
@ -779,11 +779,12 @@ class BlocksConfig:
|
||||
if fn is not None and not cancels:
|
||||
check_function_inputs_match(fn, inputs, inputs_as_dict)
|
||||
|
||||
if _targets[0][1] in ["change", "key_up"] and trigger_mode is None:
|
||||
trigger_mode = "always_last"
|
||||
elif _targets[0][1] in ["stream"] and trigger_mode is None:
|
||||
trigger_mode = "multiple"
|
||||
elif trigger_mode is None:
|
||||
if len(_targets) and trigger_mode is None:
|
||||
if _targets[0][1] in ["change", "key_up"]:
|
||||
trigger_mode = "always_last"
|
||||
elif _targets[0][1] in ["stream"]:
|
||||
trigger_mode = "multiple"
|
||||
if trigger_mode is None:
|
||||
trigger_mode = "once"
|
||||
elif trigger_mode not in ["once", "multiple", "always_last"]:
|
||||
raise ValueError(
|
||||
|
@ -48,7 +48,7 @@ INTERFACE_TEMPLATE = '''
|
||||
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
||||
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
||||
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
|
||||
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
scroll_to_output: if True, will scroll to output component on completion
|
||||
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
|
||||
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
|
46
gradio/components/api_component.py
Normal file
46
gradio/components/api_component.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""gr.Api() component."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from gradio.components.base import Component
|
||||
|
||||
|
||||
class Api(Component):
|
||||
"""
|
||||
A generic component that holds any value. Used for generating APIs with no actual frontend component.
|
||||
"""
|
||||
|
||||
EVENTS = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
_api_info: dict[str, str],
|
||||
label: str = "API",
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
value: default value.
|
||||
"""
|
||||
self._api_info = _api_info
|
||||
super().__init__(value=value, label=label)
|
||||
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
return payload
|
||||
|
||||
def postprocess(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
def api_info(self) -> dict[str, str]:
|
||||
return self._api_info
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return self.value if self.value is not None else "..."
|
||||
|
||||
def example_value(self) -> Any:
|
||||
return self.value if self.value is not None else "..."
|
||||
|
||||
# def get_block_name(self) -> str:
|
||||
# return "state" # so that it does not render in the frontend, just like state
|
126
gradio/events.py
126
gradio/events.py
@ -23,8 +23,10 @@ if TYPE_CHECKING:
|
||||
from gradio.blocks import Block, BlockContext, Component
|
||||
from gradio.components import Timer
|
||||
|
||||
from gradio_client.utils import python_type_to_json_schema
|
||||
|
||||
from gradio.context import get_blocks_context
|
||||
from gradio.utils import get_cancelled_fn_indices
|
||||
from gradio.utils import get_cancelled_fn_indices, get_function_params, get_return_types
|
||||
|
||||
|
||||
def set_cancel_events(
|
||||
@ -760,7 +762,7 @@ def on(
|
||||
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
||||
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
||||
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
|
||||
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
scroll_to_output: If True, will scroll to output component on completion
|
||||
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
@ -886,6 +888,126 @@ def on(
|
||||
return Dependency(None, dep.get_config(), dep_index, fn)
|
||||
|
||||
|
||||
@document()
|
||||
def api(
|
||||
fn: Callable | Literal["decorator"] = "decorator",
|
||||
*,
|
||||
api_name: str | None | Literal[False] = None,
|
||||
queue: bool = True,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
concurrency_id: str | None = None,
|
||||
show_api: bool = True,
|
||||
time_limit: int | None = None,
|
||||
stream_every: float = 0.5,
|
||||
) -> Dependency:
|
||||
"""
|
||||
Sets up an API endpoint for a generic function that can be called via the gradio client. Derives its type from type-hints in the function signature.
|
||||
|
||||
Parameters:
|
||||
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
||||
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
||||
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
||||
concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
|
||||
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
|
||||
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
|
||||
time_limit: The time limit for the function to run. Parameter only used for the `.stream()` event.
|
||||
stream_every: The latency (in seconds) at which stream chunks are sent to the backend. Defaults to 0.5 seconds. Parameter only used for the `.stream()` event.
|
||||
Example:
|
||||
import gradio as gr
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
input = gr.Textbox()
|
||||
button = gr.Button("Submit")
|
||||
output = gr.Textbox()
|
||||
gr.on(
|
||||
triggers=[button.click, input.submit],
|
||||
fn=lambda x: x,
|
||||
inputs=[input],
|
||||
outputs=[output]
|
||||
)
|
||||
demo.launch()
|
||||
"""
|
||||
if fn == "decorator":
|
||||
|
||||
def wrapper(func):
|
||||
api(
|
||||
fn=func,
|
||||
api_name=api_name,
|
||||
queue=queue,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
show_api=show_api,
|
||||
time_limit=time_limit,
|
||||
stream_every=stream_every,
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return Dependency(None, {}, None, wrapper)
|
||||
|
||||
root_block = get_blocks_context()
|
||||
if root_block is None:
|
||||
raise Exception("Cannot call api() outside of a gradio.Blocks context.")
|
||||
|
||||
from gradio.components.api_component import Api
|
||||
|
||||
fn_params = get_function_params(fn)
|
||||
return_types = get_return_types(fn)
|
||||
|
||||
def ordinal(n):
|
||||
return f"{n}{'th' if 10 <= n % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th')}"
|
||||
|
||||
if any(param[3] is None for param in fn_params):
|
||||
raise ValueError(
|
||||
"API endpoints must have type hints. Please specify a type hint for all parameters."
|
||||
)
|
||||
inputs = [
|
||||
Api(
|
||||
default_value if has_default else None,
|
||||
python_type_to_json_schema(_type),
|
||||
ordinal(i + 1),
|
||||
)
|
||||
for i, (_, has_default, default_value, _type) in enumerate(fn_params)
|
||||
]
|
||||
outputs = [
|
||||
Api(None, python_type_to_json_schema(type), ordinal(i + 1))
|
||||
for i, type in enumerate(return_types)
|
||||
]
|
||||
|
||||
dep, dep_index = root_block.set_event_trigger(
|
||||
[],
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess=False,
|
||||
postprocess=False,
|
||||
scroll_to_output=False,
|
||||
show_progress="hidden",
|
||||
api_name=api_name,
|
||||
js=None,
|
||||
concurrency_limit=concurrency_limit,
|
||||
concurrency_id=concurrency_id,
|
||||
queue=queue,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
show_api=show_api,
|
||||
trigger_mode=None,
|
||||
time_limit=time_limit,
|
||||
stream_every=stream_every,
|
||||
)
|
||||
return Dependency(None, dep.get_config(), dep_index, fn)
|
||||
|
||||
|
||||
class Events:
|
||||
change = EventListener(
|
||||
"change",
|
||||
|
@ -42,7 +42,7 @@ from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from types import ModuleType, NoneType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -50,6 +50,8 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import anyio
|
||||
@ -1299,9 +1301,9 @@ def get_upload_folder() -> str:
|
||||
)
|
||||
|
||||
|
||||
def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
|
||||
def get_function_params(func: Callable) -> list[tuple[str, bool, Any, Any]]:
|
||||
"""
|
||||
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value).
|
||||
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value, type_hint).
|
||||
Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken.
|
||||
"""
|
||||
params_info = []
|
||||
@ -1316,12 +1318,26 @@ def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
|
||||
if is_special_typed_parameter(name, type_hints):
|
||||
continue
|
||||
if parameter.default is inspect.Parameter.empty:
|
||||
params_info.append((name, False, None))
|
||||
params_info.append((name, False, None, type_hints.get(name, None)))
|
||||
else:
|
||||
params_info.append((name, True, parameter.default))
|
||||
params_info.append(
|
||||
(name, True, parameter.default, type_hints.get(name, None))
|
||||
)
|
||||
return params_info
|
||||
|
||||
|
||||
def get_return_types(func: Callable) -> list:
|
||||
return_hint = inspect.signature(func).return_annotation
|
||||
|
||||
if return_hint in {inspect.Signature.empty, None, NoneType}:
|
||||
return []
|
||||
|
||||
if get_origin(return_hint) == tuple:
|
||||
return list(get_args(return_hint))
|
||||
|
||||
return [return_hint]
|
||||
|
||||
|
||||
def simplify_file_data_in_str(s):
|
||||
"""
|
||||
If a FileData dictionary has been dumped as part of a string, this function will replace the dict with just the str filepath
|
||||
|
@ -368,7 +368,7 @@
|
||||
.endpoint-container {
|
||||
margin-top: var(--size-3);
|
||||
margin-bottom: var(--size-3);
|
||||
border: 1px solid var(--body-text-color);
|
||||
border: 1px solid var(--block-border-color);
|
||||
border-radius: var(--radius-xl);
|
||||
padding: var(--size-3);
|
||||
padding-top: 0;
|
||||
|
@ -9,7 +9,7 @@
|
||||
let bash_install = "curl --version";
|
||||
</script>
|
||||
|
||||
<Block border_mode="contrast">
|
||||
<Block>
|
||||
<code>
|
||||
{#if current_language === "python"}
|
||||
<div class="copy">
|
||||
|
@ -638,6 +638,7 @@ export function get_component(
|
||||
} {
|
||||
let example_component_map: Map<ComponentMeta["type"], LoadingComponent> =
|
||||
new Map();
|
||||
if (type === "api") type = "state";
|
||||
if (type === "dataset" && example_components) {
|
||||
(example_components as string[]).forEach((name: string) => {
|
||||
if (example_component_map.has(name)) {
|
||||
|
@ -12009,7 +12009,7 @@ snapshots:
|
||||
|
||||
code-red@1.0.4:
|
||||
dependencies:
|
||||
'@jridgewell/sourcemap-codec': 1.4.15
|
||||
'@jridgewell/sourcemap-codec': 1.5.0
|
||||
'@types/estree': 1.0.5
|
||||
acorn: 8.11.3
|
||||
estree-walker: 3.0.3
|
||||
|
@ -694,6 +694,18 @@ class TestRoutes:
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_api_listener(connect):
|
||||
with gr.Blocks() as demo:
|
||||
|
||||
def fn(a: int, b: int, c: str) -> tuple[int, str]:
|
||||
return a + b, c[a:b]
|
||||
|
||||
gr.api(fn, api_name="addition")
|
||||
|
||||
with connect(demo) as client:
|
||||
assert client.predict(a=1, b=3, c="testing", api_name="/addition") == (4, "es")
|
||||
|
||||
|
||||
class TestApp:
|
||||
def test_create_app(self):
|
||||
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
|
||||
|
@ -533,14 +533,14 @@ def test_diff(old, new, expected_diff):
|
||||
|
||||
class TestFunctionParams:
|
||||
def test_regular_function(self):
|
||||
def func(a, b=10, c="default", d=None):
|
||||
def func(a: int, b: int = 10, c: str = "default", d=None):
|
||||
pass
|
||||
|
||||
assert get_function_params(func) == [
|
||||
("a", False, None),
|
||||
("b", True, 10),
|
||||
("c", True, "default"),
|
||||
("d", True, None),
|
||||
("a", False, None, int),
|
||||
("b", True, 10, int),
|
||||
("c", True, "default", str),
|
||||
("d", True, None, None),
|
||||
]
|
||||
|
||||
def test_function_no_params(self):
|
||||
@ -551,32 +551,38 @@ class TestFunctionParams:
|
||||
|
||||
def test_lambda_function(self):
|
||||
assert get_function_params(lambda x, y: x + y) == [
|
||||
("x", False, None),
|
||||
("y", False, None),
|
||||
("x", False, None, None),
|
||||
("y", False, None, None),
|
||||
]
|
||||
|
||||
def test_function_with_args(self):
|
||||
def func(a, *args):
|
||||
pass
|
||||
|
||||
assert get_function_params(func) == [("a", False, None)]
|
||||
assert get_function_params(func) == [("a", False, None, None)]
|
||||
|
||||
def test_function_with_kwargs(self):
|
||||
def func(a, **kwargs):
|
||||
pass
|
||||
|
||||
assert get_function_params(func) == [("a", False, None)]
|
||||
assert get_function_params(func) == [("a", False, None, None)]
|
||||
|
||||
def test_function_with_special_args(self):
|
||||
def func(a, r: Request, b=10):
|
||||
pass
|
||||
|
||||
assert get_function_params(func) == [("a", False, None), ("b", True, 10)]
|
||||
assert get_function_params(func) == [
|
||||
("a", False, None, None),
|
||||
("b", True, 10, None),
|
||||
]
|
||||
|
||||
def func2(a, r: Request | None = None, b="abc"):
|
||||
pass
|
||||
|
||||
assert get_function_params(func2) == [("a", False, None), ("b", True, "abc")]
|
||||
assert get_function_params(func2) == [
|
||||
("a", False, None, None),
|
||||
("b", True, "abc", None),
|
||||
]
|
||||
|
||||
def test_class_method_skip_first_param(self):
|
||||
class MyClass:
|
||||
@ -584,8 +590,8 @@ class TestFunctionParams:
|
||||
pass
|
||||
|
||||
assert get_function_params(MyClass().method) == [
|
||||
("arg1", False, None),
|
||||
("arg2", True, 42),
|
||||
("arg1", False, None, None),
|
||||
("arg2", True, 42, None),
|
||||
]
|
||||
|
||||
def test_static_method_no_skip(self):
|
||||
@ -595,8 +601,8 @@ class TestFunctionParams:
|
||||
pass
|
||||
|
||||
assert get_function_params(MyClass.method) == [
|
||||
("arg1", False, None),
|
||||
("arg2", True, 42),
|
||||
("arg1", False, None, None),
|
||||
("arg2", True, 42, None),
|
||||
]
|
||||
|
||||
def test_class_method_with_args(self):
|
||||
@ -604,13 +610,13 @@ class TestFunctionParams:
|
||||
def method(self, a, *args, b=42):
|
||||
pass
|
||||
|
||||
assert get_function_params(MyClass().method) == [("a", False, None)]
|
||||
assert get_function_params(MyClass().method) == [("a", False, None, None)]
|
||||
|
||||
def test_lambda_with_args(self):
|
||||
assert get_function_params(lambda x, *args: x) == [("x", False, None)]
|
||||
assert get_function_params(lambda x, *args: x) == [("x", False, None, None)]
|
||||
|
||||
def test_lambda_with_kwargs(self):
|
||||
assert get_function_params(lambda x, **kwargs: x) == [("x", False, None)]
|
||||
assert get_function_params(lambda x, **kwargs: x) == [("x", False, None, None)]
|
||||
|
||||
|
||||
def test_parse_file_size():
|
||||
|
Loading…
Reference in New Issue
Block a user