modify preprocess to use pydantic models (#6181)

* modify preprocess to use pydantic models

* changes

* add changeset

* fix

* fix

* fix typing

* save

* revert queuing changes

* fix

* fix

* notebook

* fix

* changes

* add changeset

* fix functional tests

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-10-31 06:48:10 -07:00 committed by GitHub
parent e16b4abc37
commit 62ec2075cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 488 additions and 620 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/uploadbutton": minor
"gradio": minor
---
feat:modify preprocess to use pydantic models

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -51,4 +51,4 @@ with gr.Blocks() as demo:
demo.queue()
if __name__ == "__main__":
demo.launch()
demo.launch(allowed_paths=["avatar.png"])

View File

@ -153,6 +153,7 @@ def evaluate_values(*args):
are_false.append(a == "#000000")
else:
are_false.append(not a)
print(args)
return all(are_false)

View File

@ -341,7 +341,7 @@ class BlockFunction:
self.postprocess = postprocess
self.tracks_progress = tracks_progress
self.concurrency_limit = concurrency_limit
self.concurrency_id = concurrency_id or id(fn)
self.concurrency_id = concurrency_id or str(id(fn))
self.batch = batch
self.max_batch_size = max_batch_size
self.total_runtime = 0
@ -1260,6 +1260,14 @@ Received inputs:
inputs_cached = processing_utils.move_files_to_cache(
inputs[i], block
)
if getattr(block, "data_model", None) and inputs_cached is not None:
if issubclass(block.data_model, GradioModel): # type: ignore
print("block.data_model", block.data_model, block)
print("1inputs_cached", inputs_cached)
inputs_cached = block.data_model(**inputs_cached) # type: ignore
elif issubclass(block.data_model, GradioRootModel): # type: ignore
print("2inputs_cached", inputs_cached)
inputs_cached = block.data_model(root=inputs_cached) # type: ignore
processed_input.append(block.preprocess(inputs_cached))
else:
processed_input = inputs

View File

@ -103,20 +103,21 @@ class AnnotatedImage(Component):
def postprocess(
self,
y: tuple[
value: tuple[
np.ndarray | _Image.Image | str,
list[tuple[np.ndarray | tuple[int, int, int, int], str]],
],
]
| None,
) -> AnnotatedImageData | None:
"""
Parameters:
y: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
value: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
Returns:
Tuple of base image file and list of subsections, with each subsection a two-part tuple where the first element image path of the mask, and the second element is the label.
"""
if y is None:
if value is None:
return None
base_img = y[0]
base_img = value[0]
if isinstance(base_img, str):
base_img_path = base_img
base_img = np.array(_Image.open(base_img))
@ -144,7 +145,7 @@ class AnnotatedImage(Component):
lv = len(value)
return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)]
for mask, label in y[1]:
for mask, label in value[1]:
mask_array = np.zeros((base_img.shape[0], base_img.shape[1]))
if isinstance(mask, np.ndarray):
mask_array = mask
@ -188,5 +189,7 @@ class AnnotatedImage(Component):
def example_inputs(self) -> Any:
return {}
def preprocess(self, x: Any) -> Any:
return x
def preprocess(
self, payload: AnnotatedImageData | None
) -> AnnotatedImageData | None:
return payload

View File

@ -162,20 +162,12 @@ class Audio(
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
def preprocess(
self, x: dict[str, Any] | None
self, payload: FileData | None
) -> tuple[int, np.ndarray] | str | None:
"""
Parameters:
x: dictionary with keys "path", "crop_min", "crop_max".
Returns:
audio in requested format
"""
if x is None:
return x
if payload is None:
return payload
payload: FileData = FileData(**x)
assert payload.path
# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice
temp_file_path = Path(payload.path)
@ -211,50 +203,50 @@ class Audio(
)
def postprocess(
self, y: tuple[int, np.ndarray] | str | Path | bytes | None
) -> FileData | None | bytes:
self, value: tuple[int, np.ndarray] | str | Path | bytes | None
) -> FileData | bytes | None:
"""
Parameters:
y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
value: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
Returns:
base64 url data
"""
if y is None:
if value is None:
return None
if isinstance(y, bytes):
if isinstance(value, bytes):
if self.streaming:
return y
return value
file_path = processing_utils.save_bytes_to_cache(
y, "audio", cache_dir=self.GRADIO_CACHE
value, "audio", cache_dir=self.GRADIO_CACHE
)
elif isinstance(y, tuple):
sample_rate, data = y
elif isinstance(value, tuple):
sample_rate, data = value
file_path = processing_utils.save_audio_to_cache(
data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE
)
else:
if not isinstance(y, (str, Path)):
raise ValueError(f"Cannot process {y} as Audio")
file_path = str(y)
if not isinstance(value, (str, Path)):
raise ValueError(f"Cannot process {value} as Audio")
file_path = str(value)
return FileData(path=file_path)
def stream_output(
self, y, output_id: str, first_chunk: bool
self, value, output_id: str, first_chunk: bool
) -> tuple[bytes | None, Any]:
output_file = {
"path": output_id,
"is_stream": True,
}
if y is None:
if value is None:
return None, output_file
if isinstance(y, bytes):
return y, output_file
if client_utils.is_http_url_like(y["path"]):
response = requests.get(y["path"])
if isinstance(value, bytes):
return value, output_file
if client_utils.is_http_url_like(value["path"]):
response = requests.get(value["path"])
binary_data = response.content
else:
output_file["orig_name"] = y["orig_name"]
file_path = y["path"]
output_file["orig_name"] = value["orig_name"]
file_path = value["path"]
is_wav = file_path.endswith(".wav")
with open(file_path, "rb") as f:
binary_data = f.read()

View File

@ -258,15 +258,15 @@ class BarPlot(Plot):
return chart
def postprocess(
self, y: pd.DataFrame | dict | None
self, value: pd.DataFrame | dict | None
) -> AltairPlotData | dict | None:
# if None or update
if y is None or isinstance(y, dict):
return y
if value is None or isinstance(value, dict):
return value
if self.x is None or self.y is None:
raise ValueError("No value provided for required parameters `x` and `y`.")
chart = self.create_plot(
value=y,
value=value,
x=self.x,
y=self.y,
color=self.color,
@ -288,12 +288,10 @@ class BarPlot(Plot):
sort=self.sort, # type: ignore
)
return AltairPlotData(
**{"type": "altair", "plot": chart.to_json(), "chart": "bar"}
)
return AltairPlotData(type="altair", plot=chart.to_json(), chart="bar")
def example_inputs(self) -> dict[str, Any]:
return {}
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: AltairPlotData) -> AltairPlotData:
return payload

View File

@ -48,21 +48,21 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
EVENTS: list[EventListener | str] = []
@abstractmethod
def preprocess(self, x: Any) -> Any:
def preprocess(self, payload: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
return x
return payload
@abstractmethod
def postprocess(self, y):
def postprocess(self, value):
"""
Any postprocessing needed to be performed on function output.
"""
return y
return value
@abstractmethod
def as_example(self, y):
def as_example(self, value):
"""
Return the input data in a way that can be displayed by the examples dataset component in the front-end.
@ -88,7 +88,7 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
pass
@abstractmethod
def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
def flag(self, payload: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
"""
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
"""
@ -97,13 +97,13 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
@abstractmethod
def read_from_flag(
self,
x: Any,
payload: Any,
flag_dir: str | Path | None = None,
) -> GradioDataModel | Any:
"""
Convert the data from the csv or jsonl file into the component state.
"""
return x
return payload
@property
@abstractmethod
@ -267,26 +267,26 @@ class Component(ComponentBase, Block):
f"The api_info method has not been implemented for {self.get_block_name()}"
)
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
"""
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
"""
if self.data_model:
x = self.data_model.from_json(x)
return x.copy_to_dir(flag_dir).model_dump_json()
return x
payload = self.data_model.from_json(payload)
return payload.copy_to_dir(flag_dir).model_dump_json()
return payload
def read_from_flag(
self,
x: Any,
payload: Any,
flag_dir: str | Path | None = None,
):
"""
Convert the data from the csv or jsonl file into the component state.
"""
if self.data_model:
return self.data_model.from_json(json.loads(x))
return x
return self.data_model.from_json(json.loads(payload))
return payload
class FormComponent(Component):
@ -295,11 +295,11 @@ class FormComponent(Component):
return None
return Form
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: Any) -> Any:
return payload
def postprocess(self, y):
return y
def postprocess(self, value):
return value
class StreamingOutput(metaclass=abc.ABCMeta):
@ -308,7 +308,9 @@ class StreamingOutput(metaclass=abc.ABCMeta):
self.streaming: bool
@abc.abstractmethod
def stream_output(self, y, output_id: str, first_chunk: bool) -> tuple[bytes, Any]:
def stream_output(
self, value, output_id: str, first_chunk: bool
) -> tuple[bytes, Any]:
pass

View File

@ -77,11 +77,11 @@ class Button(Component):
def skip_api(self):
return True
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: str) -> str:
return payload
def postprocess(self, y):
return y
def postprocess(self, value: str) -> str:
return value
def example_inputs(self) -> Any:
return None

View File

@ -14,8 +14,6 @@ from gradio.components.base import Component
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.events import Events
# from pydantic import Field, TypeAdapter
set_documentation_group("component")
@ -129,26 +127,28 @@ class Chatbot(Component):
)
def _preprocess_chat_messages(
self, chat_message: str | dict | None
) -> str | tuple[str] | tuple[str, str] | None:
self, chat_message: str | FileMessage | None
) -> str | tuple[str | None] | tuple[str | None, str] | None:
if chat_message is None:
return None
elif isinstance(chat_message, dict):
if chat_message.get("alt_text"):
return (chat_message["file"]["path"], chat_message["alt_text"])
elif isinstance(chat_message, FileMessage):
if chat_message.alt_text is not None:
return (chat_message.file.path, chat_message.alt_text)
else:
return (chat_message["file"]["path"],)
else: # string
return (chat_message.file.path,)
elif isinstance(chat_message, str):
return chat_message
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
def preprocess(
self,
y: list[list[str | dict | None] | tuple[str | dict | None, str | dict | None]],
payload: ChatbotData,
) -> list[list[str | tuple[str] | tuple[str, str] | None]]:
if y is None:
return y
if payload is None:
return payload
processed_messages = []
for message_pair in y:
for message_pair in payload.root:
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"
@ -186,18 +186,12 @@ class Chatbot(Component):
def postprocess(
self,
y: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
) -> ChatbotData:
"""
Parameters:
y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string or pathlib.Path filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
Returns:
List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
"""
if y is None:
if value is None:
return ChatbotData(root=[])
processed_messages = []
for message_pair in y:
for message_pair in value:
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"

View File

@ -79,3 +79,9 @@ class Checkbox(FormComponent):
def example_inputs(self) -> bool:
return True
def preprocess(self, payload: bool | None) -> bool | None:
return payload
def postprocess(self, value: bool | None) -> bool | None:
return value

View File

@ -101,21 +101,15 @@ class CheckboxGroup(FormComponent):
}
def preprocess(
self, x: list[str | int | float]
self, payload: list[str | int | float]
) -> list[str | int | float] | list[int | None]:
"""
Parameters:
x: list of selected choices
Returns:
list of selected choice values as strings or indices within choice list
"""
if self.type == "value":
return x
return payload
elif self.type == "index":
choice_values = [value for _, value in self.choices]
return [
choice_values.index(choice) if choice in choice_values else None
for choice in x
for choice in payload
]
else:
raise ValueError(
@ -123,19 +117,13 @@ class CheckboxGroup(FormComponent):
)
def postprocess(
self, y: list[str | int | float] | str | int | float | None
self, value: list[str | int | float] | str | int | float | None
) -> list[str | int | float]:
"""
Parameters:
y: List of selected choice values. If a single choice is selected, it can be passed in as a string
Returns:
List of selected choices
"""
if y is None:
if value is None:
return []
if not isinstance(y, list):
y = [y]
return y
if not isinstance(value, list):
value = [value]
return value
def as_example(self, input_data):
if input_data is None:

View File

@ -75,16 +75,17 @@ class ClearButton(Button):
none = component.postprocess(None)
if isinstance(none, (GradioModel, GradioRootModel)):
none = none.model_dump()
print(none)
none_values.append(none)
clear_values = json.dumps(none_values)
self.click(None, [], components, _js=f"() => {clear_values}")
return self
def postprocess(self, y):
return y
def postprocess(self, value: str | None) -> str | None:
return value
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: str | None) -> str | None:
return payload
def example_inputs(self) -> Any:
return None

View File

@ -105,20 +105,20 @@ class Code(Component):
value=value,
)
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: Any) -> Any:
return payload
def postprocess(self, y: tuple | str | None) -> None | str:
if y is None:
def postprocess(self, value: tuple | str | None) -> None | str:
if value is None:
return None
elif isinstance(y, tuple):
with open(y[0]) as file_data:
elif isinstance(value, tuple):
with open(value[0]) as file_data:
return file_data.read()
else:
return y.strip()
return value.strip()
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
return super().flag(x, flag_dir)
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
return super().flag(payload, flag_dir)
def api_info(self) -> dict[str, Any]:
return {"type": "string"}

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable
from gradio_client.documentation import document, set_documentation_group
@ -77,37 +76,17 @@ class ColorPicker(Component):
def example_inputs(self) -> str:
return "#000000"
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
return x
def read_from_flag(self, x: Any, flag_dir: str | Path | None = None):
return x
def api_info(self) -> dict[str, Any]:
return {"type": "string"}
def preprocess(self, x: str | None) -> str | None:
"""
Any preprocessing needed to be performed on function input.
Parameters:
x: text
Returns:
text
"""
if x is None:
def preprocess(self, payload: str | None) -> str | None:
if payload is None:
return None
else:
return str(x)
return str(payload)
def postprocess(self, y: str | None) -> str | None:
"""
Any postprocessing needed to be performed on function output.
Parameters:
y: text
Returns:
text
"""
if y is None:
def postprocess(self, value: str | None) -> str | None:
if value is None:
return None
else:
return str(y)
return str(value)

View File

@ -162,23 +162,16 @@ class Dataframe(Component):
value=value,
)
def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list:
"""
Parameters:
x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys
Returns:
The Dataframe data in requested format
"""
value = DataframeData(**x)
def preprocess(self, payload: DataframeData) -> pd.DataFrame | np.ndarray | list:
if self.type == "pandas":
if value.headers is not None:
return pd.DataFrame(value.data, columns=value.headers)
if payload.headers is not None:
return pd.DataFrame(payload.data, columns=payload.headers)
else:
return pd.DataFrame(value.data)
return pd.DataFrame(payload.data)
if self.type == "numpy":
return np.array(value.data)
return np.array(payload.data)
elif self.type == "array":
return value.data
return payload.data
else:
raise ValueError(
"Unknown type: "
@ -188,26 +181,27 @@ class Dataframe(Component):
def postprocess(
self,
y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None,
value: pd.DataFrame
| Styler
| np.ndarray
| list
| list[list]
| dict
| str
| None,
) -> DataframeData | dict:
"""
Parameters:
y: dataframe in given format
Returns:
JSON object with key 'headers' for list of header names, 'data' for 2D array of string or numeric data
"""
if y is None:
if value is None:
return self.postprocess(self.empty_input)
if isinstance(y, dict):
return y
if isinstance(y, (str, pd.DataFrame)):
if isinstance(y, str):
y = pd.read_csv(y) # type: ignore
if isinstance(value, dict):
return value
if isinstance(value, (str, pd.DataFrame)):
if isinstance(value, str):
value = pd.read_csv(value) # type: ignore
return DataframeData(
headers=list(y.columns), # type: ignore
data=y.to_dict(orient="split")["data"], # type: ignore
headers=list(value.columns), # type: ignore
data=value.to_dict(orient="split")["data"], # type: ignore
)
elif isinstance(y, Styler):
elif isinstance(value, Styler):
if semantic_version.Version(pd.__version__) < semantic_version.Version(
"1.5.0"
):
@ -218,39 +212,38 @@ class Dataframe(Component):
warnings.warn(
"Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
)
df: pd.DataFrame = y.data # type: ignore
value = DataframeData(
df: pd.DataFrame = value.data # type: ignore
return DataframeData(
headers=list(df.columns),
data=df.to_dict(orient="split")["data"], # type: ignore
metadata=self.__extract_metadata(y),
metadata=self.__extract_metadata(value),
)
elif isinstance(y, (str, pd.DataFrame)):
df = pd.read_csv(y) if isinstance(y, str) else y # type: ignore
value = DataframeData(
elif isinstance(value, (str, pd.DataFrame)):
df = pd.read_csv(value) if isinstance(value, str) else value # type: ignore
return DataframeData(
headers=list(df.columns),
data=df.to_dict(orient="split")["data"], # type: ignore
)
elif isinstance(y, (np.ndarray, list)):
if len(y) == 0:
elif isinstance(value, (np.ndarray, list)):
if len(value) == 0:
return self.postprocess([[]])
if isinstance(y, np.ndarray):
y = y.tolist()
if not isinstance(y, list):
if isinstance(value, np.ndarray):
value = value.tolist()
if not isinstance(value, list):
raise ValueError("output cannot be converted to list")
_headers = self.headers
if len(self.headers) < len(y[0]):
if len(self.headers) < len(value[0]):
_headers: list[str] = [
*self.headers,
*[str(i) for i in range(len(self.headers) + 1, len(y[0]) + 1)],
*[str(i) for i in range(len(self.headers) + 1, len(value[0]) + 1)],
]
elif len(self.headers) > len(y[0]):
_headers = self.headers[: len(y[0])]
elif len(self.headers) > len(value[0]):
_headers = self.headers[: len(value[0])]
value = DataframeData(headers=_headers, data=y)
return DataframeData(headers=_headers, data=value)
else:
raise ValueError("Cannot process value as a Dataframe")
return value
@staticmethod
def __get_cell_style(cell_id: str, cell_styles: list[dict]) -> str:

View File

@ -121,16 +121,13 @@ class Dataset(Component):
return config
def preprocess(self, x: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
def preprocess(self, payload: int) -> int | list[list] | None:
if self.type == "index":
return x
return payload
elif self.type == "values":
return self.samples[x]
return self.samples[payload]
def postprocess(self, samples: list[list[Any]]) -> dict:
def postprocess(self, samples: list[list]) -> dict:
return {
"samples": samples,
"__type__": "update",

View File

@ -134,48 +134,48 @@ class Dropdown(FormComponent):
return self.choices[0][1] if self.choices else None
def preprocess(
self, x: str | int | float | list[str | int | float] | None
self, payload: str | int | float | list[str | int | float] | None
) -> str | int | float | list[str | int | float] | list[int | None] | None:
"""
Parameters:
x: selected choice(s)
Returns:
selected choice(s) as string or index within choice list or list of string or indices
"""
if self.type == "value":
return x
return payload
elif self.type == "index":
choice_values = [value for _, value in self.choices]
if x is None:
if payload is None:
return None
elif self.multiselect:
assert isinstance(x, list)
assert isinstance(payload, list)
return [
choice_values.index(choice) if choice in choice_values else None
for choice in x
for choice in payload
]
else:
return choice_values.index(x) if x in choice_values else None
return (
choice_values.index(payload) if payload in choice_values else None
)
else:
raise ValueError(
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
)
def _warn_if_invalid_choice(self, y):
if self.allow_custom_value or y in [value for _, value in self.choices]:
def _warn_if_invalid_choice(self, value):
if self.allow_custom_value or value in [value for _, value in self.choices]:
return
warnings.warn(
f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {y} or set allow_custom_value=True."
f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {value} or set allow_custom_value=True."
)
def postprocess(self, y):
if y is None:
def postprocess(
self, value: str | int | float | list[str | int | float] | None
) -> str | int | float | list[str | int | float] | None:
if value is None:
return None
if self.multiselect:
[self._warn_if_invalid_choice(_y) for _y in y]
if not isinstance(value, list):
value = [value]
[self._warn_if_invalid_choice(_y) for _y in value]
else:
self._warn_if_invalid_choice(y)
return y
self._warn_if_invalid_choice(value)
return value
def as_example(self, input_data):
if self.multiselect:

View File

@ -2,11 +2,11 @@ from gradio.components.base import Component
class Fallback(Component):
def preprocess(self, x):
return x
def preprocess(self, payload):
return payload
def postprocess(self, x):
return x
def postprocess(self, value):
return value
def example_inputs(self):
return {"foo": "bar"}

View File

@ -20,6 +20,12 @@ set_documentation_group("component")
class ListFiles(GradioRootModel):
root: List[FileData]
def __getitem__(self, index):
return self.root[index]
def __iter__(self):
return iter(self.root)
@document()
class File(Component):
@ -111,13 +117,12 @@ class File(Component):
self.type = type
self.height = height
def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString:
file_name = f["path"]
def _process_single_file(self, f: FileData) -> NamedString | bytes:
file_name = f.path
if self.type == "filepath":
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
return NamedString(file.name)
return NamedString(file_name)
elif self.type == "binary":
with open(file_name, "rb") as file_data:
return file_data.read()
@ -129,38 +134,25 @@ class File(Component):
)
def preprocess(
self, x: list[dict[str, Any]] | dict[str, Any] | None
self, payload: ListFiles | FileData | None
) -> bytes | NamedString | list[bytes | NamedString] | None:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
Returns:
File objects in requested format
"""
if x is None:
if payload is None:
return None
if self.file_count == "single":
if isinstance(x, list):
return self._process_single_file(x[0])
if isinstance(payload, ListFiles):
return self._process_single_file(payload[0])
else:
return self._process_single_file(x)
return self._process_single_file(payload)
else:
if isinstance(x, list):
return [self._process_single_file(f) for f in x]
if isinstance(payload, ListFiles):
return [self._process_single_file(f) for f in payload]
else:
return [self._process_single_file(x)]
return [self._process_single_file(payload)]
def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None:
"""
Parameters:
y: file path
Returns:
JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
"""
if y is None:
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
if value is None:
return None
if isinstance(y, list):
if isinstance(value, list):
return ListFiles(
root=[
FileData(
@ -168,14 +160,14 @@ class File(Component):
orig_name=Path(file).name,
size=Path(file).stat().st_size,
)
for file in y
for file in value
]
)
else:
return FileData(
path=y,
orig_name=Path(y).name,
size=Path(y).stat().st_size,
path=value,
orig_name=Path(value).name,
size=Path(value).stat().st_size,
)
def as_example(self, input_data: str | list | None) -> str:

View File

@ -104,49 +104,39 @@ class FileExplorer(Component):
def example_inputs(self) -> Any:
return ["Users", "gradio", "app.py"]
def preprocess(self, x: list[list[str]] | None) -> list[str] | str | None:
"""
Parameters:
x: File path segments as a list of list of strings for each file relative to the root.
Returns:
File path selected, as an absolute path.
"""
if x is None:
def preprocess(self, payload: list[list[str]] | None) -> list[str] | str | None:
if payload is None:
return None
if self.file_count == "single":
if len(x) > 1:
raise ValueError(f"Expected only one file, but {len(x)} were selected.")
return self._safe_join(x[0])
if len(payload) > 1:
raise ValueError(
f"Expected only one file, but {len(payload)} were selected."
)
return self._safe_join(payload[0])
return [self._safe_join(file) for file in (x)]
return [self._safe_join(file) for file in (payload)]
def _strip_root(self, path):
if path.startswith(self.root):
return path[len(self.root) + 1 :]
return path
def postprocess(self, y: str | list[str] | None) -> FileExplorerData | None:
"""
Parameters:
y: file path
Returns:
list representing filepath, where each string is a directory level relative to the root.
"""
if y is None:
def postprocess(self, value: str | list[str] | None) -> FileExplorerData | None:
if value is None:
return None
files = [y] if isinstance(y, str) else y
files = [value] if isinstance(value, str) else value
return FileExplorerData(
root=[self._strip_root(file).split(os.path.sep) for file in files]
)
@server
def ls(self, y=None) -> list[dict[str, str]] | None:
def ls(self, value=None) -> list[dict[str, str]] | None:
"""
Parameters:
y: file path as a list of strings for each directory level relative to the root.
value: file path as a list of strings for each directory level relative to the root.
Returns:
tuple of list of files in directory, then list of folders in directory
"""

View File

@ -124,20 +124,20 @@ class Gallery(Component):
def postprocess(
self,
y: list[np.ndarray | _Image.Image | str]
value: list[np.ndarray | _Image.Image | str]
| list[tuple[np.ndarray | _Image.Image | str, str]]
| None,
) -> GalleryData:
"""
Parameters:
y: list of images, or list of (image, caption) tuples
value: list of images, or list of (image, caption) tuples
Returns:
list of string file paths to images in temp directory
"""
if y is None:
if value is None:
return GalleryData(root=[])
output = []
for img in y:
for img in value:
caption = None
if isinstance(img, (tuple, list)):
img, caption = img
@ -160,8 +160,10 @@ class Gallery(Component):
output.append(entry)
return GalleryData(root=output)
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: GalleryData | None) -> GalleryData | None:
if payload is None or not payload.root:
return None
return payload
def example_inputs(self) -> Any:
return [

View File

@ -99,27 +99,27 @@ class HighlightedText(Component):
return {"value": [{"token": "Hello", "class_or_confidence": "1"}]}
def postprocess(
self, y: list[tuple[str, str | float | None]] | dict | None
self, value: list[tuple[str, str | float | None]] | dict | None
) -> HighlightedTextData | None:
"""
Parameters:
y: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end"
value: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end"
Returns:
List of (word, category) tuples
"""
if y is None:
if value is None:
return None
if isinstance(y, dict):
if isinstance(value, dict):
try:
text = y["text"]
entities = y["entities"]
text = value["text"]
entities = value["entities"]
except KeyError as ke:
raise ValueError(
"Expected a dictionary with keys 'text' and 'entities' "
"for the value of the HighlightedText component."
) from ke
if len(entities) == 0:
y = [(text, None)]
value = [(text, None)]
else:
list_format = []
index = 0
@ -132,11 +132,11 @@ class HighlightedText(Component):
)
index = entity["end"]
list_format.append((text[index:], None))
y = list_format
value = list_format
if self.combine_adjacent:
output = []
running_text, running_category = None, None
for text, category in y:
for text, category in value:
if running_text is None:
running_text = text
running_category = category
@ -160,8 +160,13 @@ class HighlightedText(Component):
)
else:
return HighlightedTextData(
root=[HighlightedToken(token=o[0], class_or_confidence=o[1]) for o in y]
root=[
HighlightedToken(token=o[0], class_or_confidence=o[1])
for o in value
]
)
def preprocess(self, x: Any) -> Any:
return super().preprocess(x)
def preprocess(self, payload: HighlightedTextData | None) -> dict | None:
if payload is None:
return None
return payload.model_dump()

View File

@ -62,11 +62,11 @@ class HTML(Component):
def example_inputs(self) -> Any:
return "<p>Hello</p>"
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: str | None) -> str | None:
return payload
def postprocess(self, y):
return y
def postprocess(self, value: str | None) -> str | None:
return value
def api_info(self) -> dict[str, Any]:
return {"type": "string"}

View File

@ -38,6 +38,7 @@ class Image(StreamingInput, Component):
Events.select,
Events.upload,
]
data_model = FileData
def __init__(
@ -141,38 +142,25 @@ class Image(StreamingInput, Component):
value=value,
)
def preprocess(self, x: dict | None) -> np.ndarray | _Image.Image | str | None:
"""
Parameters:
x: FileData containing an image path pointing to the user's image
Returns:
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
"""
if x is None:
return x
im = _Image.open(x["path"])
def preprocess(
self, payload: FileData | None
) -> np.ndarray | _Image.Image | str | None:
if payload is None:
return payload
im = _Image.open(payload.path)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
return image_utils.format_image(
im, cast(Literal["numpy", "pil", "filepath"], self.type), self.GRADIO_CACHE
)
def postprocess(
self, y: np.ndarray | _Image.Image | str | Path | None
self, value: np.ndarray | _Image.Image | str | Path | None
) -> FileData | None:
"""
Parameters:
y: image as a numpy array, PIL Image, string/Path filepath, or string URL
Returns:
base64 url data
"""
if y is None:
if value is None:
return None
return FileData(path=image_utils.save_image(y, self.GRADIO_CACHE))
return FileData(path=image_utils.save_image(value, self.GRADIO_CACHE))
def check_streamable(self):
if self.streaming and self.sources != ("webcam"):

View File

@ -69,31 +69,25 @@ class JSON(Component):
value=value,
)
def postprocess(self, y: dict | list | str | None) -> dict | list | None:
"""
Parameters:
y: either a string filepath to a JSON file, or a Python list or dict that can be converted to JSON
Returns:
JSON output in Python list or dict format
"""
if y is None:
def postprocess(self, value: dict | list | str | None) -> dict | list | None:
if value is None:
return None
if isinstance(y, str):
return json.loads(y)
if isinstance(value, str):
return json.loads(value)
else:
return y
return value
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: dict | list | str | None) -> dict | list | str | None:
return payload
def example_inputs(self) -> Any:
return {"foo": "bar"}
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
return json.dumps(x)
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
return json.dumps(payload)
def read_from_flag(self, x: Any, flag_dir: str | Path | None = None):
return json.loads(x)
def read_from_flag(self, payload: Any, flag_dir: str | Path | None = None):
return json.loads(payload)
def api_info(self) -> dict[str, Any]:
return {"type": {}, "description": "any valid json"}

View File

@ -22,7 +22,7 @@ class LabelConfidence(GradioModel):
class LabelData(GradioModel):
label: Union[str, int, float]
label: Optional[Union[str, int, float]] = None
confidences: Optional[List[LabelConfidence]] = None
@ -91,44 +91,46 @@ class Label(Component):
)
def postprocess(
self, y: dict[str, float] | str | float | None
self, value: dict[str, float] | str | float | None
) -> LabelData | dict | None:
"""
Parameters:
y: a dictionary mapping labels to confidence value, or just a string/numerical label by itself
Returns:
Object with key 'label' representing primary label, and key 'confidences' representing a list of label-confidence pairs
"""
if y is None or y == {}:
if value is None or value == {}:
return {}
if isinstance(y, str) and y.endswith(".json") and Path(y).exists():
return LabelData(**json.loads(Path(y).read_text()))
if isinstance(y, (str, float, int)):
return LabelData(label=str(y))
if isinstance(y, dict):
if "confidences" in y and isinstance(y["confidences"], dict):
y = y["confidences"]
y = {c["label"]: c["confidence"] for c in y}
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
if isinstance(value, str) and value.endswith(".json") and Path(value).exists():
return LabelData(**json.loads(Path(value).read_text()))
if isinstance(value, (str, float, int)):
return LabelData(label=str(value))
if isinstance(value, dict):
if "confidences" in value and isinstance(value["confidences"], dict):
value = value["confidences"]
value = {c["label"]: c["confidence"] for c in value}
sorted_pred = sorted(
value.items(), key=operator.itemgetter(1), reverse=True
)
if self.num_top_classes is not None:
sorted_pred = sorted_pred[: self.num_top_classes]
return LabelData(
**{
"label": sorted_pred[0][0],
"confidences": [
{"label": pred[0], "confidence": pred[1]}
for pred in sorted_pred
],
}
label=sorted_pred[0][0],
confidences=[
LabelConfidence(label=pred[0], confidence=pred[1])
for pred in sorted_pred
],
)
raise ValueError(
"The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences. "
f"Instead, got a {type(y)}"
f"Instead, got a {type(value)}"
)
def preprocess(self, x: Any) -> Any:
return x
def preprocess(
self, payload: LabelData | None
) -> dict[str, float] | str | float | None:
if payload is None:
return None
if payload.confidences is None:
return payload.label
return {
d["label"]: d["confidence"] for d in payload.model_dump()["confidences"]
}
def example_inputs(self) -> Any:
return {

View File

@ -287,15 +287,15 @@ class LinePlot(Plot):
return chart
def postprocess(
self, y: pd.DataFrame | dict | None
self, value: pd.DataFrame | dict | None
) -> AltairPlotData | dict | None:
# if None or update
if y is None or isinstance(y, dict):
return y
if value is None or isinstance(value, dict):
return value
if self.x is None or self.y is None:
raise ValueError("No value provided for required parameters `x` and `y`.")
chart = self.create_plot(
value=y,
value=value,
x=self.x,
y=self.y,
color=self.color,
@ -318,12 +318,10 @@ class LinePlot(Plot):
width=self.width,
)
return AltairPlotData(
**{"type": "altair", "plot": chart.to_json(), "chart": "line"}
)
return AltairPlotData(type="altair", plot=chart.to_json(), chart="line")
def example_inputs(self) -> Any:
return None
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, value: AltairPlotData | None) -> AltairPlotData | None:
return value

View File

@ -75,24 +75,18 @@ class Markdown(Component):
value=value,
)
def postprocess(self, y: str | None) -> str | None:
"""
Parameters:
y: markdown representation
Returns:
HTML rendering of markdown
"""
if y is None:
def postprocess(self, value: str | None) -> str | None:
if value is None:
return None
unindented_y = inspect.cleandoc(y)
unindented_y = inspect.cleandoc(value)
return unindented_y
def as_example(self, input_data: str | None) -> str:
postprocessed = self.postprocess(input_data)
return postprocessed if postprocessed else ""
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: str | None) -> str | None:
return payload
def example_inputs(self) -> Any:
return "# Hello!"

View File

@ -93,27 +93,15 @@ class Model3D(Component):
value=value,
)
def preprocess(self, x: dict[str, str] | None) -> str | None:
"""
Parameters:
x: JSON object with filename as 'name' property and base64 data as 'data' property
Returns:
string file path to temporary file with the 3D image model
"""
if x is None:
return x
return x["path"]
def preprocess(self, payload: FileData | None) -> str | None:
if payload is None:
return payload
return payload.path
def postprocess(self, y: str | Path | None) -> FileData | None:
"""
Parameters:
y: path to the model
Returns:
file name mapped to base64 url data
"""
if y is None:
return y
return FileData(path=str(y))
def postprocess(self, value: str | Path | None) -> FileData | None:
if value is None:
return value
return FileData(path=str(value))
def as_example(self, input_data: str | None) -> str:
return Path(input_data).name if input_data else ""

View File

@ -108,33 +108,21 @@ class Number(FormComponent):
else:
return round(num, precision)
def preprocess(self, x: float | None) -> float | None:
"""
Parameters:
x: numeric input
Returns:
number representing function input
"""
if x is None:
def preprocess(self, payload: float | None) -> float | None:
if payload is None:
return None
elif self.minimum is not None and x < self.minimum:
raise Error(f"Value {x} is less than minimum value {self.minimum}.")
elif self.maximum is not None and x > self.maximum:
raise Error(f"Value {x} is greater than maximum value {self.maximum}.")
return self._round_to_precision(x, self.precision)
elif self.minimum is not None and payload < self.minimum:
raise Error(f"Value {payload} is less than minimum value {self.minimum}.")
elif self.maximum is not None and payload > self.maximum:
raise Error(
f"Value {payload} is greater than maximum value {self.maximum}."
)
return self._round_to_precision(payload, self.precision)
def postprocess(self, y: float | None) -> float | None:
"""
Any postprocessing needed to be performed on function output.
Parameters:
y: numeric output
Returns:
number representing function output
"""
if y is None:
def postprocess(self, value: float | None) -> float | None:
if value is None:
return None
return self._round_to_precision(y, self.precision)
return self._round_to_precision(value, self.precision)
def api_info(self) -> dict[str, str]:
return {"type": "number"}

View File

@ -97,36 +97,30 @@ class Plot(Component):
config["bokeh_version"] = bokeh_version
return config
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: PlotData | None) -> PlotData | None:
return payload
def example_inputs(self) -> Any:
return None
def postprocess(self, y) -> PlotData | None:
"""
Parameters:
y: plot data
Returns:
plot type mapped to plot base64 data
"""
def postprocess(self, value) -> PlotData | None:
import matplotlib.figure
if y is None:
if value is None:
return None
if isinstance(y, (ModuleType, matplotlib.figure.Figure)): # type: ignore
if isinstance(value, (ModuleType, matplotlib.figure.Figure)): # type: ignore
dtype = "matplotlib"
out_y = processing_utils.encode_plot_to_base64(y)
elif "bokeh" in y.__module__:
out_y = processing_utils.encode_plot_to_base64(value)
elif "bokeh" in value.__module__:
dtype = "bokeh"
from bokeh.embed import json_item # type: ignore
out_y = json.dumps(json_item(y))
out_y = json.dumps(json_item(value))
else:
is_altair = "altair" in y.__module__
is_altair = "altair" in value.__module__
dtype = "altair" if is_altair else "plotly"
out_y = y.to_json()
return PlotData(**{"type": dtype, "plot": out_y})
out_y = value.to_json()
return PlotData(type=dtype, plot=out_y)
class AltairPlot:

View File

@ -94,28 +94,30 @@ class Radio(FormComponent):
def example_inputs(self) -> Any:
return self.choices[0][1] if self.choices else None
def preprocess(self, x: str | int | float | None) -> str | int | float | None:
def preprocess(self, payload: str | int | float | None) -> str | int | float | None:
"""
Parameters:
x: selected choice
payload: selected choice
Returns:
value of the selected choice as string or index within choice list
"""
if self.type == "value":
return x
return payload
elif self.type == "index":
if x is None:
if payload is None:
return None
else:
choice_values = [value for _, value in self.choices]
return choice_values.index(x) if x in choice_values else None
return (
choice_values.index(payload) if payload in choice_values else None
)
else:
raise ValueError(
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
)
def postprocess(self, y):
return y
def postprocess(self, value: str | int | float | None) -> str | int | float | None:
return value
def api_info(self) -> dict[str, Any]:
return {

View File

@ -310,15 +310,15 @@ class ScatterPlot(Plot):
return chart
def postprocess(
self, y: pd.DataFrame | dict | None
self, value: pd.DataFrame | dict | None
) -> AltairPlotData | dict | None:
# if None or update
if y is None or isinstance(y, dict):
return y
if value is None or isinstance(value, dict):
return value
if self.x is None or self.y is None:
raise ValueError("No value provided for required parameters `x` and `y`.")
chart = self.create_plot(
value=y,
value=value,
x=self.x,
y=self.y,
color=self.color,
@ -343,12 +343,10 @@ class ScatterPlot(Plot):
y_lim=self.y_lim,
)
return AltairPlotData(
**{"type": "altair", "plot": chart.to_json(), "chart": "scatter"}
)
return AltairPlotData(type="altair", plot=chart.to_json(), chart="scatter")
def example_inputs(self) -> Any:
return None
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None:
return payload

View File

@ -114,15 +114,8 @@ class Slider(FormComponent):
value = round(value, n_decimals)
return value
def postprocess(self, y: float | None) -> float | None:
"""
Any postprocessing needed to be performed on function output.
Parameters:
y: numeric output
Returns:
numeric output or minimum number if None
"""
return self.minimum if y is None else y
def postprocess(self, value: float | None) -> float:
return self.minimum if value is None else value
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: float) -> float:
return payload

View File

@ -46,11 +46,11 @@ class State(Component):
) from err
super().__init__(value=self.value)
def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: Any) -> Any:
return payload
def postprocess(self, y):
return y
def postprocess(self, value: Any) -> Any:
return value
def api_info(self) -> dict[str, Any]:
return {"type": {}, "description": "any valid json"}

View File

@ -117,25 +117,11 @@ class Textbox(FormComponent):
self.rtl = rtl
self.text_align = text_align
def preprocess(self, x: str | None) -> str | None:
"""
Preprocesses input (converts it to a string) before passing it to the function.
Parameters:
x: text
Returns:
text
"""
return None if x is None else str(x)
def preprocess(self, payload: str | None) -> str | None:
return None if payload is None else str(payload)
def postprocess(self, y: str | None) -> str | None:
"""
Postproccess the function output y by converting it to a str before passing it to the frontend.
Parameters:
y: function output to postprocess.
Returns:
text
"""
return None if y is None else str(y)
def postprocess(self, value: str | None) -> str | None:
return None if value is None else str(value)
def api_info(self) -> dict[str, Any]:
return {"type": "string"}

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import tempfile
import warnings
from pathlib import Path
from typing import Any, Callable, List, Literal
from gradio_client.documentation import document, set_documentation_group
@ -19,6 +20,12 @@ set_documentation_group("component")
class ListFiles(GradioRootModel):
root: List[FileData]
def __getitem__(self, index):
return self.root[index]
def __iter__(self):
return iter(self.root)
@document()
class UploadButton(Component):
@ -87,6 +94,10 @@ class UploadButton(Component):
raise ValueError(
f"Parameter file_types must be a list. Received {file_types.__class__.__name__}"
)
if self.file_count == "multiple":
self.data_model = ListFiles
else:
self.data_model = FileData
self.size = size
self.file_types = file_types
self.label = label
@ -118,12 +129,12 @@ class UploadButton(Component):
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
]
def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString:
file_name = f["path"]
def _process_single_file(self, f: FileData) -> bytes | NamedString:
file_name = f.path
if self.type == "filepath":
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
return NamedString(file.name)
return NamedString(file_name)
elif self.type == "binary":
with open(file_name, "rb") as file_data:
return file_data.read()
@ -135,30 +146,42 @@ class UploadButton(Component):
)
def preprocess(
self, x: list[dict[str, Any]] | dict[str, Any] | None
self, payload: ListFiles | FileData | None
) -> bytes | NamedString | list[bytes | NamedString] | None:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
Returns:
File objects in requested format
"""
if x is None:
if payload is None:
return None
if self.file_count == "single":
if isinstance(x, list):
return self._process_single_file(x[0])
if isinstance(payload, ListFiles):
return self._process_single_file(payload[0])
else:
return self._process_single_file(x)
return self._process_single_file(payload)
else:
if isinstance(x, list):
return [self._process_single_file(f) for f in x]
if isinstance(payload, ListFiles):
return [self._process_single_file(f) for f in payload]
else:
return [self._process_single_file(x)]
return [self._process_single_file(payload)]
def postprocess(self, y):
return super().postprocess(y)
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
if value is None:
return None
if isinstance(value, list):
return ListFiles(
root=[
FileData(
path=file,
orig_name=Path(file).name,
size=Path(file).stat().st_size,
)
for file in value
]
)
else:
return FileData(
path=value,
orig_name=Path(value).name,
size=Path(value).stat().st_size,
)
@property
def skip_api(self):

View File

@ -43,7 +43,7 @@ class Video(Component):
"""
data_model = VideoData
input_data_model = FileData
EVENTS = [
Events.change,
Events.clear,
@ -155,18 +155,11 @@ class Video(Component):
value=value,
)
def preprocess(self, x: dict | VideoData) -> str | None:
"""
Parameters:
x: A tuple of (video file data, subtitle file data) or just video file data.
Returns:
A string file path or URL to the preprocessed video. Subtitle file data is ignored.
"""
if x is None:
def preprocess(self, payload: VideoData | None) -> str | None:
if payload is None:
return None
data: VideoData = VideoData(**x) if isinstance(x, dict) else x
assert data.video.path
file_name = Path(data.video.path)
assert payload.video.path
file_name = Path(payload.video.path)
uploaded_format = file_name.suffix.replace(".", "")
needs_formatting = self.format is not None and uploaded_format != self.format
flip = self.sources == ["webcam"] and self.mirror_webcam
@ -221,24 +214,6 @@ class Video(Component):
def postprocess(
self, y: str | Path | tuple[str | Path, str | Path | None] | None
) -> VideoData | None:
"""
Processes a video to ensure that it is in the correct format before returning it to the front end.
Parameters:
y: video data in either of the following formats: a tuple of (video filepath, optional subtitle filepath), or just a filepath or URL to an video file, or None.
Returns:
a tuple with the two dictionary, reresent to video and (optional) subtitle, which following formats:
- The first dictionary represents the video file and contains the following keys:
- 'name': a file path to a temporary copy of the processed video.
- 'data': None
- 'is_file': True
- The second dictionary represents the subtitle file and contains the following keys:
- 'name': None
- 'data': Base64 encode the processed subtitle data.
- 'is_file': False
- If subtitle is None, returns (video, None).
- If both video and subtitle are None, returns None.
"""
if y is None or y == [None, None] or y == (None, None):
return None
if isinstance(y, (str, Path)):
@ -269,14 +244,6 @@ class Video(Component):
def _format_video(self, video: str | Path | None) -> FileData | None:
"""
Processes a video to ensure that it is in the correct format.
Parameters:
video: video data in either of the following formats: a string filepath or URL to an video file, or None.
Returns:
a dictionary with the following keys:
- 'name': a file path to a temporary copy of the processed video.
- 'data': None
- 'is_file': True
"""
if video is None:
return None
@ -328,13 +295,6 @@ class Video(Component):
def _format_subtitle(self, subtitle: str | Path | None) -> FileData | None:
"""
Convert subtitle format to VTT and process the video to ensure it meets the HTML5 requirements.
Parameters:
subtitle: subtitle path in either of the VTT and SRT format.
Returns:
a dictionary with the following keys:
- 'name': None
- 'data': base64-encoded subtitle data.
- 'is_file': False
"""
def srt_to_vtt(srt_file_path, vtt_file_path):

View File

@ -173,8 +173,13 @@ class Queue:
if concurrency_limit is None or existing_worker_count < concurrency_limit:
batch = block_fn.batch
if batch:
remaining_worker_count = concurrency_limit - existing_worker_count
batch_size = block_fn.max_batch_size
if concurrency_limit is None:
remaining_worker_count = batch_size - 1
else:
remaining_worker_count = (
concurrency_limit - existing_worker_count
)
rest_of_batch = [
event
for event in self.event_queue[index:]

View File

@ -52,9 +52,9 @@
all_file_data = (await upload(all_file_data, root))?.filter(
(x) => x !== null
) as FileData[];
dispatch("change", all_file_data);
dispatch("upload", all_file_data);
value = all_file_data;
value = file_count === "single" ? all_file_data?.[0] : all_file_data
dispatch("change", value);
dispatch("upload", value);
}
async function loadFilesFromUpload(e: Event): Promise<void> {

View File

@ -30,6 +30,8 @@ except ImportError:
import gradio as gr
from gradio import processing_utils, utils
from gradio.components.dataframe import DataframeData
from gradio.components.video import VideoData
from gradio.data_classes import FileData
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -480,9 +482,9 @@ class TestDropdown:
"""
dropdown_input = gr.Dropdown(["a", "b", ("c", "c full")], multiselect=True)
assert dropdown_input.preprocess("a") == "a"
assert dropdown_input.postprocess("a") == "a"
assert dropdown_input.postprocess("a") == ["a"]
assert dropdown_input.preprocess("c full") == "c full"
assert dropdown_input.postprocess("c full") == "c full"
assert dropdown_input.postprocess("c full") == ["c full"]
# When a Gradio app is loaded with gr.load, the tuples are converted to lists,
# so we need to test that case as well
@ -558,7 +560,7 @@ class TestImage:
type: pil, file, filepath, numpy
"""
img = dict(FileData(path="test/test_files/bus.png"))
img = FileData(path="test/test_files/bus.png")
image_input = gr.Image()
image_input = gr.Image(type="filepath")
@ -605,12 +607,12 @@ class TestImage:
# Output functionalities
image_output = gr.Image(type="pil")
processed_image = image_output.postprocess(
PIL.Image.open(img["path"])
PIL.Image.open(img.path)
).model_dump()
assert processed_image is not None
if processed_image is not None:
processed = PIL.Image.open(cast(dict, processed_image).get("path", ""))
source = PIL.Image.open(img["path"])
source = PIL.Image.open(img.path)
assert processed.size == source.size
def test_in_interface_as_output(self):
@ -696,7 +698,7 @@ class TestAudio:
Preprocess, postprocess serialize, get_config, deserialize
type: filepath, numpy, file
"""
x_wav = deepcopy(media_data.BASE64_AUDIO)
x_wav = FileData(path=media_data.BASE64_AUDIO["path"])
audio_input = gr.Audio()
output1 = audio_input.preprocess(x_wav)
assert output1[0] == 8000
@ -735,11 +737,6 @@ class TestAudio:
"_selectable": False,
}
assert audio_input.preprocess(None) is None
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
output2 = audio_input.preprocess(x_wav)
assert output2 is not None
assert output1 != output2
audio_input = gr.Audio(type="filepath")
assert isinstance(audio_input.preprocess(x_wav), str)
@ -821,21 +818,21 @@ class TestAudio:
assert iface(100).endswith(".wav")
def test_audio_preprocess_can_be_read_by_scipy(self, gradio_temp_dir):
x_wav = {
"path": processing_utils.save_base64_to_cache(
x_wav = FileData(
path=processing_utils.save_base64_to_cache(
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
),
}
)
)
audio_input = gr.Audio(type="filepath")
output = audio_input.preprocess(x_wav)
wavfile.read(output)
def test_prepost_process_to_mp3(self, gradio_temp_dir):
x_wav = {
"path": processing_utils.save_base64_to_cache(
x_wav = FileData(
path=processing_utils.save_base64_to_cache(
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
),
}
)
)
audio_input = gr.Audio(type="filepath", format="mp3")
output = audio_input.preprocess(x_wav)
assert output.endswith("mp3")
@ -850,13 +847,13 @@ class TestFile:
"""
Preprocess, serialize, get_config, value
"""
x_file = deepcopy(media_data.BASE64_FILE)
x_file = FileData(path=media_data.BASE64_FILE["path"])
file_input = gr.File()
output = file_input.preprocess({"path": x_file["path"]})
output = file_input.preprocess(x_file)
assert isinstance(output, str)
input1 = file_input.preprocess({"path": x_file["path"]})
input2 = file_input.preprocess({"path": x_file["path"]})
input1 = file_input.preprocess(x_file)
input2 = file_input.preprocess(x_file)
assert input1 == input1.name # Testing backwards compatibility
assert input1 == input2
assert Path(input1).name == "sample_file.pdf"
@ -884,7 +881,7 @@ class TestFile:
assert file_input.preprocess(None) is None
assert file_input.preprocess(x_file) is not None
zero_size_file = {"path": "document.txt", "size": 0}
zero_size_file = FileData(path="document.txt", size=0)
temp_file = file_input.preprocess(zero_size_file)
assert not Path(temp_file.name).exists()
@ -933,13 +930,13 @@ class TestUploadButton:
"""
preprocess
"""
x_file = deepcopy(media_data.BASE64_FILE)
x_file = FileData(path=media_data.BASE64_FILE["path"])
upload_input = gr.UploadButton()
input = upload_input.preprocess({"path": x_file})
input = upload_input.preprocess(x_file)
assert isinstance(input, str)
input1 = upload_input.preprocess({"path": x_file})
input2 = upload_input.preprocess({"path": x_file})
input1 = upload_input.preprocess(x_file)
input2 = upload_input.preprocess(x_file)
assert input1 == input1.name # Testing backwards compatibility
assert input1 == input2
@ -961,7 +958,7 @@ class TestDataframe:
"metadata": None,
}
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
output = dataframe_input.preprocess(x_data)
output = dataframe_input.preprocess(DataframeData(**x_data))
assert output["Age"][1] == 24
assert not output["Member"][0]
assert dataframe_input.postprocess(x_data) == x_data
@ -998,7 +995,7 @@ class TestDataframe:
"column_widths": [],
}
dataframe_input = gr.Dataframe()
output = dataframe_input.preprocess(x_data)
output = dataframe_input.preprocess(DataframeData(**x_data))
assert output["Age"][1] == 24
with pytest.raises(ValueError):
gr.Dataframe(type="unknown")
@ -1315,7 +1312,9 @@ class TestVideo:
"""
Preprocess, serialize, deserialize, get_config
"""
x_video = {"video": {"path": deepcopy(media_data.BASE64_VIDEO)["path"]}}
x_video = VideoData(
video=FileData(path=deepcopy(media_data.BASE64_VIDEO)["path"])
)
video_input = gr.Video()
x_video = processing_utils.move_files_to_cache([x_video], video_input)[0]
@ -1357,8 +1356,6 @@ class TestVideo:
"_selectable": False,
}
assert video_input.preprocess(None) is None
x_video["is_example"] = True
assert video_input.preprocess(x_video) is not None
video_input = gr.Video(format="avi")
output_video = video_input.preprocess(x_video)
assert output_video[-3:] == "avi"
@ -1468,7 +1465,7 @@ class TestVideo:
@patch("gradio.components.video.FFmpeg")
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
x_video = {"video": deepcopy(media_data.BASE64_VIDEO)}
x_video = VideoData(video=FileData(path=media_data.BASE64_VIDEO["path"]))
video_input = gr.Video(sources=["webcam"])
_ = video_input.preprocess(x_video)

Binary file not shown.

View File

@ -288,18 +288,19 @@ class TestThemeUploadDownload:
)
assert next_version == "3.20.2"
@pytest.mark.flaky
def test_theme_download(self):
assert (
gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict()
== dracula.to_dict()
)
## Commenting out until after 4.0 Spaces are up
# @pytest.mark.flaky
# def test_theme_download(self):
# assert (
# gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict()
# == dracula.to_dict()
# )
with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo:
pass
# with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo:
# pass
assert demo.theme.to_dict() == dracula.to_dict()
assert demo.theme.name == "gradio/dracula_test"
# assert demo.theme.to_dict() == dracula.to_dict()
# assert demo.theme.name == "gradio/dracula_test"
def test_theme_download_raises_error_if_theme_does_not_exist(self):
with pytest.raises(