diff --git a/.changeset/eleven-suits-itch.md b/.changeset/eleven-suits-itch.md
new file mode 100644
index 0000000000..e38f4b3fe4
--- /dev/null
+++ b/.changeset/eleven-suits-itch.md
@@ -0,0 +1,7 @@
+---
+"@gradio/gallery": patch
+"@gradio/image": patch
+"gradio": patch
+---
+
+fix:Allow displaying SVG images securely in `gr.Image` and `gr.Gallery` components
diff --git a/gradio/components/gallery.py b/gradio/components/gallery.py
index e7a8759452..97a38af1cc 100644
--- a/gradio/components/gallery.py
+++ b/gradio/components/gallery.py
@@ -12,7 +12,7 @@ from typing import (
Optional,
Union,
)
-from urllib.parse import urlparse
+from urllib.parse import quote, urlparse
import numpy as np
import PIL.Image
@@ -21,9 +21,9 @@ from gradio_client import utils as client_utils
from gradio_client.documentation import document
from gradio_client.utils import is_http_url_like
-from gradio import processing_utils, utils, wasm_utils
+from gradio import image_utils, processing_utils, utils, wasm_utils
from gradio.components.base import Component
-from gradio.data_classes import FileData, GradioModel, GradioRootModel
+from gradio.data_classes import FileData, GradioModel, GradioRootModel, ImageData
from gradio.events import Events
from gradio.exceptions import Error
@@ -35,7 +35,7 @@ CaptionedGalleryMediaType = tuple[GalleryMediaType, str]
class GalleryImage(GradioModel):
- image: FileData
+ image: ImageData
caption: Optional[str] = None
@@ -188,7 +188,7 @@ class Gallery(Component):
if isinstance(gallery_element, GalleryVideo):
file_path = gallery_element.video.path
else:
- file_path = gallery_element.image.path
+ file_path = gallery_element.image.path or ""
if self.file_types and not client_utils.is_valid_file(
file_path, self.file_types
):
@@ -216,6 +216,10 @@ class Gallery(Component):
"""
if value is None:
return GalleryData(root=[])
+ if isinstance(value, str):
+ raise ValueError(
+ "The `value` passed into `gr.Gallery` must be a list of images or videos, or list of (media, caption) tuples."
+ )
output = []
def _save(img):
@@ -236,14 +240,20 @@ class Gallery(Component):
)
file_path = str(utils.abspath(file))
elif isinstance(img, str):
- file_path = img
- mime_type = client_utils.get_mimetype(file_path)
- if is_http_url_like(img):
+ mime_type = client_utils.get_mimetype(img)
+ if img.lower().endswith(".svg"):
+ svg_content = image_utils.extract_svg_content(img)
+ orig_name = Path(img).name
+ url = f"data:image/svg+xml,{quote(svg_content)}"
+ file_path = None
+ elif is_http_url_like(img):
url = img
orig_name = Path(urlparse(img).path).name
+ file_path = img
else:
url = None
orig_name = Path(img).name
+ file_path = img
elif isinstance(img, Path):
file_path = str(img)
orig_name = img.name
@@ -253,7 +263,7 @@ class Gallery(Component):
if mime_type is not None and "video" in mime_type:
return GalleryVideo(
video=FileData(
- path=file_path,
+ path=file_path, # type: ignore
url=url,
orig_name=orig_name,
mime_type=mime_type,
@@ -262,7 +272,7 @@ class Gallery(Component):
)
else:
return GalleryImage(
- image=FileData(
+ image=ImageData(
path=file_path,
url=url,
orig_name=orig_name,
diff --git a/gradio/components/image.py b/gradio/components/image.py
index da1a0534c7..dc7335ee00 100644
--- a/gradio/components/image.py
+++ b/gradio/components/image.py
@@ -5,18 +5,18 @@ from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Literal, Optional, cast
+from typing import TYPE_CHECKING, Any, Literal, cast
+from urllib.parse import quote
import numpy as np
import PIL.Image
from gradio_client import handle_file
from gradio_client.documentation import document
from PIL import ImageOps
-from pydantic import ConfigDict, Field
from gradio import image_utils, utils
from gradio.components.base import Component, StreamingInput
-from gradio.data_classes import GradioModel
+from gradio.data_classes import Base64ImageData, ImageData
from gradio.events import Events
from gradio.exceptions import Error
@@ -26,28 +26,6 @@ if TYPE_CHECKING:
PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843
-class ImageData(GradioModel):
- path: Optional[str] = Field(default=None, description="Path to a local file")
- url: Optional[str] = Field(
- default=None, description="Publicly available url or base64 encoded image"
- )
- size: Optional[int] = Field(default=None, description="Size of image in bytes")
- orig_name: Optional[str] = Field(default=None, description="Original filename")
- mime_type: Optional[str] = Field(default=None, description="mime type of image")
- is_stream: bool = Field(default=False, description="Can always be set to False")
- meta: dict = {"_type": "gradio.FileData"}
-
- model_config = ConfigDict(
- json_schema_extra={
- "description": "For input, either path or url must be provided. For output, path is always provided."
- }
- )
-
-
-class Base64ImageData(GradioModel):
- url: str = Field(description="base64 encoded image")
-
-
@document()
class Image(StreamingInput, Component):
"""
@@ -112,7 +90,7 @@ class Image(StreamingInput, Component):
width: The width of the component, specified in pixels if a number is passed, or in CSS units if a string is passed. This has no effect on the preprocessed image file or numpy array, but will affect the displayed image.
image_mode: The pixel format and color depth that the image should be loaded and preprocessed as. "RGB" will load the image as a color image, or "L" as black-and-white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file type (e.g. "RGBA" for a .png image, "RGB" in most other cases).
sources: List of sources for the image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "clipboard" allows users to paste an image from the clipboard. If None, defaults to ["upload", "webcam", "clipboard"] if streaming is False, otherwise defaults to ["webcam"].
- type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to "filepath" or "pil".
+ type: The format the image is converted before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image. To support animated GIFs in input, the `type` should be set to "filepath" or "pil". To support SVGs, the `type` should be set to "filepath".
label: the label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
@@ -198,7 +176,7 @@ class Image(StreamingInput, Component):
Parameters:
payload: image data in the form of a FileData object
Returns:
- Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`. For SVGs, the `type` parameter is ignored and the filepath of the SVG is returned.
+ Passes the uploaded image as a `numpy.array`, `PIL.Image` or `str` filepath depending on `type`.
"""
if payload is None:
return payload
@@ -227,7 +205,7 @@ class Image(StreamingInput, Component):
if suffix.lower() == "svg":
if self.type == "filepath":
return str(file_path)
- raise Error("SVG files are not supported as input images.")
+ raise Error("SVG files are not supported as input images for this app.")
im = PIL.Image.open(file_path)
if self.type == "filepath" and (self.image_mode in [None, im.mode]):
@@ -267,7 +245,11 @@ class Image(StreamingInput, Component):
if value is None:
return None
if isinstance(value, str) and value.lower().endswith(".svg"):
- return ImageData(path=value, orig_name=Path(value).name)
+ svg_content = image_utils.extract_svg_content(value)
+ return ImageData(
+ orig_name=Path(value).name,
+ url=f"data:image/svg+xml,{quote(svg_content)}",
+ )
if self.streaming:
if isinstance(value, np.ndarray):
return Base64ImageData(
diff --git a/gradio/data_classes.py b/gradio/data_classes.py
index a04ea90c0d..c324be3001 100644
--- a/gradio/data_classes.py
+++ b/gradio/data_classes.py
@@ -24,6 +24,8 @@ from gradio_client.documentation import document
from gradio_client.utils import is_file_obj_with_meta, traverse
from pydantic import (
BaseModel,
+ ConfigDict,
+ Field,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
RootModel,
@@ -391,3 +393,25 @@ class MediaStreamChunk(TypedDict):
duration: float
extension: str
id: NotRequired[str]
+
+
+class ImageData(GradioModel):
+ path: Optional[str] = Field(default=None, description="Path to a local file")
+ url: Optional[str] = Field(
+ default=None, description="Publicly available url or base64 encoded image"
+ )
+ size: Optional[int] = Field(default=None, description="Size of image in bytes")
+ orig_name: Optional[str] = Field(default=None, description="Original filename")
+ mime_type: Optional[str] = Field(default=None, description="mime type of image")
+ is_stream: bool = Field(default=False, description="Can always be set to False")
+ meta: dict = {"_type": "gradio.FileData"}
+
+ model_config = ConfigDict(
+ json_schema_extra={
+ "description": "For input, either path or url must be provided. For output, path is always provided."
+ }
+ )
+
+
+class Base64ImageData(GradioModel):
+ url: str = Field(description="base64 encoded image")
diff --git a/gradio/image_utils.py b/gradio/image_utils.py
index a0a40efcfe..f71fdf8c3f 100644
--- a/gradio/image_utils.py
+++ b/gradio/image_utils.py
@@ -5,9 +5,10 @@ from io import BytesIO
from pathlib import Path
from typing import Literal, cast
+import httpx
import numpy as np
import PIL.Image
-from gradio_client.utils import get_mimetype
+from gradio_client.utils import get_mimetype, is_http_url_like
from PIL import ImageOps
from gradio import processing_utils
@@ -152,3 +153,22 @@ def encode_image_file_to_base64(image_file: str | Path) -> str:
bytes_data = f.read()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return f"data:{mime_type};base64," + base64_str
+
+
+def extract_svg_content(image_file: str | Path) -> str:
+ """
+ Provided a path or URL to an SVG file, return the SVG content as a string.
+ Parameters:
+ image_file: Local file path or URL to an SVG file
+ Returns:
+ str: The SVG content as a string
+ """
+ image_file = str(image_file)
+ if is_http_url_like(image_file):
+ response = httpx.get(image_file)
+ response.raise_for_status() # Raise an error for bad status codes
+ return response.text
+ else:
+ with open(image_file) as file:
+ svg_content = file.read()
+ return svg_content
diff --git a/js/gallery/Index.svelte b/js/gallery/Index.svelte
index c23f0a7d23..623c61f362 100644
--- a/js/gallery/Index.svelte
+++ b/js/gallery/Index.svelte
@@ -4,6 +4,7 @@
gradio.client.upload(...args)}
stream_handler={(...args) => gradio.client.stream(...args)}
- on:upload={(e) => {
+ on:upload={async (e) => {
const files = Array.isArray(e.detail) ? e.detail : [e.detail];
- value = files.map((x) =>
- x.mime_type?.includes("video")
- ? { video: x, caption: null }
- : { image: x, caption: null }
- );
+ value = await process_upload_files(files);
gradio.dispatch("upload", value);
}}
on:error={({ detail }) => {
diff --git a/js/image/shared/ImageUploader.svelte b/js/image/shared/ImageUploader.svelte
index 415df761c8..b1e6cab946 100644
--- a/js/image/shared/ImageUploader.svelte
+++ b/js/image/shared/ImageUploader.svelte
@@ -45,10 +45,20 @@
export let webcam_constraints: { [key: string]: any } | undefined = undefined;
- function handle_upload({ detail }: CustomEvent): void {
- // only trigger streaming event if streaming
+ async function handle_upload({
+ detail
+ }: CustomEvent): Promise {
if (!streaming) {
- value = detail;
+ if (detail.path?.toLowerCase().endsWith(".svg") && detail.url) {
+ const response = await fetch(detail.url);
+ const svgContent = await response.text();
+ value = {
+ ...detail,
+ url: `data:image/svg+xml,${encodeURIComponent(svgContent)}`
+ };
+ } else {
+ value = detail;
+ }
dispatch("upload");
}
}
diff --git a/test/components/test_gallery.py b/test/components/test_gallery.py
index eac9ce763a..a5d79a6a77 100644
--- a/test/components/test_gallery.py
+++ b/test/components/test_gallery.py
@@ -5,7 +5,7 @@ import PIL
import gradio as gr
from gradio.components.gallery import GalleryImage
-from gradio.data_classes import FileData
+from gradio.data_classes import ImageData
class TestGallery:
@@ -96,7 +96,7 @@ class TestGallery:
from gradio.components.gallery import GalleryData, GalleryImage
gallery = gr.Gallery()
- img = GalleryImage(image=FileData(path="test/test_files/bus.png"))
+ img = GalleryImage(image=ImageData(path="test/test_files/bus.png"))
data = GalleryData(root=[img])
assert (preprocessed := gallery.preprocess(data))
@@ -115,7 +115,7 @@ class TestGallery:
)
img_captions = GalleryImage(
- image=FileData(path="test/test_files/bus.png"), caption="bus"
+ image=ImageData(path="test/test_files/bus.png"), caption="bus"
)
data = GalleryData(root=[img_captions])
assert (preprocess := gr.Gallery().preprocess(data))
@@ -127,4 +127,6 @@ class TestGallery:
[np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)]
)
if isinstance(output.root[0], GalleryImage):
- assert output.root[0].image.path.endswith(".jpeg")
+ assert output.root[0].image.path and output.root[0].image.path.endswith(
+ ".jpeg"
+ )
diff --git a/test/test_files/file_icon.svg b/test/test_files/file_icon.svg
new file mode 100644
index 0000000000..8855359467
--- /dev/null
+++ b/test/test_files/file_icon.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/test/test_image_utils.py b/test/test_image_utils.py
new file mode 100644
index 0000000000..7666c3792d
--- /dev/null
+++ b/test/test_image_utils.py
@@ -0,0 +1,26 @@
+from gradio.image_utils import extract_svg_content
+
+
+def test_extract_svg_content_local_file():
+ svg_path = "test/test_files/file_icon.svg"
+ svg_content = extract_svg_content(svg_path)
+ assert (
+ svg_content
+ == ''
+ )
+
+
+def test_extract_svg_content_from_url(monkeypatch):
+ class MockResponse:
+ def __init__(self):
+ self.text = ""
+
+ def raise_for_status(self):
+ pass
+
+ def mock_get(*args, **kwargs):
+ return MockResponse()
+
+ monkeypatch.setattr("httpx.get", mock_get)
+ svg_content = extract_svg_content("https://example.com/test.svg")
+ assert svg_content == ""