Refactor CORS Middleware to be much faster (#7801)

* changes

* add changeset

* log

* add changeset

* changes

* changes

* lint

* middlware

* lint

* lint

* s
implified docstring

* fix

* revert test change

* docstring

* remove print

* update

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-26 13:17:38 -07:00 committed by GitHub
parent 8b099a07a5
commit 05db0c4a59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 31 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Refactor CORS Middleware to be much faster

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import functools
import hashlib
import hmac
import json
@ -33,9 +34,10 @@ import httpx
import multipart
from gradio_client.documentation import document
from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from gradio import processing_utils, utils
from gradio.data_classes import PredictBody
@ -648,41 +650,86 @@ def get_hostname(url: str) -> str:
return ""
class CustomCORSMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: fastapi.Request, call_next):
host: str = request.headers.get("host", "")
origin: str = request.headers.get("origin", "")
host_name = get_hostname(host)
origin_name = get_hostname(origin)
class CustomCORSMiddleware:
# This is a modified version of the Starlette CORSMiddleware that restricts the allowed origins when the host is localhost.
# Adapted from: https://github.com/encode/starlette/blob/89fae174a1ea10f59ae248fe030d9b7e83d0b8a0/starlette/middleware/cors.py
def __init__(
self,
app: ASGIApp,
) -> None:
self.app = app
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
self.preflight_headers = {
"Access-Control-Allow-Methods": ", ".join(self.all_methods),
"Access-Control-Max-Age": str(600),
}
self.simple_headers = {"Access-Control-Allow-Credentials": "true"}
# Any of these hosts suggests that the Gradio app is running locally.
# Note: "null" is a special case that happens if a Gradio app is running
# as an embedded web component in a local static webpage.
localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
is_preflight = (
request.method == "OPTIONS"
and "access-control-request-method" in request.headers
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
origin = headers.get("origin")
if origin is None:
await self.app(scope, receive, send)
return
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await self.simple_response(scope, receive, send, request_headers=headers)
def preflight_response(self, request_headers: Headers) -> Response:
headers = dict(self.preflight_headers)
origin = request_headers["Origin"]
if self.is_valid_origin(request_headers):
headers["Access-Control-Allow-Origin"] = origin
requested_headers = request_headers.get("access-control-request-headers")
if requested_headers is not None:
headers["Access-Control-Allow-Headers"] = requested_headers
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
send = functools.partial(self._send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
async def _send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
if message["type"] != "http.response.start":
await send(message)
return
message.setdefault("headers", [])
headers = MutableHeaders(scope=message)
headers.update(self.simple_headers)
has_cookie = "cookie" in request_headers
origin = request_headers["Origin"]
if has_cookie or self.is_valid_origin(request_headers):
self.allow_explicit_origin(headers, origin)
await send(message)
def is_valid_origin(self, request_headers: Headers) -> bool:
origin = request_headers["Origin"]
host = request_headers["Host"]
host_name = get_hostname(host)
origin_name = get_hostname(origin)
return (
host_name not in self.localhost_aliases
or origin_name in self.localhost_aliases
)
if host_name in localhost_aliases and origin_name not in localhost_aliases:
allow_origin_header = None
else:
allow_origin_header = origin
if is_preflight:
response = fastapi.Response()
else:
response = await call_next(request)
if allow_origin_header:
response.headers["Access-Control-Allow-Origin"] = allow_origin_header
response.headers[
"Access-Control-Allow-Methods"
] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers[
"Access-Control-Allow-Headers"
] = "Origin, Content-Type, Accept"
return response
@staticmethod
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: