mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Allow config to include non-pickle-able values (#7415)
* fixes * lint * add changeset * route utils --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
c2dfc592a4
commit
4ab399f40a
5
.changeset/mean-bushes-hide.md
Normal file
5
.changeset/mean-bushes-hide.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Allow config to include non-pickle-able values
|
@ -289,9 +289,11 @@ def move_files_to_cache(
|
||||
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)
|
||||
|
||||
|
||||
def add_root_url(data, root_url) -> dict:
|
||||
def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict:
|
||||
def _add_root_url(file_dict: dict):
|
||||
if not client_utils.is_http_url_like(file_dict["url"]):
|
||||
if previous_root_url and file_dict["url"].startswith(previous_root_url):
|
||||
file_dict["url"] = file_dict["url"][len(previous_root_url) :]
|
||||
file_dict["url"] = f'{root_url}{file_dict["url"]}'
|
||||
return file_dict
|
||||
|
||||
|
@ -16,7 +16,7 @@ from multipart.multipart import parse_options_header
|
||||
from starlette.datastructures import FormData, Headers, UploadFile
|
||||
from starlette.formparsers import MultiPartException, MultipartPart
|
||||
|
||||
from gradio import utils
|
||||
from gradio import processing_utils, utils
|
||||
from gradio.data_classes import PredictBody
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import EventData
|
||||
@ -561,3 +561,16 @@ class GradioMultiPartParser:
|
||||
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None:
|
||||
for file, dest in zip(files, destinations):
|
||||
shutil.move(file, dest)
|
||||
|
||||
|
||||
def update_root_in_config(config: dict, root: str) -> dict:
|
||||
"""
|
||||
Updates the root "key" in the config dictionary to the new root url. If the
|
||||
root url has changed, all of the urls in the config that correspond to component
|
||||
file urls are updated to use the new root url.
|
||||
"""
|
||||
previous_root = config.get("root", None)
|
||||
if previous_root is None or previous_root != root:
|
||||
config["root"] = root
|
||||
config = processing_utils.add_root_url(config, root, previous_root)
|
||||
return config
|
||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
@ -311,19 +310,18 @@ class App(FastAPI):
|
||||
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
blocks = app.get_blocks()
|
||||
root_path = route_utils.get_root_url(
|
||||
root = route_utils.get_root_url(
|
||||
request=request, route_path="/", root_path=app.root_path
|
||||
)
|
||||
if app.auth is None or user is not None:
|
||||
config = copy.deepcopy(app.get_blocks().config)
|
||||
config["root"] = root_path
|
||||
config = add_root_url(config, root_path)
|
||||
config = app.get_blocks().config
|
||||
config = route_utils.update_root_in_config(config, root)
|
||||
else:
|
||||
config = {
|
||||
"auth_required": True,
|
||||
"auth_message": blocks.auth_message,
|
||||
"space_id": app.get_blocks().space_id,
|
||||
"root": root_path,
|
||||
"root": root,
|
||||
}
|
||||
|
||||
try:
|
||||
@ -354,13 +352,12 @@ class App(FastAPI):
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config(request: fastapi.Request):
|
||||
config = copy.deepcopy(app.get_blocks().config)
|
||||
root_path = route_utils.get_root_url(
|
||||
config = app.get_blocks().config
|
||||
root = route_utils.get_root_url(
|
||||
request=request, route_path="/config", root_path=app.root_path
|
||||
)
|
||||
config["root"] = root_path
|
||||
config = add_root_url(config, root_path)
|
||||
return config
|
||||
config = route_utils.update_root_in_config(config, root)
|
||||
return ORJSONResponse(content=config)
|
||||
|
||||
@app.get("/static/{path:path}")
|
||||
def static_resource(path: str):
|
||||
@ -577,7 +574,7 @@ class App(FastAPI):
|
||||
root_path = route_utils.get_root_url(
|
||||
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
|
||||
)
|
||||
output = add_root_url(output, root_path)
|
||||
output = add_root_url(output, root_path, None)
|
||||
return output
|
||||
|
||||
@app.get("/queue/data", dependencies=[Depends(login_check)])
|
||||
@ -634,7 +631,7 @@ class App(FastAPI):
|
||||
"success": False,
|
||||
}
|
||||
if message:
|
||||
add_root_url(message, root_path)
|
||||
add_root_url(message, root_path, None)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
if message["msg"] == ServerMessage.process_completed:
|
||||
blocks._queue.pending_event_ids_session[
|
||||
|
@ -332,3 +332,42 @@ class TestVideoProcessing:
|
||||
)
|
||||
# If the conversion succeeded it'd be .mp4
|
||||
assert Path(playable_vid).suffix == ".avi"
|
||||
|
||||
|
||||
def test_add_root_url():
|
||||
data = {
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": "/file=path",
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
},
|
||||
}
|
||||
root_url = "http://localhost:7860"
|
||||
expected = {
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": f"{root_url}/file=path",
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
},
|
||||
}
|
||||
assert processing_utils.add_root_url(data, root_url, None) == expected
|
||||
new_root_url = "https://1234.gradio.live"
|
||||
new_expected = {
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": f"{root_url}/file=path",
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
},
|
||||
}
|
||||
assert (
|
||||
processing_utils.add_root_url(expected, root_url, new_root_url) == new_expected
|
||||
)
|
||||
|
@ -439,6 +439,18 @@ class TestRoutes:
|
||||
r = app.build_proxy_request("https://google.com")
|
||||
assert "authorization" not in dict(r.headers)
|
||||
|
||||
def test_can_get_config_that_includes_non_pickle_able_objects(self):
|
||||
my_dict = {"a": 1, "b": 2, "c": 3}
|
||||
with Blocks() as demo:
|
||||
gr.JSON(my_dict.keys())
|
||||
|
||||
app, _, _ = demo.launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.is_success
|
||||
response = client.get("/config/")
|
||||
assert response.is_success
|
||||
|
||||
|
||||
class TestApp:
|
||||
def test_create_app(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user