Merge pull request #560 from LysandreJik/upgrade-black

Upgrade black to version ~=22.0
This commit is contained in:
Abubakar Abid 2022-02-08 17:37:22 -05:00 committed by GitHub
commit e19eb0dc3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 182 additions and 176 deletions

View File

@ -15,7 +15,7 @@ def plot_forecast(final_year, companies, noise, show_legend, point_style):
ax = fig.add_subplot(111)
for i, company in enumerate(companies):
series = np.arange(0, year_count, dtype=float)
series = series ** 2 * (i + 1)
series = series**2 * (i + 1)
series += np.random.rand(year_count) * noise
ax.plot(x, series, plt_format)
if show_legend:

View File

@ -2,9 +2,11 @@ import gradio as gr
user_db = {"admin": "admin", "foo": "bar"}
def greet(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
if __name__ == "__main__":
iface.launch(auth=lambda u, p: user_db.get(u) == p)

View File

@ -1,8 +1,10 @@
import gradio as gr
def greet(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
if __name__ == "__main__":
iface.launch()

View File

@ -1,8 +1,10 @@
import gradio as gr
def greet(name):
return "Hello " + name + "!"
iface = gr.Interface(
fn=greet,
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),

View File

@ -1,11 +1,13 @@
import gradio as gr
def greet(name, is_morning, temperature):
salutation = "Good morning" if is_morning else "Good evening"
greeting = "%s %s. It is %s degrees today" % (salutation, name, temperature)
celsius = (temperature - 32) * 5 / 9
return greeting, round(celsius, 2)
iface = gr.Interface(
fn=greet,
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],

View File

@ -86,11 +86,9 @@ iface = gr.Interface(
fn,
inputs=[
gr.inputs.Textbox(default="Lorem ipsum", label="Textbox"),
gr.inputs.Textbox(lines=3, placeholder="Type here..",
label="Textbox 2"),
gr.inputs.Textbox(lines=3, placeholder="Type here..", label="Textbox 2"),
gr.inputs.Number(label="Number", default=42),
gr.inputs.Slider(minimum=10, maximum=20, default=15,
label="Slider: 10 - 20"),
gr.inputs.Slider(minimum=10, maximum=20, default=15, label="Slider: 10 - 20"),
gr.inputs.Slider(maximum=20, step=0.04, label="Slider: step @ 0.04"),
gr.inputs.Checkbox(label="Checkbox"),
gr.inputs.CheckboxGroup(
@ -99,17 +97,14 @@ iface = gr.Interface(
gr.inputs.Radio(label="Radio", choices=CHOICES, default=CHOICES[2]),
gr.inputs.Dropdown(label="Dropdown", choices=CHOICES),
gr.inputs.Image(label="Image", optional=True),
gr.inputs.Image(label="Image w/ Cropper",
tool="select", optional=True),
gr.inputs.Image(label="Image w/ Cropper", tool="select", optional=True),
gr.inputs.Image(label="Sketchpad", source="canvas", optional=True),
gr.inputs.Image(label="Webcam", source="webcam", optional=True),
gr.inputs.Video(label="Video", optional=True),
gr.inputs.Audio(label="Audio", optional=True),
gr.inputs.Audio(label="Microphone",
source="microphone", optional=True),
gr.inputs.Audio(label="Microphone", source="microphone", optional=True),
gr.inputs.File(label="File", optional=True),
gr.inputs.Dataframe(label="Dataframe", headers=[
"Name", "Age", "Gender"]),
gr.inputs.Dataframe(label="Dataframe", headers=["Name", "Age", "Gender"]),
gr.inputs.Timeseries(x="time", y=["price", "value"], optional=True),
],
outputs=[
@ -118,8 +113,9 @@ iface = gr.Interface(
gr.outputs.Audio(label="Audio"),
gr.outputs.Image(label="Image"),
gr.outputs.Video(label="Video"),
gr.outputs.HighlightedText(label="HighlightedText", color_map={
"punc": "pink", "test 0": "blue"}),
gr.outputs.HighlightedText(
label="HighlightedText", color_map={"punc": "pink", "test 0": "blue"}
),
gr.outputs.HighlightedText(label="HighlightedText", show_legend=True),
gr.outputs.JSON(label="JSON"),
gr.outputs.HTML(label="HTML"),
@ -127,8 +123,7 @@ iface = gr.Interface(
gr.outputs.Dataframe(label="Dataframe"),
gr.outputs.Dataframe(label="Numpy", type="numpy"),
gr.outputs.Carousel("image", label="Carousel"),
gr.outputs.Timeseries(
x="time", y=["price", "value"], label="Timeseries"),
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries"),
],
examples=[
[

View File

@ -6,11 +6,19 @@ import math
import numpy as np
import torch
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer)
from pytorch_transformers import (
WEIGHTS_NAME,
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from utils import (get_answer, input_to_squad_example,
squad_examples_to_features, to_list)
from utils import (
get_answer,
input_to_squad_example,
squad_examples_to_features,
to_list,
)
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]

View File

@ -5,8 +5,7 @@ import math
import numpy as np
import torch
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
whitespace_tokenize)
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

View File

@ -7,6 +7,7 @@ def reverse_audio(audio):
sr, data = audio
return (sr, np.flipud(data))
iface = gr.Interface(reverse_audio, "microphone", "audio", examples="audio")
if __name__ == "__main__":

View File

@ -13,7 +13,7 @@ def stock_forecast(final_year, companies, noise, show_legend, point_style):
ax = fig.add_subplot(111)
for i, company in enumerate(companies):
series = np.arange(0, year_count, dtype=float)
series = series ** 2 * (i + 1)
series = series**2 * (i + 1)
series += np.random.rand(year_count) * noise
ax.plot(x, series, plt_format)
if show_legend:

View File

@ -2,7 +2,7 @@ import pkg_resources
from gradio.routes import get_state, set_state
from gradio.flagging import *
from gradio.interface import *
from gradio.interface import *
from gradio.mix import *
current_pkg_version = pkg_resources.require("gradio")[0].version

View File

@ -34,11 +34,7 @@ class Component:
return {}
def save_flagged(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool
self, dir: str, label: str, data: Any, encryption_key: bool
) -> Any:
"""
Saves flagged data from component
@ -52,14 +48,10 @@ class Component:
return data
def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool
self, dir: str, label: str, data: Any, encryption_key: bool
) -> str:
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
Saved flagged data (e.g. image or audio) as a file and returns filepath
"""
if data is None:
return None
@ -81,9 +73,9 @@ class Component:
return label + "/" + new_file_name
def restore_flagged_file(
self,
dir: str,
file: str,
self,
dir: str,
file: str,
encryption_key: bool,
) -> Dict[str, Any]:
"""

View File

@ -3,18 +3,13 @@ from Crypto.Cipher import AES
from Crypto.Hash import SHA256
def get_key(
password: str
) -> bytes:
def get_key(password: str) -> bytes:
"""Generates an encryption key based on the password provided."""
key = SHA256.new(password.encode()).digest()
return key
def encrypt(
key: bytes,
source: bytes
) -> bytes:
def encrypt(key: bytes, source: bytes) -> bytes:
"""Encrypts source data using the provided encryption key"""
IV = Random.new().read(AES.block_size) # generate IV
encryptor = AES.new(key, AES.MODE_CBC, IV)
@ -24,10 +19,7 @@ def encrypt(
return data
def decrypt(
key: bytes,
source: bytes
) -> bytes:
def decrypt(key: bytes, source: bytes) -> bytes:
IV = source[: AES.block_size] # extract the IV from the beginning
decryptor = AES.new(key, AES.MODE_CBC, IV)
data = decryptor.decrypt(source[AES.block_size :]) # decrypt

View File

@ -35,12 +35,13 @@ def get_huggingface_interface(model_name, api_key, alias):
content_type = r.headers.get("content-type")
# Case 2: the data prefix is a key in the response
if content_type == "application/json":
try:
try:
content_type = r.json()[0]["content-type"]
base64_repr = r.json()[0]["blob"]
except KeyError:
raise ValueError("Cannot determine content type returned"
"by external API.")
raise ValueError(
"Cannot determine content type returned" "by external API."
)
# Case 3: the data prefix is included in the response headers
else:
pass
@ -66,7 +67,7 @@ def get_huggingface_interface(model_name, api_key, alias):
"preprocess": lambda i: base64.b64decode(
i["data"].split(",")[1]
), # convert the base64 representation to binary
"postprocess": encode_to_base64,
"postprocess": encode_to_base64,
},
"automatic-speech-recognition": {
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english

View File

@ -111,6 +111,7 @@ class CSVLogger(FlaggingCallback):
The default implementation of the FlaggingCallback abstract class.
Logs the input and output data to a CSV file. Supports encryption.
"""
def setup(self, flagging_dir: str):
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
@ -323,9 +324,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
# Generate the headers and dataset_infos
if is_new:
headers = []
for i, component in enumerate(interface.input_components):
component_label = interface.config["input_components"][i]["label"] or "Input_{}".format(i)
component_label = interface.config["input_components"][i][
"label"
] or "Input_{}".format(i)
headers.append(component_label)
infos["flagged"]["features"][component_label] = {
"dtype": "string",
@ -341,7 +344,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
break
for i, component in enumerate(interface.output_components):
component_label = interface.config["output_components"][i]["label"] or "Output_{}".format(i)
component_label = interface.config["output_components"][i][
"label"
] or "Output_{}".format(i)
headers.append(component_label)
infos["flagged"]["features"][component_label] = {
"dtype": "string",
@ -368,7 +373,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
# Generate the row corresponding to the flagged sample
csv_data = []
for i, component in enumerate(interface.input_components):
label = interface.config["input_components"][i]["label"] or "Input_{}".format(i)
label = interface.config["input_components"][i][
"label"
] or "Input_{}".format(i)
filepath = component.save_flagged(
self.dataset_dir, label, input_data[i], None
)
@ -378,9 +385,13 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
for i, component in enumerate(interface.output_components):
label = interface.config["output_components"][i]["label"] or "Output_{}".format(i)
label = interface.config["output_components"][i][
"label"
] or "Output_{}".format(i)
filepath = (
component.save_flagged(self.dataset_dir, label, output_data[i], None)
component.save_flagged(
self.dataset_dir, label, output_data[i], None
)
if output_data[i] is not None
else ""
)

View File

@ -221,10 +221,8 @@ class Textbox(InputComponent):
"""
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.array(tokens)[np.array(
binary_mask_vector, dtype=bool)]
masked_inputs.append(
self.interpretation_separator.join(masked_input))
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
masked_inputs.append(self.interpretation_separator.join(masked_input))
return masked_inputs
def get_interpretation_scores(
@ -308,10 +306,8 @@ class Number(InputComponent):
delta = 1.0 * self.interpretation_delta * x / 100
elif self.interpretation_delta_type == "absolute":
delta = self.interpretation_delta
negatives = (x + np.arange(-self.interpretation_steps, 0)
* delta).tolist()
positives = (
x + np.arange(1, self.interpretation_steps + 1) * delta).tolist()
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
positives = (x + np.arange(1, self.interpretation_steps + 1) * delta).tolist()
return negatives + positives, {}
def get_interpretation_scores(
@ -357,7 +353,7 @@ class Slider(InputComponent):
if step is None:
difference = maximum - minimum
power = math.floor(math.log10(difference) - 2)
step = 10 ** power
step = 10**power
self.step = step
self.default = minimum if default is None else default
self.test_input = self.default
@ -406,8 +402,7 @@ class Slider(InputComponent):
def get_interpretation_neighbors(self, x) -> List[float]:
return (
np.linspace(self.minimum, self.maximum,
self.interpretation_steps).tolist(),
np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(),
{},
)
@ -944,8 +939,7 @@ class Image(InputComponent):
masked_input = np.zeros_like(tokens[0], dtype=int)
for token, b in zip(tokens, binary_mask_vector):
masked_input = masked_input + token * int(b)
masked_inputs.append(
processing_utils.encode_array_to_base64(masked_input))
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
@ -1042,10 +1036,8 @@ class Video(InputComponent):
file_name = file.name
uploaded_format = file_name.split(".")[-1].lower()
if self.type is not None and uploaded_format != self.type:
output_file_name = file_name[0: file_name.rindex(
".") + 1] + self.type
ff = FFmpeg(inputs={file_name: None},
outputs={output_file_name: None})
output_file_name = file_name[0 : file_name.rindex(".") + 1] + self.type
ff = FFmpeg(inputs={file_name: None}, outputs={output_file_name: None})
ff.run()
return output_file_name
else:
@ -1200,8 +1192,7 @@ class Audio(InputComponent):
tokens = []
masks = []
duration = data.shape[0]
boundaries = np.linspace(
0, duration, self.interpretation_segments + 1).tolist()
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
boundaries = [round(boundary) for boundary in boundaries]
for index in range(len(boundaries) - 1):
start, stop = boundaries[index], boundaries[index + 1]
@ -1211,8 +1202,7 @@ class Audio(InputComponent):
leave_one_out_data = np.copy(data)
leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
processing_utils.audio_to_file(
sample_rate, leave_one_out_data, file.name)
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
out_data = processing_utils.encode_file_to_base64(file.name)
leave_one_out_sets.append(out_data)
file.close()
@ -1230,8 +1220,9 @@ class Audio(InputComponent):
tokens.append(token_data)
tokens = [{"name": "token.wav", "data": token} for token in tokens]
leave_one_out_sets = [{"name": "loo.wav", "data": loo_set}
for loo_set in leave_one_out_sets]
leave_one_out_sets = [
{"name": "loo.wav", "data": loo_set} for loo_set in leave_one_out_sets
]
return tokens, leave_one_out_sets, masks
def get_masked_inputs(self, tokens, binary_mask_matrix):
@ -1239,7 +1230,7 @@ class Audio(InputComponent):
x = tokens[0]["data"]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
zero_input = np.zeros_like(data, dtype='int16')
zero_input = np.zeros_like(data, dtype="int16")
# decode all of the tokens
token_data = []
for token in tokens:
@ -1253,8 +1244,7 @@ class Audio(InputComponent):
for t, b in zip(token_data, binary_mask_vector):
masked_input = masked_input + t * int(b)
file = tempfile.NamedTemporaryFile(delete=False)
processing_utils.audio_to_file(
sample_rate, masked_input, file.name)
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
masked_data = processing_utils.encode_file_to_base64(file.name)
file.close()
os.unlink(file.name)
@ -1428,8 +1418,7 @@ class Dataframe(InputComponent):
"date": "02/08/1993",
}
column_dtypes = (
[datatype] *
self.col_count if isinstance(datatype, str) else datatype
[datatype] * self.col_count if isinstance(datatype, str) else datatype
)
self.test_input = [
[sample_values[c] for c in column_dtypes] for _ in range(row_count)

View File

@ -21,8 +21,14 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import (encryptor, interpretation, networking, # type: ignore
queueing, strings, utils)
from gradio import (
encryptor,
interpretation,
networking, # type: ignore
queueing,
strings,
utils,
)
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import InputComponent
@ -243,20 +249,25 @@ class Interface:
self.session = None
self.title = title
CLEANER = re.compile('<.*?>')
CLEANER = re.compile("<.*?>")
def clean_html(raw_html):
cleantext = re.sub(CLEANER, '', raw_html)
cleantext = re.sub(CLEANER, "", raw_html)
return cleantext
md = MarkdownIt("js-default", {
md = MarkdownIt(
"js-default",
{
"linkify": True,
"typographer": True,
"html": True,
}).use(footnote_plugin)
},
).use(footnote_plugin)
simple_description = None
if description is not None:
description = md.render(description)
description = md.render(description)
simple_description = clean_html(description)
self.simple_description = simple_description
self.description = description
@ -264,7 +275,7 @@ class Interface:
article = utils.readme_to_html(article)
article = md.render(article)
self.article = article
self.thumbnail = thumbnail
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
DEPRECATED_THEME_MAP = {

View File

@ -442,7 +442,6 @@ class Audio(OutputComponent):
return {
"audio": {},
}
def postprocess(self, y):
"""
@ -453,7 +452,7 @@ class Audio(OutputComponent):
"""
if self.type in ["numpy", "file", "auto"]:
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
sample_rate, data = y
sample_rate, data = y
file = tempfile.NamedTemporaryFile(
prefix="sample", suffix=".wav", delete=False
)

View File

@ -18,8 +18,7 @@ CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
def process_example(
interface: Interface,
example_id: int
interface: Interface, example_id: int
) -> Tuple[List[Any], List[float]]:
"""Loads an example from the interface and returns its prediction."""
example_set = interface.examples[example_id]
@ -31,9 +30,7 @@ def process_example(
return prediction, durations
def cache_interface_examples(
interface: Interface
) -> None:
def cache_interface_examples(interface: Interface) -> None:
"""Caches all of the examples from an interface."""
if os.path.exists(CACHE_FILE):
print(
@ -54,10 +51,7 @@ def cache_interface_examples(
raise e
def load_from_cache(
interface: Interface,
example_id: int
) -> List[Any]:
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))

View File

@ -138,32 +138,37 @@ def audio_to_file(sample_rate, data, filename):
channels=(1 if len(data.shape) == 1 else data.shape[1]),
)
audio.export(filename, format="wav").close()
def convert_to_16_bit_wav(data):
# Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html
if data.dtype==np.float32:
warnings.warn("Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format.")
if data.dtype == np.float32:
warnings.warn(
"Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format."
)
data = data / np.abs(data).max()
data = data * 32767
data = data.astype(np.int16)
elif data.dtype==np.int32:
warnings.warn("Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format.")
elif data.dtype == np.int32:
warnings.warn(
"Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format."
)
data = data / 65538
data = data.astype(np.int16)
elif data.dtype==np.int16:
elif data.dtype == np.int16:
pass
elif data.dtype==np.uint8:
warnings.warn("Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format.")
elif data.dtype == np.uint8:
warnings.warn(
"Audio data is not in 16-bit integer format."
"Trying to convert to 16-bit int format."
)
data = data * 257 - 32768
data = data.astype(np.int16)
else:
raise ValueError("Audio data cannot be converted to "
"16-bit int format.")
return data
raise ValueError("Audio data cannot be converted to " "16-bit int format.")
return data
##################
@ -330,7 +335,7 @@ def _convert(image, dtype, force_copy=False, uniform=False):
Output image array. Has the same kind as `a`.
"""
kind = a.dtype.kind
if n > m and a.max() < 2 ** m:
if n > m and a.max() < 2**m:
mnew = int(np.ceil(m / 2) * 2)
if mnew > m:
dtype = "int{}".format(mnew)
@ -353,11 +358,11 @@ def _convert(image, dtype, force_copy=False, uniform=False):
# exact upscale to a multiple of `n` bits
if copy:
b = np.empty(a.shape, _dtype_bits(kind, m))
np.multiply(a, (2 ** m - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
return b
else:
a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
a *= (2 ** m - 1) // (2 ** n - 1)
a *= (2**m - 1) // (2**n - 1)
return a
else:
# upscale to a multiple of `n` bits,
@ -365,12 +370,12 @@ def _convert(image, dtype, force_copy=False, uniform=False):
o = (m // n + 1) * n
if copy:
b = np.empty(a.shape, _dtype_bits(kind, o))
np.multiply(a, (2 ** o - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
b //= 2 ** (o - m)
return b
else:
a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
a *= (2 ** o - 1) // (2 ** n - 1)
a *= (2**o - 1) // (2**n - 1)
a //= 2 ** (o - m)
return a

View File

@ -10,9 +10,7 @@ import requests
DB_FILE = "gradio_queue.db"
def queue_thread(
path_to_local_server: str
) -> None:
def queue_thread(path_to_local_server: str) -> None:
while True:
try:
next_job = pop()
@ -108,10 +106,7 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3]
def push(
input_data: Dict,
action: str
) -> Tuple[str, int]:
def push(input_data: Dict, action: str) -> Tuple[str, int]:
input_data = json.dumps(input_data)
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
@ -140,7 +135,7 @@ def push(
"""
)
result = c.fetchone()
if not(result[0] == 0):
if not (result[0] == 0):
queue_position += 1
conn.commit()
return hash, queue_position
@ -204,7 +199,7 @@ def get_status(hash: str) -> Tuple[str, int]:
"""
)
result = c.fetchone()
if not(result[0] == 0):
if not (result[0] == 0):
queue_position += 1
conn.commit()
return "QUEUED", queue_position
@ -229,10 +224,7 @@ def start_job(hash: str) -> None:
conn.commit()
def fail_job(
hash: str,
error_message: str
) -> None:
def fail_job(hash: str, error_message: str) -> None:
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute(
@ -247,10 +239,7 @@ def fail_job(
conn.commit()
def pass_job(
hash: str,
output_data: Dict
) -> None:
def pass_job(hash: str, output_data: Dict) -> None:
output_data = json.dumps(output_data)
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()

View File

@ -130,6 +130,7 @@ def static_resource(path: str):
return FileResponse(static_file)
raise HTTPException(status_code=404, detail="Static file not found")
@app.get("/build/{path:path}")
def build_resource(path: str):
if app.interface.share:

View File

@ -22,7 +22,7 @@ attrs==21.4.0
# pytest
backcall==0.2.0
# via ipython
black==21.12b0
black~=22.1
# via ipython
cachetools==5.0.0
# via google-auth

View File

@ -122,7 +122,10 @@ class TestDemo(unittest.TestCase):
time.sleep(0.2)
total_sleep += 0.2
self.assertEqual(elem.text, "L + e + W - a - n - t ' + s + t - g + o s e e a m a g i c t r i c k ? - ! +")
self.assertEqual(
elem.text,
"L + e + W - a - n - t ' + s + t - g + o s e e a m a g i c t r i c k ? - ! +",
)
golden_img = os.path.join(
current_dir, GOLDEN_PATH.format("diff_texts", "magic_trick")
)

View File

@ -12,12 +12,12 @@ class TestKeyGenerator(unittest.TestCase):
def test_same_pass(self):
key1 = encryptor.get_key("test")
key2 = encryptor.get_key("test")
self.assertEquals(key1, key2)
self.assertEquals(key1, key2)
def test_diff_pass(self):
key1 = encryptor.get_key("test")
key2 = encryptor.get_key("diff_test")
self.assertNotEquals(key1, key2)
self.assertNotEquals(key1, key2)
class TestEncryptorDecryptor(unittest.TestCase):
@ -27,7 +27,7 @@ class TestEncryptorDecryptor(unittest.TestCase):
encrypted_data = encryptor.encrypt(key, data)
decrypted_data = encryptor.decrypt(key, encrypted_data)
self.assertEquals(data, decrypted_data)
if __name__ == "__main__":
unittest.main()

View File

@ -27,7 +27,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(

View File

@ -38,7 +38,7 @@ class TestSimpleFlagging(unittest.TestCase):
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
io.close()
class TestHuggingFaceDatasetSaver(unittest.TestCase):
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
@ -47,7 +47,7 @@ class TestHuggingFaceDatasetSaver(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup(tmpdirname)
huggingface_hub.create_repo.assert_called_once()
def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
@ -65,7 +65,7 @@ class TestHuggingFaceDatasetSaver(unittest.TestCase):
self.assertEqual(row_count, 1) # 2 rows written including header
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 2) # 3 rows written including header
if __name__ == "__main__":
unittest.main()

View File

@ -147,10 +147,10 @@ class TestNumber(unittest.TestCase):
)
def test_in_interface(self):
iface = gr.Interface(lambda x: x ** 2, "number", "textbox")
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ["4.0"])
iface = gr.Interface(
lambda x: x ** 2, "number", "textbox", interpretation="default"
lambda x: x**2, "number", "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(
@ -211,10 +211,10 @@ class TestSlider(unittest.TestCase):
)
def test_in_interface(self):
iface = gr.Interface(lambda x: x ** 2, "slider", "textbox")
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ["4"])
iface = gr.Interface(
lambda x: x ** 2, "slider", "textbox", interpretation="default"
lambda x: x**2, "slider", "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(
@ -568,10 +568,10 @@ class TestAudio(unittest.TestCase):
def test_tokenize(self):
x_wav = gr.test_data.BASE64_AUDIO
audio_input = gr.inputs.Audio()
audio_input = gr.inputs.Audio()
tokens, _, _ = audio_input.tokenize(x_wav)
self.assertEquals(len(tokens), audio_input.interpretation_segments)
x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0]
x_new = audio_input.get_masked_inputs(tokens, [[1] * len(tokens)])[0]
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
self.assertGreater(similarity, 0.9)

View File

@ -6,8 +6,7 @@ import numpy as np
import gradio.interpretation
import gradio.test_data
from gradio import Interface
from gradio.processing_utils import (decode_base64_to_image,
encode_array_to_base64)
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"

View File

@ -8,14 +8,17 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestProcessExamples(unittest.TestCase):
def test_process_example(self):
io = Interface(lambda x: "Hello " + x, "text", "text",
examples=[["World"]])
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction, _ = process_examples.process_example(io, 0)
self.assertEquals(prediction[0], "Hello World")
def test_caching(self):
io = Interface(lambda x: "Hello " + x, "text", "text",
examples=[["World"], ["Dunya"], ["Monde"]])
io = Interface(
lambda x: "Hello " + x,
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
)
io.launch(prevent_thread_lock=True)
process_examples.cache_interface_examples(io)
prediction = process_examples.load_from_cache(io, 1)

View File

@ -15,16 +15,17 @@ class TestQueuingOpenClose(unittest.TestCase):
def test_init(self):
queueing.init()
self.assertTrue(os.path.exists(queueing.DB_FILE))
os.remove(queueing.DB_FILE)
os.remove(queueing.DB_FILE)
def test_close(self):
queueing.close()
self.assertFalse(os.path.exists(queueing.DB_FILE))
class TestQueuingActions(unittest.TestCase):
def setUp(self):
queueing.init()
def test_hashing(self):
hash1 = queueing.generate_hash()
hash2 = queueing.generate_hash()
@ -43,26 +44,27 @@ class TestQueuingActions(unittest.TestCase):
self.assertEquals(hash_popped, hash1)
self.assertEquals(input_data, {"data": "test1"})
self.assertEquals(action, "predict")
def test_jobs(self):
hash1, _ = queueing.push({"data": "test1"}, "predict")
hash2, position = queueing.push({"data": "test1"}, "predict")
self.assertEquals(position, 1)
queueing.start_job(hash1)
_, position = queueing.get_status(hash2)
self.assertEquals(position, 1)
queueing.pass_job(hash1, {"data": "result"})
_, position = queueing.get_status(hash2)
self.assertEquals(position, 0)
queueing.start_job(hash2)
queueing.fail_job(hash2, "failure")
status, _ = queueing.get_status(hash2)
self.assertEquals(status, "FAILED")
def tearDown(self):
queueing.close()
if __name__ == "__main__":
unittest.main()

View File

@ -58,6 +58,7 @@ for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
tags = None
if "tags: " in guide_content:
tags = guide_content.split("tags: ")[1].split("\n")[0].split(", ")
spaces = None
if "related_spaces: " in guide_content:
spaces = guide_content.split("related_spaces: ")[1].split("\n")[0].split(", ")
@ -83,9 +84,13 @@ def render_guides_main():
template = Template(template_file.read())
output_html = template.render(guides=filtered_guides, navbar_html=navbar_html)
os.makedirs(os.path.join("generated", "guides"), exist_ok=True)
with open(os.path.join("generated", "guides", "index.html"), "w", encoding='utf-8') as generated_template:
with open(
os.path.join("generated", "guides", "index.html"), "w", encoding="utf-8"
) as generated_template:
generated_template.write(output_html)
with open(os.path.join("generated", "guides.html"), "w", encoding='utf-8') as generated_template:
with open(
os.path.join("generated", "guides.html"), "w", encoding="utf-8"
) as generated_template:
generated_template.write(output_html)