mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Refactor CORS Middleware to be much faster (#7801)
* changes * add changeset * log * add changeset * changes * changes * lint * middlware * lint * lint * s implified docstring * fix * revert test change * docstring * remove print * update --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
8b099a07a5
commit
05db0c4a59
5
.changeset/stale-grapes-roll.md
Normal file
5
.changeset/stale-grapes-roll.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Refactor CORS Middleware to be much faster
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
@ -33,9 +34,10 @@ import httpx
|
||||
import multipart
|
||||
from gradio_client.documentation import document
|
||||
from multipart.multipart import parse_options_header
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile
|
||||
from starlette.formparsers import MultiPartException, MultipartPart
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
from gradio.data_classes import PredictBody
|
||||
@ -648,41 +650,86 @@ def get_hostname(url: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
class CustomCORSMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: fastapi.Request, call_next):
|
||||
host: str = request.headers.get("host", "")
|
||||
origin: str = request.headers.get("origin", "")
|
||||
host_name = get_hostname(host)
|
||||
origin_name = get_hostname(origin)
|
||||
class CustomCORSMiddleware:
|
||||
# This is a modified version of the Starlette CORSMiddleware that restricts the allowed origins when the host is localhost.
|
||||
# Adapted from: https://github.com/encode/starlette/blob/89fae174a1ea10f59ae248fe030d9b7e83d0b8a0/starlette/middleware/cors.py
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
|
||||
self.preflight_headers = {
|
||||
"Access-Control-Allow-Methods": ", ".join(self.all_methods),
|
||||
"Access-Control-Max-Age": str(600),
|
||||
}
|
||||
self.simple_headers = {"Access-Control-Allow-Credentials": "true"}
|
||||
# Any of these hosts suggests that the Gradio app is running locally.
|
||||
# Note: "null" is a special case that happens if a Gradio app is running
|
||||
# as an embedded web component in a local static webpage.
|
||||
localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
|
||||
is_preflight = (
|
||||
request.method == "OPTIONS"
|
||||
and "access-control-request-method" in request.headers
|
||||
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin")
|
||||
if origin is None:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers:
|
||||
response = self.preflight_response(request_headers=headers)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
await self.simple_response(scope, receive, send, request_headers=headers)
|
||||
|
||||
def preflight_response(self, request_headers: Headers) -> Response:
|
||||
headers = dict(self.preflight_headers)
|
||||
origin = request_headers["Origin"]
|
||||
if self.is_valid_origin(request_headers):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
if requested_headers is not None:
|
||||
headers["Access-Control-Allow-Headers"] = requested_headers
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(
|
||||
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
send = functools.partial(self._send, send=send, request_headers=request_headers)
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def _send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(scope=message)
|
||||
headers.update(self.simple_headers)
|
||||
has_cookie = "cookie" in request_headers
|
||||
origin = request_headers["Origin"]
|
||||
if has_cookie or self.is_valid_origin(request_headers):
|
||||
self.allow_explicit_origin(headers, origin)
|
||||
await send(message)
|
||||
|
||||
def is_valid_origin(self, request_headers: Headers) -> bool:
|
||||
origin = request_headers["Origin"]
|
||||
host = request_headers["Host"]
|
||||
host_name = get_hostname(host)
|
||||
origin_name = get_hostname(origin)
|
||||
return (
|
||||
host_name not in self.localhost_aliases
|
||||
or origin_name in self.localhost_aliases
|
||||
)
|
||||
|
||||
if host_name in localhost_aliases and origin_name not in localhost_aliases:
|
||||
allow_origin_header = None
|
||||
else:
|
||||
allow_origin_header = origin
|
||||
|
||||
if is_preflight:
|
||||
response = fastapi.Response()
|
||||
else:
|
||||
response = await call_next(request)
|
||||
|
||||
if allow_origin_header:
|
||||
response.headers["Access-Control-Allow-Origin"] = allow_origin_header
|
||||
response.headers[
|
||||
"Access-Control-Allow-Methods"
|
||||
] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
response.headers[
|
||||
"Access-Control-Allow-Headers"
|
||||
] = "Origin, Content-Type, Accept"
|
||||
return response
|
||||
@staticmethod
|
||||
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
||||
|
||||
|
||||
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user