Allows sources to be a string for gr.Image (#6378)

* fixes

* components

* add changeset

* sources

* video

* add changeset

* fixes

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-11-11 15:18:02 -08:00 committed by GitHub
parent c55f927f5b
commit d31d8c6ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 18 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
fix:Allows `sources` to be a string for `gr.Image`

View File

@ -108,19 +108,21 @@ class Audio(
waveform_options: A dictionary of options for the waveform display. Options include: waveform_color (str), waveform_progress_color (str), show_controls (bool), skip_length (int). Default is None, which uses the default values for these options.
"""
valid_sources: list[Literal["upload", "microphone"]] = ["upload", "microphone"]
if sources is None:
sources = ["microphone"] if streaming else valid_sources
self.sources = ["microphone"] if streaming else valid_sources
elif isinstance(sources, str) and sources in valid_sources:
sources = [sources]
self.sources = [sources]
elif isinstance(sources, list):
pass
self.sources = sources
else:
raise ValueError(
f"`sources` must be a list consisting of elements in {valid_sources}"
)
self.sources = sources
for source in self.sources:
if source not in valid_sources:
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
valid_types = ["numpy", "filepath"]
if type not in valid_types:
raise ValueError(

View File

@ -116,8 +116,6 @@ class Image(StreamingInput, Component):
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
self.sources = sources
self.streaming = streaming
self.show_download_button = show_download_button
if streaming and self.sources != ["webcam"]:

View File

@ -111,27 +111,29 @@ class Video(Component):
min_length: The minimum length of video (in seconds) that the user can pass into the prediction function. If None, there is no minimum length.
max_length: The maximum length of video (in seconds) that the user can pass into the prediction function. If None, there is no maximum length.
"""
self.format = format
self.autoplay = autoplay
valid_sources: list[Literal["upload", "webcam"]] = ["webcam", "upload"]
if sources is None:
sources = valid_sources
self.sources = valid_sources
elif isinstance(sources, str) and sources in valid_sources:
sources = [sources]
self.sources = [sources]
elif isinstance(sources, list):
pass
self.sources = sources
else:
raise ValueError(
f"`sources` must be a list consisting of elements in {valid_sources}"
)
self.sources = sources
for source in self.sources:
if source not in valid_sources:
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
self.format = format
self.autoplay = autoplay
self.height = height
self.width = width
self.mirror_webcam = mirror_webcam
self.include_audio = (
include_audio if include_audio is not None else "upload" in sources
include_audio if include_audio is not None else "upload" in self.sources
)
self.show_share_button = (
(utils.get_space() is not None)

View File

@ -572,7 +572,7 @@ class TestImage:
image_input = gr.Image(type="pil", label="Upload Your Image")
assert image_input.get_config() == {
"image_mode": "RGB",
"sources": None,
"sources": ["upload", "webcam", "clipboard"],
"name": "image",
"show_share_button": False,
"show_download_button": True,
@ -604,6 +604,8 @@ class TestImage:
with pytest.raises(ValueError):
gr.Image(type="unknown")
string_source = gr.Image(sources="upload")
assert string_source.sources == ["upload"]
# Output functionalities
image_output = gr.Image(type="pil")
processed_image = image_output.postprocess(