mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Adding ability to access in Fastapi request object into your function (#2641)
* fastapi requests * formatting * implement * fix * formatting * formatting * changelog * added demo * remove print * added to guide * changes for queuing * changes to gr.Request * formatting * formatting * fixes * lint * fixed tests * fix batching * fixing tests * cleanup * lint * added tests; fixed review * improve docs
This commit is contained in:
parent
dbfa4dced1
commit
5e148c3752
22
CHANGELOG.md
22
CHANGELOG.md
@ -1,7 +1,23 @@
|
||||
# Upcoming Release
|
||||
|
||||
## New Features:
|
||||
No changes to highlight.
|
||||
|
||||
### Accessing the Requests Object Directly
|
||||
|
||||
You can now access the Request object directly in your Python function by [@abidlabs](https://github.com/abidlabs) in [PR 2641](https://github.com/gradio-app/gradio/pull/2641). This means that you can access request headers, the client IP address, and so on. In order to use it, add a parameter to your function and set its type hint to be `gr.Request`. Here's a simple example:
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
|
||||
def echo(name, request: gr.Request):
|
||||
if request:
|
||||
print("Request headers dictionary:", request.headers)
|
||||
print("IP address:", request.client.host)
|
||||
return name
|
||||
|
||||
io = gr.Interface(echo, "textbox", "textbox").launch()
|
||||
```
|
||||
|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
@ -46,8 +62,6 @@ No changes to highlight.
|
||||
|
||||
|
||||
# 3.10.0
|
||||
|
||||
## New Features:
|
||||
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)
|
||||
* `gr.Textbox` component will now raise an exception if `type` is not "text", "email", or "password" [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653). This will cause demos using the deprecated `gr.Textbox(type="number")` to raise an exception.
|
||||
|
||||
@ -69,7 +83,7 @@ No changes to highlight.
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
No changes to highlight.
|
||||
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)
|
||||
|
||||
## Contributors Shoutout:
|
||||
No changes to highlight.
|
||||
|
@ -11,7 +11,14 @@ with gr.Blocks() as demo:
|
||||
with gr.Tab("Interface"):
|
||||
gr.Interface(lambda x:x, "audio", "audio", examples=[audio_file])
|
||||
with gr.Tab("console"):
|
||||
ip = gr.Textbox(label="User IP Address")
|
||||
gr.Interface(lambda cmd:subprocess.run([cmd], capture_output=True, shell=True).stdout.decode('utf-8').strip(), "text", "text")
|
||||
|
||||
def get_ip(request: gr.Request):
|
||||
return request.client.host
|
||||
|
||||
demo.load(get_ip, None, ip)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue()
|
||||
demo.launch()
|
||||
|
@ -59,7 +59,7 @@ from gradio.interface import Interface, TabbedInterface, close_all
|
||||
from gradio.ipython_ext import load_ipython_extension
|
||||
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
|
||||
from gradio.mix import Parallel, Series
|
||||
from gradio.routes import mount_gradio_app
|
||||
from gradio.routes import Request, mount_gradio_app
|
||||
from gradio.templates import (
|
||||
Files,
|
||||
ImageMask,
|
||||
|
@ -9,6 +9,7 @@ import pkgutil
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
import webbrowser
|
||||
from types import ModuleType
|
||||
@ -58,6 +59,7 @@ from gradio.utils import (
|
||||
|
||||
set_documentation_group("blocks")
|
||||
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import comet_ml
|
||||
import mlflow
|
||||
@ -457,6 +459,23 @@ def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) ->
|
||||
return predictions
|
||||
|
||||
|
||||
def add_request_to_inputs(
|
||||
fn: Callable, inputs: List[Any], request: routes.Request | List[routes.Request]
|
||||
):
|
||||
"""
|
||||
Adds the FastAPI Request object to the inputs of a function if the type of the parameter is FastAPI.Request.
|
||||
"""
|
||||
param_names = inspect.getfullargspec(fn)[0]
|
||||
try:
|
||||
parameter_types = typing.get_type_hints(fn)
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if parameter_types.get(param_name, "") == routes.Request:
|
||||
inputs.insert(idx, request)
|
||||
except TypeError: # A TypeError is raised if the function is a partial or other rare cases.
|
||||
pass
|
||||
return inputs
|
||||
|
||||
|
||||
@document("load")
|
||||
class Blocks(BlockContext):
|
||||
"""
|
||||
@ -796,7 +815,12 @@ class Blocks(BlockContext):
|
||||
if batch:
|
||||
processed_inputs = [[inp] for inp in processed_inputs]
|
||||
|
||||
outputs = utils.synchronize_async(self.process_api, fn_index, processed_inputs)
|
||||
outputs = utils.synchronize_async(
|
||||
self.process_api,
|
||||
fn_index=fn_index,
|
||||
inputs=processed_inputs,
|
||||
request=None,
|
||||
)
|
||||
outputs = outputs["data"]
|
||||
|
||||
if batch:
|
||||
@ -812,11 +836,11 @@ class Blocks(BlockContext):
|
||||
fn_index: int,
|
||||
processed_input: List[Any],
|
||||
iterator: Iterator[Any] | None = None,
|
||||
request: routes.Request | List[routes.Request] | None = None,
|
||||
):
|
||||
"""Calls and times function with given index and preprocessed input."""
|
||||
block_fn = self.fns[fn_index]
|
||||
is_generating = False
|
||||
start = time.time()
|
||||
|
||||
if block_fn.inputs_as_dict:
|
||||
processed_input = [
|
||||
@ -826,6 +850,12 @@ class Blocks(BlockContext):
|
||||
}
|
||||
]
|
||||
|
||||
processed_input = add_request_to_inputs(
|
||||
block_fn.fn, list(processed_input), request
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
if iterator is None: # If not a generator function that has already run
|
||||
if inspect.iscoroutinefunction(block_fn.fn):
|
||||
prediction = await block_fn.fn(*processed_input)
|
||||
@ -944,6 +974,7 @@ class Blocks(BlockContext):
|
||||
self,
|
||||
fn_index: int,
|
||||
inputs: List[Any],
|
||||
request: routes.Request | List[routes.Request] | None = None,
|
||||
username: str = None,
|
||||
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
|
||||
iterators: Dict[int, Any] | None = None,
|
||||
@ -980,7 +1011,7 @@ class Blocks(BlockContext):
|
||||
)
|
||||
|
||||
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
|
||||
result = await self.call_function(fn_index, zip(*inputs), None)
|
||||
result = await self.call_function(fn_index, zip(*inputs), None, request)
|
||||
preds = result["prediction"]
|
||||
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
|
||||
data = list(zip(*data))
|
||||
@ -988,7 +1019,7 @@ class Blocks(BlockContext):
|
||||
else:
|
||||
inputs = self.preprocess_data(fn_index, inputs, state)
|
||||
iterator = iterators.get(fn_index, None) if iterators else None
|
||||
result = await self.call_function(fn_index, inputs, iterator)
|
||||
result = await self.call_function(fn_index, inputs, iterator, request)
|
||||
data = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
is_generating, iterator = result["is_generating"], result["iterator"]
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -10,6 +10,9 @@ class PredictBody(BaseModel):
|
||||
batched: Optional[
|
||||
bool
|
||||
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
|
||||
request: Optional[
|
||||
Union[Dict, List[Dict]]
|
||||
] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
|
||||
|
||||
|
||||
class ResetBody(BaseModel):
|
||||
|
@ -290,7 +290,7 @@ class Examples:
|
||||
if self.batch:
|
||||
processed_input = [[value] for value in processed_input]
|
||||
prediction = await Context.root_block.process_api(
|
||||
fn_index, processed_input
|
||||
fn_index=fn_index, inputs=processed_input, request=None
|
||||
)
|
||||
output = prediction["data"]
|
||||
if self.batch:
|
||||
|
@ -5,14 +5,13 @@ import copy
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import islice
|
||||
from typing import Deque, Dict, List, Optional, Tuple
|
||||
from typing import Any, Deque, Dict, List, Optional, Tuple
|
||||
|
||||
import fastapi
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gradio.dataclasses import PredictBody
|
||||
from gradio.utils import Request, run_coro_in_background, set_task_name
|
||||
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
|
||||
|
||||
|
||||
class Estimation(BaseModel):
|
||||
@ -26,7 +25,11 @@ class Estimation(BaseModel):
|
||||
|
||||
|
||||
class Event:
|
||||
def __init__(self, websocket: fastapi.WebSocket, fn_index: int | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
websocket: fastapi.WebSocket,
|
||||
fn_index: int | None = None,
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.data: PredictBody | None = None
|
||||
self.lost_connection_time: float | None = None
|
||||
@ -157,18 +160,6 @@ class Queue:
|
||||
if self.live_updates:
|
||||
await self.broadcast_estimations()
|
||||
|
||||
async def gather_data_for_first_ranks(self) -> None:
|
||||
"""
|
||||
Gather data for the first x events.
|
||||
"""
|
||||
# Send all messages concurrently
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.gather_event_data(event)
|
||||
for event in islice(self.event_queue, self.data_gathering_start)
|
||||
]
|
||||
)
|
||||
|
||||
async def gather_event_data(self, event: Event) -> bool:
|
||||
"""
|
||||
Gather data for the event
|
||||
@ -253,14 +244,34 @@ class Queue:
|
||||
queue_eta=self.queue_duration,
|
||||
)
|
||||
|
||||
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": str(websocket.url),
|
||||
"headers": dict(websocket.headers),
|
||||
"query_params": dict(websocket.query_params),
|
||||
"path_params": dict(websocket.path_params),
|
||||
"client": dict(host=websocket.client.host, port=websocket.client.port),
|
||||
}
|
||||
|
||||
async def call_prediction(self, events: List[Event], batch: bool):
|
||||
data = events[0].data
|
||||
token = events[0].token
|
||||
try:
|
||||
data.request = self.get_request_params(events[0].websocket)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if batch:
|
||||
data.data = list(zip(*[event.data.data for event in events if event.data]))
|
||||
data.request = [
|
||||
self.get_request_params(event.websocket)
|
||||
for event in events
|
||||
if event.data
|
||||
]
|
||||
data.batched = True
|
||||
response = await Request(
|
||||
method=Request.Method.POST,
|
||||
|
||||
response = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=f"{self.server_path}api/predict",
|
||||
json=dict(data),
|
||||
headers={"Authorization": f"Bearer {self.access_token}"},
|
||||
@ -370,8 +381,8 @@ class Queue:
|
||||
return None
|
||||
|
||||
async def reset_iterators(self, session_hash: str, fn_index: int):
|
||||
await Request(
|
||||
method=Request.Method.POST,
|
||||
await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=f"{self.server_path}reset",
|
||||
json={
|
||||
"session_hash": session_hash,
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import posixpath
|
||||
@ -13,20 +14,20 @@ import traceback
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
import orjson
|
||||
import pkg_resources
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi import Depends, FastAPI, HTTPException, WebSocket, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
import gradio
|
||||
from gradio import encryptor, utils
|
||||
@ -108,7 +109,7 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/user")
|
||||
@app.get("/user/")
|
||||
def get_current_user(request: Request) -> Optional[str]:
|
||||
def get_current_user(request: fastapi.Request) -> Optional[str]:
|
||||
token = request.cookies.get("access-token")
|
||||
return app.tokens.get(token)
|
||||
|
||||
@ -127,13 +128,13 @@ class App(FastAPI):
|
||||
|
||||
@app.get("/token")
|
||||
@app.get("/token/")
|
||||
def get_token(request: Request) -> dict:
|
||||
def get_token(request: fastapi.Request) -> dict:
|
||||
token = request.cookies.get("access-token")
|
||||
return {"token": token, "user": app.tokens.get(token)}
|
||||
|
||||
@app.get("/app_id")
|
||||
@app.get("/app_id/")
|
||||
def app_id(request: Request) -> int:
|
||||
def app_id(request: fastapi.Request) -> int:
|
||||
return {"app_id": app.blocks.app_id}
|
||||
|
||||
@app.post("/login")
|
||||
@ -159,7 +160,7 @@ class App(FastAPI):
|
||||
|
||||
@app.head("/", response_class=HTMLResponse)
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: Request, user: str = Depends(get_current_user)):
|
||||
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
|
||||
if app.auth is None or not (user is None):
|
||||
@ -258,7 +259,9 @@ class App(FastAPI):
|
||||
return {"success": True}
|
||||
|
||||
async def run_predict(
|
||||
body: PredictBody, username: str = Depends(get_current_user)
|
||||
body: PredictBody,
|
||||
request: Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
@ -287,7 +290,12 @@ class App(FastAPI):
|
||||
raw_input = [raw_input]
|
||||
try:
|
||||
output = await app.blocks.process_api(
|
||||
fn_index, raw_input, username, session_state, iterators
|
||||
fn_index=fn_index,
|
||||
inputs=raw_input,
|
||||
request=request,
|
||||
username=username,
|
||||
state=session_state,
|
||||
iterators=iterators,
|
||||
)
|
||||
iterator = output.pop("iterator", None)
|
||||
if hasattr(body, "session_hash"):
|
||||
@ -317,7 +325,7 @@ class App(FastAPI):
|
||||
async def predict(
|
||||
api_name: str,
|
||||
body: PredictBody,
|
||||
request: Request,
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
if body.fn_index is None:
|
||||
@ -345,12 +353,20 @@ class App(FastAPI):
|
||||
# current session hash
|
||||
if app.blocks.dependencies[body.fn_index]["cancels"]:
|
||||
body.data = [body.session_hash]
|
||||
result = await run_predict(body=body, username=username)
|
||||
if body.request:
|
||||
if body.batched:
|
||||
request = [Request(**req) for req in body.request]
|
||||
else:
|
||||
request = Request(**body.request)
|
||||
else:
|
||||
request = Request(request)
|
||||
result = await run_predict(body=body, username=username, request=request)
|
||||
return result
|
||||
|
||||
@app.websocket("/queue/join")
|
||||
async def join_queue(
|
||||
websocket: WebSocket, token: str = Depends(ws_login_check)
|
||||
websocket: WebSocket,
|
||||
token: str = Depends(ws_login_check),
|
||||
):
|
||||
if app.auth is not None and token is None:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
@ -462,6 +478,64 @@ def get_server_url_from_ws_url(ws_url: str):
|
||||
set_documentation_group("routes")
|
||||
|
||||
|
||||
class Obj:
|
||||
"""
|
||||
Using a class to convert dictionaries into objects. Used by the `Request` class.
|
||||
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
|
||||
"""
|
||||
|
||||
def __init__(self, dict1):
|
||||
self.__dict__.update(dict1)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
|
||||
@document()
|
||||
class Request:
|
||||
"""
|
||||
A Gradio request object that can be used to access the request headers, cookies,
|
||||
query parameters and other information about the request from within the prediction
|
||||
function. The class is a thin wrapper around the fastapi.Request class. Attributes
|
||||
of this class include: `headers`, `client`, `query_params`, and `path_params`,
|
||||
|
||||
Example:
|
||||
import gradio as gr
|
||||
def echo(name, request: gr.Request):
|
||||
print("Request headers dictionary:", request.headers)
|
||||
print("IP address:", request.client.host)
|
||||
return name
|
||||
io = gr.Interface(echo, "textbox", "textbox").launch()
|
||||
"""
|
||||
|
||||
def __init__(self, request: fastapi.Request | None = None, **kwargs):
|
||||
"""
|
||||
Can be instantiated with either a fastapi.Request or by manually passing in
|
||||
attributes (needed for websocket-based queueing).
|
||||
"""
|
||||
self.request: fastapi.Request = request
|
||||
self.kwargs: Dict = kwargs
|
||||
|
||||
def dict_to_obj(self, d):
|
||||
if isinstance(d, dict):
|
||||
return json.loads(json.dumps(d), object_hook=Obj)
|
||||
else:
|
||||
return d
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self.request:
|
||||
return self.dict_to_obj(getattr(self.request, name))
|
||||
else:
|
||||
try:
|
||||
obj = self.kwargs[name]
|
||||
except KeyError:
|
||||
raise AttributeError(f"'Request' object has no attribute '{name}'")
|
||||
return self.dict_to_obj(obj)
|
||||
|
||||
|
||||
@document()
|
||||
def mount_gradio_app(
|
||||
app: fastapi.FastAPI,
|
||||
|
@ -12,6 +12,7 @@ import pkgutil
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import StrictVersion
|
||||
@ -395,19 +396,19 @@ def async_iteration(iterator):
|
||||
raise StopAsyncIteration()
|
||||
|
||||
|
||||
class Request:
|
||||
class AsyncRequest:
|
||||
"""
|
||||
The Request class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
|
||||
Compared to making calls by using httpx directly, Request offers more flexibility and control over:
|
||||
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
|
||||
Compared to making calls by using httpx directly, AsyncRequest offers more flexibility and control over:
|
||||
(1) Includes response validation functionality both using validation models and functions.
|
||||
(2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities.
|
||||
(3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one
|
||||
individually in the case of multiple asynchronous request calls and some of them failing.
|
||||
(4) Provides HTTP request types with Request.Method Enum class for ease of usage
|
||||
Request also offers some util functions such as has_exception, is_valid and status to inspect get detailed
|
||||
(4) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
|
||||
AsyncRequest also offers some util functions such as has_exception, is_valid and status to inspect get detailed
|
||||
information about executed request call.
|
||||
|
||||
The basic usage of Request is as follows: create a Request object with inputs(method, url etc.). Then use it
|
||||
The basic usage of AsyncRequest is as follows: create a AsyncRequest object with inputs(method, url etc.). Then use it
|
||||
with the "await" statement, and then you can use util functions to do some post request checks depending on your use-case.
|
||||
Finally, call the get_validated_data function to get the response data.
|
||||
|
||||
@ -466,13 +467,13 @@ class Request:
|
||||
# Create request
|
||||
self._request = self._create_request(method, url, **kwargs)
|
||||
|
||||
def __await__(self) -> Generator[None, Any, "Request"]:
|
||||
def __await__(self) -> Generator[None, Any, "AsyncRequest"]:
|
||||
"""
|
||||
Wrap Request's __await__ magic function to create request calls which are executed in one line.
|
||||
"""
|
||||
return self.__run().__await__()
|
||||
|
||||
async def __run(self) -> Request:
|
||||
async def __run(self) -> AsyncRequest:
|
||||
"""
|
||||
Manage the request call lifecycle.
|
||||
Execute the request by sending it through the client, then check its status.
|
||||
@ -486,7 +487,9 @@ class Request:
|
||||
"""
|
||||
try:
|
||||
# Send the request and get the response.
|
||||
self._response: httpx.Response = await Request.client.send(self._request)
|
||||
self._response: httpx.Response = await AsyncRequest.client.send(
|
||||
self._request
|
||||
)
|
||||
# Raise for _status
|
||||
self._status = self._response.status_code
|
||||
if self._raise_for_status:
|
||||
@ -503,7 +506,7 @@ class Request:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _create_request(method: Method, url: str, **kwargs) -> Request:
|
||||
def _create_request(method: Method, url: str, **kwargs) -> AsyncRequest:
|
||||
"""
|
||||
Create a request. This is a httpx request wrapper function.
|
||||
Args:
|
||||
@ -743,15 +746,24 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
|
||||
Checks if the input component set matches the function
|
||||
Returns: None if valid, a string error message if mismatch
|
||||
"""
|
||||
|
||||
def is_special_typed_parameter(name):
|
||||
from gradio.routes import Request
|
||||
|
||||
"""Checks if parameter has a type hint designating it as a gr.Request"""
|
||||
return parameter_types.get(name, "") == Request
|
||||
|
||||
signature = inspect.signature(fn)
|
||||
parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {}
|
||||
min_args = 0
|
||||
max_args = 0
|
||||
for param in signature.parameters.values():
|
||||
for name, param in signature.parameters.items():
|
||||
has_default = param.default != param.empty
|
||||
if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]:
|
||||
if not has_default:
|
||||
min_args += 1
|
||||
max_args += 1
|
||||
if not (is_special_typed_parameter(name)):
|
||||
if not has_default:
|
||||
min_args += 1
|
||||
max_args += 1
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
max_args = "infinity"
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
|
@ -106,6 +106,27 @@ def same_auth(username, password):
|
||||
demo.launch(auth=same_auth)
|
||||
```
|
||||
|
||||
## Accessing the Network Request Directly
|
||||
|
||||
When a user makes a prediction to your app, you may need the underlying network request, in order to get the request headers (e.g. for advanced authentication), log the client's IP address, or for other reasons. Gradio supports this in a similar manner to FastAPI: simply add a function parameter whose type hint is `gr.Request` and Gradio will pass in the network request as that parameter. Here is an example:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def echo(name, request: gr.Request):
|
||||
if request:
|
||||
print("Request headers dictionary:", request.headers)
|
||||
print("IP address:", request.client.host)
|
||||
return name
|
||||
|
||||
io = gr.Interface(echo, "textbox", "textbox").launch()
|
||||
```
|
||||
|
||||
Note: if your function is called directly instead of through the UI (this happens, for
|
||||
example, when examples are cached), then `request` will be `None`. You should handle
|
||||
this case explicitly to ensure that your app does not throw any errors. That is why
|
||||
we have the explicit check `if request`.
|
||||
|
||||
## Mounting Within Another FastAPI App
|
||||
|
||||
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.
|
||||
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@ -945,6 +946,70 @@ class TestEvery:
|
||||
break
|
||||
|
||||
|
||||
class TestAddRequests:
|
||||
def test_no_type_hints(self):
|
||||
def moo(a, b):
|
||||
return a + b
|
||||
|
||||
inputs = [1, 2]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs
|
||||
|
||||
boo = partial(moo, a=1)
|
||||
inputs = [2]
|
||||
inputs_ = gr.blocks.add_request_to_inputs(boo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs
|
||||
|
||||
def test_no_type_hints_with_request(self):
|
||||
def moo(a: str, b: int):
|
||||
return a + str(b)
|
||||
|
||||
inputs = ["abc", 2]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs
|
||||
|
||||
boo = partial(moo, a="def")
|
||||
inputs = [2]
|
||||
inputs_ = gr.blocks.add_request_to_inputs(boo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs
|
||||
|
||||
def test_type_hints_with_request(self):
|
||||
def moo(a: str, b: gr.Request):
|
||||
return a
|
||||
|
||||
inputs = ["abc"]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs + [request]
|
||||
|
||||
def moo(a: gr.Request, b, c: int):
|
||||
return c
|
||||
|
||||
inputs = ["abc", 5]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == [request] + inputs
|
||||
|
||||
def test_type_hints_with_multiple_requests(self):
|
||||
def moo(a: str, b: gr.Request, c: gr.Request):
|
||||
return a
|
||||
|
||||
inputs = ["abc"]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == inputs + [request, request]
|
||||
|
||||
def moo(a: gr.Request, b, c: int, d: gr.Request):
|
||||
return c
|
||||
|
||||
inputs = ["abc", 5]
|
||||
request = gr.Request()
|
||||
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
|
||||
assert inputs_ == [request] + inputs + [request]
|
||||
|
||||
|
||||
def test_queue_enabled_for_fn():
|
||||
with gr.Blocks() as demo:
|
||||
input = gr.Textbox()
|
||||
|
@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from gradio.queue import Event, Queue
|
||||
from gradio.utils import Request
|
||||
from gradio.utils import AsyncRequest
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -44,17 +44,6 @@ class TestQueueMethods:
|
||||
assert queue.stopped is False
|
||||
assert queue.get_active_worker_count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dont_gather_data_while_broadcasting(self, queue: Queue):
|
||||
queue.broadcast_live_estimations = AsyncMock()
|
||||
queue.gather_data_for_first_ranks = AsyncMock()
|
||||
await queue.broadcast_live_estimations()
|
||||
|
||||
# Should not gather data while broadcasting estimations
|
||||
# Have seen weird race conditions come up in very viral
|
||||
# spaces
|
||||
queue.gather_data_for_first_ranks.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_resume(self, queue: Queue):
|
||||
await queue.start()
|
||||
@ -107,21 +96,6 @@ class TestQueueMethods:
|
||||
assert await queue.gather_event_data(mock_event)
|
||||
assert not (queue.send_message.called)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_data_for_first_ranks(self, queue: Queue, mock_event: Event):
|
||||
websocket = MagicMock()
|
||||
mock_event2 = Event(websocket=websocket, fn_index=0)
|
||||
queue.send_message = AsyncMock()
|
||||
queue.get_message = AsyncMock()
|
||||
queue.send_message.return_value = True
|
||||
queue.get_message.return_value = {"data": ["test"], "fn": 0}
|
||||
|
||||
queue.push(mock_event)
|
||||
queue.push(mock_event2)
|
||||
await queue.gather_data_for_first_ranks()
|
||||
assert mock_event.data is not None
|
||||
assert mock_event2.data is None
|
||||
|
||||
|
||||
class TestQueueEstimation:
|
||||
def test_get_update_estimation(self, queue: Queue):
|
||||
@ -170,7 +144,7 @@ class TestQueueProcessEvents:
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queue.Request", new_callable=AsyncMock)
|
||||
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
|
||||
queue.gather_event_data = AsyncMock()
|
||||
queue.gather_event_data.return_value = True
|
||||
@ -190,7 +164,7 @@ class TestQueueProcessEvents:
|
||||
mock_event.disconnect.assert_called_once()
|
||||
queue.clean_event.assert_called_once()
|
||||
mock_request.assert_called_with(
|
||||
method=Request.Method.POST,
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=f"{queue.server_path}reset",
|
||||
json={
|
||||
"session_hash": mock_event.session_hash,
|
||||
@ -280,7 +254,7 @@ class TestQueueProcessEvents:
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queue.Request", new_callable=AsyncMock)
|
||||
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event_handles_exception_during_disconnect(
|
||||
self, mock_request, queue: Queue, mock_event: Event
|
||||
):
|
||||
@ -295,7 +269,7 @@ class TestQueueProcessEvents:
|
||||
queue.active_jobs = [[mock_event]]
|
||||
await queue.process_events([mock_event], batch=False)
|
||||
mock_request.assert_called_with(
|
||||
method=Request.Method.POST,
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=f"{queue.server_path}reset",
|
||||
json={
|
||||
"session_hash": mock_event.session_hash,
|
||||
|
@ -311,6 +311,23 @@ class TestDevMode:
|
||||
assert not gradio_fast_api.app.blocks.dev_mode
|
||||
|
||||
|
||||
class TestPassingRequest:
|
||||
def test_request_included_with_regular_function(self):
|
||||
def identity(name, request: gr.Request):
|
||||
assert isinstance(request.client.host, str)
|
||||
return name
|
||||
|
||||
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
|
||||
prevent_thread_lock=True,
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post("/api/predict/", json={"data": ["test"]})
|
||||
assert response.status_code == 200
|
||||
output = dict(response.json())
|
||||
assert output["data"] == ["test"]
|
||||
|
||||
|
||||
def test_predict_route_is_blocked_if_api_open_false():
|
||||
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
|
||||
api_open=False
|
||||
|
@ -18,7 +18,7 @@ from gradio.test_data.blocks_configs import (
|
||||
XRAY_CONFIG_WITH_MISTAKE,
|
||||
)
|
||||
from gradio.utils import (
|
||||
Request,
|
||||
AsyncRequest,
|
||||
append_unique_suffix,
|
||||
assert_configs_are_equivalent_besides_ids,
|
||||
colab_check,
|
||||
@ -256,15 +256,15 @@ async def client():
|
||||
A fixture to mock the async client object.
|
||||
"""
|
||||
async with AsyncClient() as mock_client:
|
||||
with mock.patch("gradio.utils.Request.client", mock_client):
|
||||
with mock.patch("gradio.utils.AsyncRequest.client", mock_client):
|
||||
yield
|
||||
|
||||
|
||||
class TestRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(self):
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.GET,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.GET,
|
||||
url="http://headers.jsontest.com/",
|
||||
)
|
||||
validated_data = client_response.get_validated_data()
|
||||
@ -273,8 +273,8 @@ class TestRequest:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(self):
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
)
|
||||
@ -291,8 +291,8 @@ class TestRequest:
|
||||
id: str
|
||||
createdAt: str
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_model=TestModel,
|
||||
@ -305,8 +305,8 @@ class TestRequest:
|
||||
name: Literal["John"] = "John"
|
||||
job: str
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_model=TestModel,
|
||||
@ -330,8 +330,8 @@ async def test_get(respx_mock):
|
||||
make_mock_response({"Host": "headers.jsontest.com"})
|
||||
)
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.GET,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.GET,
|
||||
url=MOCK_REQUEST_URL,
|
||||
)
|
||||
validated_data = client_response.get_validated_data()
|
||||
@ -345,8 +345,8 @@ async def test_post(respx_mock):
|
||||
payload = {"name": "morpheus", "job": "leader"}
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response(payload))
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=MOCK_REQUEST_URL,
|
||||
json=payload,
|
||||
)
|
||||
@ -375,8 +375,8 @@ async def test_validate_with_model(respx_mock):
|
||||
id: str
|
||||
createdAt: str
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=MOCK_REQUEST_URL,
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_model=TestModel,
|
||||
@ -387,14 +387,14 @@ async def test_validate_with_model(respx_mock):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_fail_with_model(respx_mock):
|
||||
class TestModel(BaseModel):
|
||||
name: Literal[str] = "John"
|
||||
name: Literal["John"]
|
||||
job: str
|
||||
|
||||
payload = {"name": "morpheus", "job": "leader"}
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response(payload))
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=MOCK_REQUEST_URL,
|
||||
json=payload,
|
||||
validation_model=TestModel,
|
||||
@ -405,7 +405,7 @@ async def test_validate_and_fail_with_model(respx_mock):
|
||||
assert isinstance(client_response.exception, Exception)
|
||||
|
||||
|
||||
@mock.patch("gradio.utils.Request._validate_response_data")
|
||||
@mock.patch("gradio.utils.AsyncRequest._validate_response_data")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_type(validate_response_data, respx_mock):
|
||||
class ResponseValidationException(Exception):
|
||||
@ -415,8 +415,8 @@ async def test_exception_type(validate_response_data, respx_mock):
|
||||
|
||||
respx_mock.get(MOCK_REQUEST_URL).mock(Response(201))
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.GET,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.GET,
|
||||
url=MOCK_REQUEST_URL,
|
||||
exception_type=ResponseValidationException,
|
||||
)
|
||||
@ -435,8 +435,8 @@ async def test_validate_with_function(respx_mock):
|
||||
return response
|
||||
raise Exception
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=MOCK_REQUEST_URL,
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_function=has_name,
|
||||
@ -457,8 +457,8 @@ async def test_validate_and_fail_with_function(respx_mock):
|
||||
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response({"name": "morpheus"}))
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
url=MOCK_REQUEST_URL,
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_function=has_name,
|
||||
|
Loading…
x
Reference in New Issue
Block a user