mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Fixes components when loading private spaces (#3068)
* file routes * adding access token * add reverse proxy * adding access token * context * rewrite * frontend * formatting * changelog * formatting * fix tests * fixed image issue * fix frontend * os removal * Update test_external.py * fixes to normalise * version * fixes so that functions work * lint * formatting
This commit is contained in:
parent
f062c7e1fd
commit
f37d17089d
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,6 +22,7 @@ gradio/templates/frontend
|
||||
gradio/launches.json
|
||||
flagged/
|
||||
gradio_cached_examples/
|
||||
tmp.zip
|
||||
|
||||
# Tests
|
||||
.coverage
|
||||
|
@ -41,6 +41,7 @@ By [@maxaudron](https://github.com/maxaudron) in [PR 3075](https://github.com/gr
|
||||
- Fixes URL resolution on Windows by [@abidlabs](https://github.com/abidlabs) in [PR 3108](https://github.com/gradio-app/gradio/pull/3108)
|
||||
- Example caching now works with components without a label attribute (e.g. `Column`) by [@abidlabs](https://github.com/abidlabs) in [PR 3123](https://github.com/gradio-app/gradio/pull/3123)
|
||||
- Ensure the Video component correctly resets the UI state whe a new video source is loaded and reduce choppiness of UI by [@pngwn](https://github.com/abidlabs) in [PR 3117](https://github.com/gradio-app/gradio/pull/3117)
|
||||
- Fixes loading private Spaces by [@abidlabs](https://github.com/abidlabs) in [PR 3068](https://github.com/gradio-app/gradio/pull/3068)
|
||||
- Added a warning when attempting to launch an `Interface` via the `%%blocks` jupyter notebook magic command by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3126](https://github.com/gradio-app/gradio/pull/3126)
|
||||
|
||||
## Documentation Changes:
|
||||
@ -420,8 +421,8 @@ No changes to highlight.
|
||||
No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
|
||||
*No changes to highlight.
|
||||
*
|
||||
## Documentation Changes:
|
||||
* Improves documentation of several queuing-related parameters by [@abidlabs](https://github.com/abidlabs) in [PR 2825](https://github.com/gradio-app/gradio/pull/2825)
|
||||
|
||||
|
@ -528,7 +528,10 @@ class Blocks(BlockContext):
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: dict, fns: List[Callable], root_url: str | None = None
|
||||
cls,
|
||||
config: dict,
|
||||
fns: List[Callable],
|
||||
root_url: str | None = None,
|
||||
) -> Blocks:
|
||||
"""
|
||||
Factory method that creates a Blocks from a config and list of functions.
|
||||
|
@ -1678,7 +1678,13 @@ class Image(
|
||||
)
|
||||
|
||||
def as_example(self, input_data: str | None) -> str:
|
||||
return "" if input_data is None else str(utils.abspath(input_data))
|
||||
if input_data is None:
|
||||
return ""
|
||||
elif (
|
||||
self.root_url
|
||||
): # If an externally hosted image, don't convert to absolute path
|
||||
return input_data
|
||||
return str(utils.abspath(input_data))
|
||||
|
||||
|
||||
@document("change", "clear", "play", "pause", "stop", "style")
|
||||
@ -4974,6 +4980,8 @@ class Dataset(Clickable, Component):
|
||||
[isinstance(c, IOComponent) for c in self.components]
|
||||
), "All components in a `Dataset` must be subclasses of `IOComponent`"
|
||||
self.components = [c for c in self.components if isinstance(c, IOComponent)]
|
||||
for component in self.components:
|
||||
component.root_url = self.root_url
|
||||
|
||||
self.samples = [[]] if samples is None else samples
|
||||
for example in self.samples:
|
||||
|
@ -12,4 +12,7 @@ class Context:
|
||||
root_block: Blocks | None = None # The current root block that holds all blocks.
|
||||
block: BlockContext | None = None # The current block that children are added to.
|
||||
id: int = 0 # Running id to uniquely refer to any block that gets defined
|
||||
ip_address: str | None = None
|
||||
ip_address: str | None = None # The IP address of the user.
|
||||
access_token: str | None = (
|
||||
None # The HF token that is provided when loading private models or Spaces
|
||||
)
|
||||
|
@ -14,6 +14,7 @@ import requests
|
||||
|
||||
import gradio
|
||||
from gradio import components, utils
|
||||
from gradio.context import Context
|
||||
from gradio.exceptions import TooManyRequestsError
|
||||
from gradio.external_utils import (
|
||||
cols_to_rows,
|
||||
@ -59,6 +60,13 @@ def load_blocks_from_repo(
|
||||
factory_methods.keys()
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
if Context.access_token is not None and Context.access_token != api_key:
|
||||
warnings.warn(
|
||||
"""You are loading a model/Space with a different access token than the one you used to load a previous model/Space. This is not recommended, as it may cause unexpected behavior."""
|
||||
)
|
||||
Context.access_token = api_key
|
||||
|
||||
blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
|
||||
return blocks
|
||||
|
||||
|
@ -410,24 +410,39 @@ class TempFileManager:
|
||||
return full_temp_file_path
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(file_path, dir=None):
|
||||
def download_tmp_copy_of_file(
|
||||
url_path: str, access_token: str | None = None, dir: str | None = None
|
||||
) -> tempfile._TemporaryFileWrapper:
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
file_name = Path(file_path).name
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
prefix = file_name[0 : file_name.index(".")]
|
||||
extension = file_name[file_name.index(".") + 1 :]
|
||||
prefix = utils.strip_invalid_filename_characters(prefix)
|
||||
if extension is None:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
|
||||
else:
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix="." + extension,
|
||||
dir=dir,
|
||||
)
|
||||
headers = {"Authorization": "Bearer " + access_token} if access_token else {}
|
||||
prefix = Path(url_path).stem
|
||||
suffix = Path(url_path).suffix
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
dir=dir,
|
||||
)
|
||||
with requests.get(url_path, headers=headers, stream=True) as r:
|
||||
with open(file_obj.name, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
return file_obj
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(
|
||||
file_path: str, dir: str | None = None
|
||||
) -> tempfile._TemporaryFileWrapper:
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
prefix = Path(file_path).stem
|
||||
suffix = Path(file_path).suffix
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
dir=dir,
|
||||
)
|
||||
shutil.copy2(file_path, file_obj.name)
|
||||
return file_obj
|
||||
|
||||
|
@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
import markupsafe
|
||||
import orjson
|
||||
import pkg_resources
|
||||
@ -31,12 +32,14 @@ from fastapi.responses import (
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.responses import RedirectResponse, StreamingResponse
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
import gradio
|
||||
import gradio.ranged_response as ranged_response
|
||||
from gradio import utils
|
||||
from gradio.context import Context
|
||||
from gradio.data_classes import PredictBody, ResetBody
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.exceptions import Error
|
||||
@ -85,6 +88,7 @@ def toorjson(value):
|
||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
templates.env.filters["toorjson"] = toorjson
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
###########
|
||||
# Auth
|
||||
@ -261,6 +265,23 @@ class App(FastAPI):
|
||||
else:
|
||||
return FileResponse(blocks.favicon_path)
|
||||
|
||||
@app.head("/proxy={url_path:path}", dependencies=[Depends(login_check)])
|
||||
@app.get("/proxy={url_path:path}", dependencies=[Depends(login_check)])
|
||||
async def reverse_proxy(url_path: str):
|
||||
# Adapted from: https://github.com/tiangolo/fastapi/issues/1788
|
||||
url = httpx.URL(url_path)
|
||||
headers = {}
|
||||
if Context.access_token is not None:
|
||||
headers["Authorization"] = f"Bearer {Context.access_token}"
|
||||
rp_req = client.build_request("GET", url, headers=headers)
|
||||
rp_resp = await client.send(rp_req, stream=True)
|
||||
return StreamingResponse(
|
||||
rp_resp.aiter_raw(),
|
||||
status_code=rp_resp.status_code,
|
||||
headers=rp_resp.headers, # type: ignore
|
||||
background=BackgroundTask(rp_resp.aclose),
|
||||
)
|
||||
|
||||
@app.head("/file={path_or_url:path}", dependencies=[Depends(login_check)])
|
||||
@app.get("/file={path_or_url:path}", dependencies=[Depends(login_check)])
|
||||
async def file(path_or_url: str, request: fastapi.Request):
|
||||
|
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
from gradio.context import Context
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
@ -164,9 +164,11 @@ class FileSerializable(Serializable):
|
||||
elif isinstance(x, dict):
|
||||
if x.get("is_file", False):
|
||||
if root_url is not None:
|
||||
x["name"] = urllib.parse.urljoin(root_url, "file=" + x["name"])
|
||||
if utils.validate_url(x["name"]):
|
||||
file_name = x["name"]
|
||||
file_name = processing_utils.download_tmp_copy_of_file(
|
||||
root_url + "file=" + x["name"],
|
||||
access_token=Context.access_token,
|
||||
dir=save_dir,
|
||||
).name
|
||||
else:
|
||||
file_name = processing_utils.create_tmp_copy_of_file(
|
||||
x["name"], dir=save_dir
|
||||
|
@ -448,12 +448,12 @@ def async_iteration(iterator):
|
||||
class AsyncRequest:
|
||||
"""
|
||||
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
|
||||
Compared to making calls by using httpx directly, AsyncRequest offers more flexibility and control over:
|
||||
Compared to making calls by using httpx directly, AsyncRequest offers several advantages:
|
||||
(1) Includes response validation functionality both using validation models and functions.
|
||||
(2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities.
|
||||
(3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one
|
||||
individually in the case of multiple asynchronous request calls and some of them failing.
|
||||
(4) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
|
||||
(2) Exceptions are handled silently during the request call, which provides the ability to inspect each one
|
||||
request call individually in the case where there are multiple asynchronous request calls and some of them fail.
|
||||
(3) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
|
||||
|
||||
AsyncRequest also offers some util functions such as has_exception, is_valid and status to inspect get detailed
|
||||
information about executed request call.
|
||||
|
||||
|
@ -1 +1 @@
|
||||
3.17.1
|
||||
3.17.1b2
|
||||
|
@ -12,6 +12,7 @@ from fastapi.testclient import TestClient
|
||||
import gradio
|
||||
import gradio as gr
|
||||
from gradio import media_data, utils
|
||||
from gradio.context import Context
|
||||
from gradio.exceptions import InvalidApiName
|
||||
from gradio.external import (
|
||||
TooManyRequestsError,
|
||||
@ -298,6 +299,40 @@ class TestLoadInterface:
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_private_space_audio(self):
|
||||
api_key = "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" # Intentionally revealing this key for testing purposes
|
||||
io = gr.Interface.load(
|
||||
"spaces/gradio-tests/not-actually-private-space-audio", api_key=api_key
|
||||
)
|
||||
try:
|
||||
output = io(media_data.BASE64_AUDIO["name"])
|
||||
assert output.endswith(".wav")
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_multiple_spaces_one_private(self):
|
||||
api_key = "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" # Intentionally revealing this key for testing purposes
|
||||
with gr.Blocks():
|
||||
gr.Interface.load(
|
||||
"spaces/gradio-tests/not-actually-private-space", api_key=api_key
|
||||
)
|
||||
gr.Interface.load(
|
||||
"spaces/gradio/test-loading-examples",
|
||||
)
|
||||
assert Context.access_token == api_key
|
||||
|
||||
def test_loading_files_via_proxy_works(self):
|
||||
api_key = "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" # Intentionally revealing this key for testing purposes
|
||||
io = gr.Interface.load(
|
||||
"spaces/gradio-tests/test-loading-examples-private", api_key=api_key
|
||||
)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
test_client = TestClient(app)
|
||||
r = test_client.get(
|
||||
"/proxy=https://gradio-tests-test-loading-examples-private.hf.space/file=/tmp/tmprahzj703/Bunnyual53t2x.obj"
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
class TestLoadInterfaceWithExamples:
|
||||
def test_interface_load_examples(self, tmp_path):
|
||||
|
@ -319,3 +319,12 @@ class TestVideoProcessing:
|
||||
)
|
||||
# If the conversion succeeded it'd be .mp4
|
||||
assert pathlib.Path(playable_vid).suffix == ".avi"
|
||||
|
||||
|
||||
def test_download_private_file():
|
||||
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
|
||||
access_token = "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" # Intentionally revealing this key for testing purposes
|
||||
file = processing_utils.download_tmp_copy_of_file(
|
||||
url_path=url_path, access_token=access_token
|
||||
)
|
||||
assert file.name.endswith(".jpg")
|
||||
|
@ -34,7 +34,7 @@
|
||||
export let loading_status: LoadingStatus;
|
||||
|
||||
let _value: null | FileData;
|
||||
$: _value = normalise_file(value, root_url ?? root);
|
||||
$: _value = normalise_file(value, root, root_url);
|
||||
|
||||
let dragging: boolean;
|
||||
</script>
|
||||
|
@ -16,7 +16,9 @@
|
||||
|
||||
const dispatch = createEventDispatcher<{ click: number }>();
|
||||
|
||||
let samples_dir: string = (root_url ?? root) + "file=";
|
||||
let samples_dir: string = root_url
|
||||
? "proxy=" + root_url + "file="
|
||||
: root + "file=";
|
||||
let page = 0;
|
||||
$: gallery = components.length < 2;
|
||||
let paginate = samples.length > samples_per_page;
|
||||
|
@ -43,6 +43,10 @@
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
})
|
||||
.catch((e) => {
|
||||
loaded_value = value;
|
||||
loaded = true;
|
||||
});
|
||||
}
|
||||
</script>
|
||||
@ -55,26 +59,30 @@
|
||||
on:mouseenter={() => (hovered = true)}
|
||||
on:mouseleave={() => (hovered = false)}
|
||||
>
|
||||
<table class="">
|
||||
{#each loaded_value.slice(0, 3) as row, i}
|
||||
<tr>
|
||||
{#each row.slice(0, 3) as cell, j}
|
||||
<td>{cell}</td>
|
||||
{/each}
|
||||
{#if row.length > 3}
|
||||
<td>…</td>
|
||||
{/if}
|
||||
</tr>
|
||||
{/each}
|
||||
{#if value.length > 3}
|
||||
<div
|
||||
class="overlay"
|
||||
class:odd={index % 2 != 0}
|
||||
class:even={index % 2 == 0}
|
||||
class:button={type === "gallery"}
|
||||
/>
|
||||
{/if}
|
||||
</table>
|
||||
{#if typeof loaded_value === "string"}
|
||||
{loaded_value}
|
||||
{:else}
|
||||
<table class="">
|
||||
{#each loaded_value.slice(0, 3) as row, i}
|
||||
<tr>
|
||||
{#each row.slice(0, 3) as cell, j}
|
||||
<td>{cell}</td>
|
||||
{/each}
|
||||
{#if row.length > 3}
|
||||
<td>…</td>
|
||||
{/if}
|
||||
</tr>
|
||||
{/each}
|
||||
{#if value.length > 3}
|
||||
<div
|
||||
class="overlay"
|
||||
class:odd={index % 2 != 0}
|
||||
class:even={index % 2 == 0}
|
||||
class:button={type === "gallery"}
|
||||
/>
|
||||
{/if}
|
||||
</table>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
import { createEventDispatcher } from "svelte";
|
||||
import { File, FileUpload } from "@gradio/file";
|
||||
import type { FileData } from "@gradio/upload";
|
||||
import { normalise_files } from "@gradio/upload";
|
||||
import { normalise_file } from "@gradio/upload";
|
||||
import { Block } from "@gradio/atoms";
|
||||
import UploadText from "../UploadText.svelte";
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
export let loading_status: LoadingStatus;
|
||||
|
||||
let _value: null | FileData | Array<FileData>;
|
||||
$: _value = normalise_files(value, root_url ?? root);
|
||||
$: _value = normalise_file(value, root, root_url);
|
||||
|
||||
let dragging = false;
|
||||
|
||||
|
@ -23,7 +23,7 @@
|
||||
export let show_label: boolean;
|
||||
|
||||
let _value: null | FileData;
|
||||
$: _value = normalise_file(value, root_url ?? root);
|
||||
$: _value = normalise_file(value, root, root_url);
|
||||
|
||||
let dragging = false;
|
||||
</script>
|
||||
|
@ -29,7 +29,7 @@
|
||||
export let mode: "static" | "dynamic";
|
||||
|
||||
let _value: null | FileData;
|
||||
$: _value = normalise_file(value, root_url ?? root);
|
||||
$: _value = normalise_file(value, root, root_url);
|
||||
|
||||
let dragging = false;
|
||||
|
||||
|
@ -21,8 +21,8 @@
|
||||
? null
|
||||
: value.map((img) =>
|
||||
Array.isArray(img)
|
||||
? [normalise_file(img[0], root_url ?? root), img[1]]
|
||||
: [normalise_file(img, root_url ?? root), null]
|
||||
? [normalise_file(img[0], root, root_url), img[1]]
|
||||
: [normalise_file(img, root, root_url), null]
|
||||
);
|
||||
|
||||
let prevValue: string[] | FileData[] | null = null;
|
||||
|
@ -2,27 +2,19 @@ import type { FileData } from "./types";
|
||||
|
||||
export function normalise_file(
|
||||
file: string | FileData | null,
|
||||
root: string
|
||||
): FileData | null {
|
||||
if (file == null) return null;
|
||||
if (typeof file === "string") {
|
||||
return {
|
||||
name: "file_data",
|
||||
data: file
|
||||
};
|
||||
} else if (Array.isArray(file)) {
|
||||
for (const x of file) {
|
||||
normalise_file(x, root);
|
||||
}
|
||||
} else if (file.is_file) {
|
||||
file.data = root + "file=" + file.name;
|
||||
}
|
||||
return file;
|
||||
}
|
||||
root: string,
|
||||
root_url: string | null
|
||||
): FileData | null;
|
||||
export function normalise_file(
|
||||
file: Array<FileData> | FileData | null,
|
||||
root: string,
|
||||
root_url: string | null
|
||||
): Array<FileData> | FileData | null;
|
||||
|
||||
export function normalise_files(
|
||||
export function normalise_file(
|
||||
file: string | FileData | Array<FileData> | null,
|
||||
root: string
|
||||
root: string,
|
||||
root_url: string | null
|
||||
): FileData | Array<FileData> | null {
|
||||
if (file == null) return null;
|
||||
if (typeof file === "string") {
|
||||
@ -32,10 +24,14 @@ export function normalise_files(
|
||||
};
|
||||
} else if (Array.isArray(file)) {
|
||||
for (const x of file) {
|
||||
normalise_file(x, root);
|
||||
normalise_file(x, root, root_url);
|
||||
}
|
||||
} else if (file.is_file) {
|
||||
file.data = root + "file=" + file.name;
|
||||
if (root_url == null) {
|
||||
file.data = "file=" + file.name;
|
||||
} else {
|
||||
file.data = "proxy=" + root_url + "file=" + file.name;
|
||||
}
|
||||
}
|
||||
return file;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user