mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-11 11:19:58 +08:00
Allow displaying SVG images securely in gr.Image
and gr.Gallery
components (#10269)
* changes * changes * add changeset * changes * add changeset * changes * changes * changes * add changeset * add changeset * add changeset * format fe * changes * changes * changes * revert * revert more * revert * add changeset * more changes * add changeset * changes * add changeset * format * add changeset * changes * changes * svg * changes * format * add changeset * fix tests --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
99123e75f5
commit
890eaa3a9e
7
.changeset/eleven-suits-itch.md
Normal file
7
.changeset/eleven-suits-itch.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/gallery": patch
|
||||
"@gradio/image": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components
|
@ -12,7 +12,7 @@ from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@ -21,9 +21,9 @@ from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_http_url_like
|
||||
|
||||
from gradio import processing_utils, utils, wasm_utils
|
||||
from gradio import image_utils, processing_utils, utils, wasm_utils
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel, ImageData
|
||||
from gradio.events import Events
|
||||
from gradio.exceptions import Error
|
||||
|
||||
@ -35,7 +35,7 @@ CaptionedGalleryMediaType = tuple[GalleryMediaType, str]
|
||||
|
||||
|
||||
class GalleryImage(GradioModel):
|
||||
image: FileData
|
||||
image: ImageData
|
||||
caption: Optional[str] = None
|
||||
|
||||
|
||||
@ -188,7 +188,7 @@ class Gallery(Component):
|
||||
if isinstance(gallery_element, GalleryVideo):
|
||||
file_path = gallery_element.video.path
|
||||
else:
|
||||
file_path = gallery_element.image.path
|
||||
file_path = gallery_element.image.path or ""
|
||||
if self.file_types and not client_utils.is_valid_file(
|
||||
file_path, self.file_types
|
||||
):
|
||||
@ -216,6 +216,10 @@ class Gallery(Component):
|
||||
"""
|
||||
if value is None:
|
||||
return GalleryData(root=[])
|
||||
if isinstance(value, str):
|
||||
raise ValueError(
|
||||
"The `value` passed into `gr.Gallery` must be a list of images or videos, or list of (media, caption) tuples."
|
||||
)
|
||||
output = []
|
||||
|
||||
def _save(img):
|
||||
@ -236,14 +240,20 @@ class Gallery(Component):
|
||||
)
|
||||
file_path = str(utils.abspath(file))
|
||||
elif isinstance(img, str):
|
||||
file_path = img
|
||||
mime_type = client_utils.get_mimetype(file_path)
|
||||
if is_http_url_like(img):
|
||||
mime_type = client_utils.get_mimetype(img)
|
||||
if img.lower().endswith(".svg"):
|
||||
svg_content = image_utils.extract_svg_content(img)
|
||||
orig_name = Path(img).name
|
||||
url = f"data:image/svg+xml,{quote(svg_content)}"
|
||||
file_path = None
|
||||
elif is_http_url_like(img):
|
||||
url = img
|
||||
orig_name = Path(urlparse(img).path).name
|
||||
file_path = img
|
||||
else:
|
||||
url = None
|
||||
orig_name = Path(img).name
|
||||
file_path = img
|
||||
elif isinstance(img, Path):
|
||||
file_path = str(img)
|
||||
orig_name = img.name
|
||||
@ -253,7 +263,7 @@ class Gallery(Component):
|
||||
if mime_type is not None and "video" in mime_type:
|
||||
return GalleryVideo(
|
||||
video=FileData(
|
||||
path=file_path,
|
||||
path=file_path, # type: ignore
|
||||
url=url,
|
||||
orig_name=orig_name,
|
||||
mime_type=mime_type,
|
||||
@ -262,7 +272,7 @@ class Gallery(Component):
|
||||
)
|
||||
else:
|
||||
return GalleryImage(
|
||||
image=FileData(
|
||||
image=ImageData(
|
||||
path=file_path,
|
||||
url=url,
|
||||
orig_name=orig_name,
|
||||
|
@ -5,18 +5,18 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
from PIL import ImageOps
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from gradio import image_utils, utils
|
||||
from gradio.components.base import Component, StreamingInput
|
||||
from gradio.data_classes import GradioModel
|
||||
from gradio.data_classes import Base64ImageData, ImageData
|
||||
from gradio.events import Events
|
||||
from gradio.exceptions import Error
|
||||
|
||||
@ -26,28 +26,6 @@ if TYPE_CHECKING:
|
||||
PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
|
||||
|
||||
|
||||
class ImageData(GradioModel):
|
||||
path: Optional[str] = Field(default=None, description="Path to a local file")
|
||||
url: Optional[str] = Field(
|
||||
default=None, description="Publicly available url or base64 encoded image"
|
||||
)
|
||||
size: Optional[int] = Field(default=None, description="Size of image in bytes")
|
||||
orig_name: Optional[str] = Field(default=None, description="Original filename")
|
||||
mime_type: Optional[str] = Field(default=None, description="mime type of image")
|
||||
is_stream: bool = Field(default=False, description="Can always be set to False")
|
||||
meta: dict = {"_type": "gradio.FileData"}
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"description": "For input, either path or url must be provided. For output, path is always provided."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Base64ImageData(GradioModel):
|
||||
url: str = Field(description="base64 encoded image")
|
||||
|
||||
|
||||
@document()
|
||||
class Image(StreamingInput, Component):
|
||||
"""
|
||||
@ -112,7 +90,7 @@ class Image(StreamingInput, Component):
|
||||
width: The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed image file or numpy array, but will affect the displayed image.
|
||||
image_mode: The pixel format and color depth that the image should be loaded and preprocessed as. "RGB" will load the image as a color image, or "L" as black-and-white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file type (e.g. "RGBA" for a .png image, "RGB" in most other cases).
|
||||
sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"].
|
||||
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil".
|
||||
type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". To support SVGs, the `type` should be set to "filepath".
|
||||
label: the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
|
||||
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
|
||||
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
|
||||
@ -198,7 +176,7 @@ class Image(StreamingInput, Component):
|
||||
Parameters:
|
||||
payload: image data in the form of a FileData object
|
||||
Returns:
|
||||
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned.
|
||||
Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`.
|
||||
"""
|
||||
if payload is None:
|
||||
return payload
|
||||
@ -227,7 +205,7 @@ class Image(StreamingInput, Component):
|
||||
if suffix.lower() == "svg":
|
||||
if self.type == "filepath":
|
||||
return str(file_path)
|
||||
raise Error("SVG files are not supported as input images.")
|
||||
raise Error("SVG files are not supported as input images for this app.")
|
||||
|
||||
im = PIL.Image.open(file_path)
|
||||
if self.type == "filepath" and (self.image_mode in [None, im.mode]):
|
||||
@ -267,7 +245,11 @@ class Image(StreamingInput, Component):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.lower().endswith(".svg"):
|
||||
return ImageData(path=value, orig_name=Path(value).name)
|
||||
svg_content = image_utils.extract_svg_content(value)
|
||||
return ImageData(
|
||||
orig_name=Path(value).name,
|
||||
url=f"data:image/svg+xml,{quote(svg_content)}",
|
||||
)
|
||||
if self.streaming:
|
||||
if isinstance(value, np.ndarray):
|
||||
return Base64ImageData(
|
||||
|
@ -24,6 +24,8 @@ from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_file_obj_with_meta, traverse
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
RootModel,
|
||||
@ -391,3 +393,25 @@ class MediaStreamChunk(TypedDict):
|
||||
duration: float
|
||||
extension: str
|
||||
id: NotRequired[str]
|
||||
|
||||
|
||||
class ImageData(GradioModel):
|
||||
path: Optional[str] = Field(default=None, description="Path to a local file")
|
||||
url: Optional[str] = Field(
|
||||
default=None, description="Publicly available url or base64 encoded image"
|
||||
)
|
||||
size: Optional[int] = Field(default=None, description="Size of image in bytes")
|
||||
orig_name: Optional[str] = Field(default=None, description="Original filename")
|
||||
mime_type: Optional[str] = Field(default=None, description="mime type of image")
|
||||
is_stream: bool = Field(default=False, description="Can always be set to False")
|
||||
meta: dict = {"_type": "gradio.FileData"}
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"description": "For input, either path or url must be provided. For output, path is always provided."
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Base64ImageData(GradioModel):
|
||||
url: str = Field(description="base64 encoded image")
|
||||
|
@ -5,9 +5,10 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal, cast
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client.utils import get_mimetype
|
||||
from gradio_client.utils import get_mimetype, is_http_url_like
|
||||
from PIL import ImageOps
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -152,3 +153,22 @@ def encode_image_file_to_base64(image_file: str | Path) -> str:
|
||||
bytes_data = f.read()
|
||||
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
||||
return f"data:{mime_type};base64," + base64_str
|
||||
|
||||
|
||||
def extract_svg_content(image_file: str | Path) -> str:
|
||||
"""
|
||||
Provided a path or URL to an SVG file, return the SVG content as a string.
|
||||
Parameters:
|
||||
image_file: Local file path or URL to an SVG file
|
||||
Returns:
|
||||
str: The SVG content as a string
|
||||
"""
|
||||
image_file = str(image_file)
|
||||
if is_http_url_like(image_file):
|
||||
response = httpx.get(image_file)
|
||||
response.raise_for_status() # Raise an error for bad status codes
|
||||
return response.text
|
||||
else:
|
||||
with open(image_file) as file:
|
||||
svg_content = file.read()
|
||||
return svg_content
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
<script lang="ts">
|
||||
import type { GalleryImage, GalleryVideo } from "./types";
|
||||
import type { FileData } from "@gradio/client";
|
||||
import type { Gradio, ShareData, SelectData } from "@gradio/utils";
|
||||
import { Block, UploadText } from "@gradio/atoms";
|
||||
import Gallery from "./shared/Gallery.svelte";
|
||||
@ -52,6 +53,30 @@
|
||||
|
||||
$: no_value = value === null ? true : value.length === 0;
|
||||
$: selected_index, dispatch("prop_change", { selected_index });
|
||||
|
||||
async function process_upload_files(
|
||||
files: FileData[]
|
||||
): Promise<GalleryData[]> {
|
||||
const processed_files = await Promise.all(
|
||||
files.map(async (x) => {
|
||||
if (x.path?.toLowerCase().endsWith(".svg") && x.url) {
|
||||
const response = await fetch(x.url);
|
||||
const svgContent = await response.text();
|
||||
return {
|
||||
...x,
|
||||
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
|
||||
};
|
||||
}
|
||||
return x;
|
||||
})
|
||||
);
|
||||
|
||||
return processed_files.map((x) =>
|
||||
x.mime_type?.includes("video")
|
||||
? { video: x, caption: null }
|
||||
: { image: x, caption: null }
|
||||
);
|
||||
}
|
||||
</script>
|
||||
|
||||
<Block
|
||||
@ -83,13 +108,9 @@
|
||||
i18n={gradio.i18n}
|
||||
upload={(...args) => gradio.client.upload(...args)}
|
||||
stream_handler={(...args) => gradio.client.stream(...args)}
|
||||
on:upload={(e) => {
|
||||
on:upload={async (e) => {
|
||||
const files = Array.isArray(e.detail) ? e.detail : [e.detail];
|
||||
value = files.map((x) =>
|
||||
x.mime_type?.includes("video")
|
||||
? { video: x, caption: null }
|
||||
: { image: x, caption: null }
|
||||
);
|
||||
value = await process_upload_files(files);
|
||||
gradio.dispatch("upload", value);
|
||||
}}
|
||||
on:error={({ detail }) => {
|
||||
|
@ -45,10 +45,20 @@
|
||||
|
||||
export let webcam_constraints: { [key: string]: any } | undefined = undefined;
|
||||
|
||||
function handle_upload({ detail }: CustomEvent<FileData>): void {
|
||||
// only trigger streaming event if streaming
|
||||
async function handle_upload({
|
||||
detail
|
||||
}: CustomEvent<FileData>): Promise<void> {
|
||||
if (!streaming) {
|
||||
value = detail;
|
||||
if (detail.path?.toLowerCase().endsWith(".svg") && detail.url) {
|
||||
const response = await fetch(detail.url);
|
||||
const svgContent = await response.text();
|
||||
value = {
|
||||
...detail,
|
||||
url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
|
||||
};
|
||||
} else {
|
||||
value = detail;
|
||||
}
|
||||
dispatch("upload");
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import PIL
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components.gallery import GalleryImage
|
||||
from gradio.data_classes import FileData
|
||||
from gradio.data_classes import ImageData
|
||||
|
||||
|
||||
class TestGallery:
|
||||
@ -96,7 +96,7 @@ class TestGallery:
|
||||
from gradio.components.gallery import GalleryData, GalleryImage
|
||||
|
||||
gallery = gr.Gallery()
|
||||
img = GalleryImage(image=FileData(path="test/test_files/bus.png"))
|
||||
img = GalleryImage(image=ImageData(path="test/test_files/bus.png"))
|
||||
data = GalleryData(root=[img])
|
||||
|
||||
assert (preprocessed := gallery.preprocess(data))
|
||||
@ -115,7 +115,7 @@ class TestGallery:
|
||||
)
|
||||
|
||||
img_captions = GalleryImage(
|
||||
image=FileData(path="test/test_files/bus.png"), caption="bus"
|
||||
image=ImageData(path="test/test_files/bus.png"), caption="bus"
|
||||
)
|
||||
data = GalleryData(root=[img_captions])
|
||||
assert (preprocess := gr.Gallery().preprocess(data))
|
||||
@ -127,4 +127,6 @@ class TestGallery:
|
||||
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
|
||||
)
|
||||
if isinstance(output.root[0], GalleryImage):
|
||||
assert output.root[0].image.path.endswith(".jpeg")
|
||||
assert output.root[0].image.path and output.root[0].image.path.endswith(
|
||||
".jpeg"
|
||||
)
|
||||
|
1
test/test_files/file_icon.svg
Normal file
1
test/test_files/file_icon.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-file"><path d="M13 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V9z"></path><polyline points="13 2 13 9 20 9"></polyline></svg>
|
After Width: | Height: | Size: 339 B |
26
test/test_image_utils.py
Normal file
26
test/test_image_utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
from gradio.image_utils import extract_svg_content
|
||||
|
||||
|
||||
def test_extract_svg_content_local_file():
|
||||
svg_path = "test/test_files/file_icon.svg"
|
||||
svg_content = extract_svg_content(svg_path)
|
||||
assert (
|
||||
svg_content
|
||||
== '<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-file"><path d="M13 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V9z"></path><polyline points="13 2 13 9 20 9"></polyline></svg>'
|
||||
)
|
||||
|
||||
|
||||
def test_extract_svg_content_from_url(monkeypatch):
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.text = "<svg>mock svg content</svg>"
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
def mock_get(*args, **kwargs):
|
||||
return MockResponse()
|
||||
|
||||
monkeypatch.setattr("httpx.get", mock_get)
|
||||
svg_content = extract_svg_content("https://example.com/test.svg")
|
||||
assert svg_content == "<svg>mock svg content</svg>"
|
Loading…
Reference in New Issue
Block a user