mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
Merge pull request #560 from LysandreJik/upgrade-black
Upgrade black to version ~=22.0
This commit is contained in:
commit
e19eb0dc3c
@ -15,7 +15,7 @@ def plot_forecast(final_year, companies, noise, show_legend, point_style):
|
||||
ax = fig.add_subplot(111)
|
||||
for i, company in enumerate(companies):
|
||||
series = np.arange(0, year_count, dtype=float)
|
||||
series = series ** 2 * (i + 1)
|
||||
series = series**2 * (i + 1)
|
||||
series += np.random.rand(year_count) * noise
|
||||
ax.plot(x, series, plt_format)
|
||||
if show_legend:
|
||||
|
@ -2,9 +2,11 @@ import gradio as gr
|
||||
|
||||
user_db = {"admin": "admin", "foo": "bar"}
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!!"
|
||||
|
||||
|
||||
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
||||
if __name__ == "__main__":
|
||||
iface.launch(auth=lambda u, p: user_db.get(u) == p)
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!!"
|
||||
|
||||
|
||||
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=greet,
|
||||
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),
|
||||
|
@ -1,11 +1,13 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name, is_morning, temperature):
|
||||
salutation = "Good morning" if is_morning else "Good evening"
|
||||
greeting = "%s %s. It is %s degrees today" % (salutation, name, temperature)
|
||||
celsius = (temperature - 32) * 5 / 9
|
||||
return greeting, round(celsius, 2)
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=greet,
|
||||
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],
|
||||
|
@ -86,11 +86,9 @@ iface = gr.Interface(
|
||||
fn,
|
||||
inputs=[
|
||||
gr.inputs.Textbox(default="Lorem ipsum", label="Textbox"),
|
||||
gr.inputs.Textbox(lines=3, placeholder="Type here..",
|
||||
label="Textbox 2"),
|
||||
gr.inputs.Textbox(lines=3, placeholder="Type here..", label="Textbox 2"),
|
||||
gr.inputs.Number(label="Number", default=42),
|
||||
gr.inputs.Slider(minimum=10, maximum=20, default=15,
|
||||
label="Slider: 10 - 20"),
|
||||
gr.inputs.Slider(minimum=10, maximum=20, default=15, label="Slider: 10 - 20"),
|
||||
gr.inputs.Slider(maximum=20, step=0.04, label="Slider: step @ 0.04"),
|
||||
gr.inputs.Checkbox(label="Checkbox"),
|
||||
gr.inputs.CheckboxGroup(
|
||||
@ -99,17 +97,14 @@ iface = gr.Interface(
|
||||
gr.inputs.Radio(label="Radio", choices=CHOICES, default=CHOICES[2]),
|
||||
gr.inputs.Dropdown(label="Dropdown", choices=CHOICES),
|
||||
gr.inputs.Image(label="Image", optional=True),
|
||||
gr.inputs.Image(label="Image w/ Cropper",
|
||||
tool="select", optional=True),
|
||||
gr.inputs.Image(label="Image w/ Cropper", tool="select", optional=True),
|
||||
gr.inputs.Image(label="Sketchpad", source="canvas", optional=True),
|
||||
gr.inputs.Image(label="Webcam", source="webcam", optional=True),
|
||||
gr.inputs.Video(label="Video", optional=True),
|
||||
gr.inputs.Audio(label="Audio", optional=True),
|
||||
gr.inputs.Audio(label="Microphone",
|
||||
source="microphone", optional=True),
|
||||
gr.inputs.Audio(label="Microphone", source="microphone", optional=True),
|
||||
gr.inputs.File(label="File", optional=True),
|
||||
gr.inputs.Dataframe(label="Dataframe", headers=[
|
||||
"Name", "Age", "Gender"]),
|
||||
gr.inputs.Dataframe(label="Dataframe", headers=["Name", "Age", "Gender"]),
|
||||
gr.inputs.Timeseries(x="time", y=["price", "value"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
@ -118,8 +113,9 @@ iface = gr.Interface(
|
||||
gr.outputs.Audio(label="Audio"),
|
||||
gr.outputs.Image(label="Image"),
|
||||
gr.outputs.Video(label="Video"),
|
||||
gr.outputs.HighlightedText(label="HighlightedText", color_map={
|
||||
"punc": "pink", "test 0": "blue"}),
|
||||
gr.outputs.HighlightedText(
|
||||
label="HighlightedText", color_map={"punc": "pink", "test 0": "blue"}
|
||||
),
|
||||
gr.outputs.HighlightedText(label="HighlightedText", show_legend=True),
|
||||
gr.outputs.JSON(label="JSON"),
|
||||
gr.outputs.HTML(label="HTML"),
|
||||
@ -127,8 +123,7 @@ iface = gr.Interface(
|
||||
gr.outputs.Dataframe(label="Dataframe"),
|
||||
gr.outputs.Dataframe(label="Numpy", type="numpy"),
|
||||
gr.outputs.Carousel("image", label="Carousel"),
|
||||
gr.outputs.Timeseries(
|
||||
x="time", y=["price", "value"], label="Timeseries"),
|
||||
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries"),
|
||||
],
|
||||
examples=[
|
||||
[
|
||||
|
@ -6,11 +6,19 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer)
|
||||
from pytorch_transformers import (
|
||||
WEIGHTS_NAME,
|
||||
BertConfig,
|
||||
BertForQuestionAnswering,
|
||||
BertTokenizer,
|
||||
)
|
||||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
||||
from utils import (get_answer, input_to_squad_example,
|
||||
squad_examples_to_features, to_list)
|
||||
from utils import (
|
||||
get_answer,
|
||||
input_to_squad_example,
|
||||
squad_examples_to_features,
|
||||
to_list,
|
||||
)
|
||||
|
||||
RawResult = collections.namedtuple(
|
||||
"RawResult", ["unique_id", "start_logits", "end_logits"]
|
||||
|
@ -5,8 +5,7 @@ import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||
whitespace_tokenize)
|
||||
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ def reverse_audio(audio):
|
||||
sr, data = audio
|
||||
return (sr, np.flipud(data))
|
||||
|
||||
|
||||
iface = gr.Interface(reverse_audio, "microphone", "audio", examples="audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -13,7 +13,7 @@ def stock_forecast(final_year, companies, noise, show_legend, point_style):
|
||||
ax = fig.add_subplot(111)
|
||||
for i, company in enumerate(companies):
|
||||
series = np.arange(0, year_count, dtype=float)
|
||||
series = series ** 2 * (i + 1)
|
||||
series = series**2 * (i + 1)
|
||||
series += np.random.rand(year_count) * noise
|
||||
ax.plot(x, series, plt_format)
|
||||
if show_legend:
|
||||
|
@ -2,7 +2,7 @@ import pkg_resources
|
||||
|
||||
from gradio.routes import get_state, set_state
|
||||
from gradio.flagging import *
|
||||
from gradio.interface import *
|
||||
from gradio.interface import *
|
||||
from gradio.mix import *
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
|
@ -34,11 +34,7 @@ class Component:
|
||||
return {}
|
||||
|
||||
def save_flagged(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool
|
||||
self, dir: str, label: str, data: Any, encryption_key: bool
|
||||
) -> Any:
|
||||
"""
|
||||
Saves flagged data from component
|
||||
@ -52,14 +48,10 @@ class Component:
|
||||
return data
|
||||
|
||||
def save_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool
|
||||
self, dir: str, label: str, data: Any, encryption_key: bool
|
||||
) -> str:
|
||||
"""
|
||||
Saved flagged data (e.g. image or audio) as a file and returns filepath
|
||||
Saved flagged data (e.g. image or audio) as a file and returns filepath
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
@ -81,9 +73,9 @@ class Component:
|
||||
return label + "/" + new_file_name
|
||||
|
||||
def restore_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
file: str,
|
||||
self,
|
||||
dir: str,
|
||||
file: str,
|
||||
encryption_key: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -3,18 +3,13 @@ from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHA256
|
||||
|
||||
|
||||
def get_key(
|
||||
password: str
|
||||
) -> bytes:
|
||||
def get_key(password: str) -> bytes:
|
||||
"""Generates an encryption key based on the password provided."""
|
||||
key = SHA256.new(password.encode()).digest()
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(
|
||||
key: bytes,
|
||||
source: bytes
|
||||
) -> bytes:
|
||||
def encrypt(key: bytes, source: bytes) -> bytes:
|
||||
"""Encrypts source data using the provided encryption key"""
|
||||
IV = Random.new().read(AES.block_size) # generate IV
|
||||
encryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
@ -24,10 +19,7 @@ def encrypt(
|
||||
return data
|
||||
|
||||
|
||||
def decrypt(
|
||||
key: bytes,
|
||||
source: bytes
|
||||
) -> bytes:
|
||||
def decrypt(key: bytes, source: bytes) -> bytes:
|
||||
IV = source[: AES.block_size] # extract the IV from the beginning
|
||||
decryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
data = decryptor.decrypt(source[AES.block_size :]) # decrypt
|
||||
|
@ -35,12 +35,13 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
content_type = r.headers.get("content-type")
|
||||
# Case 2: the data prefix is a key in the response
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
try:
|
||||
content_type = r.json()[0]["content-type"]
|
||||
base64_repr = r.json()[0]["blob"]
|
||||
except KeyError:
|
||||
raise ValueError("Cannot determine content type returned"
|
||||
"by external API.")
|
||||
raise ValueError(
|
||||
"Cannot determine content type returned" "by external API."
|
||||
)
|
||||
# Case 3: the data prefix is included in the response headers
|
||||
else:
|
||||
pass
|
||||
@ -66,7 +67,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
i["data"].split(",")[1]
|
||||
), # convert the base64 representation to binary
|
||||
"postprocess": encode_to_base64,
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
|
||||
|
@ -111,6 +111,7 @@ class CSVLogger(FlaggingCallback):
|
||||
The default implementation of the FlaggingCallback abstract class.
|
||||
Logs the input and output data to a CSV file. Supports encryption.
|
||||
"""
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
@ -323,9 +324,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
# Generate the headers and dataset_infos
|
||||
if is_new:
|
||||
headers = []
|
||||
|
||||
|
||||
for i, component in enumerate(interface.input_components):
|
||||
component_label = interface.config["input_components"][i]["label"] or "Input_{}".format(i)
|
||||
component_label = interface.config["input_components"][i][
|
||||
"label"
|
||||
] or "Input_{}".format(i)
|
||||
headers.append(component_label)
|
||||
infos["flagged"]["features"][component_label] = {
|
||||
"dtype": "string",
|
||||
@ -341,7 +344,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
break
|
||||
|
||||
for i, component in enumerate(interface.output_components):
|
||||
component_label = interface.config["output_components"][i]["label"] or "Output_{}".format(i)
|
||||
component_label = interface.config["output_components"][i][
|
||||
"label"
|
||||
] or "Output_{}".format(i)
|
||||
headers.append(component_label)
|
||||
infos["flagged"]["features"][component_label] = {
|
||||
"dtype": "string",
|
||||
@ -368,7 +373,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
for i, component in enumerate(interface.input_components):
|
||||
label = interface.config["input_components"][i]["label"] or "Input_{}".format(i)
|
||||
label = interface.config["input_components"][i][
|
||||
"label"
|
||||
] or "Input_{}".format(i)
|
||||
filepath = component.save_flagged(
|
||||
self.dataset_dir, label, input_data[i], None
|
||||
)
|
||||
@ -378,9 +385,13 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
||||
)
|
||||
for i, component in enumerate(interface.output_components):
|
||||
label = interface.config["output_components"][i]["label"] or "Output_{}".format(i)
|
||||
label = interface.config["output_components"][i][
|
||||
"label"
|
||||
] or "Output_{}".format(i)
|
||||
filepath = (
|
||||
component.save_flagged(self.dataset_dir, label, output_data[i], None)
|
||||
component.save_flagged(
|
||||
self.dataset_dir, label, output_data[i], None
|
||||
)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
|
@ -221,10 +221,8 @@ class Textbox(InputComponent):
|
||||
"""
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.array(tokens)[np.array(
|
||||
binary_mask_vector, dtype=bool)]
|
||||
masked_inputs.append(
|
||||
self.interpretation_separator.join(masked_input))
|
||||
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
|
||||
masked_inputs.append(self.interpretation_separator.join(masked_input))
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(
|
||||
@ -308,10 +306,8 @@ class Number(InputComponent):
|
||||
delta = 1.0 * self.interpretation_delta * x / 100
|
||||
elif self.interpretation_delta_type == "absolute":
|
||||
delta = self.interpretation_delta
|
||||
negatives = (x + np.arange(-self.interpretation_steps, 0)
|
||||
* delta).tolist()
|
||||
positives = (
|
||||
x + np.arange(1, self.interpretation_steps + 1) * delta).tolist()
|
||||
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
|
||||
positives = (x + np.arange(1, self.interpretation_steps + 1) * delta).tolist()
|
||||
return negatives + positives, {}
|
||||
|
||||
def get_interpretation_scores(
|
||||
@ -357,7 +353,7 @@ class Slider(InputComponent):
|
||||
if step is None:
|
||||
difference = maximum - minimum
|
||||
power = math.floor(math.log10(difference) - 2)
|
||||
step = 10 ** power
|
||||
step = 10**power
|
||||
self.step = step
|
||||
self.default = minimum if default is None else default
|
||||
self.test_input = self.default
|
||||
@ -406,8 +402,7 @@ class Slider(InputComponent):
|
||||
|
||||
def get_interpretation_neighbors(self, x) -> List[float]:
|
||||
return (
|
||||
np.linspace(self.minimum, self.maximum,
|
||||
self.interpretation_steps).tolist(),
|
||||
np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(),
|
||||
{},
|
||||
)
|
||||
|
||||
@ -944,8 +939,7 @@ class Image(InputComponent):
|
||||
masked_input = np.zeros_like(tokens[0], dtype=int)
|
||||
for token, b in zip(tokens, binary_mask_vector):
|
||||
masked_input = masked_input + token * int(b)
|
||||
masked_inputs.append(
|
||||
processing_utils.encode_array_to_base64(masked_input))
|
||||
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
|
||||
@ -1042,10 +1036,8 @@ class Video(InputComponent):
|
||||
file_name = file.name
|
||||
uploaded_format = file_name.split(".")[-1].lower()
|
||||
if self.type is not None and uploaded_format != self.type:
|
||||
output_file_name = file_name[0: file_name.rindex(
|
||||
".") + 1] + self.type
|
||||
ff = FFmpeg(inputs={file_name: None},
|
||||
outputs={output_file_name: None})
|
||||
output_file_name = file_name[0 : file_name.rindex(".") + 1] + self.type
|
||||
ff = FFmpeg(inputs={file_name: None}, outputs={output_file_name: None})
|
||||
ff.run()
|
||||
return output_file_name
|
||||
else:
|
||||
@ -1200,8 +1192,7 @@ class Audio(InputComponent):
|
||||
tokens = []
|
||||
masks = []
|
||||
duration = data.shape[0]
|
||||
boundaries = np.linspace(
|
||||
0, duration, self.interpretation_segments + 1).tolist()
|
||||
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
|
||||
boundaries = [round(boundary) for boundary in boundaries]
|
||||
for index in range(len(boundaries) - 1):
|
||||
start, stop = boundaries[index], boundaries[index + 1]
|
||||
@ -1211,8 +1202,7 @@ class Audio(InputComponent):
|
||||
leave_one_out_data = np.copy(data)
|
||||
leave_one_out_data[start:stop] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
processing_utils.audio_to_file(
|
||||
sample_rate, leave_one_out_data, file.name)
|
||||
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
||||
out_data = processing_utils.encode_file_to_base64(file.name)
|
||||
leave_one_out_sets.append(out_data)
|
||||
file.close()
|
||||
@ -1230,8 +1220,9 @@ class Audio(InputComponent):
|
||||
|
||||
tokens.append(token_data)
|
||||
tokens = [{"name": "token.wav", "data": token} for token in tokens]
|
||||
leave_one_out_sets = [{"name": "loo.wav", "data": loo_set}
|
||||
for loo_set in leave_one_out_sets]
|
||||
leave_one_out_sets = [
|
||||
{"name": "loo.wav", "data": loo_set} for loo_set in leave_one_out_sets
|
||||
]
|
||||
return tokens, leave_one_out_sets, masks
|
||||
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
@ -1239,7 +1230,7 @@ class Audio(InputComponent):
|
||||
x = tokens[0]["data"]
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
|
||||
zero_input = np.zeros_like(data, dtype='int16')
|
||||
zero_input = np.zeros_like(data, dtype="int16")
|
||||
# decode all of the tokens
|
||||
token_data = []
|
||||
for token in tokens:
|
||||
@ -1253,8 +1244,7 @@ class Audio(InputComponent):
|
||||
for t, b in zip(token_data, binary_mask_vector):
|
||||
masked_input = masked_input + t * int(b)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
processing_utils.audio_to_file(
|
||||
sample_rate, masked_input, file.name)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||
masked_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
@ -1428,8 +1418,7 @@ class Dataframe(InputComponent):
|
||||
"date": "02/08/1993",
|
||||
}
|
||||
column_dtypes = (
|
||||
[datatype] *
|
||||
self.col_count if isinstance(datatype, str) else datatype
|
||||
[datatype] * self.col_count if isinstance(datatype, str) else datatype
|
||||
)
|
||||
self.test_input = [
|
||||
[sample_values[c] for c in column_dtypes] for _ in range(row_count)
|
||||
|
@ -21,8 +21,14 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
|
||||
from gradio import (encryptor, interpretation, networking, # type: ignore
|
||||
queueing, strings, utils)
|
||||
from gradio import (
|
||||
encryptor,
|
||||
interpretation,
|
||||
networking, # type: ignore
|
||||
queueing,
|
||||
strings,
|
||||
utils,
|
||||
)
|
||||
from gradio.external import load_from_pipeline, load_interface # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.inputs import InputComponent
|
||||
@ -243,20 +249,25 @@ class Interface:
|
||||
|
||||
self.session = None
|
||||
self.title = title
|
||||
|
||||
CLEANER = re.compile('<.*?>')
|
||||
|
||||
CLEANER = re.compile("<.*?>")
|
||||
|
||||
def clean_html(raw_html):
|
||||
cleantext = re.sub(CLEANER, '', raw_html)
|
||||
cleantext = re.sub(CLEANER, "", raw_html)
|
||||
return cleantext
|
||||
md = MarkdownIt("js-default", {
|
||||
|
||||
md = MarkdownIt(
|
||||
"js-default",
|
||||
{
|
||||
"linkify": True,
|
||||
"typographer": True,
|
||||
"html": True,
|
||||
}).use(footnote_plugin)
|
||||
|
||||
},
|
||||
).use(footnote_plugin)
|
||||
|
||||
simple_description = None
|
||||
if description is not None:
|
||||
description = md.render(description)
|
||||
description = md.render(description)
|
||||
simple_description = clean_html(description)
|
||||
self.simple_description = simple_description
|
||||
self.description = description
|
||||
@ -264,7 +275,7 @@ class Interface:
|
||||
article = utils.readme_to_html(article)
|
||||
article = md.render(article)
|
||||
self.article = article
|
||||
|
||||
|
||||
self.thumbnail = thumbnail
|
||||
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
|
||||
DEPRECATED_THEME_MAP = {
|
||||
|
@ -442,7 +442,6 @@ class Audio(OutputComponent):
|
||||
return {
|
||||
"audio": {},
|
||||
}
|
||||
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
@ -453,7 +452,7 @@ class Audio(OutputComponent):
|
||||
"""
|
||||
if self.type in ["numpy", "file", "auto"]:
|
||||
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
|
||||
sample_rate, data = y
|
||||
sample_rate, data = y
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
prefix="sample", suffix=".wav", delete=False
|
||||
)
|
||||
|
@ -18,8 +18,7 @@ CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
|
||||
|
||||
|
||||
def process_example(
|
||||
interface: Interface,
|
||||
example_id: int
|
||||
interface: Interface, example_id: int
|
||||
) -> Tuple[List[Any], List[float]]:
|
||||
"""Loads an example from the interface and returns its prediction."""
|
||||
example_set = interface.examples[example_id]
|
||||
@ -31,9 +30,7 @@ def process_example(
|
||||
return prediction, durations
|
||||
|
||||
|
||||
def cache_interface_examples(
|
||||
interface: Interface
|
||||
) -> None:
|
||||
def cache_interface_examples(interface: Interface) -> None:
|
||||
"""Caches all of the examples from an interface."""
|
||||
if os.path.exists(CACHE_FILE):
|
||||
print(
|
||||
@ -54,10 +51,7 @@ def cache_interface_examples(
|
||||
raise e
|
||||
|
||||
|
||||
def load_from_cache(
|
||||
interface: Interface,
|
||||
example_id: int
|
||||
) -> List[Any]:
|
||||
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
|
||||
"""Loads a particular cached example for the interface."""
|
||||
with open(CACHE_FILE) as cache:
|
||||
examples = list(csv.reader(cache))
|
||||
|
@ -138,32 +138,37 @@ def audio_to_file(sample_rate, data, filename):
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
||||
)
|
||||
audio.export(filename, format="wav").close()
|
||||
|
||||
|
||||
|
||||
|
||||
def convert_to_16_bit_wav(data):
|
||||
# Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
|
||||
if data.dtype==np.float32:
|
||||
warnings.warn("Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format.")
|
||||
if data.dtype == np.float32:
|
||||
warnings.warn(
|
||||
"Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format."
|
||||
)
|
||||
data = data / np.abs(data).max()
|
||||
data = data * 32767
|
||||
data = data.astype(np.int16)
|
||||
elif data.dtype==np.int32:
|
||||
warnings.warn("Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format.")
|
||||
elif data.dtype == np.int32:
|
||||
warnings.warn(
|
||||
"Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format."
|
||||
)
|
||||
data = data / 65538
|
||||
data = data.astype(np.int16)
|
||||
elif data.dtype==np.int16:
|
||||
elif data.dtype == np.int16:
|
||||
pass
|
||||
elif data.dtype==np.uint8:
|
||||
warnings.warn("Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format.")
|
||||
elif data.dtype == np.uint8:
|
||||
warnings.warn(
|
||||
"Audio data is not in 16-bit integer format."
|
||||
"Trying to convert to 16-bit int format."
|
||||
)
|
||||
data = data * 257 - 32768
|
||||
data = data.astype(np.int16)
|
||||
else:
|
||||
raise ValueError("Audio data cannot be converted to "
|
||||
"16-bit int format.")
|
||||
return data
|
||||
raise ValueError("Audio data cannot be converted to " "16-bit int format.")
|
||||
return data
|
||||
|
||||
|
||||
##################
|
||||
@ -330,7 +335,7 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
Output image array. Has the same kind as `a`.
|
||||
"""
|
||||
kind = a.dtype.kind
|
||||
if n > m and a.max() < 2 ** m:
|
||||
if n > m and a.max() < 2**m:
|
||||
mnew = int(np.ceil(m / 2) * 2)
|
||||
if mnew > m:
|
||||
dtype = "int{}".format(mnew)
|
||||
@ -353,11 +358,11 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
# exact upscale to a multiple of `n` bits
|
||||
if copy:
|
||||
b = np.empty(a.shape, _dtype_bits(kind, m))
|
||||
np.multiply(a, (2 ** m - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
|
||||
np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
||||
return b
|
||||
else:
|
||||
a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
|
||||
a *= (2 ** m - 1) // (2 ** n - 1)
|
||||
a *= (2**m - 1) // (2**n - 1)
|
||||
return a
|
||||
else:
|
||||
# upscale to a multiple of `n` bits,
|
||||
@ -365,12 +370,12 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
o = (m // n + 1) * n
|
||||
if copy:
|
||||
b = np.empty(a.shape, _dtype_bits(kind, o))
|
||||
np.multiply(a, (2 ** o - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
|
||||
np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
||||
b //= 2 ** (o - m)
|
||||
return b
|
||||
else:
|
||||
a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
|
||||
a *= (2 ** o - 1) // (2 ** n - 1)
|
||||
a *= (2**o - 1) // (2**n - 1)
|
||||
a //= 2 ** (o - m)
|
||||
return a
|
||||
|
||||
|
@ -10,9 +10,7 @@ import requests
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
def queue_thread(
|
||||
path_to_local_server: str
|
||||
) -> None:
|
||||
def queue_thread(path_to_local_server: str) -> None:
|
||||
while True:
|
||||
try:
|
||||
next_job = pop()
|
||||
@ -108,10 +106,7 @@ def pop() -> Tuple[int, str, Dict, str]:
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
|
||||
def push(
|
||||
input_data: Dict,
|
||||
action: str
|
||||
) -> Tuple[str, int]:
|
||||
def push(input_data: Dict, action: str) -> Tuple[str, int]:
|
||||
input_data = json.dumps(input_data)
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
@ -140,7 +135,7 @@ def push(
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if not(result[0] == 0):
|
||||
if not (result[0] == 0):
|
||||
queue_position += 1
|
||||
conn.commit()
|
||||
return hash, queue_position
|
||||
@ -204,7 +199,7 @@ def get_status(hash: str) -> Tuple[str, int]:
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if not(result[0] == 0):
|
||||
if not (result[0] == 0):
|
||||
queue_position += 1
|
||||
conn.commit()
|
||||
return "QUEUED", queue_position
|
||||
@ -229,10 +224,7 @@ def start_job(hash: str) -> None:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fail_job(
|
||||
hash: str,
|
||||
error_message: str
|
||||
) -> None:
|
||||
def fail_job(hash: str, error_message: str) -> None:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
@ -247,10 +239,7 @@ def fail_job(
|
||||
conn.commit()
|
||||
|
||||
|
||||
def pass_job(
|
||||
hash: str,
|
||||
output_data: Dict
|
||||
) -> None:
|
||||
def pass_job(hash: str, output_data: Dict) -> None:
|
||||
output_data = json.dumps(output_data)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
|
@ -130,6 +130,7 @@ def static_resource(path: str):
|
||||
return FileResponse(static_file)
|
||||
raise HTTPException(status_code=404, detail="Static file not found")
|
||||
|
||||
|
||||
@app.get("/build/{path:path}")
|
||||
def build_resource(path: str):
|
||||
if app.interface.share:
|
||||
|
@ -22,7 +22,7 @@ attrs==21.4.0
|
||||
# pytest
|
||||
backcall==0.2.0
|
||||
# via ipython
|
||||
black==21.12b0
|
||||
black~=22.1
|
||||
# via ipython
|
||||
cachetools==5.0.0
|
||||
# via google-auth
|
||||
|
@ -122,7 +122,10 @@ class TestDemo(unittest.TestCase):
|
||||
time.sleep(0.2)
|
||||
total_sleep += 0.2
|
||||
|
||||
self.assertEqual(elem.text, "L + e + W - a - n - t ' + s + t - g + o s e e a m a g i c t r i c k ? - ! +")
|
||||
self.assertEqual(
|
||||
elem.text,
|
||||
"L + e + W - a - n - t ' + s + t - g + o s e e a m a g i c t r i c k ? - ! +",
|
||||
)
|
||||
golden_img = os.path.join(
|
||||
current_dir, GOLDEN_PATH.format("diff_texts", "magic_trick")
|
||||
)
|
||||
|
@ -12,12 +12,12 @@ class TestKeyGenerator(unittest.TestCase):
|
||||
def test_same_pass(self):
|
||||
key1 = encryptor.get_key("test")
|
||||
key2 = encryptor.get_key("test")
|
||||
self.assertEquals(key1, key2)
|
||||
self.assertEquals(key1, key2)
|
||||
|
||||
def test_diff_pass(self):
|
||||
key1 = encryptor.get_key("test")
|
||||
key2 = encryptor.get_key("diff_test")
|
||||
self.assertNotEquals(key1, key2)
|
||||
self.assertNotEquals(key1, key2)
|
||||
|
||||
|
||||
class TestEncryptorDecryptor(unittest.TestCase):
|
||||
@ -27,7 +27,7 @@ class TestEncryptorDecryptor(unittest.TestCase):
|
||||
encrypted_data = encryptor.encrypt(key, data)
|
||||
decrypted_data = encryptor.decrypt(key, encrypted_data)
|
||||
self.assertEquals(data, decrypted_data)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -27,7 +27,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
|
||||
|
||||
|
||||
def test_question_answering(self):
|
||||
model_type = "question-answering"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
|
@ -38,7 +38,7 @@ class TestSimpleFlagging(unittest.TestCase):
|
||||
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
|
||||
io.close()
|
||||
|
||||
|
||||
|
||||
class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
@ -47,7 +47,7 @@ class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
flagger.setup(tmpdirname)
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
@ -65,7 +65,7 @@ class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -147,10 +147,10 @@ class TestNumber(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x ** 2, "number", "textbox")
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4.0"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x ** 2, "number", "textbox", interpretation="default"
|
||||
lambda x: x**2, "number", "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(
|
||||
@ -211,10 +211,10 @@ class TestSlider(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x ** 2, "slider", "textbox")
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x ** 2, "slider", "textbox", interpretation="default"
|
||||
lambda x: x**2, "slider", "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(
|
||||
@ -568,10 +568,10 @@ class TestAudio(unittest.TestCase):
|
||||
|
||||
def test_tokenize(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
audio_input = gr.inputs.Audio()
|
||||
audio_input = gr.inputs.Audio()
|
||||
tokens, _, _ = audio_input.tokenize(x_wav)
|
||||
self.assertEquals(len(tokens), audio_input.interpretation_segments)
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0]
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1] * len(tokens)])[0]
|
||||
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
|
||||
self.assertGreater(similarity, 0.9)
|
||||
|
||||
|
@ -6,8 +6,7 @@ import numpy as np
|
||||
import gradio.interpretation
|
||||
import gradio.test_data
|
||||
from gradio import Interface
|
||||
from gradio.processing_utils import (decode_base64_to_image,
|
||||
encode_array_to_base64)
|
||||
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
@ -8,14 +8,17 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
class TestProcessExamples(unittest.TestCase):
|
||||
def test_process_example(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text",
|
||||
examples=[["World"]])
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
|
||||
prediction, _ = process_examples.process_example(io, 0)
|
||||
self.assertEquals(prediction[0], "Hello World")
|
||||
|
||||
def test_caching(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text",
|
||||
examples=[["World"], ["Dunya"], ["Monde"]])
|
||||
io = Interface(
|
||||
lambda x: "Hello " + x,
|
||||
"text",
|
||||
"text",
|
||||
examples=[["World"], ["Dunya"], ["Monde"]],
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
process_examples.cache_interface_examples(io)
|
||||
prediction = process_examples.load_from_cache(io, 1)
|
||||
|
@ -15,16 +15,17 @@ class TestQueuingOpenClose(unittest.TestCase):
|
||||
def test_init(self):
|
||||
queueing.init()
|
||||
self.assertTrue(os.path.exists(queueing.DB_FILE))
|
||||
os.remove(queueing.DB_FILE)
|
||||
|
||||
os.remove(queueing.DB_FILE)
|
||||
|
||||
def test_close(self):
|
||||
queueing.close()
|
||||
self.assertFalse(os.path.exists(queueing.DB_FILE))
|
||||
|
||||
|
||||
|
||||
class TestQueuingActions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
queueing.init()
|
||||
|
||||
|
||||
def test_hashing(self):
|
||||
hash1 = queueing.generate_hash()
|
||||
hash2 = queueing.generate_hash()
|
||||
@ -43,26 +44,27 @@ class TestQueuingActions(unittest.TestCase):
|
||||
self.assertEquals(hash_popped, hash1)
|
||||
self.assertEquals(input_data, {"data": "test1"})
|
||||
self.assertEquals(action, "predict")
|
||||
|
||||
|
||||
def test_jobs(self):
|
||||
hash1, _ = queueing.push({"data": "test1"}, "predict")
|
||||
hash2, position = queueing.push({"data": "test1"}, "predict")
|
||||
self.assertEquals(position, 1)
|
||||
|
||||
|
||||
queueing.start_job(hash1)
|
||||
_, position = queueing.get_status(hash2)
|
||||
self.assertEquals(position, 1)
|
||||
queueing.pass_job(hash1, {"data": "result"})
|
||||
_, position = queueing.get_status(hash2)
|
||||
self.assertEquals(position, 0)
|
||||
|
||||
|
||||
queueing.start_job(hash2)
|
||||
queueing.fail_job(hash2, "failure")
|
||||
status, _ = queueing.get_status(hash2)
|
||||
self.assertEquals(status, "FAILED")
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
queueing.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -58,6 +58,7 @@ for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
|
||||
tags = None
|
||||
if "tags: " in guide_content:
|
||||
tags = guide_content.split("tags: ")[1].split("\n")[0].split(", ")
|
||||
|
||||
spaces = None
|
||||
if "related_spaces: " in guide_content:
|
||||
spaces = guide_content.split("related_spaces: ")[1].split("\n")[0].split(", ")
|
||||
@ -83,9 +84,13 @@ def render_guides_main():
|
||||
template = Template(template_file.read())
|
||||
output_html = template.render(guides=filtered_guides, navbar_html=navbar_html)
|
||||
os.makedirs(os.path.join("generated", "guides"), exist_ok=True)
|
||||
with open(os.path.join("generated", "guides", "index.html"), "w", encoding='utf-8') as generated_template:
|
||||
with open(
|
||||
os.path.join("generated", "guides", "index.html"), "w", encoding="utf-8"
|
||||
) as generated_template:
|
||||
generated_template.write(output_html)
|
||||
with open(os.path.join("generated", "guides.html"), "w", encoding='utf-8') as generated_template:
|
||||
with open(
|
||||
os.path.join("generated", "guides.html"), "w", encoding="utf-8"
|
||||
) as generated_template:
|
||||
generated_template.write(output_html)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user