mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
modify preprocess to use pydantic models (#6181)
* modify preprocess to use pydantic models * changes * add changeset * fix * fix * fix typing * save * revert queuing changes * fix * fix * notebook * fix * changes * add changeset * fix functional tests --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
e16b4abc37
commit
62ec2075cc
6
.changeset/short-doodles-lose.md
Normal file
6
.changeset/short-doodles-lose.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"@gradio/uploadbutton": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:modify preprocess to use pydantic models
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -51,4 +51,4 @@ with gr.Blocks() as demo:
|
||||
|
||||
demo.queue()
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
demo.launch(allowed_paths=["avatar.png"])
|
||||
|
@ -153,6 +153,7 @@ def evaluate_values(*args):
|
||||
are_false.append(a == "#000000")
|
||||
else:
|
||||
are_false.append(not a)
|
||||
print(args)
|
||||
return all(are_false)
|
||||
|
||||
|
||||
|
@ -341,7 +341,7 @@ class BlockFunction:
|
||||
self.postprocess = postprocess
|
||||
self.tracks_progress = tracks_progress
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self.concurrency_id = concurrency_id or id(fn)
|
||||
self.concurrency_id = concurrency_id or str(id(fn))
|
||||
self.batch = batch
|
||||
self.max_batch_size = max_batch_size
|
||||
self.total_runtime = 0
|
||||
@ -1260,6 +1260,14 @@ Received inputs:
|
||||
inputs_cached = processing_utils.move_files_to_cache(
|
||||
inputs[i], block
|
||||
)
|
||||
if getattr(block, "data_model", None) and inputs_cached is not None:
|
||||
if issubclass(block.data_model, GradioModel): # type: ignore
|
||||
print("block.data_model", block.data_model, block)
|
||||
print("1inputs_cached", inputs_cached)
|
||||
inputs_cached = block.data_model(**inputs_cached) # type: ignore
|
||||
elif issubclass(block.data_model, GradioRootModel): # type: ignore
|
||||
print("2inputs_cached", inputs_cached)
|
||||
inputs_cached = block.data_model(root=inputs_cached) # type: ignore
|
||||
processed_input.append(block.preprocess(inputs_cached))
|
||||
else:
|
||||
processed_input = inputs
|
||||
|
@ -103,20 +103,21 @@ class AnnotatedImage(Component):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: tuple[
|
||||
value: tuple[
|
||||
np.ndarray | _Image.Image | str,
|
||||
list[tuple[np.ndarray | tuple[int, int, int, int], str]],
|
||||
],
|
||||
]
|
||||
| None,
|
||||
) -> AnnotatedImageData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
|
||||
value: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
|
||||
Returns:
|
||||
Tuple of base image file and list of subsections, with each subsection a two-part tuple where the first element image path of the mask, and the second element is the label.
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None
|
||||
base_img = y[0]
|
||||
base_img = value[0]
|
||||
if isinstance(base_img, str):
|
||||
base_img_path = base_img
|
||||
base_img = np.array(_Image.open(base_img))
|
||||
@ -144,7 +145,7 @@ class AnnotatedImage(Component):
|
||||
lv = len(value)
|
||||
return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)]
|
||||
|
||||
for mask, label in y[1]:
|
||||
for mask, label in value[1]:
|
||||
mask_array = np.zeros((base_img.shape[0], base_img.shape[1]))
|
||||
if isinstance(mask, np.ndarray):
|
||||
mask_array = mask
|
||||
@ -188,5 +189,7 @@ class AnnotatedImage(Component):
|
||||
def example_inputs(self) -> Any:
|
||||
return {}
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(
|
||||
self, payload: AnnotatedImageData | None
|
||||
) -> AnnotatedImageData | None:
|
||||
return payload
|
||||
|
@ -162,20 +162,12 @@ class Audio(
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
|
||||
|
||||
def preprocess(
|
||||
self, x: dict[str, Any] | None
|
||||
self, payload: FileData | None
|
||||
) -> tuple[int, np.ndarray] | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: dictionary with keys "path", "crop_min", "crop_max".
|
||||
Returns:
|
||||
audio in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
if payload is None:
|
||||
return payload
|
||||
|
||||
payload: FileData = FileData(**x)
|
||||
assert payload.path
|
||||
|
||||
# Need a unique name for the file to avoid re-using the same audio file if
|
||||
# a user submits the same audio file twice
|
||||
temp_file_path = Path(payload.path)
|
||||
@ -211,50 +203,50 @@ class Audio(
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, y: tuple[int, np.ndarray] | str | Path | bytes | None
|
||||
) -> FileData | None | bytes:
|
||||
self, value: tuple[int, np.ndarray] | str | Path | bytes | None
|
||||
) -> FileData | bytes | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
|
||||
value: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
|
||||
Returns:
|
||||
base64 url data
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(y, bytes):
|
||||
if isinstance(value, bytes):
|
||||
if self.streaming:
|
||||
return y
|
||||
return value
|
||||
file_path = processing_utils.save_bytes_to_cache(
|
||||
y, "audio", cache_dir=self.GRADIO_CACHE
|
||||
value, "audio", cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
elif isinstance(y, tuple):
|
||||
sample_rate, data = y
|
||||
elif isinstance(value, tuple):
|
||||
sample_rate, data = value
|
||||
file_path = processing_utils.save_audio_to_cache(
|
||||
data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
else:
|
||||
if not isinstance(y, (str, Path)):
|
||||
raise ValueError(f"Cannot process {y} as Audio")
|
||||
file_path = str(y)
|
||||
if not isinstance(value, (str, Path)):
|
||||
raise ValueError(f"Cannot process {value} as Audio")
|
||||
file_path = str(value)
|
||||
return FileData(path=file_path)
|
||||
|
||||
def stream_output(
|
||||
self, y, output_id: str, first_chunk: bool
|
||||
self, value, output_id: str, first_chunk: bool
|
||||
) -> tuple[bytes | None, Any]:
|
||||
output_file = {
|
||||
"path": output_id,
|
||||
"is_stream": True,
|
||||
}
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None, output_file
|
||||
if isinstance(y, bytes):
|
||||
return y, output_file
|
||||
if client_utils.is_http_url_like(y["path"]):
|
||||
response = requests.get(y["path"])
|
||||
if isinstance(value, bytes):
|
||||
return value, output_file
|
||||
if client_utils.is_http_url_like(value["path"]):
|
||||
response = requests.get(value["path"])
|
||||
binary_data = response.content
|
||||
else:
|
||||
output_file["orig_name"] = y["orig_name"]
|
||||
file_path = y["path"]
|
||||
output_file["orig_name"] = value["orig_name"]
|
||||
file_path = value["path"]
|
||||
is_wav = file_path.endswith(".wav")
|
||||
with open(file_path, "rb") as f:
|
||||
binary_data = f.read()
|
||||
|
@ -258,15 +258,15 @@ class BarPlot(Plot):
|
||||
return chart
|
||||
|
||||
def postprocess(
|
||||
self, y: pd.DataFrame | dict | None
|
||||
self, value: pd.DataFrame | dict | None
|
||||
) -> AltairPlotData | dict | None:
|
||||
# if None or update
|
||||
if y is None or isinstance(y, dict):
|
||||
return y
|
||||
if value is None or isinstance(value, dict):
|
||||
return value
|
||||
if self.x is None or self.y is None:
|
||||
raise ValueError("No value provided for required parameters `x` and `y`.")
|
||||
chart = self.create_plot(
|
||||
value=y,
|
||||
value=value,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
color=self.color,
|
||||
@ -288,12 +288,10 @@ class BarPlot(Plot):
|
||||
sort=self.sort, # type: ignore
|
||||
)
|
||||
|
||||
return AltairPlotData(
|
||||
**{"type": "altair", "plot": chart.to_json(), "chart": "bar"}
|
||||
)
|
||||
return AltairPlotData(type="altair", plot=chart.to_json(), chart="bar")
|
||||
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: AltairPlotData) -> AltairPlotData:
|
||||
return payload
|
||||
|
@ -48,21 +48,21 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
|
||||
EVENTS: list[EventListener | str] = []
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
"""
|
||||
Any preprocessing needed to be performed on function input.
|
||||
"""
|
||||
return x
|
||||
return payload
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, y):
|
||||
def postprocess(self, value):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
"""
|
||||
return y
|
||||
return value
|
||||
|
||||
@abstractmethod
|
||||
def as_example(self, y):
|
||||
def as_example(self, value):
|
||||
"""
|
||||
Return the input data in a way that can be displayed by the examples dataset component in the front-end.
|
||||
|
||||
@ -88,7 +88,7 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
|
||||
def flag(self, payload: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
|
||||
"""
|
||||
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
|
||||
"""
|
||||
@ -97,13 +97,13 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
|
||||
@abstractmethod
|
||||
def read_from_flag(
|
||||
self,
|
||||
x: Any,
|
||||
payload: Any,
|
||||
flag_dir: str | Path | None = None,
|
||||
) -> GradioDataModel | Any:
|
||||
"""
|
||||
Convert the data from the csv or jsonl file into the component state.
|
||||
"""
|
||||
return x
|
||||
return payload
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -267,26 +267,26 @@ class Component(ComponentBase, Block):
|
||||
f"The api_info method has not been implemented for {self.get_block_name()}"
|
||||
)
|
||||
|
||||
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
|
||||
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
|
||||
"""
|
||||
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
|
||||
"""
|
||||
if self.data_model:
|
||||
x = self.data_model.from_json(x)
|
||||
return x.copy_to_dir(flag_dir).model_dump_json()
|
||||
return x
|
||||
payload = self.data_model.from_json(payload)
|
||||
return payload.copy_to_dir(flag_dir).model_dump_json()
|
||||
return payload
|
||||
|
||||
def read_from_flag(
|
||||
self,
|
||||
x: Any,
|
||||
payload: Any,
|
||||
flag_dir: str | Path | None = None,
|
||||
):
|
||||
"""
|
||||
Convert the data from the csv or jsonl file into the component state.
|
||||
"""
|
||||
if self.data_model:
|
||||
return self.data_model.from_json(json.loads(x))
|
||||
return x
|
||||
return self.data_model.from_json(json.loads(payload))
|
||||
return payload
|
||||
|
||||
|
||||
class FormComponent(Component):
|
||||
@ -295,11 +295,11 @@ class FormComponent(Component):
|
||||
return None
|
||||
return Form
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
return payload
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class StreamingOutput(metaclass=abc.ABCMeta):
|
||||
@ -308,7 +308,9 @@ class StreamingOutput(metaclass=abc.ABCMeta):
|
||||
self.streaming: bool
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream_output(self, y, output_id: str, first_chunk: bool) -> tuple[bytes, Any]:
|
||||
def stream_output(
|
||||
self, value, output_id: str, first_chunk: bool
|
||||
) -> tuple[bytes, Any]:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -77,11 +77,11 @@ class Button(Component):
|
||||
def skip_api(self):
|
||||
return True
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: str) -> str:
|
||||
return payload
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value: str) -> str:
|
||||
return value
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return None
|
||||
|
@ -14,8 +14,6 @@ from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.events import Events
|
||||
|
||||
# from pydantic import Field, TypeAdapter
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
|
||||
@ -129,26 +127,28 @@ class Chatbot(Component):
|
||||
)
|
||||
|
||||
def _preprocess_chat_messages(
|
||||
self, chat_message: str | dict | None
|
||||
) -> str | tuple[str] | tuple[str, str] | None:
|
||||
self, chat_message: str | FileMessage | None
|
||||
) -> str | tuple[str | None] | tuple[str | None, str] | None:
|
||||
if chat_message is None:
|
||||
return None
|
||||
elif isinstance(chat_message, dict):
|
||||
if chat_message.get("alt_text"):
|
||||
return (chat_message["file"]["path"], chat_message["alt_text"])
|
||||
elif isinstance(chat_message, FileMessage):
|
||||
if chat_message.alt_text is not None:
|
||||
return (chat_message.file.path, chat_message.alt_text)
|
||||
else:
|
||||
return (chat_message["file"]["path"],)
|
||||
else: # string
|
||||
return (chat_message.file.path,)
|
||||
elif isinstance(chat_message, str):
|
||||
return chat_message
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
y: list[list[str | dict | None] | tuple[str | dict | None, str | dict | None]],
|
||||
payload: ChatbotData,
|
||||
) -> list[list[str | tuple[str] | tuple[str, str] | None]]:
|
||||
if y is None:
|
||||
return y
|
||||
if payload is None:
|
||||
return payload
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
for message_pair in payload.root:
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
@ -186,18 +186,12 @@ class Chatbot(Component):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
|
||||
value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
|
||||
) -> ChatbotData:
|
||||
"""
|
||||
Parameters:
|
||||
y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string or pathlib.Path filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
||||
Returns:
|
||||
List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return ChatbotData(root=[])
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
for message_pair in value:
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
|
@ -79,3 +79,9 @@ class Checkbox(FormComponent):
|
||||
|
||||
def example_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def preprocess(self, payload: bool | None) -> bool | None:
|
||||
return payload
|
||||
|
||||
def postprocess(self, value: bool | None) -> bool | None:
|
||||
return value
|
||||
|
@ -101,21 +101,15 @@ class CheckboxGroup(FormComponent):
|
||||
}
|
||||
|
||||
def preprocess(
|
||||
self, x: list[str | int | float]
|
||||
self, payload: list[str | int | float]
|
||||
) -> list[str | int | float] | list[int | None]:
|
||||
"""
|
||||
Parameters:
|
||||
x: list of selected choices
|
||||
Returns:
|
||||
list of selected choice values as strings or indices within choice list
|
||||
"""
|
||||
if self.type == "value":
|
||||
return x
|
||||
return payload
|
||||
elif self.type == "index":
|
||||
choice_values = [value for _, value in self.choices]
|
||||
return [
|
||||
choice_values.index(choice) if choice in choice_values else None
|
||||
for choice in x
|
||||
for choice in payload
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -123,19 +117,13 @@ class CheckboxGroup(FormComponent):
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, y: list[str | int | float] | str | int | float | None
|
||||
self, value: list[str | int | float] | str | int | float | None
|
||||
) -> list[str | int | float]:
|
||||
"""
|
||||
Parameters:
|
||||
y: List of selected choice values. If a single choice is selected, it can be passed in as a string
|
||||
Returns:
|
||||
List of selected choices
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return []
|
||||
if not isinstance(y, list):
|
||||
y = [y]
|
||||
return y
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
return value
|
||||
|
||||
def as_example(self, input_data):
|
||||
if input_data is None:
|
||||
|
@ -75,16 +75,17 @@ class ClearButton(Button):
|
||||
none = component.postprocess(None)
|
||||
if isinstance(none, (GradioModel, GradioRootModel)):
|
||||
none = none.model_dump()
|
||||
print(none)
|
||||
none_values.append(none)
|
||||
clear_values = json.dumps(none_values)
|
||||
self.click(None, [], components, _js=f"() => {clear_values}")
|
||||
return self
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: str | None) -> str | None:
|
||||
return payload
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return None
|
||||
|
@ -105,20 +105,20 @@ class Code(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
return payload
|
||||
|
||||
def postprocess(self, y: tuple | str | None) -> None | str:
|
||||
if y is None:
|
||||
def postprocess(self, value: tuple | str | None) -> None | str:
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(y, tuple):
|
||||
with open(y[0]) as file_data:
|
||||
elif isinstance(value, tuple):
|
||||
with open(value[0]) as file_data:
|
||||
return file_data.read()
|
||||
else:
|
||||
return y.strip()
|
||||
return value.strip()
|
||||
|
||||
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
|
||||
return super().flag(x, flag_dir)
|
||||
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
|
||||
return super().flag(payload, flag_dir)
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": "string"}
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
@ -77,37 +76,17 @@ class ColorPicker(Component):
|
||||
def example_inputs(self) -> str:
|
||||
return "#000000"
|
||||
|
||||
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
|
||||
return x
|
||||
|
||||
def read_from_flag(self, x: Any, flag_dir: str | Path | None = None):
|
||||
return x
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": "string"}
|
||||
|
||||
def preprocess(self, x: str | None) -> str | None:
|
||||
"""
|
||||
Any preprocessing needed to be performed on function input.
|
||||
Parameters:
|
||||
x: text
|
||||
Returns:
|
||||
text
|
||||
"""
|
||||
if x is None:
|
||||
def preprocess(self, payload: str | None) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
else:
|
||||
return str(x)
|
||||
return str(payload)
|
||||
|
||||
def postprocess(self, y: str | None) -> str | None:
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
Parameters:
|
||||
y: text
|
||||
Returns:
|
||||
text
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return str(y)
|
||||
return str(value)
|
||||
|
@ -162,23 +162,16 @@ class Dataframe(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list:
|
||||
"""
|
||||
Parameters:
|
||||
x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys
|
||||
Returns:
|
||||
The Dataframe data in requested format
|
||||
"""
|
||||
value = DataframeData(**x)
|
||||
def preprocess(self, payload: DataframeData) -> pd.DataFrame | np.ndarray | list:
|
||||
if self.type == "pandas":
|
||||
if value.headers is not None:
|
||||
return pd.DataFrame(value.data, columns=value.headers)
|
||||
if payload.headers is not None:
|
||||
return pd.DataFrame(payload.data, columns=payload.headers)
|
||||
else:
|
||||
return pd.DataFrame(value.data)
|
||||
return pd.DataFrame(payload.data)
|
||||
if self.type == "numpy":
|
||||
return np.array(value.data)
|
||||
return np.array(payload.data)
|
||||
elif self.type == "array":
|
||||
return value.data
|
||||
return payload.data
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
@ -188,26 +181,27 @@ class Dataframe(Component):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None,
|
||||
value: pd.DataFrame
|
||||
| Styler
|
||||
| np.ndarray
|
||||
| list
|
||||
| list[list]
|
||||
| dict
|
||||
| str
|
||||
| None,
|
||||
) -> DataframeData | dict:
|
||||
"""
|
||||
Parameters:
|
||||
y: dataframe in given format
|
||||
Returns:
|
||||
JSON object with key 'headers' for list of header names, 'data' for 2D array of string or numeric data
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return self.postprocess(self.empty_input)
|
||||
if isinstance(y, dict):
|
||||
return y
|
||||
if isinstance(y, (str, pd.DataFrame)):
|
||||
if isinstance(y, str):
|
||||
y = pd.read_csv(y) # type: ignore
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, (str, pd.DataFrame)):
|
||||
if isinstance(value, str):
|
||||
value = pd.read_csv(value) # type: ignore
|
||||
return DataframeData(
|
||||
headers=list(y.columns), # type: ignore
|
||||
data=y.to_dict(orient="split")["data"], # type: ignore
|
||||
headers=list(value.columns), # type: ignore
|
||||
data=value.to_dict(orient="split")["data"], # type: ignore
|
||||
)
|
||||
elif isinstance(y, Styler):
|
||||
elif isinstance(value, Styler):
|
||||
if semantic_version.Version(pd.__version__) < semantic_version.Version(
|
||||
"1.5.0"
|
||||
):
|
||||
@ -218,39 +212,38 @@ class Dataframe(Component):
|
||||
warnings.warn(
|
||||
"Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
|
||||
)
|
||||
df: pd.DataFrame = y.data # type: ignore
|
||||
value = DataframeData(
|
||||
df: pd.DataFrame = value.data # type: ignore
|
||||
return DataframeData(
|
||||
headers=list(df.columns),
|
||||
data=df.to_dict(orient="split")["data"], # type: ignore
|
||||
metadata=self.__extract_metadata(y),
|
||||
metadata=self.__extract_metadata(value),
|
||||
)
|
||||
elif isinstance(y, (str, pd.DataFrame)):
|
||||
df = pd.read_csv(y) if isinstance(y, str) else y # type: ignore
|
||||
value = DataframeData(
|
||||
elif isinstance(value, (str, pd.DataFrame)):
|
||||
df = pd.read_csv(value) if isinstance(value, str) else value # type: ignore
|
||||
return DataframeData(
|
||||
headers=list(df.columns),
|
||||
data=df.to_dict(orient="split")["data"], # type: ignore
|
||||
)
|
||||
elif isinstance(y, (np.ndarray, list)):
|
||||
if len(y) == 0:
|
||||
elif isinstance(value, (np.ndarray, list)):
|
||||
if len(value) == 0:
|
||||
return self.postprocess([[]])
|
||||
if isinstance(y, np.ndarray):
|
||||
y = y.tolist()
|
||||
if not isinstance(y, list):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = value.tolist()
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("output cannot be converted to list")
|
||||
|
||||
_headers = self.headers
|
||||
if len(self.headers) < len(y[0]):
|
||||
if len(self.headers) < len(value[0]):
|
||||
_headers: list[str] = [
|
||||
*self.headers,
|
||||
*[str(i) for i in range(len(self.headers) + 1, len(y[0]) + 1)],
|
||||
*[str(i) for i in range(len(self.headers) + 1, len(value[0]) + 1)],
|
||||
]
|
||||
elif len(self.headers) > len(y[0]):
|
||||
_headers = self.headers[: len(y[0])]
|
||||
elif len(self.headers) > len(value[0]):
|
||||
_headers = self.headers[: len(value[0])]
|
||||
|
||||
value = DataframeData(headers=_headers, data=y)
|
||||
return DataframeData(headers=_headers, data=value)
|
||||
else:
|
||||
raise ValueError("Cannot process value as a Dataframe")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def __get_cell_style(cell_id: str, cell_styles: list[dict]) -> str:
|
||||
|
@ -121,16 +121,13 @@ class Dataset(Component):
|
||||
|
||||
return config
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
"""
|
||||
Any preprocessing needed to be performed on function input.
|
||||
"""
|
||||
def preprocess(self, payload: int) -> int | list[list] | None:
|
||||
if self.type == "index":
|
||||
return x
|
||||
return payload
|
||||
elif self.type == "values":
|
||||
return self.samples[x]
|
||||
return self.samples[payload]
|
||||
|
||||
def postprocess(self, samples: list[list[Any]]) -> dict:
|
||||
def postprocess(self, samples: list[list]) -> dict:
|
||||
return {
|
||||
"samples": samples,
|
||||
"__type__": "update",
|
||||
|
@ -134,48 +134,48 @@ class Dropdown(FormComponent):
|
||||
return self.choices[0][1] if self.choices else None
|
||||
|
||||
def preprocess(
|
||||
self, x: str | int | float | list[str | int | float] | None
|
||||
self, payload: str | int | float | list[str | int | float] | None
|
||||
) -> str | int | float | list[str | int | float] | list[int | None] | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: selected choice(s)
|
||||
Returns:
|
||||
selected choice(s) as string or index within choice list or list of string or indices
|
||||
"""
|
||||
if self.type == "value":
|
||||
return x
|
||||
return payload
|
||||
elif self.type == "index":
|
||||
choice_values = [value for _, value in self.choices]
|
||||
if x is None:
|
||||
if payload is None:
|
||||
return None
|
||||
elif self.multiselect:
|
||||
assert isinstance(x, list)
|
||||
assert isinstance(payload, list)
|
||||
return [
|
||||
choice_values.index(choice) if choice in choice_values else None
|
||||
for choice in x
|
||||
for choice in payload
|
||||
]
|
||||
else:
|
||||
return choice_values.index(x) if x in choice_values else None
|
||||
return (
|
||||
choice_values.index(payload) if payload in choice_values else None
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
|
||||
)
|
||||
|
||||
def _warn_if_invalid_choice(self, y):
|
||||
if self.allow_custom_value or y in [value for _, value in self.choices]:
|
||||
def _warn_if_invalid_choice(self, value):
|
||||
if self.allow_custom_value or value in [value for _, value in self.choices]:
|
||||
return
|
||||
warnings.warn(
|
||||
f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {y} or set allow_custom_value=True."
|
||||
f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {value} or set allow_custom_value=True."
|
||||
)
|
||||
|
||||
def postprocess(self, y):
|
||||
if y is None:
|
||||
def postprocess(
|
||||
self, value: str | int | float | list[str | int | float] | None
|
||||
) -> str | int | float | list[str | int | float] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self.multiselect:
|
||||
[self._warn_if_invalid_choice(_y) for _y in y]
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
[self._warn_if_invalid_choice(_y) for _y in value]
|
||||
else:
|
||||
self._warn_if_invalid_choice(y)
|
||||
return y
|
||||
self._warn_if_invalid_choice(value)
|
||||
return value
|
||||
|
||||
def as_example(self, input_data):
|
||||
if self.multiselect:
|
||||
|
@ -2,11 +2,11 @@ from gradio.components.base import Component
|
||||
|
||||
|
||||
class Fallback(Component):
|
||||
def preprocess(self, x):
|
||||
return x
|
||||
def preprocess(self, payload):
|
||||
return payload
|
||||
|
||||
def postprocess(self, x):
|
||||
return x
|
||||
def postprocess(self, value):
|
||||
return value
|
||||
|
||||
def example_inputs(self):
|
||||
return {"foo": "bar"}
|
||||
|
@ -20,6 +20,12 @@ set_documentation_group("component")
|
||||
class ListFiles(GradioRootModel):
|
||||
root: List[FileData]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.root[index]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.root)
|
||||
|
||||
|
||||
@document()
|
||||
class File(Component):
|
||||
@ -111,13 +117,12 @@ class File(Component):
|
||||
self.type = type
|
||||
self.height = height
|
||||
|
||||
def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString:
|
||||
file_name = f["path"]
|
||||
|
||||
def _process_single_file(self, f: FileData) -> NamedString | bytes:
|
||||
file_name = f.path
|
||||
if self.type == "filepath":
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
|
||||
file.name = file_name
|
||||
return NamedString(file.name)
|
||||
return NamedString(file_name)
|
||||
elif self.type == "binary":
|
||||
with open(file_name, "rb") as file_data:
|
||||
return file_data.read()
|
||||
@ -129,38 +134,25 @@ class File(Component):
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, x: list[dict[str, Any]] | dict[str, Any] | None
|
||||
self, payload: ListFiles | FileData | None
|
||||
) -> bytes | NamedString | list[bytes | NamedString] | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
Returns:
|
||||
File objects in requested format
|
||||
"""
|
||||
if x is None:
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if self.file_count == "single":
|
||||
if isinstance(x, list):
|
||||
return self._process_single_file(x[0])
|
||||
if isinstance(payload, ListFiles):
|
||||
return self._process_single_file(payload[0])
|
||||
else:
|
||||
return self._process_single_file(x)
|
||||
return self._process_single_file(payload)
|
||||
else:
|
||||
if isinstance(x, list):
|
||||
return [self._process_single_file(f) for f in x]
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload]
|
||||
else:
|
||||
return [self._process_single_file(x)]
|
||||
return [self._process_single_file(payload)]
|
||||
|
||||
def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: file path
|
||||
Returns:
|
||||
JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(y, list):
|
||||
if isinstance(value, list):
|
||||
return ListFiles(
|
||||
root=[
|
||||
FileData(
|
||||
@ -168,14 +160,14 @@ class File(Component):
|
||||
orig_name=Path(file).name,
|
||||
size=Path(file).stat().st_size,
|
||||
)
|
||||
for file in y
|
||||
for file in value
|
||||
]
|
||||
)
|
||||
else:
|
||||
return FileData(
|
||||
path=y,
|
||||
orig_name=Path(y).name,
|
||||
size=Path(y).stat().st_size,
|
||||
path=value,
|
||||
orig_name=Path(value).name,
|
||||
size=Path(value).stat().st_size,
|
||||
)
|
||||
|
||||
def as_example(self, input_data: str | list | None) -> str:
|
||||
|
@ -104,49 +104,39 @@ class FileExplorer(Component):
|
||||
def example_inputs(self) -> Any:
|
||||
return ["Users", "gradio", "app.py"]
|
||||
|
||||
def preprocess(self, x: list[list[str]] | None) -> list[str] | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: File path segments as a list of list of strings for each file relative to the root.
|
||||
Returns:
|
||||
File path selected, as an absolute path.
|
||||
"""
|
||||
if x is None:
|
||||
def preprocess(self, payload: list[list[str]] | None) -> list[str] | str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if self.file_count == "single":
|
||||
if len(x) > 1:
|
||||
raise ValueError(f"Expected only one file, but {len(x)} were selected.")
|
||||
return self._safe_join(x[0])
|
||||
if len(payload) > 1:
|
||||
raise ValueError(
|
||||
f"Expected only one file, but {len(payload)} were selected."
|
||||
)
|
||||
return self._safe_join(payload[0])
|
||||
|
||||
return [self._safe_join(file) for file in (x)]
|
||||
return [self._safe_join(file) for file in (payload)]
|
||||
|
||||
def _strip_root(self, path):
|
||||
if path.startswith(self.root):
|
||||
return path[len(self.root) + 1 :]
|
||||
return path
|
||||
|
||||
def postprocess(self, y: str | list[str] | None) -> FileExplorerData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: file path
|
||||
Returns:
|
||||
list representing filepath, where each string is a directory level relative to the root.
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: str | list[str] | None) -> FileExplorerData | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
files = [y] if isinstance(y, str) else y
|
||||
files = [value] if isinstance(value, str) else value
|
||||
|
||||
return FileExplorerData(
|
||||
root=[self._strip_root(file).split(os.path.sep) for file in files]
|
||||
)
|
||||
|
||||
@server
|
||||
def ls(self, y=None) -> list[dict[str, str]] | None:
|
||||
def ls(self, value=None) -> list[dict[str, str]] | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: file path as a list of strings for each directory level relative to the root.
|
||||
value: file path as a list of strings for each directory level relative to the root.
|
||||
Returns:
|
||||
tuple of list of files in directory, then list of folders in directory
|
||||
"""
|
||||
|
@ -124,20 +124,20 @@ class Gallery(Component):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: list[np.ndarray | _Image.Image | str]
|
||||
value: list[np.ndarray | _Image.Image | str]
|
||||
| list[tuple[np.ndarray | _Image.Image | str, str]]
|
||||
| None,
|
||||
) -> GalleryData:
|
||||
"""
|
||||
Parameters:
|
||||
y: list of images, or list of (image, caption) tuples
|
||||
value: list of images, or list of (image, caption) tuples
|
||||
Returns:
|
||||
list of string file paths to images in temp directory
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return GalleryData(root=[])
|
||||
output = []
|
||||
for img in y:
|
||||
for img in value:
|
||||
caption = None
|
||||
if isinstance(img, (tuple, list)):
|
||||
img, caption = img
|
||||
@ -160,8 +160,10 @@ class Gallery(Component):
|
||||
output.append(entry)
|
||||
return GalleryData(root=output)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: GalleryData | None) -> GalleryData | None:
|
||||
if payload is None or not payload.root:
|
||||
return None
|
||||
return payload
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return [
|
||||
|
@ -99,27 +99,27 @@ class HighlightedText(Component):
|
||||
return {"value": [{"token": "Hello", "class_or_confidence": "1"}]}
|
||||
|
||||
def postprocess(
|
||||
self, y: list[tuple[str, str | float | None]] | dict | None
|
||||
self, value: list[tuple[str, str | float | None]] | dict | None
|
||||
) -> HighlightedTextData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end"
|
||||
value: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end"
|
||||
Returns:
|
||||
List of (word, category) tuples
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(y, dict):
|
||||
if isinstance(value, dict):
|
||||
try:
|
||||
text = y["text"]
|
||||
entities = y["entities"]
|
||||
text = value["text"]
|
||||
entities = value["entities"]
|
||||
except KeyError as ke:
|
||||
raise ValueError(
|
||||
"Expected a dictionary with keys 'text' and 'entities' "
|
||||
"for the value of the HighlightedText component."
|
||||
) from ke
|
||||
if len(entities) == 0:
|
||||
y = [(text, None)]
|
||||
value = [(text, None)]
|
||||
else:
|
||||
list_format = []
|
||||
index = 0
|
||||
@ -132,11 +132,11 @@ class HighlightedText(Component):
|
||||
)
|
||||
index = entity["end"]
|
||||
list_format.append((text[index:], None))
|
||||
y = list_format
|
||||
value = list_format
|
||||
if self.combine_adjacent:
|
||||
output = []
|
||||
running_text, running_category = None, None
|
||||
for text, category in y:
|
||||
for text, category in value:
|
||||
if running_text is None:
|
||||
running_text = text
|
||||
running_category = category
|
||||
@ -160,8 +160,13 @@ class HighlightedText(Component):
|
||||
)
|
||||
else:
|
||||
return HighlightedTextData(
|
||||
root=[HighlightedToken(token=o[0], class_or_confidence=o[1]) for o in y]
|
||||
root=[
|
||||
HighlightedToken(token=o[0], class_or_confidence=o[1])
|
||||
for o in value
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return super().preprocess(x)
|
||||
def preprocess(self, payload: HighlightedTextData | None) -> dict | None:
|
||||
if payload is None:
|
||||
return None
|
||||
return payload.model_dump()
|
||||
|
@ -62,11 +62,11 @@ class HTML(Component):
|
||||
def example_inputs(self) -> Any:
|
||||
return "<p>Hello</p>"
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: str | None) -> str | None:
|
||||
return payload
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": "string"}
|
||||
|
@ -38,6 +38,7 @@ class Image(StreamingInput, Component):
|
||||
Events.select,
|
||||
Events.upload,
|
||||
]
|
||||
|
||||
data_model = FileData
|
||||
|
||||
def __init__(
|
||||
@ -141,38 +142,25 @@ class Image(StreamingInput, Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def preprocess(self, x: dict | None) -> np.ndarray | _Image.Image | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: FileData containing an image path pointing to the user's image
|
||||
Returns:
|
||||
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
|
||||
im = _Image.open(x["path"])
|
||||
def preprocess(
|
||||
self, payload: FileData | None
|
||||
) -> np.ndarray | _Image.Image | str | None:
|
||||
if payload is None:
|
||||
return payload
|
||||
im = _Image.open(payload.path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
|
||||
return image_utils.format_image(
|
||||
im, cast(Literal["numpy", "pil", "filepath"], self.type), self.GRADIO_CACHE
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, y: np.ndarray | _Image.Image | str | Path | None
|
||||
self, value: np.ndarray | _Image.Image | str | Path | None
|
||||
) -> FileData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: image as a numpy array, PIL Image, string/Path filepath, or string URL
|
||||
Returns:
|
||||
base64 url data
|
||||
"""
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
return FileData(path=image_utils.save_image(y, self.GRADIO_CACHE))
|
||||
return FileData(path=image_utils.save_image(value, self.GRADIO_CACHE))
|
||||
|
||||
def check_streamable(self):
|
||||
if self.streaming and self.sources != ("webcam"):
|
||||
|
@ -69,31 +69,25 @@ class JSON(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def postprocess(self, y: dict | list | str | None) -> dict | list | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: either a string filepath to a JSON file, or a Python list or dict that can be converted to JSON
|
||||
Returns:
|
||||
JSON output in Python list or dict format
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: dict | list | str | None) -> dict | list | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(y, str):
|
||||
return json.loads(y)
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
else:
|
||||
return y
|
||||
return value
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: dict | list | str | None) -> dict | list | str | None:
|
||||
return payload
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return {"foo": "bar"}
|
||||
|
||||
def flag(self, x: Any, flag_dir: str | Path = "") -> str:
|
||||
return json.dumps(x)
|
||||
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
|
||||
return json.dumps(payload)
|
||||
|
||||
def read_from_flag(self, x: Any, flag_dir: str | Path | None = None):
|
||||
return json.loads(x)
|
||||
def read_from_flag(self, payload: Any, flag_dir: str | Path | None = None):
|
||||
return json.loads(payload)
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": {}, "description": "any valid json"}
|
||||
|
@ -22,7 +22,7 @@ class LabelConfidence(GradioModel):
|
||||
|
||||
|
||||
class LabelData(GradioModel):
|
||||
label: Union[str, int, float]
|
||||
label: Optional[Union[str, int, float]] = None
|
||||
confidences: Optional[List[LabelConfidence]] = None
|
||||
|
||||
|
||||
@ -91,44 +91,46 @@ class Label(Component):
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, y: dict[str, float] | str | float | None
|
||||
self, value: dict[str, float] | str | float | None
|
||||
) -> LabelData | dict | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: a dictionary mapping labels to confidence value, or just a string/numerical label by itself
|
||||
Returns:
|
||||
Object with key 'label' representing primary label, and key 'confidences' representing a list of label-confidence pairs
|
||||
"""
|
||||
if y is None or y == {}:
|
||||
if value is None or value == {}:
|
||||
return {}
|
||||
if isinstance(y, str) and y.endswith(".json") and Path(y).exists():
|
||||
return LabelData(**json.loads(Path(y).read_text()))
|
||||
if isinstance(y, (str, float, int)):
|
||||
return LabelData(label=str(y))
|
||||
if isinstance(y, dict):
|
||||
if "confidences" in y and isinstance(y["confidences"], dict):
|
||||
y = y["confidences"]
|
||||
y = {c["label"]: c["confidence"] for c in y}
|
||||
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
|
||||
if isinstance(value, str) and value.endswith(".json") and Path(value).exists():
|
||||
return LabelData(**json.loads(Path(value).read_text()))
|
||||
if isinstance(value, (str, float, int)):
|
||||
return LabelData(label=str(value))
|
||||
if isinstance(value, dict):
|
||||
if "confidences" in value and isinstance(value["confidences"], dict):
|
||||
value = value["confidences"]
|
||||
value = {c["label"]: c["confidence"] for c in value}
|
||||
sorted_pred = sorted(
|
||||
value.items(), key=operator.itemgetter(1), reverse=True
|
||||
)
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[: self.num_top_classes]
|
||||
return LabelData(
|
||||
**{
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{"label": pred[0], "confidence": pred[1]}
|
||||
for pred in sorted_pred
|
||||
],
|
||||
}
|
||||
label=sorted_pred[0][0],
|
||||
confidences=[
|
||||
LabelConfidence(label=pred[0], confidence=pred[1])
|
||||
for pred in sorted_pred
|
||||
],
|
||||
)
|
||||
raise ValueError(
|
||||
"The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences. "
|
||||
f"Instead, got a {type(y)}"
|
||||
f"Instead, got a {type(value)}"
|
||||
)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(
|
||||
self, payload: LabelData | None
|
||||
) -> dict[str, float] | str | float | None:
|
||||
if payload is None:
|
||||
return None
|
||||
if payload.confidences is None:
|
||||
return payload.label
|
||||
return {
|
||||
d["label"]: d["confidence"] for d in payload.model_dump()["confidences"]
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return {
|
||||
|
@ -287,15 +287,15 @@ class LinePlot(Plot):
|
||||
return chart
|
||||
|
||||
def postprocess(
|
||||
self, y: pd.DataFrame | dict | None
|
||||
self, value: pd.DataFrame | dict | None
|
||||
) -> AltairPlotData | dict | None:
|
||||
# if None or update
|
||||
if y is None or isinstance(y, dict):
|
||||
return y
|
||||
if value is None or isinstance(value, dict):
|
||||
return value
|
||||
if self.x is None or self.y is None:
|
||||
raise ValueError("No value provided for required parameters `x` and `y`.")
|
||||
chart = self.create_plot(
|
||||
value=y,
|
||||
value=value,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
color=self.color,
|
||||
@ -318,12 +318,10 @@ class LinePlot(Plot):
|
||||
width=self.width,
|
||||
)
|
||||
|
||||
return AltairPlotData(
|
||||
**{"type": "altair", "plot": chart.to_json(), "chart": "line"}
|
||||
)
|
||||
return AltairPlotData(type="altair", plot=chart.to_json(), chart="line")
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return None
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, value: AltairPlotData | None) -> AltairPlotData | None:
|
||||
return value
|
||||
|
@ -75,24 +75,18 @@ class Markdown(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def postprocess(self, y: str | None) -> str | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: markdown representation
|
||||
Returns:
|
||||
HTML rendering of markdown
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
unindented_y = inspect.cleandoc(y)
|
||||
unindented_y = inspect.cleandoc(value)
|
||||
return unindented_y
|
||||
|
||||
def as_example(self, input_data: str | None) -> str:
|
||||
postprocessed = self.postprocess(input_data)
|
||||
return postprocessed if postprocessed else ""
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: str | None) -> str | None:
|
||||
return payload
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return "# Hello!"
|
||||
|
@ -93,27 +93,15 @@ class Model3D(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def preprocess(self, x: dict[str, str] | None) -> str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: JSON object with filename as 'name' property and base64 data as 'data' property
|
||||
Returns:
|
||||
string file path to temporary file with the 3D image model
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
return x["path"]
|
||||
def preprocess(self, payload: FileData | None) -> str | None:
|
||||
if payload is None:
|
||||
return payload
|
||||
return payload.path
|
||||
|
||||
def postprocess(self, y: str | Path | None) -> FileData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: path to the model
|
||||
Returns:
|
||||
file name mapped to base64 url data
|
||||
"""
|
||||
if y is None:
|
||||
return y
|
||||
return FileData(path=str(y))
|
||||
def postprocess(self, value: str | Path | None) -> FileData | None:
|
||||
if value is None:
|
||||
return value
|
||||
return FileData(path=str(value))
|
||||
|
||||
def as_example(self, input_data: str | None) -> str:
|
||||
return Path(input_data).name if input_data else ""
|
||||
|
@ -108,33 +108,21 @@ class Number(FormComponent):
|
||||
else:
|
||||
return round(num, precision)
|
||||
|
||||
def preprocess(self, x: float | None) -> float | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: numeric input
|
||||
Returns:
|
||||
number representing function input
|
||||
"""
|
||||
if x is None:
|
||||
def preprocess(self, payload: float | None) -> float | None:
|
||||
if payload is None:
|
||||
return None
|
||||
elif self.minimum is not None and x < self.minimum:
|
||||
raise Error(f"Value {x} is less than minimum value {self.minimum}.")
|
||||
elif self.maximum is not None and x > self.maximum:
|
||||
raise Error(f"Value {x} is greater than maximum value {self.maximum}.")
|
||||
return self._round_to_precision(x, self.precision)
|
||||
elif self.minimum is not None and payload < self.minimum:
|
||||
raise Error(f"Value {payload} is less than minimum value {self.minimum}.")
|
||||
elif self.maximum is not None and payload > self.maximum:
|
||||
raise Error(
|
||||
f"Value {payload} is greater than maximum value {self.maximum}."
|
||||
)
|
||||
return self._round_to_precision(payload, self.precision)
|
||||
|
||||
def postprocess(self, y: float | None) -> float | None:
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
|
||||
Parameters:
|
||||
y: numeric output
|
||||
Returns:
|
||||
number representing function output
|
||||
"""
|
||||
if y is None:
|
||||
def postprocess(self, value: float | None) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
return self._round_to_precision(y, self.precision)
|
||||
return self._round_to_precision(value, self.precision)
|
||||
|
||||
def api_info(self) -> dict[str, str]:
|
||||
return {"type": "number"}
|
||||
|
@ -97,36 +97,30 @@ class Plot(Component):
|
||||
config["bokeh_version"] = bokeh_version
|
||||
return config
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: PlotData | None) -> PlotData | None:
|
||||
return payload
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return None
|
||||
|
||||
def postprocess(self, y) -> PlotData | None:
|
||||
"""
|
||||
Parameters:
|
||||
y: plot data
|
||||
Returns:
|
||||
plot type mapped to plot base64 data
|
||||
"""
|
||||
def postprocess(self, value) -> PlotData | None:
|
||||
import matplotlib.figure
|
||||
|
||||
if y is None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(y, (ModuleType, matplotlib.figure.Figure)): # type: ignore
|
||||
if isinstance(value, (ModuleType, matplotlib.figure.Figure)): # type: ignore
|
||||
dtype = "matplotlib"
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
elif "bokeh" in y.__module__:
|
||||
out_y = processing_utils.encode_plot_to_base64(value)
|
||||
elif "bokeh" in value.__module__:
|
||||
dtype = "bokeh"
|
||||
from bokeh.embed import json_item # type: ignore
|
||||
|
||||
out_y = json.dumps(json_item(y))
|
||||
out_y = json.dumps(json_item(value))
|
||||
else:
|
||||
is_altair = "altair" in y.__module__
|
||||
is_altair = "altair" in value.__module__
|
||||
dtype = "altair" if is_altair else "plotly"
|
||||
out_y = y.to_json()
|
||||
return PlotData(**{"type": dtype, "plot": out_y})
|
||||
out_y = value.to_json()
|
||||
return PlotData(type=dtype, plot=out_y)
|
||||
|
||||
|
||||
class AltairPlot:
|
||||
|
@ -94,28 +94,30 @@ class Radio(FormComponent):
|
||||
def example_inputs(self) -> Any:
|
||||
return self.choices[0][1] if self.choices else None
|
||||
|
||||
def preprocess(self, x: str | int | float | None) -> str | int | float | None:
|
||||
def preprocess(self, payload: str | int | float | None) -> str | int | float | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: selected choice
|
||||
payload: selected choice
|
||||
Returns:
|
||||
value of the selected choice as string or index within choice list
|
||||
"""
|
||||
if self.type == "value":
|
||||
return x
|
||||
return payload
|
||||
elif self.type == "index":
|
||||
if x is None:
|
||||
if payload is None:
|
||||
return None
|
||||
else:
|
||||
choice_values = [value for _, value in self.choices]
|
||||
return choice_values.index(x) if x in choice_values else None
|
||||
return (
|
||||
choice_values.index(payload) if payload in choice_values else None
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown type: {self.type}. Please choose from: 'value', 'index'."
|
||||
)
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value: str | int | float | None) -> str | int | float | None:
|
||||
return value
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
@ -310,15 +310,15 @@ class ScatterPlot(Plot):
|
||||
return chart
|
||||
|
||||
def postprocess(
|
||||
self, y: pd.DataFrame | dict | None
|
||||
self, value: pd.DataFrame | dict | None
|
||||
) -> AltairPlotData | dict | None:
|
||||
# if None or update
|
||||
if y is None or isinstance(y, dict):
|
||||
return y
|
||||
if value is None or isinstance(value, dict):
|
||||
return value
|
||||
if self.x is None or self.y is None:
|
||||
raise ValueError("No value provided for required parameters `x` and `y`.")
|
||||
chart = self.create_plot(
|
||||
value=y,
|
||||
value=value,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
color=self.color,
|
||||
@ -343,12 +343,10 @@ class ScatterPlot(Plot):
|
||||
y_lim=self.y_lim,
|
||||
)
|
||||
|
||||
return AltairPlotData(
|
||||
**{"type": "altair", "plot": chart.to_json(), "chart": "scatter"}
|
||||
)
|
||||
return AltairPlotData(type="altair", plot=chart.to_json(), chart="scatter")
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return None
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None:
|
||||
return payload
|
||||
|
@ -114,15 +114,8 @@ class Slider(FormComponent):
|
||||
value = round(value, n_decimals)
|
||||
return value
|
||||
|
||||
def postprocess(self, y: float | None) -> float | None:
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
Parameters:
|
||||
y: numeric output
|
||||
Returns:
|
||||
numeric output or minimum number if None
|
||||
"""
|
||||
return self.minimum if y is None else y
|
||||
def postprocess(self, value: float | None) -> float:
|
||||
return self.minimum if value is None else value
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: float) -> float:
|
||||
return payload
|
||||
|
@ -46,11 +46,11 @@ class State(Component):
|
||||
) from err
|
||||
super().__init__(value=self.value)
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
return x
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
return payload
|
||||
|
||||
def postprocess(self, y):
|
||||
return y
|
||||
def postprocess(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": {}, "description": "any valid json"}
|
||||
|
@ -117,25 +117,11 @@ class Textbox(FormComponent):
|
||||
self.rtl = rtl
|
||||
self.text_align = text_align
|
||||
|
||||
def preprocess(self, x: str | None) -> str | None:
|
||||
"""
|
||||
Preprocesses input (converts it to a string) before passing it to the function.
|
||||
Parameters:
|
||||
x: text
|
||||
Returns:
|
||||
text
|
||||
"""
|
||||
return None if x is None else str(x)
|
||||
def preprocess(self, payload: str | None) -> str | None:
|
||||
return None if payload is None else str(payload)
|
||||
|
||||
def postprocess(self, y: str | None) -> str | None:
|
||||
"""
|
||||
Postproccess the function output y by converting it to a str before passing it to the frontend.
|
||||
Parameters:
|
||||
y: function output to postprocess.
|
||||
Returns:
|
||||
text
|
||||
"""
|
||||
return None if y is None else str(y)
|
||||
def postprocess(self, value: str | None) -> str | None:
|
||||
return None if value is None else str(value)
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": "string"}
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Literal
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
@ -19,6 +20,12 @@ set_documentation_group("component")
|
||||
class ListFiles(GradioRootModel):
|
||||
root: List[FileData]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.root[index]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.root)
|
||||
|
||||
|
||||
@document()
|
||||
class UploadButton(Component):
|
||||
@ -87,6 +94,10 @@ class UploadButton(Component):
|
||||
raise ValueError(
|
||||
f"Parameter file_types must be a list. Received {file_types.__class__.__name__}"
|
||||
)
|
||||
if self.file_count == "multiple":
|
||||
self.data_model = ListFiles
|
||||
else:
|
||||
self.data_model = FileData
|
||||
self.size = size
|
||||
self.file_types = file_types
|
||||
self.label = label
|
||||
@ -118,12 +129,12 @@ class UploadButton(Component):
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
]
|
||||
|
||||
def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString:
|
||||
file_name = f["path"]
|
||||
def _process_single_file(self, f: FileData) -> bytes | NamedString:
|
||||
file_name = f.path
|
||||
if self.type == "filepath":
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
|
||||
file.name = file_name
|
||||
return NamedString(file.name)
|
||||
return NamedString(file_name)
|
||||
elif self.type == "binary":
|
||||
with open(file_name, "rb") as file_data:
|
||||
return file_data.read()
|
||||
@ -135,30 +146,42 @@ class UploadButton(Component):
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, x: list[dict[str, Any]] | dict[str, Any] | None
|
||||
self, payload: ListFiles | FileData | None
|
||||
) -> bytes | NamedString | list[bytes | NamedString] | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
Returns:
|
||||
File objects in requested format
|
||||
"""
|
||||
if x is None:
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
if self.file_count == "single":
|
||||
if isinstance(x, list):
|
||||
return self._process_single_file(x[0])
|
||||
if isinstance(payload, ListFiles):
|
||||
return self._process_single_file(payload[0])
|
||||
else:
|
||||
return self._process_single_file(x)
|
||||
return self._process_single_file(payload)
|
||||
else:
|
||||
if isinstance(x, list):
|
||||
return [self._process_single_file(f) for f in x]
|
||||
if isinstance(payload, ListFiles):
|
||||
return [self._process_single_file(f) for f in payload]
|
||||
else:
|
||||
return [self._process_single_file(x)]
|
||||
return [self._process_single_file(payload)]
|
||||
|
||||
def postprocess(self, y):
|
||||
return super().postprocess(y)
|
||||
def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return ListFiles(
|
||||
root=[
|
||||
FileData(
|
||||
path=file,
|
||||
orig_name=Path(file).name,
|
||||
size=Path(file).stat().st_size,
|
||||
)
|
||||
for file in value
|
||||
]
|
||||
)
|
||||
else:
|
||||
return FileData(
|
||||
path=value,
|
||||
orig_name=Path(value).name,
|
||||
size=Path(value).stat().st_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def skip_api(self):
|
||||
|
@ -43,7 +43,7 @@ class Video(Component):
|
||||
"""
|
||||
|
||||
data_model = VideoData
|
||||
input_data_model = FileData
|
||||
|
||||
EVENTS = [
|
||||
Events.change,
|
||||
Events.clear,
|
||||
@ -155,18 +155,11 @@ class Video(Component):
|
||||
value=value,
|
||||
)
|
||||
|
||||
def preprocess(self, x: dict | VideoData) -> str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: A tuple of (video file data, subtitle file data) or just video file data.
|
||||
Returns:
|
||||
A string file path or URL to the preprocessed video. Subtitle file data is ignored.
|
||||
"""
|
||||
if x is None:
|
||||
def preprocess(self, payload: VideoData | None) -> str | None:
|
||||
if payload is None:
|
||||
return None
|
||||
data: VideoData = VideoData(**x) if isinstance(x, dict) else x
|
||||
assert data.video.path
|
||||
file_name = Path(data.video.path)
|
||||
assert payload.video.path
|
||||
file_name = Path(payload.video.path)
|
||||
uploaded_format = file_name.suffix.replace(".", "")
|
||||
needs_formatting = self.format is not None and uploaded_format != self.format
|
||||
flip = self.sources == ["webcam"] and self.mirror_webcam
|
||||
@ -221,24 +214,6 @@ class Video(Component):
|
||||
def postprocess(
|
||||
self, y: str | Path | tuple[str | Path, str | Path | None] | None
|
||||
) -> VideoData | None:
|
||||
"""
|
||||
Processes a video to ensure that it is in the correct format before returning it to the front end.
|
||||
Parameters:
|
||||
y: video data in either of the following formats: a tuple of (video filepath, optional subtitle filepath), or just a filepath or URL to an video file, or None.
|
||||
Returns:
|
||||
a tuple with the two dictionary, reresent to video and (optional) subtitle, which following formats:
|
||||
- The first dictionary represents the video file and contains the following keys:
|
||||
- 'name': a file path to a temporary copy of the processed video.
|
||||
- 'data': None
|
||||
- 'is_file': True
|
||||
- The second dictionary represents the subtitle file and contains the following keys:
|
||||
- 'name': None
|
||||
- 'data': Base64 encode the processed subtitle data.
|
||||
- 'is_file': False
|
||||
- If subtitle is None, returns (video, None).
|
||||
- If both video and subtitle are None, returns None.
|
||||
"""
|
||||
|
||||
if y is None or y == [None, None] or y == (None, None):
|
||||
return None
|
||||
if isinstance(y, (str, Path)):
|
||||
@ -269,14 +244,6 @@ class Video(Component):
|
||||
def _format_video(self, video: str | Path | None) -> FileData | None:
|
||||
"""
|
||||
Processes a video to ensure that it is in the correct format.
|
||||
Parameters:
|
||||
video: video data in either of the following formats: a string filepath or URL to an video file, or None.
|
||||
Returns:
|
||||
a dictionary with the following keys:
|
||||
|
||||
- 'name': a file path to a temporary copy of the processed video.
|
||||
- 'data': None
|
||||
- 'is_file': True
|
||||
"""
|
||||
if video is None:
|
||||
return None
|
||||
@ -328,13 +295,6 @@ class Video(Component):
|
||||
def _format_subtitle(self, subtitle: str | Path | None) -> FileData | None:
|
||||
"""
|
||||
Convert subtitle format to VTT and process the video to ensure it meets the HTML5 requirements.
|
||||
Parameters:
|
||||
subtitle: subtitle path in either of the VTT and SRT format.
|
||||
Returns:
|
||||
a dictionary with the following keys:
|
||||
- 'name': None
|
||||
- 'data': base64-encoded subtitle data.
|
||||
- 'is_file': False
|
||||
"""
|
||||
|
||||
def srt_to_vtt(srt_file_path, vtt_file_path):
|
||||
|
@ -173,8 +173,13 @@ class Queue:
|
||||
if concurrency_limit is None or existing_worker_count < concurrency_limit:
|
||||
batch = block_fn.batch
|
||||
if batch:
|
||||
remaining_worker_count = concurrency_limit - existing_worker_count
|
||||
batch_size = block_fn.max_batch_size
|
||||
if concurrency_limit is None:
|
||||
remaining_worker_count = batch_size - 1
|
||||
else:
|
||||
remaining_worker_count = (
|
||||
concurrency_limit - existing_worker_count
|
||||
)
|
||||
rest_of_batch = [
|
||||
event
|
||||
for event in self.event_queue[index:]
|
||||
|
@ -52,9 +52,9 @@
|
||||
all_file_data = (await upload(all_file_data, root))?.filter(
|
||||
(x) => x !== null
|
||||
) as FileData[];
|
||||
dispatch("change", all_file_data);
|
||||
dispatch("upload", all_file_data);
|
||||
value = all_file_data;
|
||||
value = file_count === "single" ? all_file_data?.[0] : all_file_data
|
||||
dispatch("change", value);
|
||||
dispatch("upload", value);
|
||||
}
|
||||
|
||||
async function loadFilesFromUpload(e: Event): Promise<void> {
|
||||
|
@ -30,6 +30,8 @@ except ImportError:
|
||||
|
||||
import gradio as gr
|
||||
from gradio import processing_utils, utils
|
||||
from gradio.components.dataframe import DataframeData
|
||||
from gradio.components.video import VideoData
|
||||
from gradio.data_classes import FileData
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -480,9 +482,9 @@ class TestDropdown:
|
||||
"""
|
||||
dropdown_input = gr.Dropdown(["a", "b", ("c", "c full")], multiselect=True)
|
||||
assert dropdown_input.preprocess("a") == "a"
|
||||
assert dropdown_input.postprocess("a") == "a"
|
||||
assert dropdown_input.postprocess("a") == ["a"]
|
||||
assert dropdown_input.preprocess("c full") == "c full"
|
||||
assert dropdown_input.postprocess("c full") == "c full"
|
||||
assert dropdown_input.postprocess("c full") == ["c full"]
|
||||
|
||||
# When a Gradio app is loaded with gr.load, the tuples are converted to lists,
|
||||
# so we need to test that case as well
|
||||
@ -558,7 +560,7 @@ class TestImage:
|
||||
type: pil, file, filepath, numpy
|
||||
"""
|
||||
|
||||
img = dict(FileData(path="test/test_files/bus.png"))
|
||||
img = FileData(path="test/test_files/bus.png")
|
||||
image_input = gr.Image()
|
||||
|
||||
image_input = gr.Image(type="filepath")
|
||||
@ -605,12 +607,12 @@ class TestImage:
|
||||
# Output functionalities
|
||||
image_output = gr.Image(type="pil")
|
||||
processed_image = image_output.postprocess(
|
||||
PIL.Image.open(img["path"])
|
||||
PIL.Image.open(img.path)
|
||||
).model_dump()
|
||||
assert processed_image is not None
|
||||
if processed_image is not None:
|
||||
processed = PIL.Image.open(cast(dict, processed_image).get("path", ""))
|
||||
source = PIL.Image.open(img["path"])
|
||||
source = PIL.Image.open(img.path)
|
||||
assert processed.size == source.size
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
@ -696,7 +698,7 @@ class TestAudio:
|
||||
Preprocess, postprocess serialize, get_config, deserialize
|
||||
type: filepath, numpy, file
|
||||
"""
|
||||
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
||||
x_wav = FileData(path=media_data.BASE64_AUDIO["path"])
|
||||
audio_input = gr.Audio()
|
||||
output1 = audio_input.preprocess(x_wav)
|
||||
assert output1[0] == 8000
|
||||
@ -735,11 +737,6 @@ class TestAudio:
|
||||
"_selectable": False,
|
||||
}
|
||||
assert audio_input.preprocess(None) is None
|
||||
x_wav["is_example"] = True
|
||||
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
|
||||
output2 = audio_input.preprocess(x_wav)
|
||||
assert output2 is not None
|
||||
assert output1 != output2
|
||||
|
||||
audio_input = gr.Audio(type="filepath")
|
||||
assert isinstance(audio_input.preprocess(x_wav), str)
|
||||
@ -821,21 +818,21 @@ class TestAudio:
|
||||
assert iface(100).endswith(".wav")
|
||||
|
||||
def test_audio_preprocess_can_be_read_by_scipy(self, gradio_temp_dir):
|
||||
x_wav = {
|
||||
"path": processing_utils.save_base64_to_cache(
|
||||
x_wav = FileData(
|
||||
path=processing_utils.save_base64_to_cache(
|
||||
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
audio_input = gr.Audio(type="filepath")
|
||||
output = audio_input.preprocess(x_wav)
|
||||
wavfile.read(output)
|
||||
|
||||
def test_prepost_process_to_mp3(self, gradio_temp_dir):
|
||||
x_wav = {
|
||||
"path": processing_utils.save_base64_to_cache(
|
||||
x_wav = FileData(
|
||||
path=processing_utils.save_base64_to_cache(
|
||||
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
audio_input = gr.Audio(type="filepath", format="mp3")
|
||||
output = audio_input.preprocess(x_wav)
|
||||
assert output.endswith("mp3")
|
||||
@ -850,13 +847,13 @@ class TestFile:
|
||||
"""
|
||||
Preprocess, serialize, get_config, value
|
||||
"""
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
x_file = FileData(path=media_data.BASE64_FILE["path"])
|
||||
file_input = gr.File()
|
||||
output = file_input.preprocess({"path": x_file["path"]})
|
||||
output = file_input.preprocess(x_file)
|
||||
assert isinstance(output, str)
|
||||
|
||||
input1 = file_input.preprocess({"path": x_file["path"]})
|
||||
input2 = file_input.preprocess({"path": x_file["path"]})
|
||||
input1 = file_input.preprocess(x_file)
|
||||
input2 = file_input.preprocess(x_file)
|
||||
assert input1 == input1.name # Testing backwards compatibility
|
||||
assert input1 == input2
|
||||
assert Path(input1).name == "sample_file.pdf"
|
||||
@ -884,7 +881,7 @@ class TestFile:
|
||||
assert file_input.preprocess(None) is None
|
||||
assert file_input.preprocess(x_file) is not None
|
||||
|
||||
zero_size_file = {"path": "document.txt", "size": 0}
|
||||
zero_size_file = FileData(path="document.txt", size=0)
|
||||
temp_file = file_input.preprocess(zero_size_file)
|
||||
assert not Path(temp_file.name).exists()
|
||||
|
||||
@ -933,13 +930,13 @@ class TestUploadButton:
|
||||
"""
|
||||
preprocess
|
||||
"""
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
x_file = FileData(path=media_data.BASE64_FILE["path"])
|
||||
upload_input = gr.UploadButton()
|
||||
input = upload_input.preprocess({"path": x_file})
|
||||
input = upload_input.preprocess(x_file)
|
||||
assert isinstance(input, str)
|
||||
|
||||
input1 = upload_input.preprocess({"path": x_file})
|
||||
input2 = upload_input.preprocess({"path": x_file})
|
||||
input1 = upload_input.preprocess(x_file)
|
||||
input2 = upload_input.preprocess(x_file)
|
||||
assert input1 == input1.name # Testing backwards compatibility
|
||||
assert input1 == input2
|
||||
|
||||
@ -961,7 +958,7 @@ class TestDataframe:
|
||||
"metadata": None,
|
||||
}
|
||||
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
output = dataframe_input.preprocess(DataframeData(**x_data))
|
||||
assert output["Age"][1] == 24
|
||||
assert not output["Member"][0]
|
||||
assert dataframe_input.postprocess(x_data) == x_data
|
||||
@ -998,7 +995,7 @@ class TestDataframe:
|
||||
"column_widths": [],
|
||||
}
|
||||
dataframe_input = gr.Dataframe()
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
output = dataframe_input.preprocess(DataframeData(**x_data))
|
||||
assert output["Age"][1] == 24
|
||||
with pytest.raises(ValueError):
|
||||
gr.Dataframe(type="unknown")
|
||||
@ -1315,7 +1312,9 @@ class TestVideo:
|
||||
"""
|
||||
Preprocess, serialize, deserialize, get_config
|
||||
"""
|
||||
x_video = {"video": {"path": deepcopy(media_data.BASE64_VIDEO)["path"]}}
|
||||
x_video = VideoData(
|
||||
video=FileData(path=deepcopy(media_data.BASE64_VIDEO)["path"])
|
||||
)
|
||||
video_input = gr.Video()
|
||||
|
||||
x_video = processing_utils.move_files_to_cache([x_video], video_input)[0]
|
||||
@ -1357,8 +1356,6 @@ class TestVideo:
|
||||
"_selectable": False,
|
||||
}
|
||||
assert video_input.preprocess(None) is None
|
||||
x_video["is_example"] = True
|
||||
assert video_input.preprocess(x_video) is not None
|
||||
video_input = gr.Video(format="avi")
|
||||
output_video = video_input.preprocess(x_video)
|
||||
assert output_video[-3:] == "avi"
|
||||
@ -1468,7 +1465,7 @@ class TestVideo:
|
||||
@patch("gradio.components.video.FFmpeg")
|
||||
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
|
||||
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
|
||||
x_video = {"video": deepcopy(media_data.BASE64_VIDEO)}
|
||||
x_video = VideoData(video=FileData(path=media_data.BASE64_VIDEO["path"]))
|
||||
video_input = gr.Video(sources=["webcam"])
|
||||
_ = video_input.preprocess(x_video)
|
||||
|
||||
|
Binary file not shown.
@ -288,18 +288,19 @@ class TestThemeUploadDownload:
|
||||
)
|
||||
assert next_version == "3.20.2"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_theme_download(self):
|
||||
assert (
|
||||
gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict()
|
||||
== dracula.to_dict()
|
||||
)
|
||||
## Commenting out until after 4.0 Spaces are up
|
||||
# @pytest.mark.flaky
|
||||
# def test_theme_download(self):
|
||||
# assert (
|
||||
# gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict()
|
||||
# == dracula.to_dict()
|
||||
# )
|
||||
|
||||
with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo:
|
||||
pass
|
||||
# with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo:
|
||||
# pass
|
||||
|
||||
assert demo.theme.to_dict() == dracula.to_dict()
|
||||
assert demo.theme.name == "gradio/dracula_test"
|
||||
# assert demo.theme.to_dict() == dracula.to_dict()
|
||||
# assert demo.theme.name == "gradio/dracula_test"
|
||||
|
||||
def test_theme_download_raises_error_if_theme_does_not_exist(self):
|
||||
with pytest.raises(
|
||||
|
Loading…
x
Reference in New Issue
Block a user