Cleaning up the way data is processed for components (#1967)

* remove preprocess_example

* removing methods

* added path support for images

* fixes

* video

* formatting

* fixing preprocess

* fixes

* removed from audio

* fixed file

* formatting

* serialization

* foramtting

* formatting

* removed save flag / restore flag

* formatting

* removed flagging

* removed

* load value

* fixing typing

* fixes, typing

* fixes

* file

* handling images

* formatting

* fixed serializing for flagging

* formatting

* json

* temp file

* removed processing

* changed processing

* fixed temp FINALLY

* flagging works

* fix examples test

* formatting

* async examples

* working on mix

* comment out failing test

* fixed interface problem

* fix kitchen sink deprecation warning

* gallery examples

* fixes

* fixes to serialization

* fixing label serializing

* fixed file serialization

* kitchen sink restored

* outbreak forecast updated

* formatting

* formatting and api mode

* fix 1 test :/

* fixing tests

* fixed components tests

* remvoed test files

* formatting

* fixed examples

* fixes

* formatting

* restored certain files

* added encryption

* fixed syntax mistake

* formatting

* fixed 1 test

* clean up interface

* formatting

* fixed route tests

* more fixes

* formatting

* formatting

* fixing pipeline

* format frontend

* format backend

* tweaks

* fix

* fix final test?

* merged

* Sanitize for CSV (#2017)

* sanitize for csv

* added sanitization logic

* fixed examples

* turn cache off

* fixed example caching with optional inputs

* fixed review problems

* fixed Interface.load

* updating the tests

* updating the tests

* fix

* fixed seriailizing

* testing

* rewrite run prediction

* formatting

* await

* fixes

* formatting

* finally fixed mix

* fixed tests

* formatting

* formatting

* deserialize fix

* formatting

* fixes

* fixes

* fix

* fix tests

* fixes

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
This commit is contained in:
Abubakar Abid 2022-08-23 08:31:04 -07:00 committed by GitHub
parent e2dc87aa2b
commit 88e9c19c27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1016 additions and 1487 deletions

View File

@ -8,7 +8,6 @@ import gradio as gr
def fake_gan(count, *args):
time.sleep(1)
images = [
random.choice(
[
@ -40,12 +39,12 @@ demo = gr.Interface(
title="FD-GAN",
description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
examples=[
[2, cheetah, 12, None, None, None],
[1, cheetah, 2, None, None, None],
[4, cheetah, 42, None, None, None],
[5, cheetah, 23, None, None, None],
[4, cheetah, 11, None, None, None],
[3, cheetah, 1, None, None, None],
[2, cheetah, None, 12, None, None],
[1, cheetah, None, 2, None, None],
[4, cheetah, None, 42, None, None],
[5, cheetah, None, 23, None, None],
[4, cheetah, None, 11, None, None],
[3, cheetah, None, 1, None, None],
],
)

View File

@ -1,7 +1,7 @@
import gradio as gr
def func(slider_1, slider_2):
def func(slider_1, slider_2, *args):
return slider_1 + slider_2 * 5

View File

@ -83,7 +83,6 @@ def fn(
df2, # Timeseries
)
demo = gr.Interface(
fn,
inputs=[
@ -115,8 +114,8 @@ demo = gr.Interface(
gr.Audio(label="Audio"),
gr.Image(label="Image"),
gr.Video(label="Video"),
gr.HighlightedText(
label="HighlightedText", color_map={"punc": "pink", "test 0": "blue"}
gr.HighlightedText(label="HighlightedText").style(
color_map={"punc": "pink", "test 0": "blue"}
),
gr.HighlightedText(label="HighlightedText", show_legend=True),
gr.JSON(label="JSON"),
@ -152,9 +151,9 @@ demo = gr.Interface(
* 3,
theme="default",
title="Kitchen Sink",
cache_examples=False,
description="Try out all the components!",
article="Learn more about [Gradio](http://gradio.app)",
cache_examples=True
)
if __name__ == "__main__":

View File

@ -4,7 +4,6 @@ import os
def load_mesh(mesh_file_name):
time.sleep(2)
return mesh_file_name, mesh_file_name
@ -22,7 +21,6 @@ demo = gr.Interface(
[os.path.join(os.path.dirname(__file__), "files/Fox.gltf")],
[os.path.join(os.path.dirname(__file__), "files/face.obj")],
],
cache_examples=True,
)
if __name__ == "__main__":

View File

@ -6,9 +6,6 @@ import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import pandas as pd
import bokeh.plotting as bk
from bokeh.models import ColumnDataSource
from bokeh.embed import json_item
import gradio as gr
@ -42,15 +39,10 @@ def outbreak(plot_type, r, month, countries, social_distancing):
yaxis_title="Days Since Day 0")
return fig
else:
source = ColumnDataSource(df)
p = bk.figure(title="Outbreak in " + month, x_axis_label="Cases", y_axis_label="Days Since Day 0")
for country in countries:
p.line(x='day', y=country, line_width=2, source=source)
item_text = json_item(p, "plotDiv")
return item_text
raise ValueError("A plot type must be selected")
inputs = [
gr.Dropdown(["Matplotlib", "Plotly", "Bokeh"], label="Plot Type"),
gr.Dropdown(["Matplotlib", "Plotly"], label="Plot Type"),
gr.Slider(1, 4, 3.2, label="R"),
gr.Dropdown(["January", "February", "March", "April", "May"], label="Month"),
gr.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries",
@ -62,7 +54,6 @@ outputs = gr.Plot()
demo = gr.Interface(fn=outbreak, inputs=inputs, outputs=outputs, examples=[
["Matplotlib", 2, "March", ["Mexico", "UK"], True],
["Plotly", 3.6, "February", ["Canada", "Mexico", "UK"], False],
["Bokeh", 1.2, "May", ["UK"], True]
], cache_examples=True)

View File

@ -6,6 +6,7 @@ import inspect
import os
import random
import sys
import tempfile
import time
import warnings
import webbrowser
@ -66,6 +67,8 @@ class Block:
Context.block.children.append(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
if hasattr(self, "temp_dir"):
Context.root_block.temp_dirs.add(self.temp_dir)
def unrender(self):
"""
@ -213,6 +216,18 @@ class BlockFunction:
self.total_runtime = 0
self.total_runs = 0
def __str__(self):
return str(
{
"fn": self.fn.__name__ if self.fn is not None else None,
"preprocess": self.preprocess,
"postprocess": self.postprocess,
}
)
def __repr__(self):
return str(self)
class class_or_instancemethod(classmethod):
def __get__(self, instance, type_):
@ -327,7 +342,6 @@ class Blocks(BlockContext):
# Cleanup shared parameters with Interface #TODO: is this part still necessary after Interface with Blocks?
self.limiter = None
self.save_to = None
self.api_mode = False
self.theme = theme
self.requires_permissions = False # TODO: needs to be implemented
self.encrypt = False
@ -366,6 +380,7 @@ class Blocks(BlockContext):
self.auth = None
self.dev_mode = True
self.app_id = random.getrandbits(64)
self.temp_dirs = set()
self.title = title
@property
@ -446,7 +461,6 @@ class Blocks(BlockContext):
blocks.input_components = [blocks.blocks[i] for i in dependency["inputs"]]
blocks.output_components = [blocks.blocks[o] for o in dependency["outputs"]]
blocks.api_mode = True
return blocks
def __call__(self, *params, fn_index=0):
@ -459,49 +473,41 @@ class Blocks(BlockContext):
dependency = self.dependencies[fn_index]
block_fn = self.fns[fn_index]
if self.api_mode:
serialized_params = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the inputs are stateful."
)
else:
serialized_input = block.serialize(params[i], True)
serialized_params.append(serialized_input)
else:
serialized_params = params
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the inputs are stateful."
)
else:
serialized_input = block.serialize(params[i])
processed_input.append(serialized_input)
processed_input = self.preprocess_data(fn_index, serialized_params, None)
processed_input = self.preprocess_data(fn_index, processed_input, None)
if inspect.iscoroutinefunction(block_fn.fn):
predictions = utils.synchronize_async(block_fn.fn, *processed_input)
else:
predictions = block_fn.fn(*processed_input)
output = self.postprocess_data(fn_index, predictions, None)
predictions = self.postprocess_data(fn_index, predictions, None)
if self.api_mode:
output_copy = copy.deepcopy(output)
deserialized_output = []
for o, output_id in enumerate(dependency["outputs"]):
block = self.blocks[output_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the outputs are stateful."
)
else:
deserialized = block.deserialize(output_copy[o])
deserialized_output.append(deserialized)
else:
deserialized_output = output
output_copy = copy.deepcopy(predictions)
predictions = []
for o, output_id in enumerate(dependency["outputs"]):
block = self.blocks[output_id]
if getattr(block, "stateful", False):
raise ValueError(
"Cannot call Blocks object as a function if any of"
" the outputs are stateful."
)
else:
deserialized = block.deserialize(output_copy[o])
predictions.append(deserialized)
if len(deserialized_output) == 1:
return deserialized_output[0]
return deserialized_output
return utils.resolve_singleton(predictions)
def __str__(self):
return self.__repr__()
@ -528,6 +534,7 @@ class Blocks(BlockContext):
Context.root_block.blocks.update(self.blocks)
Context.root_block.fns.extend(self.fns)
Context.root_block.dependencies.extend(self.dependencies)
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
if Context.block is not None:
Context.block.children.extend(self.children)
@ -621,7 +628,7 @@ class Blocks(BlockContext):
async def process_api(
self,
fn_index: int,
raw_input: List[Any],
inputs: List[Any],
username: str = None,
state: Optional[Dict[int, any]] = None,
) -> Dict[str, Any]:
@ -636,16 +643,16 @@ class Blocks(BlockContext):
"""
block_fn = self.fns[fn_index]
processed_input = self.preprocess_data(fn_index, raw_input, state)
inputs = self.preprocess_data(fn_index, inputs, state)
predictions, duration = await self.call_function(fn_index, processed_input)
predictions, duration = await self.call_function(fn_index, inputs)
block_fn.total_runtime += duration
block_fn.total_runs += 1
output = self.postprocess_data(fn_index, predictions, state)
predictions = self.postprocess_data(fn_index, predictions, state)
return {
"data": output,
"data": predictions,
"duration": duration,
"average_duration": block_fn.total_runtime / block_fn.total_runs,
}

File diff suppressed because it is too large Load Diff

View File

@ -20,8 +20,7 @@ from gradio.documentation import document, set_documentation_group
from gradio.flagging import CSVLogger
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
from gradio.components import Component
from gradio.components import IOComponent
CACHED_FOLDER = "gradio_cached_examples"
LOG_FILE = "log.csv"
@ -31,11 +30,12 @@ set_documentation_group("component-helpers")
def create_examples(
examples: List[Any] | List[List[Any]] | str,
inputs: Component | List[Component],
outputs: Optional[Component | List[Component]] = None,
inputs: IOComponent | List[IOComponent],
outputs: Optional[IOComponent | List[IOComponent]] = None,
fn: Optional[Callable] = None,
cache_examples: bool = False,
examples_per_page: int = 10,
_api_mode: bool = False,
label: Optional[str] = None,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
@ -46,6 +46,7 @@ def create_examples(
fn=fn,
cache_examples=cache_examples,
examples_per_page=examples_per_page,
_api_mode=_api_mode,
label=label,
_initiated_directly=False,
)
@ -68,13 +69,14 @@ class Examples:
def __init__(
self,
examples: List[Any] | List[List[Any]] | str,
inputs: Component | List[Component],
outputs: Optional[Component | List[Component]] = None,
inputs: IOComponent | List[IOComponent],
outputs: Optional[IOComponent | List[IOComponent]] = None,
fn: Optional[Callable] = None,
cache_examples: bool = False,
examples_per_page: int = 10,
_api_mode: bool = False,
_initiated_directly: bool = True,
label: str = "Examples",
_initiated_directly=True,
):
"""
Parameters:
@ -87,7 +89,7 @@ class Examples:
label: the label to use for the examples component (by default, "Examples")
"""
if _initiated_directly:
raise warnings.warn(
warnings.warn(
"Please use gr.Examples(...) instead of gr.examples.Examples(...) to create the Examples.",
)
@ -164,15 +166,20 @@ class Examples:
self.fn = fn
self.cache_examples = cache_examples
self.examples_per_page = examples_per_page
self._api_mode = _api_mode
with utils.set_directory(working_directory):
self.processed_examples = [
[
component.preprocess_example(sample)
for component, sample in zip(inputs_with_examples, example)
component.postprocess(sample)
for component, sample in zip(inputs, example)
]
for example in non_none_examples
for example in examples
]
self.non_none_processed_examples = [
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in self.processed_examples
]
self.dataset = Dataset(
components=inputs_with_examples,
@ -193,11 +200,11 @@ class Examples:
async def load_example(example_id):
if self.cache_examples:
processed_example = self.processed_examples[
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
else:
processed_example = self.processed_examples[example_id]
processed_example = self.non_none_processed_examples[example_id]
return utils.resolve_singleton(processed_example)
if Context.root_block:
@ -221,41 +228,32 @@ class Examples:
cache_logger = CSVLogger()
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
try:
prediction = await self.process_example(example_id)
cache_logger.flag(prediction)
except Exception as e:
shutil.rmtree(self.cached_folder)
raise e
prediction = await self.predict_example(example_id)
cache_logger.flag(prediction)
async def process_example(self, example_id: int) -> Tuple[List[Any], List[float]]:
async def predict_example(self, example_id: int) -> List[Any]:
"""Loads an example from the interface and returns its prediction.
Parameters:
example_id: The id of the example to process (zero-indexed).
"""
example_set = self.examples[example_id]
raw_input = [
self.inputs[i].preprocess_example(example)
for i, example in enumerate(example_set)
]
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.inputs)
]
processed_input = self.processed_examples[example_id]
if not self._api_mode:
processed_input = [
input_component.preprocess(processed_input[i])
for i, input_component in enumerate(self.inputs_with_examples)
]
if inspect.iscoroutinefunction(self.fn):
predictions = await self.fn(*processed_input)
else:
predictions = await anyio.to_thread.run_sync(self.fn, *processed_input)
if len(self.outputs) == 1:
predictions = [predictions]
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.outputs)
]
return processed_output
if not self._api_mode:
predictions = [
output_component.postprocess(predictions[i])
for i, output_component in enumerate(self.outputs)
]
return predictions
async def load_from_cache(self, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface.
@ -263,15 +261,9 @@ class Examples:
example_id: The id of the example to process (zero-indexed).
"""
with open(self.cached_file) as cache:
examples = list(csv.reader(cache, quotechar="'"))
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(self.outputs, example):
output.append(
component.restore_flagged(
self.cached_folder,
cell,
None,
)
)
for component, value in zip(self.outputs, example):
output.append(component.serialize(value, self.cached_folder))
return output

View File

@ -3,6 +3,7 @@ use the `gr.Blocks.load()` or `gr.Interface.load()` functions."""
import base64
import json
import operator
import re
from copy import deepcopy
from typing import Callable, Dict
@ -56,6 +57,15 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
assert response.status_code == 200, "Invalid model name or src"
p = response.json().get("pipeline_tag")
def postprocess_label(scores):
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
return {
"label": sorted_pred[0][0],
"confidences": [
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
],
}
def encode_to_base64(r: requests.Response) -> str:
# Handles the different ways HF API returns the prediction
base64_repr = base64.b64encode(r.content).decode("utf-8")
@ -82,18 +92,18 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
pipelines = {
"audio-classification": {
# example model: https://hf.co/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
# example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Label(label="Class"),
"preprocess": lambda i: base64.b64decode(
i["data"].split(",")[1]
), # convert the base64 representation to binary
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()
},
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
),
},
"audio-to-audio": {
# example model: https://hf.co/speechbrain/mtl-mimic-voicebank
# example model: speechbrain/mtl-mimic-voicebank
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Audio(label="Output"),
"preprocess": lambda i: base64.b64decode(
@ -102,7 +112,7 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"postprocess": encode_to_base64,
},
"automatic-speech-recognition": {
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
# example model: jonatasgrosman/wav2vec2-large-xlsr-53-english
"inputs": components.Audio(source="upload", type="filepath", label="Input"),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda i: base64.b64decode(
@ -111,7 +121,7 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"postprocess": lambda r: r.json()["text"],
},
"feature-extraction": {
# example model: hf.co/julien-c/distilbert-feature-extraction
# example model: julien-c/distilbert-feature-extraction
"inputs": components.Textbox(label="Input"),
"outputs": components.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
@ -121,29 +131,23 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r.json()},
"postprocess": lambda r: postprocess_label(
{i["token_str"]: i["score"] for i in r.json()}
),
},
"image-classification": {
# Example: https://huggingface.co/google/vit-base-patch16-224
# Example: google/vit-base-patch16-224
"inputs": components.Image(type="filepath", label="Input Image"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda i: base64.b64decode(
i.split(",")[1]
), # convert the base64 representation to binary
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()
},
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()}
),
},
# TODO: support image segmentation pipeline -- should we add a new output component type?
# 'image-segmentation': {
# # Example: https://hf.co/facebook/detr-resnet-50-panoptic
# 'inputs': inputs.Image(label="Input Image", type="filepath"),
# 'outputs': outputs.Image(label="Segmentation"),
# 'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
# 'postprocess': lambda x: base64.b64encode(x.json()[0]["mask"]).decode('utf-8'),
# },
# TODO: also: support NER pipeline, object detection, table question answering
"question-answering": {
# Example: deepset/xlm-roberta-base-squad2
"inputs": [
components.Textbox(lines=7, label="Context"),
components.Textbox(label="Question"),
@ -153,29 +157,33 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
components.Label(label="Score"),
],
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
"postprocess": lambda r: (r.json()["answer"], r.json()["score"]),
"postprocess": lambda r: (r.json()["answer"], {"label": r.json()["score"]}),
},
"summarization": {
# Example: facebook/bart-large-cnn
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["summary_text"],
},
"text-classification": {
# Example: distilbert-base-uncased-finetuned-sst-2-english
"inputs": components.Textbox(label="Input"),
"outputs": components.Label(label="Classification"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()[0]
},
"postprocess": lambda r: postprocess_label(
{i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
),
},
"text-generation": {
# Example: gpt2
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"text2text-generation": {
# Example: valhalla/t5-small-qa-qg-hl
"inputs": components.Textbox(label="Input"),
"outputs": components.Textbox(label="Generated Text"),
"preprocess": lambda x: {"inputs": x},
@ -188,6 +196,7 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"postprocess": lambda r: r.json()[0]["translation_text"],
},
"zero-shot-classification": {
# Example: facebook/bart-large-mnli
"inputs": [
components.Textbox(label="Input"),
components.Textbox(label="Possible class names (" "comma-separated)"),
@ -198,13 +207,15 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"inputs": i,
"parameters": {"candidate_labels": c, "multi_class": m},
},
"postprocess": lambda r: {
r.json()["labels"][i]: r.json()["scores"][i]
for i in range(len(r.json()["labels"]))
},
"postprocess": lambda r: postprocess_label(
{
r.json()["labels"][i]: r.json()["scores"][i]
for i in range(len(r.json()["labels"]))
}
),
},
"sentence-similarity": {
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
# Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens
"inputs": [
components.Textbox(
value="That is a happy person", label="Source Sentence"
@ -222,26 +233,26 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
"sentences": [s for s in sentences.splitlines() if s != ""],
}
},
"postprocess": lambda r: {
f"sentence {i}": v for i, v in enumerate(r.json())
},
"postprocess": lambda r: postprocess_label(
{f"sentence {i}": v for i, v in enumerate(r.json())}
),
},
"text-to-speech": {
# example model: hf.co/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
# Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
"inputs": components.Textbox(label="Input"),
"outputs": components.Audio(label="Audio"),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
"text-to-image": {
# example model: hf.co/osanseviero/BigGAN-deep-128
# example model: osanseviero/BigGAN-deep-128
"inputs": components.Textbox(label="Input"),
"outputs": components.Image(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
"token-classification": {
# example model: hf.co/huggingface-course/bert-finetuned-ner
# example model: huggingface-course/bert-finetuned-ner
"inputs": components.Textbox(label="Input"),
"outputs": components.HighlightedText(label="Output"),
"preprocess": lambda x: {"inputs": x},
@ -591,7 +602,6 @@ def load_from_pipeline(pipeline):
data = pipeline(*data)
else:
data = pipeline(**data)
# print("Before postprocessing", data)
output = pipeline_info["postprocess"](data)
return output

View File

@ -13,7 +13,7 @@ from gradio import encryptor, utils
from gradio.documentation import document, set_documentation_group
if TYPE_CHECKING:
from gradio.components import Component
from gradio.components import IOComponent
set_documentation_group("flagging")
@ -24,7 +24,7 @@ class FlaggingCallback(ABC):
"""
@abstractmethod
def setup(self, components: List[Component], flagging_dir: str):
def setup(self, components: List[IOComponent], flagging_dir: str):
"""
This method should be overridden and ensure that everything is set up correctly for flag().
This method gets called once at the beginning of the Interface.launch() method.
@ -74,7 +74,7 @@ class SimpleCSVLogger(FlaggingCallback):
def __init__(self):
pass
def setup(self, components: List[Component], flagging_dir: str):
def setup(self, components: List[IOComponent], flagging_dir: str):
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
@ -91,18 +91,20 @@ class SimpleCSVLogger(FlaggingCallback):
csv_data = []
for component, sample in zip(self.components, flag_data):
save_dir = os.path.join(
flagging_dir, utils.strip_invalid_filename_characters(component.label)
)
csv_data.append(
component.save_flagged(
flagging_dir,
component.label,
component.deserialize(
sample,
save_dir,
None,
)
)
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerow(csv_data)
writer = csv.writer(csvfile)
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
@ -128,7 +130,7 @@ class CSVLogger(FlaggingCallback):
def setup(
self,
components: List[Component],
components: List[IOComponent],
flagging_dir: str,
encryption_key: Optional[str] = None,
):
@ -151,12 +153,17 @@ class CSVLogger(FlaggingCallback):
if flag_index is None:
csv_data = []
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
save_dir = os.path.join(
flagging_dir,
utils.strip_invalid_filename_characters(
component.label or f"component {idx}"
),
)
csv_data.append(
component.save_flagged(
flagging_dir,
component.label or f"component {idx}",
component.deserialize(
sample,
self.encryption_key,
save_dir=save_dir,
encryption_key=self.encryption_key,
)
if sample is not None
else ""
@ -181,8 +188,8 @@ class CSVLogger(FlaggingCallback):
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerows(content)
writer = csv.writer(output)
writer.writerows(utils.sanitize_list_for_csv(content))
return output.getvalue()
if self.encryption_key:
@ -197,11 +204,11 @@ class CSVLogger(FlaggingCallback):
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(file_content)
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer = csv.writer(output)
if flag_index is None:
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "wb", encoding="utf-8") as csvfile:
csvfile.write(
encryptor.encrypt(self.encryption_key, output.getvalue().encode())
@ -209,12 +216,10 @@ class CSVLogger(FlaggingCallback):
else:
if flag_index is None:
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
)
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(csv_data))
else:
with open(log_filepath, encoding="utf-8") as csvfile:
file_content = csvfile.read()
@ -222,7 +227,7 @@ class CSVLogger(FlaggingCallback):
with open(
log_filepath, "w", newline="", encoding="utf-8"
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
csvfile.write(utils.sanitize_list_for_csv(file_content))
with open(log_filepath, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -262,7 +267,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
self.organization_name = organization
self.dataset_private = private
def setup(self, components: List[Component], flagging_dir: str):
def setup(self, components: List[IOComponent], flagging_dir: str):
"""
Params:
flagging_dir (str): local directory where the dataset is cloned,
@ -346,21 +351,23 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"_type": "Value",
}
writer.writerow(headers)
writer.writerow(utils.sanitize_list_for_csv(headers))
# Generate the row corresponding to the flagged sample
csv_data = []
for component, sample in zip(self.components, flag_data):
filepath = component.save_flagged(
self.dataset_dir, component.label, sample, None
save_dir = os.path.join(
self.dataset_dir,
utils.strip_invalid_filename_characters(component.label),
)
filepath = component.deserialize(sample, save_dir, None)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
csv_data.append(
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
csv_data.append(flag_option if flag_option is not None else "")
writer.writerow(csv_data)
writer.writerow(utils.sanitize_list_for_csv(csv_data))
if is_new:
json.dump(infos, open(self.infos_file, "w"))

View File

@ -185,8 +185,6 @@ class Interface(Blocks):
**kwargs,
)
# TODO(faruk): Can we remove or move init configurations into Blocks? This long init function feels like coming from pre-Blocks era.
self.interface_type = self.InterfaceTypes.STANDARD
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
raise ValueError("Must provide at least one of `inputs` or `outputs`")
@ -402,7 +400,6 @@ class Interface(Blocks):
else:
component.label = "output " + str(i)
# TODO(faruk): Can we move these into the flag component, when it is implemented?
if self.allow_flagging != "never":
if self.interface_type == self.InterfaceTypes.UNIFIED:
self.flagging_callback.setup(self.input_components, self.flagging_dir)
@ -435,7 +432,6 @@ class Interface(Blocks):
for flag_option in flagging_options
]
# TODO(faruk): Can we remove the interface types?
with Row().style(equal_height=False):
if self.interface_type in [
self.InterfaceTypes.STANDARD,
@ -489,25 +485,29 @@ class Interface(Blocks):
flag_btns = render_flag_btns(self.flagging_options)
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = self.submit_func
if self.live:
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
super().load(submit_fn, None, self.output_components)
super().load(self.fn, None, self.output_components)
submit_btn.click(
submit_fn,
self.fn,
None,
self.output_components,
api_name="predict",
status_tracker=status_tracker,
_preprocess=not (self.api_mode),
_postprocess=not (self.api_mode),
)
else:
for component in self.input_components:
if isinstance(component, Streamable):
if component.streaming:
component.stream(
submit_fn,
self.fn,
self.input_components,
self.output_components,
api_name="predict",
_preprocess=not (self.api_mode),
_postprocess=not (self.api_mode),
)
continue
else:
@ -518,16 +518,23 @@ class Interface(Blocks):
)
if isinstance(component, Changeable):
component.change(
submit_fn, self.input_components, self.output_components
self.fn,
self.input_components,
self.output_components,
api_name="predict",
_preprocess=not (self.api_mode),
_postprocess=not (self.api_mode),
)
else:
submit_btn.click(
submit_fn,
self.fn,
self.input_components,
self.output_components,
api_name="predict",
scroll_to_output=True,
status_tracker=status_tracker,
_preprocess=not (self.api_mode),
_postprocess=not (self.api_mode),
)
clear_btn.click(
None,
@ -568,11 +575,11 @@ class Interface(Blocks):
def __init__(self, flagging_callback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
# TODO(faruk): Change with flag component when it is implemented..
if self.allow_flagging == "manual":
if self.interface_type in [
self.InterfaceTypes.STANDARD,
@ -604,12 +611,12 @@ class Interface(Blocks):
examples=examples,
inputs=non_state_inputs,
outputs=non_state_outputs,
fn=submit_fn,
fn=self.fn,
cache_examples=self.cache_examples,
examples_per_page=examples_per_page,
_api_mode=_api_mode,
)
# TODO(faruk): Change with interpretation component when implemented.
if self.interpretation:
interpretation_btn.click(
self.interpret_func,
@ -625,17 +632,6 @@ class Interface(Blocks):
self.config = self.get_config_file()
def __call__(self, *params):
if (
self.api_mode
): # skip the preprocessing/postprocessing if sending to a remote API
output = utils.synchronize_async(
self.run_prediction, params, called_directly=True
)
else:
output = utils.synchronize_async(self.process, params)
return output[0] if len(output) == 1 else output
def __str__(self):
return self.__repr__()
@ -650,75 +646,6 @@ class Interface(Blocks):
repr += "\n|-{}".format(str(component))
return repr
async def submit_func(self, *args):
prediction = await self.run_prediction(args)
# TODO(faruk): We don't have tuple or array clearence in Blocks, can we remove this and have one standart?
return prediction[0] if len(self.output_components) == 1 else prediction
async def run_prediction(
self,
processed_input: List[Any],
called_directly: bool = False,
) -> List[Any] | Tuple[List[Any], List[float]]:
"""
Runs the prediction function with the given (already processed) inputs.
Parameters:
processed_input (list): A list of processed inputs.
called_directly (bool): Whether the prediction is being called directly (i.e. as a function, not through the GUI).
Returns:
predictions (list): A list of predictions (not post-processed).
"""
# TODO(faruk): We might keep this function in interface for usage in mix or interpretation.
# However we need to use "call_function" instead of manually serializing, and deserializing and running prediction.
if self.api_mode: # Serialize the input
processed_input = [
input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)
]
if inspect.iscoroutinefunction(self.fn):
prediction = await self.fn(*processed_input)
else:
prediction = await anyio.to_thread.run_sync(
self.fn, *processed_input, limiter=self.limiter
)
if prediction is None or len(self.output_components) == 1:
prediction = [prediction]
if self.api_mode: # Deserialize the input
prediction = [
output_component.deserialize(prediction[i])
for i, output_component in enumerate(self.output_components)
]
return prediction
async def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""
First preprocesses the input, then runs prediction using
self.run_prediction(), then postprocesses the output.
Parameters:
raw_input: a list of raw inputs to process and apply the prediction(s) on.
Returns:
processed output: a list of processed outputs to return as the prediction(s).
duration: a list of time deltas measuring inference time for each prediction fn.
"""
# TODO(faruk): We might keep this function in interface for usage in mix or interpretation.
# However we need to use process_api instead of manually processing and running prediction.
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)
]
predictions = await self.run_prediction(processed_input)
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.output_components)
]
return processed_output
async def interpret_func(self, *args):
return await self.interpret(args) + [
Column.update(visible=False),
@ -747,7 +674,7 @@ class Interface(Blocks):
else:
raw_input.append(input_component.test_input)
else:
self.process(raw_input)
self(raw_input)
print("PASSED")

View File

@ -19,7 +19,12 @@ async def run_interpret(interface, raw_input):
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
original_output = await interface.run_prediction(processed_input)
original_output = await interface.call_function(0, processed_input)
original_output = original_output[0]
if len(interface.output_components) == 1:
original_output = [original_output]
scores, alternative_outputs = [], []
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
@ -39,9 +44,12 @@ async def run_interpret(interface, raw_input):
)
]
neighbor_output = await interface.run_prediction(
processed_neighbor_input
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output[0]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
@ -80,9 +88,12 @@ async def run_interpret(interface, raw_input):
interface.input_components
)
]
neighbor_output = await interface.run_prediction(
processed_neighbor_input
neighbor_output = await interface.call_function(
0, processed_neighbor_input
)
neighbor_output = neighbor_output[0]
if len(interface.output_components) == 1:
neighbor_output = [neighbor_output]
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
@ -131,8 +142,11 @@ async def run_interpret(interface, raw_input):
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = utils.synchronize_async(
interface.run_prediction, processed_masked_input
interface.call_function, 0, processed_masked_input
)
new_output = new_output[0]
if len(interface.output_components) == 1:
new_output = [new_output]
pred = get_regression_or_classification_value(
interface, original_output, new_output
)

View File

@ -37,12 +37,16 @@ class Parallel(gradio.Interface):
outputs.extend(interface.output_components)
async def parallel_fn(*args):
return_values = await asyncio.gather(
*[interface.run_prediction(args) for interface in interfaces]
return_values_with_durations = await asyncio.gather(
*[interface.call_function(0, args) for interface in interfaces]
)
return_values = [rv[0] for rv in return_values_with_durations]
combined_list = []
for value in return_values:
combined_list.extend(value)
for interface, return_value in zip(interfaces, return_values):
if len(interface.output_components) == 1:
combined_list.append(return_value)
else:
combined_list.extend(return_value)
if len(outputs) == 1:
return combined_list[0]
return combined_list
@ -110,7 +114,7 @@ class Series(gradio.Interface):
"fn": connected_fn,
"inputs": interfaces[0].input_components,
"outputs": interfaces[-1].output_components,
"_api_mode": interfaces[0].api_mode, # TODO: set api_mode per-interface
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[0].api_mode # TODO: set api_mode per-function

View File

@ -30,12 +30,12 @@ def decode_base64_to_image(encoding):
return Image.open(BytesIO(base64.b64decode(image_encoded)))
def encode_url_or_file_to_base64(path):
def encode_url_or_file_to_base64(path, encryption_key=None):
try:
requests.get(path)
return encode_url_to_base64(path)
return encode_url_to_base64(path, encryption_key=encryption_key)
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return encode_file_to_base64(path)
return encode_file_to_base64(path, encryption_key=encryption_key)
def get_mimetype(filename):
@ -71,8 +71,10 @@ def encode_file_to_base64(f, encryption_key=None):
)
def encode_url_to_base64(url):
def encode_url_to_base64(url, encryption_key=None):
encoded_string = base64.b64encode(requests.get(url).content)
if encryption_key:
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
base64_str = str(encoded_string, "utf-8")
mimetype = get_mimetype(url)
return (
@ -88,10 +90,23 @@ def encode_plot_to_base64(plt):
return "data:image/png;base64," + base64_str
def save_array_to_file(image_array, dir=None):
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj)
return file_obj
def save_pil_to_file(pil_image, dir=None):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj)
return file_obj
def encode_array_to_base64(image_array):
with BytesIO() as output_bytes:
PIL_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
PIL_image.save(output_bytes, "PNG")
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
pil_image.save(output_bytes, "PNG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
@ -201,7 +216,9 @@ def decode_base64_to_binary(encoding):
return base64.b64decode(data), extension
def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
def decode_base64_to_file(encoding, encryption_key=None, file_path=None, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
data, extension = decode_base64_to_binary(encoding)
prefix = None
if file_path is not None:
@ -211,10 +228,13 @@ def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
prefix = filename[0 : filename.index(".")]
extension = filename[filename.index(".") + 1 :]
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
else:
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, suffix="." + extension
delete=False,
prefix=prefix,
suffix="." + extension,
dir=dir,
)
if encryption_key is not None:
data = encryptor.encrypt(encryption_key, data)
@ -223,17 +243,55 @@ def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
return file_obj
def create_tmp_copy_of_file(file_path):
def create_tmp_copy_of_file_or_url(file_path_or_url: str, dir=None):
try:
response = requests.get(file_path_or_url, stream=True)
if file_path_or_url.find("/"):
new_file_path = file_path_or_url.rsplit("/", 1)[1]
else:
new_file_path = "file.txt"
with open(new_file_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
del response
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return create_tmp_copy_of_file(file_path_or_url, dir)
def dict_or_str_to_json_file(jsn, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
file_obj = tempfile.NamedTemporaryFile(
delete=False, suffix=".json", dir=dir, mode="w+"
)
if isinstance(jsn, str):
jsn = json.loads(jsn)
json.dump(jsn, file_obj)
file_obj.flush()
return file_obj
def file_to_json(file_path):
return json.load(open(file_path))
def create_tmp_copy_of_file(file_path, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
file_name = os.path.basename(file_path)
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0 : file_name.index(".")]
extension = file_name[file_name.index(".") + 1 :]
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
else:
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, suffix="." + extension
delete=False,
prefix=prefix,
suffix="." + extension,
dir=dir,
)
shutil.copy2(file_path, file_obj.name)
return file_obj
@ -520,10 +578,6 @@ def _convert(image, dtype, force_copy=False, uniform=False):
return image.astype(dtype_out)
def strip_invalid_filename_characters(filename: str) -> str:
return "".join([char for char in filename if char.isalnum() or char in "._- "])
def ffmpeg_installed() -> bool:
return shutil.which("ffmpeg") is not None

View File

@ -225,9 +225,16 @@ class App(FastAPI):
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
elif Path(app.cwd).resolve() in Path(path).resolve().parents or any(
Path(temp_dir).resolve() in Path(path).resolve().parents
for temp_dir in app.blocks.temp_dirs
):
return FileResponse(Path(path).resolve())
else:
if Path(app.cwd).resolve() in Path(path).resolve().parents:
return FileResponse(Path(path).resolve())
raise ValueError(
f"File cannot be fetched: {path}, perhaps because "
f"it is not in any of {app.blocks.temp_dirs}"
)
async def run_predict(
body: PredictBody, username: str = Depends(get_current_user)

179
gradio/serializing.py Normal file
View File

@ -0,0 +1,179 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from typing import Any, Dict
from gradio import processing_utils
class Serializable(ABC):
@abstractmethod
def serialize(
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
):
"""
Convert data from human-readable format to serialized format for a browser.
"""
pass
@abstractmethod
def deserialize(
x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
):
"""
Convert data from serialized format for a browser to human-readable format.
"""
pass
class SimpleSerializable(Serializable):
def serialize(
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
) -> Any:
"""
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
Parameters:
x: Input data to serialize
load_dir: Ignored
encryption_key: Ignored
"""
return x
def deserialize(
self, x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
):
"""
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
Parameters:
x: Input data to deserialize
save_dir: Ignored
encryption_key: Ignored
"""
return x
class ImgSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> str:
"""
Convert from human-friendly version of a file (string filepath) to a seralized
representation (base64).
Parameters:
x: String path to file to serialize
load_dir: Path to directory containing x
encryption_key: Used to encrypt the file
"""
if x is None or x == "":
return None
return processing_utils.encode_url_or_file_to_base64(
os.path.join(load_dir, x), encryption_key=encryption_key
)
def deserialize(
self, x: str, save_dir: str | None = None, encryption_key: bytes | None = None
) -> str:
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by save_dir
Parameters:
x: Base64 representation of image to deserialize into a string filepath
save_dir: Path to directory to save the deserialized image to
encryption_key: Used to decrypt the file
"""
if x is None or x == "":
return None
file = processing_utils.decode_base64_to_file(
x, dir=save_dir, encryption_key=encryption_key
)
return file.name
class FileSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> Any:
"""
Convert from human-friendly version of a file (string filepath) to a
seralized representation (base64)
Parameters:
x: String path to file to serialize
load_dir: Path to directory containing x
encryption_key: Used to encrypt the file
"""
if x is None or x == "":
return None
filename = os.path.join(load_dir, x)
return {
"name": filename,
"data": processing_utils.encode_url_or_file_to_base64(
filename, encryption_key=encryption_key
),
"is_file": False,
}
def deserialize(
self, x: Dict, save_dir: str | None = None, encryption_key: bytes | None = None
):
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
Parameters:
x: Base64 representation of file to deserialize into a string filepath
save_dir: Path to directory to save the deserialized file to
encryption_key: Used to decrypt the file
"""
if x is None:
return None
if isinstance(x, str):
file = processing_utils.decode_base64_to_file(
x, dir=save_dir, encryption_key=encryption_key
)
elif isinstance(x, dict):
if x.get("is_file", False):
file = processing_utils.create_tmp_copy_of_file(x["name"], dir=save_dir)
else:
file = processing_utils.decode_base64_to_file(
x["data"], dir=save_dir, encryption_key=encryption_key
)
else:
raise ValueError(
f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}"
)
return file.name
class JSONSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> str:
"""
Convert from a a human-friendly version (string path to json file) to a
serialized representation (json string)
Parameters:
x: String path to json file to read to get json string
load_dir: Path to directory containing x
encryption_key: Ignored
"""
if x is None or x == "":
return None
return processing_utils.file_to_json(os.path.join(load_dir, x))
def deserialize(
self,
x: str | Dict,
save_dir: str | None = None,
encryption_key: bytes | None = None,
) -> str:
"""
Convert from serialized representation (json string) to a human-friendly
version (string path to json file). Optionally, save the file to the directory specified by `save_dir`
Parameters:
x: Json string
save_dir: Path to save the deserialized json file to
encryption_key: Ignored
"""
if x is None:
return None
return processing_utils.dict_or_str_to_json_file(x, dir=save_dir).name

View File

@ -15,8 +15,19 @@ from contextlib import contextmanager
from copy import deepcopy
from distutils.version import StrictVersion
from enum import Enum
from numbers import Number
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, NewType, Tuple, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
NewType,
Tuple,
Type,
)
import aiohttp
import analytics
@ -292,7 +303,7 @@ def delete_none(_dict):
return _dict
def resolve_singleton(_list):
def resolve_singleton(_list: List[Any] | Any) -> Any:
if len(_list) == 1:
return _list[0]
else:
@ -609,3 +620,41 @@ def set_directory(path: Path):
yield
finally:
os.chdir(origin)
def strip_invalid_filename_characters(filename: str) -> str:
return "".join([char for char in filename if char.isalnum() or char in "._- "])
def sanitize_value_for_csv(value: str | Number) -> str | Number:
"""
Sanitizes a value that is being written to a CSV file to prevent CSV injection attacks.
Reference: https://owasp.org/www-community/attacks/CSV_Injection
"""
if isinstance(value, Number):
return value
unsafe_prefixes = ["=", "+", "-", "@", "\t", "\n"]
unsafe_sequences = [",=", ",+", ",-", ",@", ",\t", ",\n"]
if any(value.startswith(prefix) for prefix in unsafe_prefixes) or any(
sequence in value for sequence in unsafe_sequences
):
value = "'" + value
return value
def sanitize_list_for_csv(
values: List[str | Number] | List[List[str | Number]],
) -> List[str | Number] | List[List[str | Number]]:
"""
Sanitizes a list of values (or a list of list of values) that is being written to a
CSV file to prevent CSV injection attacks.
"""
sanitized_values = []
for value in values:
if isinstance(value, list):
sanitized_value = [sanitize_value_for_csv(v) for v in value]
sanitized_values.append(sanitized_value)
else:
sanitized_value = sanitize_value_for_csv(value)
sanitized_values.append(sanitized_value)
return sanitized_values

View File

@ -1,3 +1,4 @@
import filecmp
import json
import os
import pathlib
@ -46,23 +47,16 @@ def test_raise_warnings():
class TestTextbox(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, tokenize, generate_sample, get_config
Preprocess, postprocess, serialize, tokenize, generate_sample, get_config
"""
text_input = gr.Textbox()
self.assertEqual(text_input.preprocess("Hello World!"), "Hello World!")
self.assertEqual(text_input.preprocess_example("Hello World!"), "Hello World!")
self.assertEqual(text_input.postprocess("Hello World!"), "Hello World!")
self.assertEqual(text_input.postprocess(None), None)
self.assertEqual(text_input.postprocess("Ali"), "Ali")
self.assertEqual(text_input.postprocess(2), "2")
self.assertEqual(text_input.postprocess(2.14), "2.14")
self.assertEqual(text_input.serialize("Hello World!", True), "Hello World!")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = text_input.save_flagged(
tmpdirname, "text_input", "Hello World!", None
)
self.assertEqual(to_save, "Hello World!")
restored = text_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "Hello World!")
with self.assertWarns(Warning):
_ = gr.Textbox(type="number")
@ -117,7 +111,7 @@ class TestTextbox(unittest.TestCase):
Interface, process, interpret,
"""
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
self.assertEqual(await iface.process(["Hello"]), ["olleH"])
self.assertEqual(await iface(["Hello"]), ["olleH"])
iface = gr.Interface(
lambda sentence: max([len(word) for word in sentence.split()]),
gr.Textbox(),
@ -159,9 +153,9 @@ class TestTextbox(unittest.TestCase):
"""
iface = gr.Interface(lambda x: x[-1], "textbox", gr.Textbox())
self.assertEqual(await iface.process(["Hello"]), ["o"])
self.assertEqual(await iface(["Hello"]), ["o"])
iface = gr.Interface(lambda x: x / 2, "number", gr.Textbox())
self.assertEqual(iface.process([10]), ["5.0"])
self.assertEqual(iface([10]), ["5.0"])
def test_static(self):
"""
@ -174,22 +168,17 @@ class TestTextbox(unittest.TestCase):
class TestNumber(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_config
Preprocess, postprocess, serialize, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_config
"""
numeric_input = gr.Number()
self.assertEqual(numeric_input.preprocess(3), 3.0)
self.assertEqual(numeric_input.preprocess(None), None)
self.assertEqual(numeric_input.preprocess_example(3), 3)
self.assertEqual(numeric_input.postprocess(3), 3)
self.assertEqual(numeric_input.postprocess(3), 3.0)
self.assertEqual(numeric_input.postprocess(2.14), 2.14)
self.assertEqual(numeric_input.postprocess(None), None)
self.assertEqual(numeric_input.serialize(3, True), 3)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None)
self.assertEqual(to_save, 3)
restored = numeric_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, 3)
self.assertIsInstance(numeric_input.generate_sample(), float)
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
self.assertEqual(
@ -217,22 +206,17 @@ class TestNumber(unittest.TestCase):
def test_component_functions_integer(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
Preprocess, postprocess, serialize, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
"""
numeric_input = gr.Number(precision=0, value=42)
self.assertEqual(numeric_input.preprocess(3), 3)
self.assertEqual(numeric_input.preprocess(None), None)
self.assertEqual(numeric_input.preprocess_example(3), 3)
self.assertEqual(numeric_input.postprocess(3), 3)
self.assertEqual(numeric_input.postprocess(3), 3)
self.assertEqual(numeric_input.postprocess(2.85), 3)
self.assertEqual(numeric_input.postprocess(None), None)
self.assertEqual(numeric_input.serialize(3, True), 3)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None)
self.assertEqual(to_save, 3)
restored = numeric_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, 3)
self.assertIsInstance(numeric_input.generate_sample(), int)
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
self.assertEqual(
@ -269,13 +253,13 @@ class TestNumber(unittest.TestCase):
def test_component_functions_precision(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
Preprocess, postprocess, serialize, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
"""
numeric_input = gr.Number(precision=2, value=42.3428)
self.assertEqual(numeric_input.preprocess(3.231241), 3.23)
self.assertEqual(numeric_input.preprocess(None), None)
self.assertEqual(numeric_input.preprocess_example(-42.1241), -42.12)
self.assertEqual(numeric_input.postprocess(-42.1241), -42.12)
self.assertEqual(numeric_input.postprocess(5.6784), 5.68)
self.assertEqual(numeric_input.postprocess(2.1421), 2.14)
self.assertEqual(numeric_input.postprocess(None), None)
@ -285,7 +269,7 @@ class TestNumber(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2]), ["4.0"])
self.assertEqual(iface([2]), ["4.0"])
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
@ -308,7 +292,7 @@ class TestNumber(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: x**2, gr.Number(precision=0), "textbox")
self.assertEqual(iface.process([2]), ["4"])
self.assertEqual(iface([2]), ["4"])
iface = gr.Interface(
lambda x: x**2, "number", gr.Number(precision=0), interpretation="default"
)
@ -332,7 +316,7 @@ class TestNumber(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: int(x) ** 2, "textbox", "number")
self.assertEqual(iface.process([2]), [4.0])
self.assertEqual(iface([2]), [4.0])
iface = gr.Interface(
lambda x: x**2, "number", "number", interpretation="default"
)
@ -363,19 +347,14 @@ class TestNumber(unittest.TestCase):
class TestSlider(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, get_config
Preprocess, postprocess, serialize, generate_sample, get_config
"""
slider_input = gr.Slider()
self.assertEqual(slider_input.preprocess(3.0), 3.0)
self.assertEqual(slider_input.preprocess_example(3), 3)
self.assertEqual(slider_input.postprocess(3), 3)
self.assertEqual(slider_input.postprocess(3), 3)
self.assertEqual(slider_input.postprocess(None), 0)
self.assertEqual(slider_input.serialize(3, True), 3)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = slider_input.save_flagged(tmpdirname, "slider_input", 3, None)
self.assertEqual(to_save, 3)
restored = slider_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, 3)
self.assertIsInstance(slider_input.generate_sample(), int)
slider_input = gr.Slider(10, 20, value=15, step=1, label="Slide Your Input")
@ -401,7 +380,7 @@ class TestSlider(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2]), ["4"])
self.assertEqual(iface([2]), ["4"])
iface = gr.Interface(
lambda x: x**2, "slider", "number", interpretation="default"
)
@ -437,14 +416,9 @@ class TestCheckbox(unittest.TestCase):
"""
bool_input = gr.Checkbox()
self.assertEqual(bool_input.preprocess(True), True)
self.assertEqual(bool_input.preprocess_example(True), True)
self.assertEqual(bool_input.postprocess(True), True)
self.assertEqual(bool_input.postprocess(True), True)
self.assertEqual(bool_input.serialize(True, True), True)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = bool_input.save_flagged(tmpdirname, "bool_input", True, None)
self.assertEqual(to_save, True)
restored = bool_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, True)
self.assertIsInstance(bool_input.generate_sample(), bool)
bool_input = gr.Checkbox(value=True, label="Check Your Input")
self.assertEqual(
@ -466,7 +440,7 @@ class TestCheckbox(unittest.TestCase):
Interface, process, interpret
"""
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
self.assertEqual(iface.process([True]), [1])
self.assertEqual(iface([True]), [1])
iface = gr.Interface(
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
)
@ -479,19 +453,12 @@ class TestCheckbox(unittest.TestCase):
class TestCheckboxGroup(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, preprocess_example, serialize, save_flagged, restore_flagged, generate_sample, get_config
Preprocess, postprocess, serialize, generate_sample, get_config
"""
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
self.assertEqual(checkboxes_input.preprocess(["a", "c"]), ["a", "c"])
self.assertEqual(checkboxes_input.preprocess_example(["a", "c"]), ["a", "c"])
self.assertEqual(checkboxes_input.postprocess(["a", "c"]), ["a", "c"])
self.assertEqual(checkboxes_input.serialize(["a", "c"], True), ["a", "c"])
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = checkboxes_input.save_flagged(
tmpdirname, "checkboxes_input", ["a", "c"], None
)
self.assertEqual(to_save, '["a", "c"]')
restored = checkboxes_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, ["a", "c"])
self.assertIsInstance(checkboxes_input.generate_sample(), list)
checkboxes_input = gr.CheckboxGroup(
value=["a", "c"],
@ -522,26 +489,21 @@ class TestCheckboxGroup(unittest.TestCase):
"""
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
self.assertEqual(await iface.process([["a", "c"]]), ["a|c"])
self.assertEqual(await iface.process([[]]), [""])
self.assertEqual(await iface([["a", "c"]]), ["a|c"])
self.assertEqual(await iface([[]]), [""])
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
class TestRadio(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, preprocess_example, serialize, save_flagged, generate_sample, get_config
Preprocess, postprocess, serialize, generate_sample, get_config
"""
radio_input = gr.Radio(["a", "b", "c"])
self.assertEqual(radio_input.preprocess("c"), "c")
self.assertEqual(radio_input.preprocess_example("a"), "a")
self.assertEqual(radio_input.postprocess("a"), "a")
self.assertEqual(radio_input.serialize("a", True), "a")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = radio_input.save_flagged(tmpdirname, "radio_input", "a", None)
self.assertEqual(to_save, "a")
restored = radio_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "a")
self.assertIsInstance(radio_input.generate_sample(), str)
radio_input = gr.Radio(
choices=["a", "b", "c"], default="a", label="Pick Your One Input"
@ -570,12 +532,12 @@ class TestRadio(unittest.TestCase):
"""
radio_input = gr.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
self.assertEqual(iface.process(["c"]), ["cc"])
self.assertEqual(iface(["c"]), ["cc"])
radio_input = gr.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(
lambda x: 2 * x, radio_input, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"]), [4])
self.assertEqual(iface(["c"]), [4])
scores = (await iface.interpret(["b"]))[0]["interpretation"]
self.assertEqual(scores, [-2.0, None, 2.0])
@ -583,7 +545,7 @@ class TestRadio(unittest.TestCase):
class TestImage(unittest.TestCase):
async def test_component_functions(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, get_config, _segment_by_slic
Preprocess, postprocess, serialize, generate_sample, get_config, _segment_by_slic
type: pil, file, filepath, numpy
"""
img = deepcopy(media_data.BASE64_IMAGE)
@ -593,15 +555,8 @@ class TestImage(unittest.TestCase):
self.assertEqual(image_input.preprocess(img).shape, (25, 25))
image_input = gr.Image(shape=(30, 10), type="pil")
self.assertEqual(image_input.preprocess(img).size, (30, 10))
self.assertEqual(image_input.preprocess_example("test/test_files/bus.png"), img)
self.assertEqual(image_input.postprocess("test/test_files/bus.png"), img)
self.assertEqual(image_input.serialize("test/test_files/bus.png", True), img)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = image_input.save_flagged(tmpdirname, "image_input", img, None)
self.assertEqual("image_input/0.png", to_save)
to_save = image_input.save_flagged(tmpdirname, "image_input", img, None)
self.assertEqual("image_input/1.png", to_save)
restored = image_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, media_data.BASE64_IMAGE)
self.assertIsInstance(image_input.generate_sample(), str)
image_input = gr.Image(
@ -673,15 +628,6 @@ class TestImage(unittest.TestCase):
self.assertTrue(
image_output.postprocess(y_img).startswith("data:image/png;base64,")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = image_output.save_flagged(
tmpdirname, "image_output", deepcopy(media_data.BASE64_IMAGE), None
)
self.assertEqual("image_output/0.png", to_save)
to_save = image_output.save_flagged(
tmpdirname, "image_output", deepcopy(media_data.BASE64_IMAGE), None
)
self.assertEqual("image_output/1.png", to_save)
async def test_in_interface_as_input(self):
"""
@ -696,7 +642,7 @@ class TestImage(unittest.TestCase):
gr.Image(shape=(30, 10), type="file"),
"image",
)
output = (await iface.process([img]))[0]
output = (await iface([img]))[0]
self.assertEqual(
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
)
@ -730,9 +676,7 @@ class TestImage(unittest.TestCase):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(
(await iface.process([10, 20]))[0].startswith("data:image/png;base64")
)
self.assertTrue((await iface([10, 20]))[0].startswith("data:image/png;base64"))
def test_static(self):
"""
@ -756,7 +700,7 @@ class TestPlot(unittest.TestCase):
return fig
iface = gr.Interface(plot, "slider", "plot")
output = (await iface.process([10]))[0]
output = (await iface([10]))[0]
self.assertEqual(output["type"], "matplotlib")
self.assertTrue(output["plot"].startswith("data:image/png;base64"))
@ -776,7 +720,7 @@ class TestPlot(unittest.TestCase):
class TestAudio(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess serialize, save_flagged, restore_flagged, generate_sample, get_config, deserialize
Preprocess, postprocess serialize, generate_sample, get_config, deserialize
type: filepath, numpy, file
"""
x_wav = deepcopy(media_data.BASE64_AUDIO)
@ -784,19 +728,11 @@ class TestAudio(unittest.TestCase):
output = audio_input.preprocess(x_wav)
self.assertEqual(output[0], 8000)
self.assertEqual(output[1].shape, (8046,))
self.assertEqual(
audio_input.serialize("test/test_files/audio_sample.wav", True)["data"],
x_wav["data"],
assert filecmp.cmp(
"test/test_files/audio_sample.wav",
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None)
self.assertEqual("audio_input/0.wav", to_save)
to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None)
self.assertEqual("audio_input/1.wav", to_save)
restored = audio_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored["file_name"], "audio_input/1.wav")
self.assertIsInstance(audio_input.generate_sample(), dict)
audio_input = gr.Audio(label="Upload Your Audio")
self.assertEqual(
@ -818,20 +754,14 @@ class TestAudio(unittest.TestCase):
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
self.assertIsNotNone(audio_input.preprocess(x_wav))
with self.assertWarns(UserWarning):
audio_input = gr.Audio(type="file")
audio_input.preprocess(x_wav)
with open("test/test_files/audio_sample.wav") as f:
audio_input.serialize(f, False)
audio_input = gr.Audio(type="filepath")
self.assertIsInstance(audio_input.preprocess(x_wav), str)
with self.assertRaises(ValueError):
audio_input = gr.Audio(type="unknown")
audio_input.preprocess(x_wav)
audio_input.serialize(x_wav, False)
audio_input.serialize(x_wav)
audio_input = gr.Audio(type="numpy")
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
# Output functionalities
y_audio = gr.processing_utils.decode_base64_to_file(
@ -839,9 +769,7 @@ class TestAudio(unittest.TestCase):
)
audio_output = gr.Audio(type="file")
self.assertTrue(
audio_output.postprocess(y_audio.name).startswith(
"data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"
)
filecmp.cmp(y_audio.name, audio_output.postprocess(y_audio.name)["name"])
)
self.assertEqual(
audio_output.get_config(),
@ -860,18 +788,13 @@ class TestAudio(unittest.TestCase):
)
self.assertTrue(
audio_output.deserialize(
deepcopy(media_data.BASE64_AUDIO)["data"]
{
"name": None,
"data": deepcopy(media_data.BASE64_AUDIO)["data"],
"is_file": False,
}
).endswith(".wav")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", deepcopy(media_data.BASE64_AUDIO), None
)
self.assertEqual("audio_output/0.wav", to_save)
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", deepcopy(media_data.BASE64_AUDIO), None
)
self.assertEqual("audio_output/1.wav", to_save)
def test_tokenize(self):
"""
@ -891,16 +814,16 @@ class TestAudio(unittest.TestCase):
return (sr, np.flipud(data))
iface = gr.Interface(reverse_audio, "audio", "audio")
reversed_data = (await iface.process([deepcopy(media_data.BASE64_AUDIO)]))[0]
reversed_data = (await iface([deepcopy(media_data.BASE64_AUDIO)]))[0]
reversed_input = {"name": "fake_name", "data": reversed_data}
self.assertTrue(reversed_data.startswith("data:audio/wav;base64,UklGRgA/"))
self.assertTrue(
(await iface.process([deepcopy(media_data.BASE64_AUDIO)]))[0].startswith(
(await iface([deepcopy(media_data.BASE64_AUDIO)]))[0].startswith(
"data:audio/wav;base64,UklGRgA/"
)
)
self.maxDiff = None
reversed_reversed_data = (await iface.process([reversed_input]))[0]
reversed_reversed_data = (await iface([reversed_input]))[0]
similarity = SequenceMatcher(
a=reversed_reversed_data, b=media_data.BASE64_AUDIO["data"]
).ratio()
@ -915,33 +838,23 @@ class TestAudio(unittest.TestCase):
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
iface = gr.Interface(generate_noise, "slider", "audio")
self.assertTrue(
(await iface.process([100]))[0].startswith("data:audio/wav;base64")
)
self.assertTrue((await iface([100]))[0].startswith("data:audio/wav;base64"))
class TestFile(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, serialize, save_flagged, restore_flagged, generate_sample, get_config, value
Preprocess, serialize, generate_sample, get_config, value
"""
x_file = deepcopy(media_data.BASE64_FILE)
file_input = gr.File()
output = file_input.preprocess(x_file)
self.assertIsInstance(output, tempfile._TemporaryFileWrapper)
self.assertEqual(
file_input.serialize("test/test_files/sample_file.pdf", True),
assert filecmp.cmp(
file_input.serialize("test/test_files/sample_file.pdf")["name"],
"test/test_files/sample_file.pdf",
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None)
self.assertEqual("file_input/0", to_save)
to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None)
self.assertEqual("file_input/1", to_save)
restored = file_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored["file_name"], "file_input/1")
self.assertIsInstance(file_input.generate_sample(), dict)
file_input = gr.File(label="Upload Your File")
self.assertEqual(
@ -962,12 +875,6 @@ class TestFile(unittest.TestCase):
x_file["is_example"] = True
self.assertIsNotNone(file_input.preprocess(x_file))
file_input = gr.File("test/test_files/sample_file.pdf")
self.assertEqual(
file_input.get_config(),
deepcopy(media_data.FILE_TEMPLATE_CONTEXT),
)
async def test_in_interface_as_input(self):
"""
Interface, process
@ -978,11 +885,11 @@ class TestFile(unittest.TestCase):
return os.path.getsize(file_obj.name)
iface = gr.Interface(get_size_of_file, "file", "number")
self.assertEqual(await iface.process([[x_file]]), [10558])
self.assertEqual(await iface([[x_file]]), [10558])
async def test_as_component_as_output(self):
"""
Interface, process, save_flagged,
Interface, process
"""
def write_file(content):
@ -992,29 +899,19 @@ class TestFile(unittest.TestCase):
iface = gr.Interface(write_file, "text", "file")
self.assertDictEqual(
(await iface.process(["hello world"]))[0],
(await iface(["hello world"]))[0],
{
"name": "test.txt",
"size": 11,
"data": "data:text/plain;base64,aGVsbG8gd29ybGQ=",
},
)
file_output = gr.File()
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_output.save_flagged(
tmpdirname, "file_output", [deepcopy(media_data.BASE64_FILE)], None
)
self.assertEqual("file_output/0", to_save)
to_save = file_output.save_flagged(
tmpdirname, "file_output", [deepcopy(media_data.BASE64_FILE)], None
)
self.assertEqual("file_output/1", to_save)
class TestDataframe(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, serialize, save_flagged, restore_flagged, generate_sample, get_config
Preprocess, serialize, generate_sample, get_config
"""
x_data = {
"data": [["Tim", 12, False], ["Jan", 24, True]],
@ -1024,16 +921,7 @@ class TestDataframe(unittest.TestCase):
output = dataframe_input.preprocess(x_data)
self.assertEqual(output["Age"][1], 24)
self.assertEqual(output["Member"][0], False)
self.assertEqual(dataframe_input.preprocess_example(x_data), x_data)
self.assertEqual(dataframe_input.serialize(x_data, True), x_data)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dataframe_input.save_flagged(
tmpdirname, "dataframe_input", x_data, None
)
self.assertEqual(json.dumps(x_data), to_save)
restored = dataframe_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(x_data, restored)
self.assertEqual(dataframe_input.postprocess(x_data), x_data)
self.assertIsInstance(dataframe_input.generate_sample(), list)
dataframe_input = gr.Dataframe(
@ -1122,26 +1010,6 @@ class TestDataframe(unittest.TestCase):
with self.assertRaises(ValueError):
wrong_type = gr.Dataframe(type="unknown")
wrong_type.postprocess(0)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dataframe_output.save_flagged(
tmpdirname, "dataframe_output", output, None
)
self.assertEqual(
to_save,
json.dumps(
{
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
}
),
)
self.assertEqual(
dataframe_output.restore_flagged(tmpdirname, to_save, None),
{
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
},
)
# When the headers don't match the data
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
@ -1169,14 +1037,14 @@ class TestDataframe(unittest.TestCase):
"""
x_data = {"data": [[1, 2, 3], [4, 5, 6]]}
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(await iface.process([x_data]), [6])
self.assertEqual(await iface([x_data]), [6])
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]], "headers": [1, 2, 3]}
def get_last(my_list):
return my_list[-1][-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(await iface.process([x_data]), ["Sal"])
self.assertEqual(await iface([x_data]), ["Sal"])
async def test_in_interface_as_output(self):
"""
@ -1188,7 +1056,7 @@ class TestDataframe(unittest.TestCase):
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
(await iface.process([{"data": [[2, 3, 4]]}]))[0],
(await iface([{"data": [[2, 3, 4]]}]))[0],
{"data": [[True, False, True]], "headers": [1, 2, 3]},
)
@ -1196,21 +1064,13 @@ class TestDataframe(unittest.TestCase):
class TestVideo(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, serialize, deserialize, save_flagged, restore_flagged, generate_sample, get_config
Preprocess, serialize, deserialize, generate_sample, get_config
"""
x_video = deepcopy(media_data.BASE64_VIDEO)
video_input = gr.Video()
output = video_input.preprocess(x_video)
self.assertIsInstance(output, str)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None)
self.assertEqual("video_input/0.mp4", to_save)
to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None)
self.assertEqual("video_input/1.mp4", to_save)
restored = video_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored["file_name"], "video_input/1.mp4")
self.assertIsInstance(video_input.generate_sample(), dict)
video_input = gr.Video(label="Upload Your Video")
self.assertEqual(
@ -1234,30 +1094,23 @@ class TestVideo(unittest.TestCase):
video_input = gr.Video(format="avi")
self.assertEqual(video_input.preprocess(x_video)[-3:], "avi")
self.assertEqual(
video_input.serialize(x_video["name"], True)["data"], x_video["data"]
assert filecmp.cmp(
video_input.serialize(x_video["name"])["name"], x_video["name"]
)
# Output functionalities
y_vid_path = "test/test_files/video_sample.mp4"
video_output = gr.Video()
self.assertTrue(video_output.postprocess(y_vid_path)["name"].endswith("mp4"))
self.assertTrue(
video_output.postprocess(y_vid_path)["data"].startswith(
"data:video/mp4;base64,"
)
video_output.deserialize(
{
"name": None,
"data": deepcopy(media_data.BASE64_VIDEO)["data"],
"is_file": False,
}
).endswith(".mp4")
)
self.assertTrue(
video_output.deserialize(deepcopy(media_data.BASE64_VIDEO)).endswith(".mp4")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = video_output.save_flagged(
tmpdirname, "video_output", deepcopy(media_data.BASE64_VIDEO), None
)
self.assertEqual("video_output/0.mp4", to_save)
to_save = video_output.save_flagged(
tmpdirname, "video_output", deepcopy(media_data.BASE64_VIDEO), None
)
self.assertEqual("video_output/1.mp4", to_save)
async def test_in_interface(self):
"""
@ -1265,13 +1118,13 @@ class TestVideo(unittest.TestCase):
"""
x_video = deepcopy(media_data.BASE64_VIDEO)
iface = gr.Interface(lambda x: x, "video", "playable_video")
self.assertEqual((await iface.process([x_video]))[0]["data"], x_video["data"])
self.assertEqual((await iface([x_video]))[0]["data"], x_video["data"])
class TestTimeseries(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess, save_flagged, restore_flagged, generate_sample, get_config,
Preprocess, postprocess, generate_sample, get_config,
"""
timeseries_input = gr.Timeseries(x="time", y=["retail", "food", "other"])
x_timeseries = {
@ -1281,14 +1134,6 @@ class TestTimeseries(unittest.TestCase):
output = timeseries_input.preprocess(x_timeseries)
self.assertIsInstance(output, pd.core.frame.DataFrame)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = timeseries_input.save_flagged(
tmpdirname, "video_input", x_timeseries, None
)
self.assertEqual(json.dumps(x_timeseries), to_save)
restored = timeseries_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(x_timeseries, restored)
self.assertIsInstance(timeseries_input.generate_sample(), dict)
timeseries_input = gr.Timeseries(
x="time", y="retail", label="Upload Your Timeseries"
@ -1353,28 +1198,6 @@ class TestTimeseries(unittest.TestCase):
},
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = timeseries_output.save_flagged(
tmpdirname, "timeseries_output", output, None
)
self.assertEqual(
to_save,
'{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
'["jack", 18]]}',
)
self.assertEqual(
timeseries_output.restore_flagged(tmpdirname, to_save, None),
{
"headers": ["Name", "Age"],
"data": [
["Tom", 20],
["nick", 21],
["krish", 19],
["jack", 18],
],
},
)
async def test_in_interface_as_input(self):
"""
Interface, process
@ -1386,7 +1209,7 @@ class TestTimeseries(unittest.TestCase):
}
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
self.assertEqual(
await iface.process([x_timeseries]),
await iface([x_timeseries]),
[
{
"headers": ["time", "retail", "food", "other"],
@ -1417,7 +1240,7 @@ class TestTimeseries(unittest.TestCase):
)
}
self.assertEqual(
await iface.process([df]),
await iface([df]),
[
{
"headers": ["time", "retail", "food", "other"],
@ -1443,19 +1266,15 @@ class TestNames(unittest.TestCase):
class TestLabel(unittest.TestCase):
def test_component_functions(self):
"""
Process, postprocess, deserialize, save_flagged, restore_flagged
Process, postprocess, deserialize
"""
y = "happy"
label_output = gr.Label()
label = label_output.postprocess(y)
self.assertDictEqual(label, {"label": "happy"})
self.assertEqual(label_output.deserialize(y), y)
self.assertEqual(label_output.deserialize(label), y)
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, y)
y = {3: 0.7, 1: 0.2, 0: 0.1}
label_output = gr.Label()
self.assertEqual(json.load(open(label_output.deserialize(label))), label)
y = {3: 0.7, 1: 0.2, 0: 0.1}
label = label_output.postprocess(y)
self.assertDictEqual(
label,
@ -1483,20 +1302,6 @@ class TestLabel(unittest.TestCase):
with self.assertRaises(ValueError):
label_output.postprocess([1, 2, 3])
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
self.assertEqual(
label_output.restore_flagged(tmpdir, to_save, None),
{
"label": "3",
"confidences": [
{"label": "3", "confidence": 0.7},
{"label": "1", "confidence": 0.2},
],
},
)
self.assertEqual(
label_output.get_config(),
{
@ -1529,7 +1334,7 @@ class TestLabel(unittest.TestCase):
}
iface = gr.Interface(rgb_distribution, "image", "label")
output = (await iface.process([x_img]))[0]
output = (await iface([x_img]))[0]
self.assertDictEqual(
output,
{
@ -1581,7 +1386,7 @@ class TestHighlightedText(unittest.TestCase):
def test_component_functions(self):
"""
get_config, save_flagged, restore_flagged
get_config
"""
ht_output = gr.HighlightedText(color_map={"pos": "green", "neg": "red"})
self.assertEqual(
@ -1599,14 +1404,6 @@ class TestHighlightedText(unittest.TestCase):
"interactive": None,
},
)
ht = {"pos": "Hello ", "neg": "World"}
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = ht_output.save_flagged(tmpdirname, "ht_output", ht, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(
ht_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
async def test_in_interface(self):
"""
@ -1630,7 +1427,7 @@ class TestHighlightedText(unittest.TestCase):
iface = gr.Interface(highlight_vowels, "text", "highlight")
self.assertListEqual(
(await iface.process(["Helloooo"]))[0],
(await iface(["Helloooo"]))[0],
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
)
@ -1638,20 +1435,12 @@ class TestHighlightedText(unittest.TestCase):
class TestJSON(unittest.TestCase):
def test_component_functions(self):
"""
Postprocess, save_flagged, restore_flagged
Postprocess
"""
js_output = gr.JSON()
self.assertTrue(
js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
)
js = {"pos": "Hello ", "neg": "World"}
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = js_output.save_flagged(tmpdirname, "js_output", js, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(
js_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
self.assertEqual(
js_output.get_config(),
{
@ -1692,7 +1481,7 @@ class TestJSON(unittest.TestCase):
["F", 30],
]
self.assertDictEqual(
(await iface.process([{"data": y_data, "headers": ["gender", "age"]}]))[0],
(await iface([{"data": y_data, "headers": ["gender", "age"]}]))[0],
{"M": 35, "F": 25, "O": 20},
)
@ -1726,7 +1515,7 @@ class TestHTML(unittest.TestCase):
return "<strong>" + text + "</strong>"
iface = gr.Interface(bold_text, "text", "html")
self.assertEqual((await iface.process(["test"]))[0], "<strong>test</strong>")
self.assertEqual((await iface(["test"]))[0], "<strong>test</strong>")
class TestModel3D(unittest.TestCase):
@ -1734,11 +1523,11 @@ class TestModel3D(unittest.TestCase):
"""
get_config
"""
component = gr.components.Model3D("test/test_files/Box.gltf", label="Model")
component = gr.components.Model3D(None, label="Model")
self.assertEqual(
{
"clearColor": [0.2, 0.2, 0.2, 1.0],
"value": media_data.BASE64_MODEL3D,
"value": None,
"label": "Model",
"show_label": True,
"interactive": None,
@ -1756,20 +1545,20 @@ class TestModel3D(unittest.TestCase):
"""
iface = gr.Interface(lambda x: x, "model3d", "model3d")
input_data = gr.media_data.BASE64_MODEL3D["data"]
output_data = (await iface.process([{"name": "Box.gltf", "data": input_data}]))[
0
]["data"]
output_data = (await iface([{"name": "Box.gltf", "data": input_data}]))[0][
"data"
]
self.assertEqual(input_data.split(";")[1], output_data.split(";")[1])
class TestColorPicker(unittest.TestCase):
def test_component_functions(self):
"""
Preprocess, postprocess, serialize, save_flagged, restore_flagged, tokenize, generate_sample, get_config
Preprocess, postprocess, serialize, tokenize, generate_sample, get_config
"""
color_picker_input = gr.ColorPicker()
self.assertEqual(color_picker_input.preprocess("#000000"), "#000000")
self.assertEqual(color_picker_input.preprocess_example("#000000"), "#000000")
self.assertEqual(color_picker_input.postprocess("#000000"), "#000000")
self.assertEqual(color_picker_input.postprocess(None), None)
self.assertEqual(color_picker_input.postprocess("#FFFFFF"), "#FFFFFF")
self.assertEqual(color_picker_input.serialize("#000000", True), "#000000")
@ -1796,7 +1585,7 @@ class TestColorPicker(unittest.TestCase):
Interface, process, interpret,
"""
iface = gr.Interface(lambda x: x, "colorpicker", "colorpicker")
self.assertEqual(await iface.process(["#000000"]), ["#000000"])
self.assertEqual(await iface(["#000000"]), ["#000000"])
async def test_in_interface_as_output(self):
"""
@ -1804,7 +1593,7 @@ class TestColorPicker(unittest.TestCase):
"""
iface = gr.Interface(lambda x: x, "colorpicker", gr.ColorPicker())
self.assertEqual(await iface.process(["#000000"]), ["#000000"])
self.assertEqual(await iface(["#000000"]), ["#000000"])
def test_static(self):
"""
@ -1815,7 +1604,7 @@ class TestColorPicker(unittest.TestCase):
@patch("uuid.uuid4", return_value="my-uuid")
def test_gallery_save_and_restore_flagged(my_uuid, tmp_path):
def test_gallery(mock_uuid):
gallery = gr.Gallery()
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
data = [
@ -1826,13 +1615,12 @@ def test_gallery_save_and_restore_flagged(my_uuid, tmp_path):
pathlib.Path(test_file_dir, "cheetah1.jpg")
),
]
label = "Gallery, 1"
path = gallery.save_flagged(str(tmp_path), label, data, encryption_key=None)
assert path == os.path.join("Gallery 1", "my-uuid")
assert sorted(os.listdir(os.path.join(tmp_path, path))) == ["0.png", "1.jpg"]
data_restored = gallery.restore_flagged(tmp_path, path, encryption_key=None)
assert data == data_restored
with tempfile.TemporaryDirectory() as tmpdir:
path = gallery.deserialize(data, tmpdir)
assert path.endswith("my-uuid")
data_restored = gallery.serialize(path)
assert sorted(data) == sorted(data_restored)
@patch("gradio.Slider.get_random_value", return_value=7)

View File

@ -62,9 +62,9 @@ class TestExamplesDataset:
class TestProcessExamples:
@pytest.mark.asyncio
async def test_process_example(self):
async def test_predict_example(self):
io = gr.Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction = await io.examples_handler.process_example(0)
prediction = await io.examples_handler.predict_example(0)
assert prediction[0] == "Hello World"
@pytest.mark.asyncio
@ -73,7 +73,7 @@ class TestProcessExamples:
return "Hello " + x
io = gr.Interface(coroutine, "text", "text", examples=[["World"]])
prediction = await io.examples_handler.process_example(0)
prediction = await io.examples_handler.predict_example(0)
assert prediction[0] == "Hello World"
@pytest.mark.asyncio

View File

@ -1,3 +1,4 @@
import json
import os
import pathlib
import unittest
@ -179,7 +180,7 @@ class TestLoadInterface(unittest.TestCase):
io = gr.Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english")
try:
output = io("I am happy, I love you")
self.assertGreater(output["POSITIVE"], 0.5)
assert json.load(open(output))["label"] == "POSITIVE"
except TooManyRequestsError:
pass
@ -187,7 +188,7 @@ class TestLoadInterface(unittest.TestCase):
io = gr.Blocks.load(name="models/google/vit-base-patch16-224")
try:
output = io("gradio/test_data/lion.jpg")
self.assertGreater(output["lion"], 0.5)
assert json.load(open(output))["label"] == "lion"
except TooManyRequestsError:
pass
@ -203,7 +204,7 @@ class TestLoadInterface(unittest.TestCase):
io = gr.Interface.load("spaces/abidlabs/titanic-survival")
try:
output = io("male", 77, 10)
self.assertLess(output["Survives"], 0.5)
assert json.load(open(output))["label"] == "Perishes"
except TooManyRequestsError:
pass

View File

@ -1,3 +1,4 @@
import json
import os
import unittest
@ -16,42 +17,38 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestSeries:
@pytest.mark.asyncio
async def test_in_interface(self):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.Textbox())
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.Textbox())
series = mix.Series(io1, io2)
assert await series.process(["Hello"]) == ["Hello World!"]
assert series("Hello") == "Hello World!"
# @pytest.mark.asyncio
# @pytest.mark.flaky
# async def test_with_external(self):
# io1 = gr.Interface.load("spaces/abidlabs/image-identity")
# io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
# series = mix.Series(io1, io2)
# try:
# output = series("gradio/test_data/lion.jpg")
# assert output["lion"] > 0.5
# except TooManyRequestsError:
# pass
@pytest.mark.flaky
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
series = mix.Series(io1, io2)
try:
output = series("gradio/test_data/lion.jpg")
assert json.load(open(output))["label"] == "lion"
except TooManyRequestsError:
pass
class TestParallel:
@pytest.mark.asyncio
async def test_in_interface(self):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.Textbox())
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
parallel = mix.Parallel(io1, io2)
assert await parallel.process(["Hello"]) == ["Hello World 1!", "Hello World 2!"]
assert parallel("Hello") == ["Hello World 1!", "Hello World 2!"]
@pytest.mark.asyncio
async def test_multiple_return_in_interface(self):
def test_multiple_return_in_interface(self):
io1 = gr.Interface(
lambda x: (x, x + x), "textbox", [gr.Textbox(), gr.Textbox()]
)
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
parallel = mix.Parallel(io1, io2)
assert await parallel.process(["Hello"]) == [
assert parallel("Hello") == [
"Hello",
"HelloHello",
"Hello World 2!",

View File

@ -66,7 +66,6 @@ class TestRoutes(unittest.TestCase):
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post(
"/api/predict/",

View File

@ -1,4 +1,5 @@
import copy
import json
import os
import unittest
import unittest.mock as mock
@ -25,9 +26,10 @@ from gradio.utils import (
format_ner_list,
get_local_ip_address,
ipython_check,
json,
launch_analytics,
readme_to_html,
sanitize_list_for_csv,
sanitize_value_for_csv,
version_check,
)
@ -467,5 +469,25 @@ async def test_validate_and_fail_with_function(respx_mock):
assert client_response.exception is not None
class TestSanitizeForCSV:
def test_unsafe_value(self):
assert sanitize_value_for_csv("=OPEN()") == "'=OPEN()"
assert sanitize_value_for_csv("=1+2") == "'=1+2"
assert sanitize_value_for_csv('=1+2";=1+2') == "'=1+2\";=1+2"
def test_safe_value(self):
assert sanitize_value_for_csv(4) == 4
assert sanitize_value_for_csv(-44.44) == -44.44
assert sanitize_value_for_csv("1+1=2") == "1+1=2"
assert sanitize_value_for_csv("1aaa2") == "1aaa2"
def test_list(self):
assert sanitize_list_for_csv([4, "def=", "=gh+ij"]) == [4, "def=", "'=gh+ij"]
assert sanitize_list_for_csv(
[["=abc", "def", "gh,+ij"], ["abc", "=def", "+ghij"]]
) == [["'=abc", "def", "'gh,+ij"], ["abc", "'=def", "'+ghij"]]
assert sanitize_list_for_csv([1, ["ab", "=de"]]) == [1, ["ab", "'=de"]]
if __name__ == "__main__":
unittest.main()

View File

@ -57,10 +57,16 @@
if (!value) return;
let base64_model_content = value["data"];
let raw_content = BABYLON.Tools.DecodeBase64(base64_model_content);
let blob = new Blob([raw_content]);
let url = URL.createObjectURL(blob);
let url: string;
if (value.is_file) {
url = value.data;
} else {
let base64_model_content = value.data;
let raw_content = BABYLON.Tools.DecodeBase64(base64_model_content);
let blob = new Blob([raw_content]);
url = URL.createObjectURL(blob);
}
BABYLON.SceneLoader.Append(
"",
url,

View File

@ -20,7 +20,7 @@
});
afterUpdate(() => {
if (value != null && value.is_example) {
if (value != null && value.is_file) {
addNewModel();
}
});
@ -80,7 +80,7 @@
if (!value) return;
let url: string;
if (value.is_example) {
if (value.is_file) {
url = value.data;
} else {
let base64_model_content = value.data;

View File

@ -2,5 +2,5 @@ export interface FileData {
name: string;
size?: number;
data: string;
is_example?: boolean;
is_file?: boolean;
}

View File

@ -10,7 +10,7 @@ export function normalise_file(
name: "file_data",
data: file
};
} else if (file.is_example) {
} else if (file.is_file) {
file.data = root + "file/" + file.name;
}
return file;

View File

@ -1,4 +1,4 @@
lockfileVersion: 5.3
lockfileVersion: 5.4
importers:
@ -43,7 +43,7 @@ importers:
'@tailwindcss/forms': 0.5.0_tailwindcss@3.1.6
'@testing-library/dom': 8.11.3
'@testing-library/svelte': 3.1.0_svelte@3.49.0
'@testing-library/user-event': 13.5.0_@testing-library+dom@8.11.3
'@testing-library/user-event': 13.5.0_gzufz4q333be4gqfrvipwvqt6a
autoprefixer: 10.4.4_postcss@8.4.6
babylonjs: 5.18.0
babylonjs-loaders: 5.18.0
@ -56,13 +56,13 @@ importers:
postcss: 8.4.6
postcss-nested: 5.0.6_postcss@8.4.6
prettier: 2.6.2
prettier-plugin-svelte: 2.7.0_prettier@2.6.2+svelte@3.49.0
prettier-plugin-svelte: 2.7.0_3cyj5wbackxvw67rnaarcmbw7y
sirv: 2.0.2
sirv-cli: 2.0.2
svelte: 3.49.0
svelte-check: 2.8.0_postcss@8.4.6+svelte@3.49.0
svelte-check: 2.8.0_mgmdnb6x5rpawk37gozc2sbtta
svelte-i18n: 3.3.13_svelte@3.49.0
svelte-preprocess: 4.10.6_62d50a01257de5eec5be08cad9d3ed66
svelte-preprocess: 4.10.6_mlkquajfpxs65rn6bdfntu7nmy
tailwindcss: 3.1.6
tinyspy: 0.3.0
typescript: 4.7.4
@ -131,7 +131,7 @@ importers:
'@gradio/video': link:../video
mime-types: 2.1.34
playwright: 1.22.2
svelte-i18n: 3.3.13_svelte@3.49.0
svelte-i18n: 3.3.13
packages/atoms:
specifiers:
@ -385,13 +385,13 @@ importers:
'@gradio/upload': link:../upload
'@gradio/video': link:../video
devDependencies:
'@sveltejs/adapter-auto': 1.0.0-next.64
'@sveltejs/kit': 1.0.0-next.318_svelte@3.49.0
'@sveltejs/adapter-auto': 1.0.0-next.65
'@sveltejs/kit': 1.0.0-next.318
autoprefixer: 10.4.2_postcss@8.4.6
postcss: 8.4.6
postcss-load-config: 3.1.1
svelte-check: 2.4.1_736abba5ed1eb6f8ecf70b1d49ead14b
svelte-preprocess: 4.10.2_d50790bb30dd88cc44babe7efa52bace
svelte-check: 2.4.1_2y4otvh2n6klv6metqycpfiuzy
svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu
tailwindcss: 3.0.23_autoprefixer@10.4.2
tslib: 2.3.1
typescript: 4.5.5
@ -565,12 +565,12 @@ packages:
estree-walker: 2.0.2
picomatch: 2.3.1
/@sveltejs/adapter-auto/1.0.0-next.64:
resolution: {integrity: sha512-Q8DwcS6wl1GovzS9JJzaD/WL/Lfk1ur4nAF1HtmsUvZDpsPBVDqnK2AhYU4G3oFNiuHstrjAogMy5th8ptSFGw==}
/@sveltejs/adapter-auto/1.0.0-next.65:
resolution: {integrity: sha512-wYEcOeuCrswcmeOdmbaq+WxTp7vWN1fG1yPvmdzqe2LoUchOw6FQb6X/fR8miX6L8MXQXJteA0ntqE3FKqaBsw==}
dependencies:
'@sveltejs/adapter-cloudflare': 1.0.0-next.31
'@sveltejs/adapter-netlify': 1.0.0-next.71
'@sveltejs/adapter-vercel': 1.0.0-next.66
'@sveltejs/adapter-vercel': 1.0.0-next.67
transitivePeerDependencies:
- encoding
- supports-color
@ -593,8 +593,8 @@ packages:
tiny-glob: 0.2.9
dev: true
/@sveltejs/adapter-vercel/1.0.0-next.66:
resolution: {integrity: sha512-s3Hcxu9nCG/rR3C3cFbdQGjTa5W4K2kRcc6S5Xefx7itbrw+4v3KpO8ZPB6qM55XDwVxuG7260NMHVI6MUGmSA==}
/@sveltejs/adapter-vercel/1.0.0-next.67:
resolution: {integrity: sha512-xg85d/vlivbTaZu70zmaPNkrY1YZhDrcxljuwVWO0LCzA4DACIA7CnXI9klUiXM5SPpsB8BhY6dS8sW5cDYWzw==}
dependencies:
'@vercel/nft': 0.21.0
esbuild: 0.14.53
@ -603,16 +603,15 @@ packages:
- supports-color
dev: true
/@sveltejs/kit/1.0.0-next.318_svelte@3.49.0:
/@sveltejs/kit/1.0.0-next.318:
resolution: {integrity: sha512-/M/XNvEqK71KCGro1xLuiUuklsMPe+G5DiVMs39tpfFIFhH4oCzAt+YBaIZDKORogGz3QDaYc5BV+eFv9E5cyw==}
engines: {node: '>=14.13'}
hasBin: true
peerDependencies:
svelte: ^3.44.0
dependencies:
'@sveltejs/vite-plugin-svelte': 1.0.0-next.41_svelte@3.49.0+vite@2.9.5
'@sveltejs/vite-plugin-svelte': 1.0.0-next.41_vite@2.9.5
sade: 1.8.1
svelte: 3.49.0
vite: 2.9.5
transitivePeerDependencies:
- diff-match-patch
@ -622,7 +621,7 @@ packages:
- supports-color
dev: true
/@sveltejs/vite-plugin-svelte/1.0.0-next.41_svelte@3.49.0+vite@2.9.5:
/@sveltejs/vite-plugin-svelte/1.0.0-next.41_vite@2.9.5:
resolution: {integrity: sha512-2kZ49mpi/YW1PIPvKaJNSSwIFgmw9QUf1+yaNa4U8yJD6AsfSHXAU3goscWbi1jfWnSg2PhvwAf+bvLCdp2F9g==}
engines: {node: ^14.13.1 || >= 16}
peerDependencies:
@ -637,8 +636,7 @@ packages:
debug: 4.3.4
kleur: 4.1.4
magic-string: 0.26.1
svelte: 3.49.0
svelte-hmr: 0.14.11_svelte@3.49.0
svelte-hmr: 0.14.11
vite: 2.9.5
transitivePeerDependencies:
- supports-color
@ -714,7 +712,7 @@ packages:
svelte: 3.49.0
dev: false
/@testing-library/user-event/13.5.0_@testing-library+dom@8.11.3:
/@testing-library/user-event/13.5.0_gzufz4q333be4gqfrvipwvqt6a:
resolution: {integrity: sha512-5Kwtbo3Y/NowpkbRuSepbyMFkZmHgD+vPzYB/RJ4oxt5Gj/avFFBYjhw27cqSVPVw/3a67NK1PbiIr9k4Gwmdg==}
engines: {node: '>=10', npm: '>=6'}
peerDependencies:
@ -2883,7 +2881,7 @@ packages:
picocolors: 1.0.0
source-map-js: 1.0.2
/prettier-plugin-svelte/2.7.0_prettier@2.6.2+svelte@3.49.0:
/prettier-plugin-svelte/2.7.0_3cyj5wbackxvw67rnaarcmbw7y:
resolution: {integrity: sha512-fQhhZICprZot2IqEyoiUYLTRdumULGRvw0o4dzl5jt0jfzVWdGqeYW27QTWAeXhoupEZJULmNoH3ueJwUWFLIA==}
peerDependencies:
prettier: ^1.16.4 || ^2.0.0
@ -3287,7 +3285,7 @@ packages:
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
engines: {node: '>= 0.4'}
/svelte-check/2.4.1_736abba5ed1eb6f8ecf70b1d49ead14b:
/svelte-check/2.4.1_2y4otvh2n6klv6metqycpfiuzy:
resolution: {integrity: sha512-xhf3ShP5rnRwBokrgTBJ/0cO9QIc1DAVu1NWNRTfCDsDBNjGmkS3HgitgUadRuoMKj1+irZR/yHJ+Uqobnkbrw==}
hasBin: true
peerDependencies:
@ -3300,8 +3298,7 @@ packages:
picocolors: 1.0.0
sade: 1.8.1
source-map: 0.7.3
svelte: 3.49.0
svelte-preprocess: 4.10.2_d50790bb30dd88cc44babe7efa52bace
svelte-preprocess: 4.10.2_bw7ic75prjd4umr4fb55sbospu
typescript: 4.5.5
transitivePeerDependencies:
- '@babel/core'
@ -3316,7 +3313,7 @@ packages:
- sugarss
dev: true
/svelte-check/2.8.0_postcss@8.4.6+svelte@3.49.0:
/svelte-check/2.8.0_mgmdnb6x5rpawk37gozc2sbtta:
resolution: {integrity: sha512-HRL66BxffMAZusqe5I5k26mRWQ+BobGd9Rxm3onh7ZVu0nTk8YTKJ9vu3LVPjUGLU9IX7zS+jmwPVhJYdXJ8vg==}
hasBin: true
peerDependencies:
@ -3329,7 +3326,7 @@ packages:
picocolors: 1.0.0
sade: 1.8.1
svelte: 3.49.0
svelte-preprocess: 4.10.6_62d50a01257de5eec5be08cad9d3ed66
svelte-preprocess: 4.10.6_mlkquajfpxs65rn6bdfntu7nmy
typescript: 4.7.4
transitivePeerDependencies:
- '@babel/core'
@ -3344,6 +3341,13 @@ packages:
- sugarss
dev: false
/svelte-hmr/0.14.11:
resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==}
engines: {node: ^12.20 || ^14.13.1 || >= 16}
peerDependencies:
svelte: '>=3.19.0'
dev: true
/svelte-hmr/0.14.11_svelte@3.49.0:
resolution: {integrity: sha512-R9CVfX6DXxW1Kn45Jtmx+yUe+sPhrbYSUp7TkzbW0jI5fVPn6lsNG9NEs5dFg5qRhFNAoVdRw5qQDLALNKhwbQ==}
engines: {node: ^12.20 || ^14.13.1 || >= 16}
@ -3351,6 +3355,21 @@ packages:
svelte: '>=3.19.0'
dependencies:
svelte: 3.49.0
dev: false
/svelte-i18n/3.3.13:
resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==}
engines: {node: '>= 11.15.0'}
hasBin: true
peerDependencies:
svelte: ^3.25.1
dependencies:
deepmerge: 4.2.2
estree-walker: 2.0.2
intl-messageformat: 9.11.4
sade: 1.8.1
tiny-glob: 0.2.9
dev: false
/svelte-i18n/3.3.13_svelte@3.49.0:
resolution: {integrity: sha512-RQM+ys4+Y9ztH//tX22H1UL2cniLNmIR+N4xmYygV6QpQ6EyQvloZiENRew8XrVzfvJ8HaE8NU6/yurLkl7z3g==}
@ -3367,7 +3386,7 @@ packages:
tiny-glob: 0.2.9
dev: false
/svelte-preprocess/4.10.2_d50790bb30dd88cc44babe7efa52bace:
/svelte-preprocess/4.10.2_bw7ic75prjd4umr4fb55sbospu:
resolution: {integrity: sha512-aPpkCreSo8EL/y8kJSa1trhiX0oyAtTjlNNM7BNjRAsMJ8Yy2LtqHt0zyd4pQPXt+D4PzbO3qTjjio3kwOxDlA==}
engines: {node: '>= 9.11.2'}
requiresBuild: true
@ -3416,11 +3435,10 @@ packages:
postcss-load-config: 3.1.1
sorcery: 0.10.0
strip-indent: 3.0.0
svelte: 3.49.0
typescript: 4.5.5
dev: true
/svelte-preprocess/4.10.6_62d50a01257de5eec5be08cad9d3ed66:
/svelte-preprocess/4.10.6_mlkquajfpxs65rn6bdfntu7nmy:
resolution: {integrity: sha512-I2SV1w/AveMvgIQlUF/ZOO3PYVnhxfcpNyGt8pxpUVhPfyfL/CZBkkw/KPfuFix5FJ9TnnNYMhACK3DtSaYVVQ==}
engines: {node: '>= 9.11.2'}
requiresBuild: true
@ -3479,6 +3497,7 @@ packages:
/svelte/3.49.0:
resolution: {integrity: sha512-+lmjic1pApJWDfPCpUUTc1m8azDqYCG1JN9YEngrx/hUyIcFJo6VZhj0A1Ai0wqoHcEIuQy+e9tk+4uDgdtsFA==}
engines: {node: '>= 8'}
dev: false
/sync-request/6.1.0:
resolution: {integrity: sha512-8fjNkrNlNCrVc/av+Jn+xxqfCjYaBoHqCsDz6mt030UMxJGr+GSfCV1dQt2gRtlL63+VPidwDVLr7V2OcTSdRw==}