Allow config to include non-pickle-able values (#7415)

* fixes

* lint

* add changeset

* route utils

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-02-14 09:43:41 -08:00 committed by GitHub
parent c2dfc592a4
commit 4ab399f40a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 83 additions and 15 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Allow config to include non-pickle-able values

View File

@ -289,9 +289,11 @@ def move_files_to_cache(
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)
def add_root_url(data, root_url) -> dict:
def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict:
def _add_root_url(file_dict: dict):
if not client_utils.is_http_url_like(file_dict["url"]):
if previous_root_url and file_dict["url"].startswith(previous_root_url):
file_dict["url"] = file_dict["url"][len(previous_root_url) :]
file_dict["url"] = f'{root_url}{file_dict["url"]}'
return file_dict

View File

@ -16,7 +16,7 @@ from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart
from gradio import utils
from gradio import processing_utils, utils
from gradio.data_classes import PredictBody
from gradio.exceptions import Error
from gradio.helpers import EventData
@ -561,3 +561,16 @@ class GradioMultiPartParser:
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
for file, dest in zip(files, destinations):
shutil.move(file, dest)
def update_root_in_config(config: dict, root: str) -> dict:
"""
Updates the root "key" in the config dictionary to the new root url. If the
root url has changed, all of the urls in the config that correspond to component
file urls are updated to use the new root url.
"""
previous_root = config.get("root", None)
if previous_root is None or previous_root != root:
config["root"] = root
config = processing_utils.add_root_url(config, root, previous_root)
return config

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio
import contextlib
import copy
import sys
if sys.version_info >= (3, 9):
@ -311,19 +310,18 @@ class App(FastAPI):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
blocks = app.get_blocks()
root_path = route_utils.get_root_url(
root = route_utils.get_root_url(
request=request, route_path="/", root_path=app.root_path
)
if app.auth is None or user is not None:
config = copy.deepcopy(app.get_blocks().config)
config["root"] = root_path
config = add_root_url(config, root_path)
config = app.get_blocks().config
config = route_utils.update_root_in_config(config, root)
else:
config = {
"auth_required": True,
"auth_message": blocks.auth_message,
"space_id": app.get_blocks().space_id,
"root": root_path,
"root": root,
}
try:
@ -354,13 +352,12 @@ class App(FastAPI):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
config = copy.deepcopy(app.get_blocks().config)
root_path = route_utils.get_root_url(
config = app.get_blocks().config
root = route_utils.get_root_url(
request=request, route_path="/config", root_path=app.root_path
)
config["root"] = root_path
config = add_root_url(config, root_path)
return config
config = route_utils.update_root_in_config(config, root)
return ORJSONResponse(content=config)
@app.get("/static/{path:path}")
def static_resource(path: str):
@ -577,7 +574,7 @@ class App(FastAPI):
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
output = add_root_url(output, root_path)
output = add_root_url(output, root_path, None)
return output
@app.get("/queue/data", dependencies=[Depends(login_check)])
@ -634,7 +631,7 @@ class App(FastAPI):
"success": False,
}
if message:
add_root_url(message, root_path)
add_root_url(message, root_path, None)
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == ServerMessage.process_completed:
blocks._queue.pending_event_ids_session[

View File

@ -332,3 +332,42 @@ class TestVideoProcessing:
)
# If the conversion succeeded it'd be .mp4
assert Path(playable_vid).suffix == ".avi"
def test_add_root_url():
data = {
"file": {
"path": "path",
"url": "/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
root_url = "http://localhost:7860"
expected = {
"file": {
"path": "path",
"url": f"{root_url}/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
assert processing_utils.add_root_url(data, root_url, None) == expected
new_root_url = "https://1234.gradio.live"
new_expected = {
"file": {
"path": "path",
"url": f"{root_url}/file=path",
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
},
}
assert (
processing_utils.add_root_url(expected, root_url, new_root_url) == new_expected
)

View File

@ -439,6 +439,18 @@ class TestRoutes:
r = app.build_proxy_request("https://google.com")
assert "authorization" not in dict(r.headers)
def test_can_get_config_that_includes_non_pickle_able_objects(self):
my_dict = {"a": 1, "b": 2, "c": 3}
with Blocks() as demo:
gr.JSON(my_dict.keys())
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.get("/")
assert response.is_success
response = client.get("/config/")
assert response.is_success
class TestApp:
def test_create_app(self):