Fix the download button of the gr.Gallery() component to work (#6487)

* Fix the download button of the `gr.Gallery()` component to work

* Refactoring js/gallery/shared/Gallery.svelte

* Fix `gr.Gallery()` to set `orig_name` for URLs

* Fix Gallery.postprocess()

* Fix `download()` to fallback to `window.open()` when CORS is not allowed

* Fix `gr.Gallery` to leave  as None so it will be replaced with a local cache path and restore the `<a>` tag-based download feature on the frontend

* Align a variable name to its type name

* Fix Gallery's tests

* Fix the frontend test for gallery

* Revert "Fix `gr.Gallery` to leave  as None so it will be replaced with a local cache path and restore the `<a>` tag-based download feature on the frontend"

This reverts commit d754980cc27ded760bfc26df4310f913c2c6944a.

* Revert "Fix Gallery's tests"

This reverts commit 4e2aa3fff1ef7b586839fa6c485d1a3b8738fd03.

* Revert "Fix the frontend test for gallery"

This reverts commit 007caa23e7b9dbab36376307137a30f277fca297.

* Fix for linter

* Add a test about the download button

* Fix type defs on Gallery.postprocess

* Improve TestGallery

* add changeset

* Update gradio/components/gallery.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/components/gallery.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Revert "Update gradio/components/gallery.py"

This reverts commit 4d6e12730511fe9840a5372787ad81fd83cbb44c.

* Revert "Update gradio/components/gallery.py"

This reverts commit f2bfad0744d20e121c8eef40a335979a7c703517.

* Use `tuple` instead of `typing.Tuple`

* Revert "Use `tuple` instead of `typing.Tuple`"

This reverts commit 69ab93cad4f39fe38f0e0f88126be572bf12cecf.

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-12-09 17:58:19 +01:00 committed by GitHub
parent 5177132d71
commit 9a5811df92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 171 additions and 84 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/gallery": patch
"gradio": patch
---
fix:Fix the download button of the `gr.Gallery()` component to work

View File

@ -3,7 +3,8 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, List, Literal, Optional
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from urllib.parse import urlparse
import numpy as np
from gradio_client.documentation import document, set_documentation_group
@ -18,6 +19,10 @@ from gradio.events import Events
set_documentation_group("component")
GalleryImageType = Union[np.ndarray, _Image.Image, Path, str]
CaptionedGalleryImageType = Tuple[GalleryImageType, str]
class GalleryImage(GradioModel):
image: FileData
caption: Optional[str] = None
@ -125,9 +130,7 @@ class Gallery(Component):
def postprocess(
self,
value: list[np.ndarray | _Image.Image | str]
| list[tuple[np.ndarray | _Image.Image | str, str]]
| None,
value: list[GalleryImageType | CaptionedGalleryImageType] | None,
) -> GalleryData:
"""
Parameters:
@ -141,6 +144,7 @@ class Gallery(Component):
for img in value:
url = None
caption = None
orig_name = None
if isinstance(img, (tuple, list)):
img, caption = img
if isinstance(img, np.ndarray):
@ -155,13 +159,20 @@ class Gallery(Component):
file_path = str(utils.abspath(file))
elif isinstance(img, str):
file_path = img
url = img if is_http_url_like(img) else None
if is_http_url_like(img):
url = img
orig_name = Path(urlparse(img).path).name
else:
url = None
orig_name = Path(img).name
elif isinstance(img, Path):
file_path = str(img)
orig_name = img.name
else:
raise ValueError(f"Cannot process type as image: {type(img)}")
entry = GalleryImage(
image=FileData(path=file_path, url=url), caption=caption
image=FileData(path=file_path, url=url, orig_name=orig_name),
caption=caption,
)
output.append(entry)
return GalleryData(root=output)

View File

@ -15,10 +15,17 @@ test("Gallery preview mode displays all images correctly.", async ({
).toEqual("https://gradio-builds.s3.amazonaws.com/assets/cheetah-003.jpg");
});
test("Gallery select event returns the right value", async ({ page }) => {
test("Gallery select event returns the right value and the download button works correctly", async ({
page
}) => {
await page.getByRole("button", { name: "Run" }).click();
await page.getByLabel("Thumbnail 2 of 3").click();
await expect(page.getByLabel("Select Data")).toHaveValue(
"https://gradio-builds.s3.amazonaws.com/assets/lite-logo.png"
);
const downloadPromise = page.waitForEvent("download");
await page.getByLabel("Download").click();
const download = await downloadPromise;
expect(download.suggestedFilename()).toBe("lite-logo.png");
});

View File

@ -12,11 +12,14 @@
import { IconButton } from "@gradio/atoms";
import type { I18nFormatter } from "@gradio/utils";
type GalleryImage = { image: FileData; caption: string | null };
type GalleryData = GalleryImage[];
export let show_label = true;
export let label: string;
export let root = "";
export let proxy_url: null | string = null;
export let value: { image: FileData; caption: string | null }[] | null = null;
export let value: GalleryData | null = null;
export let columns: number | number[] | undefined = [2];
export let rows: number | number[] | undefined = undefined;
export let height: number | "auto" = "auto";
@ -37,25 +40,24 @@
// tracks whether the value of the gallery was reset
let was_reset = true;
$: was_reset = value == null || value.length == 0 ? true : was_reset;
$: was_reset = value == null || value.length === 0 ? true : was_reset;
let _value: { image: FileData; caption: string | null }[] | null = null;
$: _value =
value === null
let resolved_value: GalleryData | null = null;
$: resolved_value =
value == null
? null
: value.map((data) => ({
image: normalise_file(data.image, root, proxy_url) as FileData,
caption: data.caption
}));
let prevValue: { image: FileData; caption: string | null }[] | null | null =
value;
if (selected_index === null && preview && value?.length) {
let prev_value: GalleryData | null = value;
if (selected_index == null && preview && value?.length) {
selected_index = 0;
}
let old_selected_index: number | null = selected_index;
$: if (!dequal(prevValue, value)) {
$: if (!dequal(prev_value, value)) {
// When value is falsy (clear button or first load),
// preview determines the selected image
if (was_reset) {
@ -65,19 +67,18 @@
// gallery has at least as many elements as it did before
} else {
selected_index =
selected_index !== null &&
value !== null &&
selected_index < value.length
selected_index != null && value != null && selected_index < value.length
? selected_index
: null;
}
dispatch("change");
prevValue = value;
prev_value = value;
}
$: previous =
((selected_index ?? 0) + (_value?.length ?? 0) - 1) % (_value?.length ?? 0);
$: next = ((selected_index ?? 0) + 1) % (_value?.length ?? 0);
((selected_index ?? 0) + (resolved_value?.length ?? 0) - 1) %
(resolved_value?.length ?? 0);
$: next = ((selected_index ?? 0) + 1) % (resolved_value?.length ?? 0);
function handle_preview_click(event: MouseEvent): void {
const element = event.target as HTMLElement;
@ -111,28 +112,13 @@
}
}
function isFileData(obj: any): obj is FileData {
return typeof obj === "object" && obj !== null && "data" in obj;
}
function getHrefValue(selected: any): string {
if (isFileData(selected)) {
return selected.path;
} else if (typeof selected === "string") {
return selected;
} else if (Array.isArray(selected)) {
return getHrefValue(selected[0]);
}
return "";
}
$: {
if (selected_index !== old_selected_index) {
old_selected_index = selected_index;
if (selected_index !== null) {
dispatch("select", {
index: selected_index,
value: _value?.[selected_index]
value: resolved_value?.[selected_index]
});
}
}
@ -175,6 +161,39 @@
let client_height = 0;
let window_height = 0;
// Unlike `gr.Image()`, images specified via remote URLs are not cached in the server
// and their remote URLs are directly passed to the client as `value[].image.url`.
// The `download` attribute of the <a> tag doesn't work for remote URLs (https://developer.mozilla.org/en-US/docs/Web/HTML/Element/a#download),
// so we need to download the image via JS as below.
async function download(file_url: string, name: string): Promise<void> {
let response;
try {
response = await fetch(file_url);
} catch (error) {
if (error instanceof TypeError) {
// If CORS is not allowed (https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch#checking_that_the_fetch_was_successful),
// open the link in a new tab instead, mimicing the behavior of the `download` attribute for remote URLs,
// which is not ideal, but a reasonable fallback.
window.open(file_url, "_blank", "noreferrer");
return;
}
throw error;
}
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = url;
link.download = name;
link.click();
URL.revokeObjectURL(url);
}
$: selected_image =
selected_index != null && resolved_value != null
? resolved_value[selected_index]
: null;
</script>
<svelte:window bind:innerHeight={window_height} />
@ -182,20 +201,29 @@
{#if show_label}
<BlockLabel {show_label} Icon={Image} label={label || "Gallery"} />
{/if}
{#if value === null || _value === null || _value.length === 0}
{#if value == null || resolved_value == null || resolved_value.length === 0}
<Empty unpadded_box={true} size="large"><Image /></Empty>
{:else}
{#if selected_index !== null && allow_preview}
{#if selected_image && allow_preview}
<button on:keydown={on_keydown} class="preview">
<div class="icon-buttons">
{#if show_download_button}
<a
href={getHrefValue(value[selected_index])}
target={window.__is_colab__ ? "_blank" : null}
download="image"
>
<IconButton Icon={Download} label={i18n("common.download")} />
</a>
<div class="download-button-container">
<IconButton
Icon={Download}
label={i18n("common.download")}
on:click={() => {
const image = selected_image?.image;
if (image == null) {
return;
}
const { url, orig_name } = image;
if (url) {
download(url, orig_name ?? "image");
}
}}
/>
</div>
{/if}
<ModifyUpload
@ -207,23 +235,21 @@
<button
class="image-button"
on:click={(event) => handle_preview_click(event)}
style="height: calc(100% - {_value[selected_index].caption
? '80px'
: '60px'})"
style="height: calc(100% - {selected_image.caption ? '80px' : '60px'})"
aria-label="detailed view of selected image"
>
<img
data-testid="detailed-image"
src={_value[selected_index].image.url}
alt={_value[selected_index].caption || ""}
title={_value[selected_index].caption || null}
class:with-caption={!!_value[selected_index].caption}
src={selected_image.image.url}
alt={selected_image.caption || ""}
title={selected_image.caption || null}
class:with-caption={!!selected_image.caption}
loading="lazy"
/>
</button>
{#if _value[selected_index]?.caption}
{#if selected_image?.caption}
<caption class="caption">
{_value[selected_index].caption}
{selected_image.caption}
</caption>
{/if}
<div
@ -231,13 +257,13 @@
class="thumbnails scroll-hide"
data-testid="container_el"
>
{#each _value as image, i}
{#each resolved_value as image, i}
<button
bind:this={el[i]}
on:click={() => (selected_index = i)}
class="thumbnail-item thumbnail-small"
class:selected={selected_index === i}
aria-label={"Thumbnail " + (i + 1) + " of " + _value.length}
aria-label={"Thumbnail " + (i + 1) + " of " + resolved_value.length}
>
<img
src={image.image.url}
@ -268,17 +294,17 @@
{i18n}
on:share
on:error
value={_value}
value={resolved_value}
formatter={format_gallery_for_sharing}
/>
</div>
{/if}
{#each _value as entry, i}
{#each resolved_value as entry, i}
<button
class="thumbnail-item thumbnail-lg"
class:selected={selected_index === i}
on:click={() => (selected_index = i)}
aria-label={"Thumbnail " + (i + 1) + " of " + _value.length}
aria-label={"Thumbnail " + (i + 1) + " of " + resolved_value.length}
>
<img
alt={entry.caption || ""}
@ -465,7 +491,7 @@
right: 0;
}
.icon-buttons a {
.icon-buttons .download-button-container {
margin: var(--size-1) 0;
}
</style>

View File

@ -2117,36 +2117,73 @@ class TestGallery:
def test_postprocess(self):
url = "https://huggingface.co/Norod78/SDXL-VintageMagStyle-Lora/resolve/main/Examples/00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
gallery = gr.Gallery([url])
assert gallery.get_config()["value"][0]["image"]["path"] == url
@patch("uuid.uuid4", return_value="my-uuid")
def test_gallery(self, mock_uuid):
gallery = gr.Gallery()
test_file_dir = Path(Path(__file__).parent, "test_files")
[
client_utils.encode_file_to_base64(Path(test_file_dir, "bus.png")),
client_utils.encode_file_to_base64(Path(test_file_dir, "cheetah1.jpg")),
]
postprocessed_gallery = gallery.postprocess(
[Path("test/test_files/bus.png")]
).model_dump()
processed_gallery = [
assert gallery.get_config()["value"] == [
{
"image": {
"path": "bus.png",
"orig_name": None,
"path": url,
"orig_name": "00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg",
"mime_type": None,
"size": None,
"url": url,
},
"caption": None,
}
]
def test_gallery(self):
gallery = gr.Gallery()
Path(Path(__file__).parent, "test_files")
postprocessed_gallery = gallery.postprocess(
[
("test/test_files/foo.png", "foo_caption"),
(Path("test/test_files/bar.png"), "bar_caption"),
"test/test_files/baz.png",
Path("test/test_files/qux.png"),
]
).model_dump()
assert postprocessed_gallery == [
{
"image": {
"path": "test/test_files/foo.png",
"orig_name": "foo.png",
"mime_type": None,
"size": None,
"url": None,
},
"caption": "foo_caption",
},
{
"image": {
"path": "test/test_files/bar.png",
"orig_name": "bar.png",
"mime_type": None,
"size": None,
"url": None,
},
"caption": "bar_caption",
},
{
"image": {
"path": "test/test_files/baz.png",
"orig_name": "baz.png",
"mime_type": None,
"size": None,
"url": None,
},
"caption": None,
}
},
{
"image": {
"path": "test/test_files/qux.png",
"orig_name": "qux.png",
"mime_type": None,
"size": None,
"url": None,
},
"caption": None,
},
]
postprocessed_gallery[0]["image"]["path"] = os.path.basename(
postprocessed_gallery[0]["image"]["path"]
)
assert processed_gallery == postprocessed_gallery
class TestState: