mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Fix the download button of the gr.Gallery()
component to work (#6487)
* Fix the download button of the `gr.Gallery()` component to work * Refactoring js/gallery/shared/Gallery.svelte * Fix `gr.Gallery()` to set `orig_name` for URLs * Fix Gallery.postprocess() * Fix `download()` to fallback to `window.open()` when CORS is not allowed * Fix `gr.Gallery` to leave as None so it will be replaced with a local cache path and restore the `<a>` tag-based download feature on the frontend * Align a variable name to its type name * Fix Gallery's tests * Fix the frontend test for gallery * Revert "Fix `gr.Gallery` to leave as None so it will be replaced with a local cache path and restore the `<a>` tag-based download feature on the frontend" This reverts commit d754980cc27ded760bfc26df4310f913c2c6944a. * Revert "Fix Gallery's tests" This reverts commit 4e2aa3fff1ef7b586839fa6c485d1a3b8738fd03. * Revert "Fix the frontend test for gallery" This reverts commit 007caa23e7b9dbab36376307137a30f277fca297. * Fix for linter * Add a test about the download button * Fix type defs on Gallery.postprocess * Improve TestGallery * add changeset * Update gradio/components/gallery.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/components/gallery.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Revert "Update gradio/components/gallery.py" This reverts commit 4d6e12730511fe9840a5372787ad81fd83cbb44c. * Revert "Update gradio/components/gallery.py" This reverts commit f2bfad0744d20e121c8eef40a335979a7c703517. * Use `tuple` instead of `typing.Tuple` * Revert "Use `tuple` instead of `typing.Tuple`" This reverts commit 69ab93cad4f39fe38f0e0f88126be572bf12cecf. --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
5177132d71
commit
9a5811df92
6
.changeset/good-areas-trade.md
Normal file
6
.changeset/good-areas-trade.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"@gradio/gallery": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Fix the download button of the `gr.Gallery()` component to work
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Literal, Optional
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
@ -18,6 +19,10 @@ from gradio.events import Events
|
||||
set_documentation_group("component")
|
||||
|
||||
|
||||
GalleryImageType = Union[np.ndarray, _Image.Image, Path, str]
|
||||
CaptionedGalleryImageType = Tuple[GalleryImageType, str]
|
||||
|
||||
|
||||
class GalleryImage(GradioModel):
|
||||
image: FileData
|
||||
caption: Optional[str] = None
|
||||
@ -125,9 +130,7 @@ class Gallery(Component):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
value: list[np.ndarray | _Image.Image | str]
|
||||
| list[tuple[np.ndarray | _Image.Image | str, str]]
|
||||
| None,
|
||||
value: list[GalleryImageType | CaptionedGalleryImageType] | None,
|
||||
) -> GalleryData:
|
||||
"""
|
||||
Parameters:
|
||||
@ -141,6 +144,7 @@ class Gallery(Component):
|
||||
for img in value:
|
||||
url = None
|
||||
caption = None
|
||||
orig_name = None
|
||||
if isinstance(img, (tuple, list)):
|
||||
img, caption = img
|
||||
if isinstance(img, np.ndarray):
|
||||
@ -155,13 +159,20 @@ class Gallery(Component):
|
||||
file_path = str(utils.abspath(file))
|
||||
elif isinstance(img, str):
|
||||
file_path = img
|
||||
url = img if is_http_url_like(img) else None
|
||||
if is_http_url_like(img):
|
||||
url = img
|
||||
orig_name = Path(urlparse(img).path).name
|
||||
else:
|
||||
url = None
|
||||
orig_name = Path(img).name
|
||||
elif isinstance(img, Path):
|
||||
file_path = str(img)
|
||||
orig_name = img.name
|
||||
else:
|
||||
raise ValueError(f"Cannot process type as image: {type(img)}")
|
||||
entry = GalleryImage(
|
||||
image=FileData(path=file_path, url=url), caption=caption
|
||||
image=FileData(path=file_path, url=url, orig_name=orig_name),
|
||||
caption=caption,
|
||||
)
|
||||
output.append(entry)
|
||||
return GalleryData(root=output)
|
||||
|
@ -15,10 +15,17 @@ test("Gallery preview mode displays all images correctly.", async ({
|
||||
).toEqual("https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg");
|
||||
});
|
||||
|
||||
test("Gallery select event returns the right value", async ({ page }) => {
|
||||
test("Gallery select event returns the right value and the download button works correctly", async ({
|
||||
page
|
||||
}) => {
|
||||
await page.getByRole("button", { name: "Run" }).click();
|
||||
await page.getByLabel("Thumbnail 2 of 3").click();
|
||||
await expect(page.getByLabel("Select Data")).toHaveValue(
|
||||
"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png"
|
||||
);
|
||||
|
||||
const downloadPromise = page.waitForEvent("download");
|
||||
await page.getByLabel("Download").click();
|
||||
const download = await downloadPromise;
|
||||
expect(download.suggestedFilename()).toBe("lite-logo.png");
|
||||
});
|
||||
|
@ -12,11 +12,14 @@
|
||||
import { IconButton } from "@gradio/atoms";
|
||||
import type { I18nFormatter } from "@gradio/utils";
|
||||
|
||||
type GalleryImage = { image: FileData; caption: string | null };
|
||||
type GalleryData = GalleryImage[];
|
||||
|
||||
export let show_label = true;
|
||||
export let label: string;
|
||||
export let root = "";
|
||||
export let proxy_url: null | string = null;
|
||||
export let value: { image: FileData; caption: string | null }[] | null = null;
|
||||
export let value: GalleryData | null = null;
|
||||
export let columns: number | number[] | undefined = [2];
|
||||
export let rows: number | number[] | undefined = undefined;
|
||||
export let height: number | "auto" = "auto";
|
||||
@ -37,25 +40,24 @@
|
||||
// tracks whether the value of the gallery was reset
|
||||
let was_reset = true;
|
||||
|
||||
$: was_reset = value == null || value.length == 0 ? true : was_reset;
|
||||
$: was_reset = value == null || value.length === 0 ? true : was_reset;
|
||||
|
||||
let _value: { image: FileData; caption: string | null }[] | null = null;
|
||||
$: _value =
|
||||
value === null
|
||||
let resolved_value: GalleryData | null = null;
|
||||
$: resolved_value =
|
||||
value == null
|
||||
? null
|
||||
: value.map((data) => ({
|
||||
image: normalise_file(data.image, root, proxy_url) as FileData,
|
||||
caption: data.caption
|
||||
}));
|
||||
|
||||
let prevValue: { image: FileData; caption: string | null }[] | null | null =
|
||||
value;
|
||||
if (selected_index === null && preview && value?.length) {
|
||||
let prev_value: GalleryData | null = value;
|
||||
if (selected_index == null && preview && value?.length) {
|
||||
selected_index = 0;
|
||||
}
|
||||
let old_selected_index: number | null = selected_index;
|
||||
|
||||
$: if (!dequal(prevValue, value)) {
|
||||
$: if (!dequal(prev_value, value)) {
|
||||
// When value is falsy (clear button or first load),
|
||||
// preview determines the selected image
|
||||
if (was_reset) {
|
||||
@ -65,19 +67,18 @@
|
||||
// gallery has at least as many elements as it did before
|
||||
} else {
|
||||
selected_index =
|
||||
selected_index !== null &&
|
||||
value !== null &&
|
||||
selected_index < value.length
|
||||
selected_index != null && value != null && selected_index < value.length
|
||||
? selected_index
|
||||
: null;
|
||||
}
|
||||
dispatch("change");
|
||||
prevValue = value;
|
||||
prev_value = value;
|
||||
}
|
||||
|
||||
$: previous =
|
||||
((selected_index ?? 0) + (_value?.length ?? 0) - 1) % (_value?.length ?? 0);
|
||||
$: next = ((selected_index ?? 0) + 1) % (_value?.length ?? 0);
|
||||
((selected_index ?? 0) + (resolved_value?.length ?? 0) - 1) %
|
||||
(resolved_value?.length ?? 0);
|
||||
$: next = ((selected_index ?? 0) + 1) % (resolved_value?.length ?? 0);
|
||||
|
||||
function handle_preview_click(event: MouseEvent): void {
|
||||
const element = event.target as HTMLElement;
|
||||
@ -111,28 +112,13 @@
|
||||
}
|
||||
}
|
||||
|
||||
function isFileData(obj: any): obj is FileData {
|
||||
return typeof obj === "object" && obj !== null && "data" in obj;
|
||||
}
|
||||
|
||||
function getHrefValue(selected: any): string {
|
||||
if (isFileData(selected)) {
|
||||
return selected.path;
|
||||
} else if (typeof selected === "string") {
|
||||
return selected;
|
||||
} else if (Array.isArray(selected)) {
|
||||
return getHrefValue(selected[0]);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
$: {
|
||||
if (selected_index !== old_selected_index) {
|
||||
old_selected_index = selected_index;
|
||||
if (selected_index !== null) {
|
||||
dispatch("select", {
|
||||
index: selected_index,
|
||||
value: _value?.[selected_index]
|
||||
value: resolved_value?.[selected_index]
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -175,6 +161,39 @@
|
||||
|
||||
let client_height = 0;
|
||||
let window_height = 0;
|
||||
|
||||
// Unlike `gr.Image()`, images specified via remote URLs are not cached in the server
|
||||
// and their remote URLs are directly passed to the client as `value[].image.url`.
|
||||
// The `download` attribute of the <a> tag doesn't work for remote URLs (https://developer.mozilla.org/en-US/docs/Web/HTML/Element/a#download),
|
||||
// so we need to download the image via JS as below.
|
||||
async function download(file_url: string, name: string): Promise<void> {
|
||||
let response;
|
||||
try {
|
||||
response = await fetch(file_url);
|
||||
} catch (error) {
|
||||
if (error instanceof TypeError) {
|
||||
// If CORS is not allowed (https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch#checking_that_the_fetch_was_successful),
|
||||
// open the link in a new tab instead, mimicing the behavior of the `download` attribute for remote URLs,
|
||||
// which is not ideal, but a reasonable fallback.
|
||||
window.open(file_url, "_blank", "noreferrer");
|
||||
return;
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const link = document.createElement("a");
|
||||
link.href = url;
|
||||
link.download = name;
|
||||
link.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
$: selected_image =
|
||||
selected_index != null && resolved_value != null
|
||||
? resolved_value[selected_index]
|
||||
: null;
|
||||
</script>
|
||||
|
||||
<svelte:window bind:innerHeight={window_height} />
|
||||
@ -182,20 +201,29 @@
|
||||
{#if show_label}
|
||||
<BlockLabel {show_label} Icon={Image} label={label || "Gallery"} />
|
||||
{/if}
|
||||
{#if value === null || _value === null || _value.length === 0}
|
||||
{#if value == null || resolved_value == null || resolved_value.length === 0}
|
||||
<Empty unpadded_box={true} size="large"><Image /></Empty>
|
||||
{:else}
|
||||
{#if selected_index !== null && allow_preview}
|
||||
{#if selected_image && allow_preview}
|
||||
<button on:keydown={on_keydown} class="preview">
|
||||
<div class="icon-buttons">
|
||||
{#if show_download_button}
|
||||
<a
|
||||
href={getHrefValue(value[selected_index])}
|
||||
target={window.__is_colab__ ? "_blank" : null}
|
||||
download="image"
|
||||
>
|
||||
<IconButton Icon={Download} label={i18n("common.download")} />
|
||||
</a>
|
||||
<div class="download-button-container">
|
||||
<IconButton
|
||||
Icon={Download}
|
||||
label={i18n("common.download")}
|
||||
on:click={() => {
|
||||
const image = selected_image?.image;
|
||||
if (image == null) {
|
||||
return;
|
||||
}
|
||||
const { url, orig_name } = image;
|
||||
if (url) {
|
||||
download(url, orig_name ?? "image");
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<ModifyUpload
|
||||
@ -207,23 +235,21 @@
|
||||
<button
|
||||
class="image-button"
|
||||
on:click={(event) => handle_preview_click(event)}
|
||||
style="height: calc(100% - {_value[selected_index].caption
|
||||
? '80px'
|
||||
: '60px'})"
|
||||
style="height: calc(100% - {selected_image.caption ? '80px' : '60px'})"
|
||||
aria-label="detailed view of selected image"
|
||||
>
|
||||
<img
|
||||
data-testid="detailed-image"
|
||||
src={_value[selected_index].image.url}
|
||||
alt={_value[selected_index].caption || ""}
|
||||
title={_value[selected_index].caption || null}
|
||||
class:with-caption={!!_value[selected_index].caption}
|
||||
src={selected_image.image.url}
|
||||
alt={selected_image.caption || ""}
|
||||
title={selected_image.caption || null}
|
||||
class:with-caption={!!selected_image.caption}
|
||||
loading="lazy"
|
||||
/>
|
||||
</button>
|
||||
{#if _value[selected_index]?.caption}
|
||||
{#if selected_image?.caption}
|
||||
<caption class="caption">
|
||||
{_value[selected_index].caption}
|
||||
{selected_image.caption}
|
||||
</caption>
|
||||
{/if}
|
||||
<div
|
||||
@ -231,13 +257,13 @@
|
||||
class="thumbnails scroll-hide"
|
||||
data-testid="container_el"
|
||||
>
|
||||
{#each _value as image, i}
|
||||
{#each resolved_value as image, i}
|
||||
<button
|
||||
bind:this={el[i]}
|
||||
on:click={() => (selected_index = i)}
|
||||
class="thumbnail-item thumbnail-small"
|
||||
class:selected={selected_index === i}
|
||||
aria-label={"Thumbnail " + (i + 1) + " of " + _value.length}
|
||||
aria-label={"Thumbnail " + (i + 1) + " of " + resolved_value.length}
|
||||
>
|
||||
<img
|
||||
src={image.image.url}
|
||||
@ -268,17 +294,17 @@
|
||||
{i18n}
|
||||
on:share
|
||||
on:error
|
||||
value={_value}
|
||||
value={resolved_value}
|
||||
formatter={format_gallery_for_sharing}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{#each _value as entry, i}
|
||||
{#each resolved_value as entry, i}
|
||||
<button
|
||||
class="thumbnail-item thumbnail-lg"
|
||||
class:selected={selected_index === i}
|
||||
on:click={() => (selected_index = i)}
|
||||
aria-label={"Thumbnail " + (i + 1) + " of " + _value.length}
|
||||
aria-label={"Thumbnail " + (i + 1) + " of " + resolved_value.length}
|
||||
>
|
||||
<img
|
||||
alt={entry.caption || ""}
|
||||
@ -465,7 +491,7 @@
|
||||
right: 0;
|
||||
}
|
||||
|
||||
.icon-buttons a {
|
||||
.icon-buttons .download-button-container {
|
||||
margin: var(--size-1) 0;
|
||||
}
|
||||
</style>
|
||||
|
@ -2117,36 +2117,73 @@ class TestGallery:
|
||||
def test_postprocess(self):
|
||||
url = "https://huggingface.co/Norod78/SDXL-VintageMagStyle-Lora/resolve/main/Examples/00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
|
||||
gallery = gr.Gallery([url])
|
||||
assert gallery.get_config()["value"][0]["image"]["path"] == url
|
||||
|
||||
@patch("uuid.uuid4", return_value="my-uuid")
|
||||
def test_gallery(self, mock_uuid):
|
||||
gallery = gr.Gallery()
|
||||
test_file_dir = Path(Path(__file__).parent, "test_files")
|
||||
[
|
||||
client_utils.encode_file_to_base64(Path(test_file_dir, "bus.png")),
|
||||
client_utils.encode_file_to_base64(Path(test_file_dir, "cheetah1.jpg")),
|
||||
]
|
||||
|
||||
postprocessed_gallery = gallery.postprocess(
|
||||
[Path("test/test_files/bus.png")]
|
||||
).model_dump()
|
||||
processed_gallery = [
|
||||
assert gallery.get_config()["value"] == [
|
||||
{
|
||||
"image": {
|
||||
"path": "bus.png",
|
||||
"orig_name": None,
|
||||
"path": url,
|
||||
"orig_name": "00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg",
|
||||
"mime_type": None,
|
||||
"size": None,
|
||||
"url": url,
|
||||
},
|
||||
"caption": None,
|
||||
}
|
||||
]
|
||||
|
||||
def test_gallery(self):
|
||||
gallery = gr.Gallery()
|
||||
Path(Path(__file__).parent, "test_files")
|
||||
|
||||
postprocessed_gallery = gallery.postprocess(
|
||||
[
|
||||
("test/test_files/foo.png", "foo_caption"),
|
||||
(Path("test/test_files/bar.png"), "bar_caption"),
|
||||
"test/test_files/baz.png",
|
||||
Path("test/test_files/qux.png"),
|
||||
]
|
||||
).model_dump()
|
||||
assert postprocessed_gallery == [
|
||||
{
|
||||
"image": {
|
||||
"path": "test/test_files/foo.png",
|
||||
"orig_name": "foo.png",
|
||||
"mime_type": None,
|
||||
"size": None,
|
||||
"url": None,
|
||||
},
|
||||
"caption": "foo_caption",
|
||||
},
|
||||
{
|
||||
"image": {
|
||||
"path": "test/test_files/bar.png",
|
||||
"orig_name": "bar.png",
|
||||
"mime_type": None,
|
||||
"size": None,
|
||||
"url": None,
|
||||
},
|
||||
"caption": "bar_caption",
|
||||
},
|
||||
{
|
||||
"image": {
|
||||
"path": "test/test_files/baz.png",
|
||||
"orig_name": "baz.png",
|
||||
"mime_type": None,
|
||||
"size": None,
|
||||
"url": None,
|
||||
},
|
||||
"caption": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"image": {
|
||||
"path": "test/test_files/qux.png",
|
||||
"orig_name": "qux.png",
|
||||
"mime_type": None,
|
||||
"size": None,
|
||||
"url": None,
|
||||
},
|
||||
"caption": None,
|
||||
},
|
||||
]
|
||||
postprocessed_gallery[0]["image"]["path"] = os.path.basename(
|
||||
postprocessed_gallery[0]["image"]["path"]
|
||||
)
|
||||
assert processed_gallery == postprocessed_gallery
|
||||
|
||||
|
||||
class TestState:
|
||||
|
Loading…
x
Reference in New Issue
Block a user