Format The Codebase

- black formatting
- isort formatting
This commit is contained in:
Ömer Faruk Özdemir 2022-01-21 16:44:12 +03:00
parent 7fc0c83beb
commit cc0cff893f
82 changed files with 12652 additions and 2636 deletions

View File

@ -1,5 +1,6 @@
import gradio as gr
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
@ -10,7 +11,9 @@ def calculator(num1, operation, num2):
elif operation == "divide":
return num1 / num2
iface = gr.Interface(calculator,
iface = gr.Interface(
calculator,
["number", gr.inputs.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
examples=[

View File

@ -1,5 +1,6 @@
import gradio as gr
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
@ -10,10 +11,12 @@ def calculator(num1, operation, num2):
elif operation == "divide":
return num1 / num2
iface = gr.Interface(calculator,
iface = gr.Interface(
calculator,
["number", gr.inputs.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
live=True
live=True,
)
if __name__ == "__main__":

View File

@ -1,10 +1,12 @@
import gradio as gr
import random
import gradio as gr
def chat(message, history):
history = history or []
if message.startswith("How many"):
response = random.randint(1,10)
response = random.randint(1, 10)
elif message.startswith("How"):
response = random.choice(["Great", "Good", "Okay", "Bad"])
elif message.startswith("Where"):
@ -19,11 +21,19 @@ def chat(message, history):
html += "</div>"
return html, history
iface = gr.Interface(chat, ["text", "state"], ["html", "state"], css="""
iface = gr.Interface(
chat,
["text", "state"],
["html", "state"],
css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=False, allow_flagging="never")
""",
allow_screenshot=False,
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch()
iface.launch()

View File

@ -1,20 +1,25 @@
import gradio as gr
from difflib import Differ
import gradio as gr
def diff_texts(text1, text2):
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2)
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(text1, text2)
]
iface = gr.Interface(
diff_texts,
[
gr.inputs.Textbox(
lines=3, default="The quick brown fox jumped over the lazy dogs."),
gr.inputs.Textbox(
lines=3, default="The fast brown fox jumps over lazy dogs."),
lines=3, default="The quick brown fox jumped over the lazy dogs."
),
gr.inputs.Textbox(lines=3, default="The fast brown fox jumps over lazy dogs."),
],
gr.outputs.HighlightedText())
gr.outputs.HighlightedText(),
)
if __name__ == "__main__":
iface.launch()
iface.launch()

View File

@ -1,10 +1,14 @@
import os
from urllib.request import urlretrieve
import tensorflow as tf
import gradio
import gradio as gr
from urllib.request import urlretrieve
import os
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
urlretrieve(
"https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5"
)
model = tf.keras.models.load_model("mnist-model.h5")
@ -13,11 +17,14 @@ def recognize_digit(image):
prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)}
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False, source="canvas")
im = gradio.inputs.Image(
shape=(28, 28), image_mode="L", invert_colors=False, source="canvas"
)
iface = gr.Interface(
recognize_digit,
im,
recognize_digit,
im,
gradio.outputs.Label(num_top_classes=3),
live=True,
interpretation="default",
@ -28,4 +35,3 @@ iface.test_launch()
if __name__ == "__main__":
iface.launch()

View File

@ -1,9 +1,12 @@
import gradio as gr
import numpy as np
from fpdf import FPDF
import os
import tempfile
import numpy as np
from fpdf import FPDF
import gradio as gr
def disease_report(img, scan_for, generate_report):
results = []
for i, mode in enumerate(["Red", "Green", "Blue"]):
@ -16,28 +19,30 @@ def disease_report(img, scan_for, generate_report):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=15)
pdf.cell(200, 10, txt="Disease Report",
ln=1, align='C')
pdf.cell(200, 10, txt="A Gradio Demo.",
ln=2, align='C')
pdf.cell(200, 10, txt="Disease Report", ln=1, align="C")
pdf.cell(200, 10, txt="A Gradio Demo.", ln=2, align="C")
pdf.output(report)
return results, report if generate_report else None
iface = gr.Interface(disease_report,
iface = gr.Interface(
disease_report,
[
"image",
gr.inputs.CheckboxGroup(["Cancer", "Rash", "Heart Failure", "Stroke", "Diabetes", "Pneumonia"]),
"checkbox"
"image",
gr.inputs.CheckboxGroup(
["Cancer", "Rash", "Heart Failure", "Stroke", "Diabetes", "Pneumonia"]
),
"checkbox",
],
[
gr.outputs.Carousel(["text", "image"], label="Disease"),
gr.outputs.File(label="Report")
gr.outputs.File(label="Report"),
],
title="Disease Report",
description="Upload an Xray and select the diseases to scan for.",
theme="grass",
flagging_options=["good", "bad", "etc"],
allow_flagging="auto"
allow_flagging="auto",
)
if __name__ == "__main__":

View File

@ -1,19 +1,25 @@
import gradio as gr
def filter_records(records, gender):
return records[records['gender'] == gender]
iface = gr.Interface(filter_records,
[
gr.inputs.Dataframe(headers=["name", "age", "gender"], datatype=["str", "number", "str"], row_count=5),
gr.inputs.Dropdown(["M", "F", "O"])
],
"dataframe",
description="Enter gender as 'M', 'F', or 'O' for other."
def filter_records(records, gender):
return records[records["gender"] == gender]
iface = gr.Interface(
filter_records,
[
gr.inputs.Dataframe(
headers=["name", "age", "gender"],
datatype=["str", "number", "str"],
row_count=5,
),
gr.inputs.Dropdown(["M", "F", "O"]),
],
"dataframe",
description="Enter gender as 'M', 'F', or 'O' for other.",
)
iface.test_launch()
if __name__ == "__main__":
iface.launch()

View File

@ -1,8 +1,10 @@
import gradio as gr
import random
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
def plot_forecast(final_year, companies, noise, show_legend, point_style):
start_year = 2020
@ -22,18 +24,19 @@ def plot_forecast(final_year, companies, noise, show_legend, point_style):
return fig
iface = gr.Interface(plot_forecast,
[
gr.inputs.Radio([2025, 2030, 2035, 2040],
label="Project to:"),
gr.inputs.CheckboxGroup(
["Google", "Microsoft", "Gradio"], label="Company Selection"),
gr.inputs.Slider(1, 100, label="Noise Level"),
gr.inputs.Checkbox(label="Show Legend"),
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
],
gr.outputs.Image(plot=True, label="forecast")
)
iface = gr.Interface(
plot_forecast,
[
gr.inputs.Radio([2025, 2030, 2035, 2040], label="Project to:"),
gr.inputs.CheckboxGroup(
["Google", "Microsoft", "Gradio"], label="Company Selection"
),
gr.inputs.Slider(1, 100, label="Noise Level"),
gr.inputs.Checkbox(label="Show Legend"),
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
],
gr.outputs.Image(plot=True, label="forecast"),
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,34 +1,38 @@
import gradio as gr
import pandas as pd
import random
import pandas as pd
import gradio as gr
def fraud_detector(card_activity, categories, sensitivity):
activity_range = random.randint(0, 100)
drop_columns = [column for column in ["retail", "food", "other"] if column not in categories]
drop_columns = [
column for column in ["retail", "food", "other"] if column not in categories
]
if len(drop_columns):
card_activity.drop(columns=drop_columns, inplace=True)
return card_activity, card_activity, {"fraud": activity_range / 100., "not fraud": 1 - activity_range / 100.}
return (
card_activity,
card_activity,
{"fraud": activity_range / 100.0, "not fraud": 1 - activity_range / 100.0},
)
iface = gr.Interface(fraud_detector,
[
gr.inputs.Timeseries(
x="time",
y=["retail", "food", "other"]
),
gr.inputs.CheckboxGroup(["retail", "food", "other"], default=[
"retail", "food", "other"]),
gr.inputs.Slider(1, 3)
],
[
"dataframe",
gr.outputs.Timeseries(
x="time",
y=["retail", "food", "other"]
),
gr.outputs.Label(label="Fraud Level"),
]
)
iface = gr.Interface(
fraud_detector,
[
gr.inputs.Timeseries(x="time", y=["retail", "food", "other"]),
gr.inputs.CheckboxGroup(
["retail", "food", "other"], default=["retail", "food", "other"]
),
gr.inputs.Slider(1, 3),
],
[
"dataframe",
gr.outputs.Timeseries(x="time", y=["retail", "food", "other"]),
gr.outputs.Label(label="Fraud Level"),
],
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,29 +1,42 @@
import gradio as gr
import re
import gradio as gr
male_words, female_words = ["he", "his", "him"], ["she", "hers", "her"]
def gender_of_sentence(sentence):
male_count = len([word for word in sentence.split() if word.lower() in male_words])
female_count = len([word for word in sentence.split() if word.lower() in female_words])
total = max(male_count + female_count, 1)
return {"male": male_count / total, "female": female_count / total}
male_count = len([word for word in sentence.split() if word.lower() in male_words])
female_count = len(
[word for word in sentence.split() if word.lower() in female_words]
)
total = max(male_count + female_count, 1)
return {"male": male_count / total, "female": female_count / total}
def interpret_gender(sentence):
result = gender_of_sentence(sentence)
is_male = result["male"] > result["female"]
interpretation = []
for word in re.split('( )', sentence):
score = 0
token = word.lower()
if (is_male and token in male_words) or (not is_male and token in female_words):
score = 1
elif (is_male and token in female_words) or (not is_male and token in male_words):
score = -1
interpretation.append((word, score))
return interpretation
result = gender_of_sentence(sentence)
is_male = result["male"] > result["female"]
interpretation = []
for word in re.split("( )", sentence):
score = 0
token = word.lower()
if (is_male and token in male_words) or (not is_male and token in female_words):
score = 1
elif (is_male and token in female_words) or (
not is_male and token in male_words
):
score = -1
interpretation.append((word, score))
return interpretation
iface = gr.Interface(
fn=gender_of_sentence, inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
outputs="label", interpretation=interpret_gender, enable_queue=True)
fn=gender_of_sentence,
inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
outputs="label",
interpretation=interpret_gender,
enable_queue=True,
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,15 +1,24 @@
import gradio as gr
import re
import gradio as gr
male_words, female_words = ["he", "his", "him"], ["she", "hers", "her"]
def gender_of_sentence(sentence):
male_count = len([word for word in sentence.split() if word.lower() in male_words])
female_count = len([word for word in sentence.split() if word.lower() in female_words])
total = max(male_count + female_count, 1)
return {"male": male_count / total, "female": female_count / total}
male_count = len([word for word in sentence.split() if word.lower() in male_words])
female_count = len(
[word for word in sentence.split() if word.lower() in female_words]
)
total = max(male_count + female_count, 1)
return {"male": male_count / total, "female": female_count / total}
iface = gr.Interface(
fn=gender_of_sentence, inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
outputs="label", interpretation="default")
fn=gender_of_sentence,
inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
outputs="label",
interpretation="default",
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,8 +1,10 @@
import gradio as gr
import numpy as np
import gradio as gr
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
def generate_tone(note, octave, duration):
sr = 48000
a4_freq, tones_from_a4 = 440, 12 * (octave - 4) + (note - 9)
@ -14,12 +16,14 @@ def generate_tone(note, octave, duration):
iface = gr.Interface(
generate_tone,
generate_tone,
[
gr.inputs.Dropdown(notes, type="index"),
gr.inputs.Slider(4, 6, step=1),
gr.inputs.Textbox(type="number", default=1, label="Duration in seconds")
], "audio")
gr.inputs.Textbox(type="number", default=1, label="Duration in seconds"),
],
"audio",
)
if __name__ == "__main__":
iface.launch()

View File

@ -3,11 +3,14 @@ import gradio as gr
title = "GPT-J-6B"
examples = [
['The tower is 324 metres (1,063 ft) tall,'],
["The tower is 324 metres (1,063 ft) tall,"],
["The Moon's orbit around Earth has"],
["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
["The smooth Borealis basin in the Northern Hemisphere covers 40%"],
]
gr.Interface.load("huggingface/EleutherAI/gpt-j-6B",
gr.Interface.load(
"huggingface/EleutherAI/gpt-j-6B",
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
title=title, examples=examples).launch();
title=title,
examples=examples,
).launch()

View File

@ -1,7 +1,9 @@
import gradio as gr
def greet(name):
return "Hello " + name + "!!"
return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
if __name__ == "__main__":

View File

@ -1,11 +1,14 @@
import gradio as gr
def greet(name):
return "Hello " + name + "!"
return "Hello " + name + "!"
iface = gr.Interface(
fn=greet,
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),
outputs="text")
fn=greet,
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),
outputs="text",
)
if __name__ == "__main__":
iface.launch()
iface.launch()

View File

@ -1,15 +1,17 @@
import gradio as gr
def greet(name, is_morning, temperature):
salutation = "Good morning" if is_morning else "Good evening"
greeting = "%s %s. It is %s degrees today" % (
salutation, name, temperature)
celsius = (temperature - 32) * 5 / 9
return greeting, round(celsius, 2)
salutation = "Good morning" if is_morning else "Good evening"
greeting = "%s %s. It is %s degrees today" % (salutation, name, temperature)
celsius = (temperature - 32) * 5 / 9
return greeting, round(celsius, 2)
iface = gr.Interface(
fn=greet,
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],
outputs=["text", "number"])
fn=greet,
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],
outputs=["text", "number"],
)
if __name__ == "__main__":
iface.launch()
iface.launch()

View File

@ -1,21 +1,28 @@
import gradio as gr
import tensorflow as tf
import requests
import tensorflow as tf
inception_net = tf.keras.applications.MobileNetV2() # load the model
import gradio as gr
inception_net = tf.keras.applications.MobileNetV2() # load the model
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def classify_image(inp):
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = inception_net.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = inception_net.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
image = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=[
["images/cheetah1.jpg"], ["images/lion.jpg"]]).launch()
gr.Interface(
fn=classify_image,
inputs=image,
outputs=label,
examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
).launch()

View File

@ -1,21 +1,24 @@
import torch
import requests
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
import gradio as gr
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
return {labels[i]: float(prediction[i]) for i in range(1000)}
inp = Image.fromarray(inp.astype("uint8"), "RGB")
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
return {labels[i]: float(prediction[i]) for i in range(1000)}
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)

View File

@ -1,20 +1,25 @@
import gradio as gr
import tensorflow as tf
import requests
import tensorflow as tf
inception_net = tf.keras.applications.MobileNetV2() # load the model
import gradio as gr
inception_net = tf.keras.applications.MobileNetV2() # load the model
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def classify_image(inp):
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = inception_net.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
inp = inp.reshape((-1, 224, 224, 3))
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
prediction = inception_net.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
image = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(fn=classify_image, inputs=image, outputs=label, interpretation="default").launch()
gr.Interface(
fn=classify_image, inputs=image, outputs=label, interpretation="default"
).launch()

View File

@ -1,8 +1,10 @@
import gradio as gr
def image_mod(image):
return image.rotate(45)
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image")
if __name__ == "__main__":
iface.launch()

View File

@ -1,34 +1,84 @@
import gradio as gr
import numpy as np
import json
import numpy as np
import gradio as gr
CHOICES = ["foo", "bar", "baz"]
JSONOBJ = """{"items":{"item":[{"id": "0001","type": null,"is_good": false,"ppu": 0.55,"batters":{"batter":[{ "id": "1001", "type": "Regular" },{ "id": "1002", "type": "Chocolate" },{ "id": "1003", "type": "Blueberry" },{ "id": "1004", "type": "Devil's Food" }]},"topping":[{ "id": "5001", "type": "None" },{ "id": "5002", "type": "Glazed" },{ "id": "5005", "type": "Sugar" },{ "id": "5007", "type": "Powdered Sugar" },{ "id": "5006", "type": "Chocolate with Sprinkles" },{ "id": "5003", "type": "Chocolate" },{ "id": "5004", "type": "Maple" }]}]}}"""
def fn(text1, text2, num, slider1, slider2, single_checkbox,
checkboxes, radio, dropdown, im1, im2, im3, im4, video, audio1,
audio2, file, df1, df2):
def fn(
text1,
text2,
num,
slider1,
slider2,
single_checkbox,
checkboxes,
radio,
dropdown,
im1,
im2,
im3,
im4,
video,
audio1,
audio2,
file,
df1,
df2,
):
return (
(text1 if single_checkbox else text2) +
", selected:" + ", ".join(checkboxes), # Text
(text1 if single_checkbox else text2)
+ ", selected:"
+ ", ".join(checkboxes), # Text
{
"positive": num / (num + slider1 + slider2),
"negative": slider1 / (num + slider1 + slider2),
"neutral": slider2 / (num + slider1 + slider2),
}, # Label
(audio1[0], np.flipud(audio1[1])) if audio1 is not None else "files/cantina.wav", # Audio
(audio1[0], np.flipud(audio1[1]))
if audio1 is not None
else "files/cantina.wav", # Audio
np.flipud(im1) if im1 is not None else "files/cheetah1.jpg", # Image
video if video is not None else "files/world.mp4", # Video
[("The", "art"), ("quick brown", "adj"), ("fox", "nn"), ("jumped", "vrb"), ("testing testing testing", None), ("over", "prp"), ("the", "art"), ("testing", None), ("lazy", "adj"), ("dogs", "nn"), (".", "punc")] + [(f"test {x}", f"test {x}") for x in range(10)], # HighlightedText
[
("The", "art"),
("quick brown", "adj"),
("fox", "nn"),
("jumped", "vrb"),
("testing testing testing", None),
("over", "prp"),
("the", "art"),
("testing", None),
("lazy", "adj"),
("dogs", "nn"),
(".", "punc"),
]
+ [(f"test {x}", f"test {x}") for x in range(10)], # HighlightedText
# [("The testing testing testing", None), ("quick brown", 0.2), ("fox", 1), ("jumped", -1), ("testing testing testing", 0), ("over", 0), ("the", 0), ("testing", 0), ("lazy", 1), ("dogs", 0), (".", 1)] + [(f"test {x}", x/10) for x in range(-10, 10)], # HighlightedText
[("The testing testing testing", None), ("over", 0.6), ("the", 0.2), ("testing", None), ("lazy", -.1), ("dogs", 0.4), (".", 0)] + [(f"test", x/10) for x in range(-10, 10)], # HighlightedText
[
("The testing testing testing", None),
("over", 0.6),
("the", 0.2),
("testing", None),
("lazy", -0.1),
("dogs", 0.4),
(".", 0),
]
+ [(f"test", x / 10) for x in range(-10, 10)], # HighlightedText
json.loads(JSONOBJ), # JSON
"<button style='background-color: red'>Click Me: " + radio + "</button>", # HTML
"<button style='background-color: red'>Click Me: "
+ radio
+ "</button>", # HTML
"files/titanic.csv",
df1, # Dataframe
np.random.randint(0, 10, (4,4)), # Dataframe
[im for im in [im1, im2, im3, im4, "files/cheetah1.jpg"] if im is not None], # Carousel
df2 # Timeseries
np.random.randint(0, 10, (4, 4)), # Dataframe
[
im for im in [im1, im2, im3, im4, "files/cheetah1.jpg"] if im is not None
], # Carousel
df2, # Timeseries
)
@ -41,7 +91,9 @@ iface = gr.Interface(
gr.inputs.Slider(minimum=10, maximum=20, default=15, label="Slider: 10 - 20"),
gr.inputs.Slider(maximum=20, step=0.04, label="Slider: step @ 0.04"),
gr.inputs.Checkbox(label="Checkbox"),
gr.inputs.CheckboxGroup(label="CheckboxGroup", choices=CHOICES, default=CHOICES[0:2]),
gr.inputs.CheckboxGroup(
label="CheckboxGroup", choices=CHOICES, default=CHOICES[0:2]
),
gr.inputs.Radio(label="Radio", choices=CHOICES, default=CHOICES[2]),
gr.inputs.Dropdown(label="Dropdown", choices=CHOICES),
gr.inputs.Image(label="Image", optional=True),
@ -69,15 +121,36 @@ iface = gr.Interface(
gr.outputs.Dataframe(label="Dataframe"),
gr.outputs.Dataframe(label="Numpy", type="numpy"),
gr.outputs.Carousel("image", label="Carousel"),
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries")
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries"),
],
examples=[
["the quick brown fox", "jumps over the lazy dog", 10, 12, 4, True, ["foo", "baz"], "baz", "bar", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/world.mp4", "files/cantina.wav", "files/cantina.wav","files/titanic.csv", [[1,2,3],[3,4,5]], "files/time.csv"]
] * 3,
[
"the quick brown fox",
"jumps over the lazy dog",
10,
12,
4,
True,
["foo", "baz"],
"baz",
"bar",
"files/cheetah1.jpg",
"files/cheetah1.jpg",
"files/cheetah1.jpg",
"files/cheetah1.jpg",
"files/world.mp4",
"files/cantina.wav",
"files/cantina.wav",
"files/titanic.csv",
[[1, 2, 3], [3, 4, 5]],
"files/time.csv",
]
]
* 3,
theme="huggingface",
title="Kitchen Sink",
description="Try out all the components!",
article="Learn more about [Gradio](http://gradio.app)"
article="Learn more about [Gradio](http://gradio.app)",
)
if __name__ == "__main__":

View File

@ -1,14 +1,17 @@
import gradio as gr
def longest_word(text):
words = text.split(" ")
lengths = [len(word) for word in words]
return max(lengths)
ex = "The quick brown fox jumped over the lazy dog."
iface = gr.Interface(longest_word, "textbox", "label",
interpretation="default", examples=[[ex]])
iface = gr.Interface(
longest_word, "textbox", "label", interpretation="default", examples=[[ex]]
)
iface.test_launch()

View File

@ -1,28 +1,32 @@
import gradio as gr
import numpy as np
from scipy.fftpack import fft
import matplotlib.pyplot as plt
from math import log2, pow
import matplotlib.pyplot as plt
import numpy as np
from scipy.fftpack import fft
import gradio as gr
A4 = 440
C0 = A4*pow(2, -4.75)
C0 = A4 * pow(2, -4.75)
name = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
def get_pitch(freq):
h = round(12*log2(freq/C0))
h = round(12 * log2(freq / C0))
n = h % 12
return name[n]
def main_note(audio):
rate, y = audio
if len(y.shape) == 2:
y = y.T[0]
N = len(y)
T = 1.0 / rate
x = np.linspace(0.0, N*T, N)
x = np.linspace(0.0, N * T, N)
yf = fft(y)
yf2 = 2.0/N * np.abs(yf[0:N//2])
xf = np.linspace(0.0, 1.0/(2.0*T), N//2)
yf2 = 2.0 / N * np.abs(yf[0 : N // 2])
xf = np.linspace(0.0, 1.0 / (2.0 * T), N // 2)
volume_per_pitch = {}
total_volume = np.sum(yf2)
@ -35,15 +39,17 @@ def main_note(audio):
volume_per_pitch[pitch] += 1.0 * volume / total_volume
return volume_per_pitch
iface = gr.Interface(
main_note,
"audio",
main_note,
"audio",
gr.outputs.Label(num_top_classes=4),
examples=[
["audio/recording1.wav"],
["audio/cantina.wav"],
],
interpretation="default")
interpretation="default",
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,6 +1,8 @@
import gradio as gr
import numpy as np
import gradio as gr
def transpose(matrix):
return matrix.T
@ -10,12 +12,12 @@ iface = gr.Interface(
gr.inputs.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
"numpy",
examples=[
[np.zeros((3,3)).tolist()],
[np.ones((2,2)).tolist()],
[np.random.randint(0, 10, (3,10)).tolist()],
[np.random.randint(0, 10, (10,3)).tolist()],
[np.random.randint(0, 10, (10,10)).tolist()],
]
[np.zeros((3, 3)).tolist()],
[np.ones((2, 2)).tolist()],
[np.random.randint(0, 10, (3, 10)).tolist()],
[np.random.randint(0, 10, (10, 3)).tolist()],
[np.random.randint(0, 10, (10, 10)).tolist()],
],
)
iface.test_launch()

View File

@ -1,14 +1,17 @@
import gradio as gr
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
from math import sqrt
import gradio as gr
def outbreak(r, month, countries, social_distancing):
months = ["January", "February", "March", "April", "May"]
m = months.index(month)
start_day = 30 * m
final_day = 30 * (m + 1)
x = np.arange(start_day, final_day+1)
x = np.arange(start_day, final_day + 1)
day_count = x.shape[0]
pop_count = {"USA": 350, "Canada": 40, "Mexico": 300, "UK": 120}
r = sqrt(r)
@ -23,13 +26,18 @@ def outbreak(r, month, countries, social_distancing):
plt.legend(countries)
return plt
iface = gr.Interface(outbreak,
iface = gr.Interface(
outbreak,
[
gr.inputs.Slider(1, 4, default=3.2, label="R"),
gr.inputs.Dropdown(["January", "February", "March", "April", "May"], label="Month"),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
gr.inputs.Checkbox(label="Social Distancing?"),
],
"plot")
gr.inputs.Slider(1, 4, default=3.2, label="R"),
gr.inputs.Dropdown(
["January", "February", "March", "April", "May"], label="Month"
),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
gr.inputs.Checkbox(label="Social Distancing?"),
],
"plot",
)
if __name__ == "__main__":
iface.launch()

View File

@ -9,18 +9,16 @@ import torch
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from utils import (get_answer, input_to_squad_example,
squad_examples_to_features, to_list)
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]
)
class QA:
def __init__(self,model_path: str):
def __init__(self, model_path: str):
self.max_seq_length = 384
self.doc_stride = 128
self.do_lower_case = True
@ -29,47 +27,71 @@ class QA:
self.max_answer_length = 30
self.model, self.tokenizer = self.load_model(model_path)
if torch.cuda.is_available():
self.device = 'cuda'
self.device = "cuda"
else:
self.device = 'cpu'
self.device = "cpu"
self.model.to(self.device)
self.model.eval()
def load_model(self,model_path: str,do_lower_case=False):
def load_model(self, model_path: str, do_lower_case=False):
config = BertConfig.from_pretrained(model_path + "/bert_config.json")
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=do_lower_case)
model = BertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
tokenizer = BertTokenizer.from_pretrained(
model_path, do_lower_case=do_lower_case
)
model = BertForQuestionAnswering.from_pretrained(
model_path, from_tf=False, config=config
)
return model, tokenizer
def predict(self,passage :str,question :str):
example = input_to_squad_example(passage,question)
features = squad_examples_to_features(example,self.tokenizer,self.max_seq_length,self.doc_stride,self.max_query_length)
def predict(self, passage: str, question: str):
example = input_to_squad_example(passage, question)
features = squad_examples_to_features(
example,
self.tokenizer,
self.max_seq_length,
self.doc_stride,
self.max_query_length,
)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor(
[f.input_mask for f in features], dtype=torch.long
)
all_segment_ids = torch.tensor(
[f.segment_ids for f in features], dtype=torch.long
)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_example_index)
dataset = TensorDataset(
all_input_ids, all_input_mask, all_segment_ids, all_example_index
)
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
all_results = []
for batch in eval_dataloader:
batch = tuple(t.to(self.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2]
}
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
}
example_indices = batch[3]
outputs = self.model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
result = RawResult(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]))
result = RawResult(
unique_id=unique_id,
start_logits=to_list(outputs[0][i]),
end_logits=to_list(outputs[1][i]),
)
all_results.append(result)
answer = get_answer(example,features,all_results,self.n_best_size,self.max_answer_length,self.do_lower_case)
answer = get_answer(
example,
features,
all_results,
self.n_best_size,
self.max_answer_length,
self.do_lower_case,
)
return answer

View File

@ -16,13 +16,15 @@ class SquadExample(object):
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None):
def __init__(
self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
@ -36,8 +38,7 @@ class SquadExample(object):
def __repr__(self):
s = ""
s += "qas_id: %s" % (self.qas_id)
s += ", question_text: %s" % (
self.question_text)
s += ", question_text: %s" % (self.question_text)
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
@ -45,22 +46,25 @@ class SquadExample(object):
s += ", end_position: %d" % (self.end_position)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
paragraph_len,
start_position=None,
end_position=None,):
def __init__(
self,
unique_id,
example_index,
doc_span_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
paragraph_len,
start_position=None,
end_position=None,
):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
@ -74,6 +78,7 @@ class InputFeatures(object):
self.start_position = start_position
self.end_position = end_position
def input_to_squad_example(passage, question):
"""Convert input passage and question into a SquadExample."""
@ -109,10 +114,12 @@ def input_to_squad_example(passage, question):
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position)
end_position=end_position,
)
return example
def _check_is_max_context(doc_spans, cur_span_index, position):
"""Check if this is the 'max context' doc span for the token."""
@ -149,12 +156,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index
def squad_examples_to_features(example, tokenizer, max_seq_length,
doc_stride, max_query_length,cls_token_at_end=False,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True):
def squad_examples_to_features(
example,
tokenizer,
max_seq_length,
doc_stride,
max_query_length,
cls_token_at_end=False,
cls_token="[CLS]",
sep_token="[SEP]",
pad_token=0,
sequence_a_segment_id=0,
sequence_b_segment_id=1,
cls_token_segment_id=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
"""Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000
@ -188,7 +206,8 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
"DocSpan", ["start", "length"]
)
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
@ -225,8 +244,9 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
split_token_index)
is_max_context = _check_is_max_context(
doc_spans, doc_span_index, split_token_index
)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_b_segment_id)
@ -273,14 +293,18 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
segment_ids=segment_ids,
paragraph_len=paragraph_len,
start_position=start_position,
end_position=end_position))
end_position=end_position,
)
)
unique_id += 1
return features
def to_list(tensor):
return tensor.detach().cpu().tolist()
def _get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
@ -292,7 +316,11 @@ def _get_best_indexes(logits, n_best_size):
best_indexes.append(index_and_score[i][0])
return best_indexes
RawResult = collections.namedtuple("RawResult",["unique_id", "start_logits", "end_logits"])
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]
)
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
@ -376,9 +404,10 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if orig_end_position is None:
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
@ -401,17 +430,22 @@ def _compute_softmax(scores):
probs.append(score / total_sum)
return probs
def get_answer(example, features, all_results, n_best_size,
max_answer_length, do_lower_case):
def get_answer(
example, features, all_results, n_best_size, max_answer_length, do_lower_case
):
example_index_to_features = collections.defaultdict(list)
for feature in features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( "PrelimPrediction",["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
_PrelimPrediction = collections.namedtuple(
"PrelimPrediction",
["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
)
example_index = 0
features = example_index_to_features[example_index]
@ -448,10 +482,16 @@ def get_answer(example, features, all_results, n_best_size,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
prelim_predictions = sorted(prelim_predictions,key=lambda x: (x.start_logit + x.end_logit),reverse=True)
_NbestPrediction = collections.namedtuple("NbestPrediction",
["text", "start_logit", "end_logit","start_index","end_index"])
end_logit=result.end_logits[end_index],
)
)
prelim_predictions = sorted(
prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True
)
_NbestPrediction = collections.namedtuple(
"NbestPrediction",
["text", "start_logit", "end_logit", "start_index", "end_index"],
)
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
@ -461,10 +501,10 @@ def get_answer(example, features, all_results, n_best_size,
orig_doc_start = -1
orig_doc_end = -1
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
@ -476,7 +516,7 @@ def get_answer(example, features, all_results, n_best_size,
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text,do_lower_case)
final_text = get_final_text(tok_text, orig_text, do_lower_case)
if final_text in seen_predictions:
continue
@ -491,11 +531,20 @@ def get_answer(example, features, all_results, n_best_size,
start_logit=pred.start_logit,
end_logit=pred.end_logit,
start_index=orig_doc_start,
end_index=orig_doc_end))
end_index=orig_doc_end,
)
)
if not nbest:
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0,start_index=-1,
end_index=-1))
nbest.append(
_NbestPrediction(
text="empty",
start_logit=0.0,
end_logit=0.0,
start_index=-1,
end_index=-1,
)
)
assert len(nbest) >= 1
@ -504,11 +553,12 @@ def get_answer(example, features, all_results, n_best_size,
total_scores.append(entry.start_logit + entry.end_logit)
probs = _compute_softmax(total_scores)
answer = {"answer" : nbest[0].text,
"start" : nbest[0].start_index,
"end" : nbest[0].end_index,
"confidence" : probs[0],
"document" : example.doc_tokens
}
return answer
answer = {
"answer": nbest[0].text,
"start": nbest[0].start_index,
"end": nbest[0].end_index,
"confidence": probs[0],
"document": example.doc_tokens,
}
return answer

View File

@ -1,13 +1,24 @@
import gradio as gr
examples = [
["The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America",
"Which continent is the Amazon rainforest in?"]
[
"The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America",
"Which continent is the Amazon rainforest in?",
]
]
gr.Interface.load("huggingface/deepset/roberta-base-squad2",
inputs=[gr.inputs.Textbox(lines=5, label="Context", placeholder="Type a sentence or paragraph here."),
gr.inputs.Textbox(lines=2, label="Question", placeholder="Ask a question based on the context.")],
outputs=[gr.outputs.Textbox(label="Answer"),
gr.outputs.Label(label="Probability")],
examples=examples).launch()
gr.Interface.load(
"huggingface/deepset/roberta-base-squad2",
inputs=[
gr.inputs.Textbox(
lines=5, label="Context", placeholder="Type a sentence or paragraph here."
),
gr.inputs.Textbox(
lines=2,
label="Question",
placeholder="Ask a question based on the context.",
),
],
outputs=[gr.outputs.Textbox(label="Answer"), gr.outputs.Label(label="Probability")],
examples=examples,
).launch()

View File

@ -1,10 +1,13 @@
import gradio as gr
import numpy as np
import gradio as gr
def reverse_audio(audio):
sr, data = audio
return (sr, np.flipud(data))
iface = gr.Interface(reverse_audio, "microphone", "audio", examples="audio")
if __name__ == "__main__":

View File

@ -1,32 +1,37 @@
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
def sales_projections(employee_data):
sales_data = employee_data.iloc[:, 1:4].astype("int").to_numpy()
regression_values = np.apply_along_axis(lambda row:
np.array(np.poly1d(np.polyfit([0,1,2], row, 2))), 0, sales_data)
projected_months = np.repeat(np.expand_dims(
np.arange(3,12), 0), len(sales_data), axis=0)
projected_values = np.array([
month * month * regression[0] + month * regression[1] + regression[2]
for month, regression in zip(projected_months, regression_values)])
regression_values = np.apply_along_axis(
lambda row: np.array(np.poly1d(np.polyfit([0, 1, 2], row, 2))), 0, sales_data
)
projected_months = np.repeat(
np.expand_dims(np.arange(3, 12), 0), len(sales_data), axis=0
)
projected_values = np.array(
[
month * month * regression[0] + month * regression[1] + regression[2]
for month, regression in zip(projected_months, regression_values)
]
)
plt.plot(projected_values.T)
plt.legend(employee_data["Name"])
return employee_data, plt.gcf(), regression_values
iface = gr.Interface(sales_projections,
iface = gr.Interface(
sales_projections,
gr.inputs.Dataframe(
headers=["Name", "Jan Sales", "Feb Sales", "Mar Sales"],
default=[["Jon", 12, 14, 18], ["Alice", 14, 17, 2], ["Sana", 8, 9.5, 12]]
default=[["Jon", 12, 14, 18], ["Alice", 14, 17, 2], ["Sana", 8, 9.5, 12]],
),
[
"dataframe",
"plot",
"numpy"
],
description="Enter sales figures for employees to predict sales trajectory over year."
["dataframe", "plot", "numpy"],
description="Enter sales figures for employees to predict sales trajectory over year.",
)
if __name__ == "__main__":
iface.launch()

View File

@ -1,5 +1,6 @@
import gradio as gr
def sentence_builder(quantity, animal, place, activity_list, morning):
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""

View File

@ -1,14 +1,18 @@
import gradio as gr
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer
nltk.download('vader_lexicon')
import gradio as gr
nltk.download("vader_lexicon")
sid = SentimentIntensityAnalyzer()
def sentiment_analysis(text):
scores = sid.polarity_scores(text)
del scores["compound"]
return scores
iface = gr.Interface(sentiment_analysis, "textbox", "label", interpretation="default")
iface.test_launch()

View File

@ -1,15 +1,18 @@
import gradio as gr
import numpy as np
import gradio as gr
def sepia(input_img):
sepia_filter = np.array([[.393, .769, .189],
[.349, .686, .168],
[.272, .534, .131]])
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
sepia_filter = np.array(
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
)
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
iface = gr.Interface(sepia, gr.inputs.Image(shape=(200, 200)), "image")
if __name__ == "__main__":
iface.launch()
iface.launch()

View File

@ -1,14 +1,17 @@
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
import gradio as gr
def spectrogram(audio):
sr, data = audio
if len(data.shape) == 2:
data = np.mean(data, axis=0)
frequencies, times, spectrogram_data = signal.spectrogram(data, sr, window="hamming")
frequencies, times, spectrogram_data = signal.spectrogram(
data, sr, window="hamming"
)
plt.pcolormesh(times, frequencies, np.log10(spectrogram_data))
return plt

View File

@ -1,7 +1,8 @@
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
def stock_forecast(final_year, companies, noise, show_legend, point_style):
start_year = 2020
@ -28,8 +29,10 @@ iface = gr.Interface(
gr.inputs.CheckboxGroup(["Google", "Microsoft", "Gradio"]),
gr.inputs.Slider(1, 100),
"checkbox",
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style")],
gr.outputs.Image(plot=True, label="forecast"))
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
],
gr.outputs.Image(plot=True, label="forecast"),
)
iface.test_launch()
if __name__ == "__main__":

View File

@ -1,13 +1,8 @@
import gradio as gr
def tax_calculator(income, marital_status, assets):
tax_brackets = [
(10, 0),
(25, 8),
(60, 12),
(120, 20),
(250, 30)
]
tax_brackets = [(10, 0), (25, 8), (60, 12), (120, 20), (250, 30)]
total_deductible = sum(assets[assets["Deduct"]]["Cost"])
taxable_income = income - total_deductible
@ -15,31 +10,32 @@ def tax_calculator(income, marital_status, assets):
for bracket, rate in tax_brackets:
if taxable_income > bracket:
total_tax += (taxable_income - bracket) * rate / 100
if marital_status == "Married":
total_tax *= 0.75
elif marital_status == "Divorced":
total_tax *= 0.8
return round(total_tax)
iface = gr.Interface(
tax_calculator,
tax_calculator,
[
"number",
gr.inputs.Radio(["Single", "Married", "Divorced"]),
gr.inputs.Dataframe(
headers=["Item", "Cost", "Deduct"],
headers=["Item", "Cost", "Deduct"],
datatype=["str", "number", "bool"],
label="Assets Purchased this Year"
)
label="Assets Purchased this Year",
),
],
"number",
# interpretation="default", # Removed interpretation for dataframes
examples=[
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],
]
],
)
if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
import spacy
from spacy import displacy
import gradio as gr
nlp = spacy.load("en_core_web_sm")
@ -8,7 +9,11 @@ nlp = spacy.load("en_core_web_sm")
def text_analysis(text):
doc = nlp(text)
html = displacy.render(doc, style="dep", page=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
pos_count = {
"char_count": len(text),
"token_count": 0,
@ -17,20 +22,18 @@ def text_analysis(text):
for token in doc:
pos_tokens.extend([(token.text, token.pos_), (" ", None)])
return pos_tokens, pos_count, html
iface = gr.Interface(
text_analysis,
gr.inputs.Textbox(placeholder="Enter sentence here..."),
[
"highlight", "key_values", "html"
],
["highlight", "key_values", "html"],
examples=[
["What a beautiful morning for a walk!"],
["It was the best of times, it was the worst of times."],
]
],
)
iface.test_launch()

View File

@ -1,15 +1,18 @@
import pandas as pd
import numpy as np
import sklearn
import gradio as gr
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import os
import numpy as np
import pandas as pd
import sklearn
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import gradio as gr
current_dir = os.path.dirname(os.path.realpath(__file__))
data = pd.read_csv(os.path.join(current_dir, 'files/titanic.csv'))
data = pd.read_csv(os.path.join(current_dir, "files/titanic.csv"))
def encode_age(df):
df.Age = df.Age.fillna(-0.5)
@ -18,6 +21,7 @@ def encode_age(df):
df.Age = categories
return df
def encode_fare(df):
df.Fare = df.Fare.fillna(-0.5)
bins = (-1, 0, 8, 15, 31, 1000)
@ -25,46 +29,67 @@ def encode_fare(df):
df.Fare = categories
return df
def encode_df(df):
df = encode_age(df)
df = encode_fare(df)
sex_mapping = {"male": 0, "female": 1}
df = df.replace({'Sex': sex_mapping})
df = df.replace({"Sex": sex_mapping})
embark_mapping = {"S": 1, "C": 2, "Q": 3}
df = df.replace({'Embarked': embark_mapping})
df = df.replace({"Embarked": embark_mapping})
df.Embarked = df.Embarked.fillna(0)
df["Company"] = 0
df.loc[(df["SibSp"] > 0), "Company"] = 1
df.loc[(df["Parch"] > 0), "Company"] = 2
df.loc[(df["SibSp"] > 0) & (df["Parch"] > 0), "Company"] = 3
df = df[["PassengerId", "Pclass", "Sex", "Age", "Fare", "Embarked", "Company", "Survived"]]
df = df[
[
"PassengerId",
"Pclass",
"Sex",
"Age",
"Fare",
"Embarked",
"Company",
"Survived",
]
]
return df
train = encode_df(data)
X_all = train.drop(['Survived', 'PassengerId'], axis=1)
y_all = train['Survived']
X_all = train.drop(["Survived", "PassengerId"], axis=1)
y_all = train["Survived"]
num_test = 0.20
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)
X_train, X_test, y_train, y_test = train_test_split(
X_all, y_all, test_size=num_test, random_state=23
)
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
def predict_survival(passenger_class, is_male, age, company, fare, embark_point):
df = pd.DataFrame.from_dict({
'Pclass': [passenger_class + 1],
'Sex': [0 if is_male else 1],
'Age': [age],
'Company': [(1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)],
'Fare': [fare],
'Embarked': [embark_point + 1]
})
df = pd.DataFrame.from_dict(
{
"Pclass": [passenger_class + 1],
"Sex": [0 if is_male else 1],
"Age": [age],
"Company": [
(1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)
],
"Fare": [fare],
"Embarked": [embark_point + 1],
}
)
df = encode_age(df)
df = encode_fare(df)
pred = clf.predict_proba(df)[0]
return {'Perishes': pred[0], 'Survives': pred[1]}
return {"Perishes": pred[0], "Survives": pred[1]}
iface = gr.Interface(
predict_survival,
@ -72,7 +97,9 @@ iface = gr.Interface(
gr.inputs.Dropdown(["first", "second", "third"], type="index"),
"checkbox",
gr.inputs.Slider(0, 80),
gr.inputs.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
gr.inputs.CheckboxGroup(
["Sibling", "Child"], label="Travelling with (select all)"
),
gr.inputs.Number(),
gr.inputs.Radio(["S", "C", "Q"], type="index"),
],

View File

@ -1,8 +1,10 @@
import gradio as gr
def video_flip(video):
return video
iface = gr.Interface(video_flip, gr.inputs.Video(source="webcam"), "playable_video")
if __name__ == "__main__":

View File

@ -1,10 +1,12 @@
import gradio as gr
import numpy as np
import gradio as gr
def snap(image):
return np.flipud(image)
iface = gr.Interface(snap, gr.inputs.Image(source="webcam", tool=None), "image")
if __name__ == "__main__":
iface.launch()

View File

@ -1,15 +1,19 @@
import gradio as gr
from zipfile import ZipFile
import gradio as gr
def zip_to_json(file_obj):
files = []
with ZipFile(file_obj.name) as zfile:
for zinfo in zfile.infolist():
files.append({
"name": zinfo.filename,
"file_size": zinfo.file_size,
"compressed_size": zinfo.compress_size,
})
files.append(
{
"name": zinfo.filename,
"file_size": zinfo.file_size,
"compressed_size": zinfo.compress_size,
}
)
return files

View File

@ -1,19 +1,22 @@
import gradio as gr
from zipfile import ZipFile
import gradio as gr
def zip_two_files(file1, file2):
with ZipFile('tmp.zip', 'w') as zipObj:
with ZipFile("tmp.zip", "w") as zipObj:
zipObj.write(file1.name, "file1")
zipObj.write(file2.name, "file2")
return "tmp.zip"
iface = gr.Interface(
zip_two_files,
["file", "file"],
zip_two_files,
["file", "file"],
"file",
examples=[
["files/titanic.csv", "files/titanic.csv"],
]
],
)
if __name__ == "__main__":

View File

@ -1,4 +1,4 @@
Metadata-Version: 1.0
Metadata-Version: 2.1
Name: gradio
Version: 2.7.1
Summary: Python library for easily interacting with trained machine learning models
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq
Author-email: team@gradio.app
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,5 +1,5 @@
aiohttp
analytics-python
aiohttp
fastapi
ffmpy
markdown2
@ -9,7 +9,7 @@ pandas
paramiko
pillow
pycryptodome
pydub
python-multipart
pydub
requests
uvicorn

View File

@ -1,8 +1,9 @@
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
from gradio.app import get_state, set_state
from gradio.mix import *
from gradio.flagging import *
import pkg_resources
from gradio.app import get_state, set_state
from gradio.flagging import *
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
from gradio.mix import *
current_pkg_version = pkg_resources.require("gradio")[0].version
__version__ = current_pkg_version
__version__ = current_pkg_version

View File

@ -1,34 +1,36 @@
"""Implements a FastAPI server to run the gradio interface."""
from __future__ import annotations
from fastapi import FastAPI, Request, Depends, HTTPException, status
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
import inspect
import os
import posixpath
import pkg_resources
import secrets
from starlette.responses import RedirectResponse
import traceback
from typing import List, Optional, Type
import urllib
import uvicorn
from typing import List, Optional, Type
from gradio import utils, queueing
import pkg_resources
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from starlette.responses import RedirectResponse
from gradio import queueing, utils
from gradio.process_examples import load_from_cache, process_example
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename(
"gradio", "templates/frontend/static")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
VERSION_FILE = pkg_resources.resource_filename("gradio", "version.txt")
with open(VERSION_FILE) as version_file:
VERSION = version_file.read()
GRADIO_STATIC_ROOT = "https://gradio.s3-us-west-2.amazonaws.com/{}/static/".format(VERSION)
GRADIO_STATIC_ROOT = "https://gradio.s3-us-west-2.amazonaws.com/{}/static/".format(
VERSION
)
app = FastAPI()
app.add_middleware(
@ -43,40 +45,43 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
# Auth
# Auth
###########
@app.get('/user')
@app.get('/user/')
@app.get("/user")
@app.get("/user/")
def get_current_user(request: Request) -> Optional[str]:
token = request.cookies.get('access-token')
token = request.cookies.get("access-token")
return app.tokens.get(token)
@app.get('/login_check')
@app.get('/login_check/')
@app.get("/login_check")
@app.get("/login_check/")
def login_check(user: str = Depends(get_current_user)):
if app.auth is None or not(user is None):
if app.auth is None or not (user is None):
return
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)
@app.get('/token')
@app.get('/token/')
@app.get("/token")
@app.get("/token/")
def get_token(request: Request) -> Optional[str]:
token = request.cookies.get('access-token')
token = request.cookies.get("access-token")
return {"token": token, "user": app.tokens.get(token)}
@app.post('/login')
@app.post('/login/')
@app.post("/login")
@app.post("/login/")
def login(form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if ((not callable(app.auth) and username in app.auth
and app.auth[username] == password)
or (callable(app.auth) and app.auth.__call__(username, password))):
if (
not callable(app.auth)
and username in app.auth
and app.auth[username] == password
) or (callable(app.auth) and app.auth.__call__(username, password)):
token = secrets.token_urlsafe(16)
app.tokens[token] = username
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
@ -91,21 +96,19 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
###############
@app.head('/', response_class=HTMLResponse)
@app.get('/', response_class=HTMLResponse)
@app.head("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
def main(request: Request, user: str = Depends(get_current_user)):
if app.auth is None or not(user is None):
if app.auth is None or not (user is None):
config = app.interface.config
else:
config = {"auth_required": True,
"auth_message": app.interface.auth_message}
config = {"auth_required": True, "auth_message": app.interface.auth_message}
return templates.TemplateResponse(
"frontend/index.html",
{"request": request, "config": config}
"frontend/index.html", {"request": request, "config": config}
)
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config():
@ -135,8 +138,9 @@ def api_docs(request: Request):
if app.interface.examples is not None:
sample_inputs = app.interface.examples[0]
else:
sample_inputs = [inp.generate_sample()
for inp in app.interface.input_components]
sample_inputs = [
inp.generate_sample() for inp in app.interface.input_components
]
docs = {
"inputs": input_names,
"outputs": output_names,
@ -150,90 +154,90 @@ def api_docs(request: Request):
"output_types_doc": output_types_doc,
"sample_inputs": sample_inputs,
"auth": app.interface.auth,
"local_login_url": urllib.parse.urljoin(
app.interface.local_url, "login"),
"local_api_url": urllib.parse.urljoin(
app.interface.local_url, "api/predict")
"local_login_url": urllib.parse.urljoin(app.interface.local_url, "login"),
"local_api_url": urllib.parse.urljoin(app.interface.local_url, "api/predict"),
}
return templates.TemplateResponse(
"api_docs.html",
{"request": request, **docs}
)
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(
request: Request,
username: str = Depends(get_current_user)
):
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
flag_index = None
if body.get("example_id") != None:
example_id = body["example_id"]
if app.interface.cache_examples:
prediction = await run_in_threadpool(
load_from_cache, app.interface, example_id)
load_from_cache, app.interface, example_id
)
durations = None
else:
prediction, durations = await run_in_threadpool(
process_example, app.interface, example_id)
process_example, app.interface, example_id
)
else:
raw_input = body["data"]
if app.interface.show_error:
try:
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input)
app.interface.process, raw_input
)
except BaseException as error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)},
status_code=500)
return JSONResponse(content={"error": str(error)}, status_code=500)
else:
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input)
app.interface.process, raw_input
)
if app.interface.allow_flagging == "auto":
flag_index = await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface, raw_input, prediction,
flag_option="" if app.interface.flagging_options else None,
username=username)
app.interface,
raw_input,
prediction,
flag_option="" if app.interface.flagging_options else None,
username=username,
)
output = {
"data": prediction,
"durations": durations,
"data": prediction,
"durations": durations,
"avg_durations": app.interface.config.get("avg_durations"),
"flag_index": flag_index
}
"flag_index": flag_index,
}
return output
@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(
request: Request,
username: str = Depends(get_current_user)
):
async def flag(request: Request, username: str = Depends(get_current_user)):
if app.interface.analytics_enabled:
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
await utils.log_feature_analytics(app.interface.ip_address, "flag")
body = await request.json()
data = body['data']
data = body["data"]
await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface, data['input_data'], data['output_data'],
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
username=username)
return {'success': True}
app.interface,
data["input_data"],
data["output_data"],
flag_option=data.get("flag_option"),
flag_index=data.get("flag_index"),
username=username,
)
return {"success": True}
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
if app.interface.analytics_enabled:
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
await utils.log_feature_analytics(app.interface.ip_address, "interpret")
body = await request.json()
raw_input = body["data"]
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.interface.interpret, raw_input)
app.interface.interpret, raw_input
)
return {
"interpretation_scores": interpretation_scores,
"alternative_outputs": alternative_outputs
"alternative_outputs": alternative_outputs,
}
@ -249,7 +253,7 @@ async def queue_push(request: Request):
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request):
body = await request.json()
hash = body['hash']
hash = body["hash"]
status, data = queueing.get_status(hash)
return {"status": status, "data": data}
@ -277,7 +281,7 @@ def safe_join(directory: str, path: str) -> Optional[str]:
):
return None
return posixpath.join(directory, filename)
return posixpath.join(directory, filename)
def get_types(cls_set: List[Type], component: str):
@ -302,26 +306,30 @@ def get_state():
raise DeprecationWarning(
"This function is deprecated. To create stateful demos, pass 'state'"
"as both an input and output component. Please see the getting started"
"guide for more information.")
"guide for more information."
)
def set_state(*args):
raise DeprecationWarning(
"This function is deprecated. To create stateful demos, pass 'state'"
"as both an input and output component. Please see the getting started"
"guide for more information.")
"guide for more information."
)
if __name__ == '__main__': # Run directly for debugging: python app.py
from gradio import Interface
app.interface = Interface(lambda x: "Hello, " + x, "text", "text",
analytics_enabled=False)
if __name__ == "__main__": # Run directly for debugging: python app.py
from gradio import Interface
app.interface = Interface(
lambda x: "Hello, " + x, "text", "text", analytics_enabled=False
)
app.interface.favicon_path = None
app.interface.config = app.interface.get_config_file()
app.interface.show_error = True
app.interface.flagging_callback.setup(app.interface.flagging_dir)
app.tokens = {}
auth = True
if auth:
app.interface.auth = ("a", "b")

View File

@ -1,8 +1,10 @@
import os
import shutil
from gradio import processing_utils
class Component():
class Component:
"""
A class for defining the methods that all gradio input and output components should have.
"""
@ -15,16 +17,13 @@ class Component():
return self.__repr__()
def __repr__(self):
return "{}(label=\"{}\")".format(type(self).__name__, self.label)
return '{}(label="{}")'.format(type(self).__name__, self.label)
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {
"name": self.__class__.__name__.lower(),
"label": self.label
}
return {"name": self.__class__.__name__.lower(), "label": self.label}
@classmethod
def get_shortcut_implementations(cls):
@ -60,13 +59,15 @@ class Component():
new_file_name = str(file_index)
if "." in old_file_name:
uploaded_format = old_file_name.split(".")[-1].lower()
new_file_name += "." + uploaded_format
new_file_name += "." + uploaded_format
file.close()
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
return label + "/" + new_file_name
def restore_flagged_file(self, dir, file, encryption_key):
data = processing_utils.encode_file_to_base64(os.path.join(dir, file), encryption_key=encryption_key)
data = processing_utils.encode_file_to_base64(
os.path.join(dir, file), encryption_key=encryption_key
)
return {"name": file, "data": data}
@classmethod

View File

@ -1,11 +1,13 @@
from Crypto import Random
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from Crypto import Random
def get_key(password):
key = SHA256.new(password.encode()).digest()
return key
def encrypt(key, source):
IV = Random.new().read(AES.block_size) # generate IV
encryptor = AES.new(key, AES.MODE_CBC, IV)
@ -14,11 +16,14 @@ def encrypt(key, source):
data = IV + encryptor.encrypt(source) # store the IV at the beginning and encrypt
return data
def decrypt(key, source):
IV = source[:AES.block_size] # extract the IV from the beginning
IV = source[: AES.block_size] # extract the IV from the beginning
decryptor = AES.new(key, AES.MODE_CBC, IV)
data = decryptor.decrypt(source[AES.block_size:]) # decrypt
data = decryptor.decrypt(source[AES.block_size :]) # decrypt
padding = data[-1] # pick the padding value from the end; Python 2.x: ord(data[-1])
if data[-padding:] != bytes([padding]) * padding: # Python 2.x: chr(padding) * padding
if (
data[-padding:] != bytes([padding]) * padding
): # Python 2.x: chr(padding) * padding
raise ValueError("Invalid padding...")
return data[:-padding] # remove the padding
return data[:-padding] # remove the padding

View File

@ -1,9 +1,11 @@
import json
import tempfile
import requests
from gradio import inputs, outputs
import re
import base64
import json
import re
import tempfile
import requests
from gradio import inputs, outputs
def get_huggingface_interface(model_name, api_key, alias):
@ -17,56 +19,63 @@ def get_huggingface_interface(model_name, api_key, alias):
headers = {}
# Checking if model exists, and if so, it gets the pipeline
response = requests.request("GET", api_url, headers=headers)
response = requests.request("GET", api_url, headers=headers)
assert response.status_code == 200, "Invalid model name or src"
p = response.json().get('pipeline_tag')
p = response.json().get("pipeline_tag")
def encode_to_base64(r: requests.Response) -> str:
base64_repr = base64.b64encode(r.content).decode('utf-8')
base64_repr = base64.b64encode(r.content).decode("utf-8")
data_prefix = ";base64,"
if data_prefix in base64_repr:
return base64_repr
else:
content_type = r.headers.get('content-type')
content_type = r.headers.get("content-type")
return "data:{};base64,".format(content_type) + base64_repr
pipelines = {
'audio-classification': {
"audio-classification": {
# example model: https://hf.co/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
'inputs': inputs.Audio(label="Input", source="upload",
type="filepath"),
'outputs': outputs.Label(label="Class", type="confidences"),
'preprocess': lambda i: base64.b64decode(i['data'].split(",")[1]), # convert the base64 representation to binary
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
"outputs": outputs.Label(label="Class", type="confidences"),
"preprocess": lambda i: base64.b64decode(
i["data"].split(",")[1]
), # convert the base64 representation to binary
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()
},
},
'automatic-speech-recognition': {
"automatic-speech-recognition": {
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
'inputs': inputs.Audio(label="Input", source="upload",
type="filepath"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda i: base64.b64decode(i['data'].split(",")[1]), # convert the base64 representation to binary
'postprocess': lambda r: r.json()["text"]
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
"outputs": outputs.Textbox(label="Output"),
"preprocess": lambda i: base64.b64decode(
i["data"].split(",")[1]
), # convert the base64 representation to binary
"postprocess": lambda r: r.json()["text"],
},
'feature-extraction': {
"feature-extraction": {
# example model: hf.co/julien-c/distilbert-feature-extraction
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Dataframe(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r.json()[0],
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0],
},
'fill-mask': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
"fill-mask": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r.json()},
},
'image-classification': {
"image-classification": {
# Example: https://huggingface.co/google/vit-base-patch16-224
'inputs': inputs.Image(label="Input Image", type="filepath"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
"inputs": inputs.Image(label="Input Image", type="filepath"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda i: base64.b64decode(
i.split(",")[1]
), # convert the base64 representation to binary
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()
},
},
# TODO: support image segmentation pipeline -- should we add a new output component type?
# 'image-segmentation': {
@ -75,165 +84,220 @@ def get_huggingface_interface(model_name, api_key, alias):
# 'outputs': outputs.Image(label="Segmentation"),
# 'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
# 'postprocess': lambda x: base64.b64encode(x.json()[0]["mask"]).decode('utf-8'),
# },
# },
# TODO: also: support NER pipeline, object detection, table question answering
'question-answering': {
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
'preprocess': lambda c, q: {"inputs": {"context": c, "question": q}},
'postprocess': lambda r: (r.json()["answer"], r.json()["score"]),
},
'summarization': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Summary"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r.json()[0]["summary_text"]
},
'text-classification': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
},
'text-generation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r.json()[0]["generated_text"],
},
'text2text-generation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Generated Text"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r.json()[0]["generated_text"]
},
'translation': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Translation"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r.json()[0]["translation_text"]
},
'zero-shot-classification': {
'inputs': [inputs.Textbox(label="Input"),
inputs.Textbox(label="Possible class names ("
"comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes")],
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
{"candidate_labels": c, "multi_class": m}},
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
range(len(r.json()["labels"]))}
},
'sentence-similarity': {
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
'inputs': [
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
"question-answering": {
"inputs": [
inputs.Textbox(label="Context", lines=7),
inputs.Textbox(label="Question"),
],
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda src, sentences: {"inputs": {
"source_sentence": src,
"sentences": [s for s in sentences.splitlines() if s != ""],
}},
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r.json()) },
"outputs": [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
"postprocess": lambda r: (r.json()["answer"], r.json()["score"]),
},
'text-to-speech': {
"summarization": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["summary_text"],
},
"text-classification": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {
i["label"].split(", ")[0]: i["score"] for i in r.json()[0]
},
},
"text-generation": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"text2text-generation": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Generated Text"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["generated_text"],
},
"translation": {
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Translation"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r.json()[0]["translation_text"],
},
"zero-shot-classification": {
"inputs": [
inputs.Textbox(label="Input"),
inputs.Textbox(label="Possible class names (" "comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes"),
],
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda i, c, m: {
"inputs": i,
"parameters": {"candidate_labels": c, "multi_class": m},
},
"postprocess": lambda r: {
r.json()["labels"][i]: r.json()["scores"][i]
for i in range(len(r.json()["labels"]))
},
},
"sentence-similarity": {
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
"inputs": [
inputs.Textbox(
label="Source Sentence", default="That is a happy person"
),
inputs.Textbox(
lines=7,
label="Sentences to compare to",
placeholder="Separate each sentence by a newline",
),
],
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda src, sentences: {
"inputs": {
"source_sentence": src,
"sentences": [s for s in sentences.splitlines() if s != ""],
}
},
"postprocess": lambda r: {
f"sentence {i}": v for i, v in enumerate(r.json())
},
},
"text-to-speech": {
# example model: hf.co/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Audio(label="Audio"),
'preprocess': lambda x: {"inputs": x},
'postprocess': encode_to_base64,
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Audio(label="Audio"),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
'text-to-image': {
"text-to-image": {
# example model: hf.co/osanseviero/BigGAN-deep-128
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Image(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': encode_to_base64,
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Image(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
},
}
if p is None or not(p in pipelines):
if p is None or not (p in pipelines):
raise ValueError("Unsupported pipeline type: {}".format(type(p)))
pipeline = pipelines[p]
def query_huggingface_api(*params):
# Convert to a list of input components
data = pipeline['preprocess'](*params)
if isinstance(data, dict): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
data.update({'options': {'wait_for_model': True}})
data = pipeline["preprocess"](*params)
if isinstance(
data, dict
): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
data.update({"options": {"wait_for_model": True}})
data = json.dumps(data)
response = requests.request("POST", api_url, headers=headers, data=data)
if not(response.status_code == 200):
raise ValueError("Could not complete request to HuggingFace API, Error {}".format(response.status_code))
output = pipeline['postprocess'](response)
response = requests.request("POST", api_url, headers=headers, data=data)
if not (response.status_code == 200):
raise ValueError(
"Could not complete request to HuggingFace API, Error {}".format(
response.status_code
)
)
output = pipeline["postprocess"](response)
return output
if alias is None:
query_huggingface_api.__name__ = model_name
else:
query_huggingface_api.__name__ = alias
interface_info = {
'fn': query_huggingface_api,
'inputs': pipeline['inputs'],
'outputs': pipeline['outputs'],
'title': model_name,
"fn": query_huggingface_api,
"inputs": pipeline["inputs"],
"outputs": pipeline["outputs"],
"title": model_name,
}
return interface_info
def load_interface(name, src=None, api_key=None, alias=None):
if src is None:
tokens = name.split("/") # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
assert len(tokens) > 1, "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
tokens = name.split(
"/"
) # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
assert (
len(tokens) > 1
), "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
src = tokens[0]
name = "/".join(tokens[1:])
assert src.lower() in repos, "parameter: src must be one of {}".format(repos.keys())
interface_info = repos[src](name, api_key, alias)
return interface_info
def interface_params_from_config(config_dict):
## instantiate input component and output component
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
config_dict["inputs"] = [
inputs.get_input_instance(component)
for component in config_dict["input_components"]
]
config_dict["outputs"] = [
outputs.get_output_instance(component)
for component in config_dict["output_components"]
]
parameters = {
"allow_flagging", "allow_screenshot", "article", "description", "flagging_options", "inputs", "outputs",
"show_input", "show_output", "theme", "title"
"allow_flagging",
"allow_screenshot",
"article",
"description",
"flagging_options",
"inputs",
"outputs",
"show_input",
"show_output",
"theme",
"title",
}
config_dict = {k: config_dict[k] for k in parameters}
return config_dict
def get_spaces_interface(model_name, api_key, alias):
space_url = "https://huggingface.co/spaces/{}".format(model_name)
print("Fetching interface from: {}".format(space_url))
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
headers = {'Content-Type': 'application/json'}
headers = {"Content-Type": "application/json"}
r = requests.get(iframe_url)
result = re.search('window.gradio_config = (.*?);</script>', r.text) # some basic regex to extract the config
result = re.search(
"window.gradio_config = (.*?);</script>", r.text
) # some basic regex to extract the config
config = json.loads(result.group(1))
interface_info = interface_params_from_config(config)
# The function should call the API with preprocessed data
def fn(*data):
data = json.dumps({"data": data})
response = requests.post(api_url, headers=headers, data=data)
result = json.loads(response.content.decode("utf-8"))
output = result["data"]
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
if (
len(interface_info["outputs"]) == 1
): # if the fn is supposed to return a single value, pop it
output = output[0]
if len(interface_info["outputs"])==1 and isinstance(output, list): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
if len(interface_info["outputs"]) == 1 and isinstance(
output, list
): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
output = output[0]
return output
fn.__name__ = alias if (alias is not None) else model_name
interface_info["fn"] = fn
return interface_info
repos = {
# for each repo, we have a method that returns the Interface given the model name & optionally an api_key
"huggingface": get_huggingface_interface,
@ -241,6 +305,7 @@ repos = {
"spaces": get_spaces_interface,
}
def load_from_pipeline(pipeline):
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
@ -251,128 +316,163 @@ def load_from_pipeline(pipeline):
try:
import transformers
except ImportError:
raise ImportError("transformers not installed. Please try `pip install transformers`")
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
)
if not isinstance(pipeline, transformers.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, 'AudioClassificationPipeline') and isinstance(pipeline, transformers.AudioClassificationPipeline):
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, transformers.AudioClassificationPipeline
):
pipeline_info = {
'inputs': inputs.Audio(label="Input", source="microphone",
type="filepath"),
'outputs': outputs.Label(label="Class", type="confidences"),
'preprocess': lambda i: {"inputs": i},
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
"inputs": inputs.Audio(label="Input", source="microphone", type="filepath"),
"outputs": outputs.Label(label="Class", type="confidences"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, 'AutomaticSpeechRecognitionPipeline') and isinstance(pipeline, transformers.AutomaticSpeechRecognitionPipeline):
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline, transformers.AutomaticSpeechRecognitionPipeline
):
pipeline_info = {
'inputs': inputs.Audio(label="Input", source="microphone",
type="filepath"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda i: {"inputs": i},
'postprocess': lambda r: r["text"]
"inputs": inputs.Audio(label="Input", source="microphone", type="filepath"),
"outputs": outputs.Textbox(label="Output"),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, 'FeatureExtractionPipeline') and isinstance(pipeline, transformers.FeatureExtractionPipeline):
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, transformers.FeatureExtractionPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Dataframe(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0],
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Dataframe(label="Output"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, 'FillMaskPipeline') and isinstance(pipeline, transformers.FillMaskPipeline):
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, transformers.FillMaskPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r}
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, 'ImageClassificationPipeline') and isinstance(pipeline, transformers.ImageClassificationPipeline):
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, transformers.ImageClassificationPipeline
):
pipeline_info = {
'inputs': inputs.Image(label="Input Image", type="filepath"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i: {"images": i},
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
"inputs": inputs.Image(label="Input Image", type="filepath"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, 'QuestionAnsweringPipeline') and isinstance(pipeline, transformers.QuestionAnsweringPipeline):
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, transformers.QuestionAnsweringPipeline
):
pipeline_info = {
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
'preprocess': lambda c, q: {"context": c, "question": q},
'postprocess': lambda r: (r["answer"], r["score"]),
"inputs": [
inputs.Textbox(label="Context", lines=7),
inputs.Textbox(label="Question"),
],
"outputs": [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, 'SummarizationPipeline') and isinstance(pipeline, transformers.SummarizationPipeline):
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, transformers.SummarizationPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input", lines=7),
'outputs': outputs.Textbox(label="Summary"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: r[0]["summary_text"]
"inputs": inputs.Textbox(label="Input", lines=7),
"outputs": outputs.Textbox(label="Summary"),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, 'TextClassificationPipeline') and isinstance(pipeline, transformers.TextClassificationPipeline):
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, transformers.TextClassificationPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: [x],
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, 'TextGenerationPipeline') and isinstance(pipeline, transformers.TextGenerationPipeline):
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, transformers.TextGenerationPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Output"),
'preprocess': lambda x: {"text_inputs": x},
'postprocess': lambda r: r[0]["generated_text"],
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Output"),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, 'TranslationPipeline') and isinstance(pipeline, transformers.TranslationPipeline):
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, transformers.TranslationPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Translation"),
'preprocess': lambda x: [x],
'postprocess': lambda r: r[0]["translation_text"]
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Translation"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, 'Text2TextGenerationPipeline') and isinstance(pipeline, transformers.Text2TextGenerationPipeline):
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, transformers.Text2TextGenerationPipeline
):
pipeline_info = {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Textbox(label="Generated Text"),
'preprocess': lambda x: [x],
'postprocess': lambda r: r[0]["generated_text"]
"inputs": inputs.Textbox(label="Input"),
"outputs": outputs.Textbox(label="Generated Text"),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, 'ZeroShotClassificationPipeline') and isinstance(pipeline, transformers.ZeroShotClassificationPipeline):
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, transformers.ZeroShotClassificationPipeline
):
pipeline_info = {
'inputs': [inputs.Textbox(label="Input"),
inputs.Textbox(label="Possible class names ("
"comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes")],
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i, c, m: {"sequences": i,
"candidate_labels": c, "multi_label": m},
'postprocess': lambda r: {r["labels"][i]: r["scores"][i] for i in
range(len(r["labels"]))}
"inputs": [
inputs.Textbox(label="Input"),
inputs.Textbox(label="Possible class names (" "comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes"),
],
"outputs": outputs.Label(label="Classification", type="confidences"),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
"multi_label": m,
},
"postprocess": lambda r: {
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
else:
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
# define the function that will be called by the Interface
def fn(*params):
data = pipeline_info["preprocess"](*params)
# special cases that needs to be handled differently
if isinstance(pipeline, (transformers.TextClassificationPipeline,
transformers.Text2TextGenerationPipeline,
transformers.TranslationPipeline)):
if isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.Text2TextGenerationPipeline,
transformers.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data)
# print("Before postprocessing", data)
output = pipeline_info["postprocess"](data)
return output
interface_info = pipeline_info.copy()
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
return interface_info

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import csv
import datetime
import io
import json
import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional
import gradio as gr
@ -17,9 +18,7 @@ class FlaggingCallback(ABC):
"""
@abstractmethod
def setup(
self,
flagging_dir: str):
def setup(self, flagging_dir: str):
"""
This method should be overridden and ensure that everything is set up correctly for flag().
This method gets called once at the beginning of the Interface.launch() method.
@ -30,20 +29,21 @@ class FlaggingCallback(ABC):
@abstractmethod
def flag(
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None) -> int:
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
"""
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
This gets called every time the <flag> button is pressed.
Parameters:
interface: The Interface object that is being used to launch the flagging interface.
input_data: The input data to be flagged.
output_data: The output data to be flagged.
output_data: The output data to be flagged.
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
flag_index (optional): The index of the sample that is being flagged.
username (optional): The username of the user that is flagging the data, if logged in.
@ -55,41 +55,52 @@ class FlaggingCallback(ABC):
class SimpleCSVLogger(FlaggingCallback):
"""
A simple example implementation of the FlaggingCallback abstract class
A simple example implementation of the FlaggingCallback abstract class
provided for illustrative purposes.
"""
def setup(
self,
flagging_dir: str
):
def setup(self, flagging_dir: str):
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
def flag(
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
flagging_dir = self.flagging_dir
log_filepath = "{}/log.csv".format(flagging_dir)
csv_data = []
for i, input in enumerate(interface.input_components):
csv_data.append(input.save_flagged(
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], None))
csv_data.append(
input.save_flagged(
flagging_dir,
interface.config["input_components"][i]["label"],
input_data[i],
None,
)
)
for i, output in enumerate(interface.output_components):
csv_data.append(output.save_flagged(
flagging_dir, interface.config["output_components"][i]["label"], output_data[i], None) if
output_data[i] is not None else "")
csv_data.append(
output.save_flagged(
flagging_dir,
interface.config["output_components"][i]["label"],
output_data[i],
None,
)
if output_data[i] is not None
else ""
)
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(csv_data)
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -97,24 +108,22 @@ class SimpleCSVLogger(FlaggingCallback):
class CSVLogger(FlaggingCallback):
"""
The default implementation of the FlaggingCallback abstract class.
The default implementation of the FlaggingCallback abstract class.
Logs the input and output data to a CSV file.
"""
def setup(
self,
flagging_dir: str
):
def setup(self, flagging_dir: str):
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
def flag(
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
flagging_dir = self.flagging_dir
log_fp = "{}/log.csv".format(flagging_dir)
@ -126,12 +135,25 @@ class CSVLogger(FlaggingCallback):
csv_data = []
if not output_only_mode:
for i, input in enumerate(interface.input_components):
csv_data.append(input.save_flagged(
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], encryption_key))
csv_data.append(
input.save_flagged(
flagging_dir,
interface.config["input_components"][i]["label"],
input_data[i],
encryption_key,
)
)
for i, output in enumerate(interface.output_components):
csv_data.append(output.save_flagged(
flagging_dir, interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
output_data[i] is not None else "")
csv_data.append(
output.save_flagged(
flagging_dir,
interface.config["output_components"][i]["label"],
output_data[i],
encryption_key,
)
if output_data[i] is not None
else ""
)
if not output_only_mode:
if flag_option is not None:
csv_data.append(flag_option)
@ -141,10 +163,14 @@ class CSVLogger(FlaggingCallback):
if is_new:
headers = []
if not output_only_mode:
headers += [interface["label"]
for interface in interface.config["input_components"]]
headers += [interface["label"]
for interface in interface.config["output_components"]]
headers += [
interface["label"]
for interface in interface.config["input_components"]
]
headers += [
interface["label"]
for interface in interface.config["output_components"]
]
if not output_only_mode:
if interface.flagging_options is not None:
headers.append("flag")
@ -169,7 +195,8 @@ class CSVLogger(FlaggingCallback):
with open(log_fp, "rb") as csvfile:
encrypted_csv = csvfile.read()
decrypted_csv = encryptor.decrypt(
interface.encryption_key, encrypted_csv)
interface.encryption_key, encrypted_csv
)
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
@ -180,8 +207,11 @@ class CSVLogger(FlaggingCallback):
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()))
csvfile.write(
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
)
)
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:
@ -193,7 +223,9 @@ class CSVLogger(FlaggingCallback):
with open(log_fp) as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
with open(log_fp, "w", newline="") as csvfile: # newline parameter needed for Windows
with open(
log_fp, "w", newline=""
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
with open(log_fp, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
@ -204,24 +236,26 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"""
A FlaggingCallback that saves flagged data to a HuggingFace dataset.
"""
def __init__(
self,
hf_foken: str,
dataset_name: str,
organization: Optional[str] = None,
private: bool = False,
verbose: bool = True):
self,
hf_foken: str,
dataset_name: str,
organization: Optional[str] = None,
private: bool = False,
verbose: bool = True,
):
"""
Params:
hf_token (str): The token to use to access the huggingface API.
dataset_name (str): The name of the dataset to save the data to, e.g.
dataset_name (str): The name of the dataset to save the data to, e.g.
"image-classifier-1"
organization (str): The name of the organization to which to attach
organization (str): The name of the organization to which to attach
the datasets. If None, the dataset attaches to the user only.
private (bool): If the dataset does not already exist, whether it
should be created as a private dataset or public. Private datasets
private (bool): If the dataset does not already exist, whether it
should be created as a private dataset or public. Private datasets
may require paid huggingface.co accounts
verbose (bool): Whether to print out the status of the dataset
verbose (bool): Whether to print out the status of the dataset
creation.
"""
self.hf_foken = hf_foken
@ -230,118 +264,129 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
self.dataset_private = private
self.verbose = verbose
def setup(
self,
flagging_dir: str):
def setup(self, flagging_dir: str):
"""
Params:
flagging_dir (str): local directory where the dataset is cloned,
flagging_dir (str): local directory where the dataset is cloned,
updated, and pushed from.
"""
try:
import huggingface_hub
import huggingface_hub
except (ImportError, ModuleNotFoundError):
raise ImportError("Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'.")
raise ImportError(
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
)
path_to_dataset_repo = huggingface_hub.create_repo(
name=self.dataset_name, token=self.hf_foken,
private=self.dataset_private, repo_type="dataset", exist_ok=True)
name=self.dataset_name,
token=self.hf_foken,
private=self.dataset_private,
repo_type="dataset",
exist_ok=True,
)
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
self.flagging_dir = flagging_dir
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
self.repo = huggingface_hub.Repository(
local_dir=self.dataset_dir, clone_from=path_to_dataset_repo,
use_auth_token=self.hf_foken)
local_dir=self.dataset_dir,
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_foken,
)
self.repo.git_pull()
#Should filename be user-specified?
self.log_file = os.path.join(self.dataset_dir, "data.csv")
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
# Should filename be user-specified?
self.log_file = os.path.join(self.dataset_dir, "data.csv")
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
def flag(
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None
) -> int:
self,
interface: gr.Interface,
input_data: List[Any],
output_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
is_new = not os.path.exists(self.log_file)
infos = {"flagged": {"features": {}}}
with open(self.log_file, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
# File previews for certain input and output types
file_preview_types = {
gr.inputs.Audio: "Audio",
gr.inputs.Audio: "Audio",
gr.outputs.Audio: "Audio",
gr.inputs.Image: "Image",
gr.inputs.Image: "Image",
gr.outputs.Image: "Image",
}
# Generate the headers and dataset_infos
if is_new:
headers = []
for i, component in enumerate(interface.input_components +
interface.output_components):
component_label = (interface.config["input_components"] +
interface.config["output_components"])[i]["label"]
for i, component in enumerate(
interface.input_components + interface.output_components
):
component_label = (
interface.config["input_components"]
+ interface.config["output_components"]
)[i]["label"]
headers.append(component_label)
infos["flagged"]["features"][component_label] = {
"dtype": "string",
"_type": "Value"
"dtype": "string",
"_type": "Value",
}
if isinstance(component, tuple(file_preview_types)):
headers.append(component_label + " file")
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
infos["flagged"]["features"][component_label +
" file"] = {
"_type": _type
}
infos["flagged"]["features"][
component_label + " file"
] = {"_type": _type}
break
if interface.flagging_options is not None:
headers.append("flag")
infos["flagged"]["features"]["flag"] = {
"dtype": "string",
"_type": "Value"
}
"dtype": "string",
"_type": "Value",
}
writer.writerow(headers)
# Generate the row corresponding to the flagged sample
csv_data = []
for i, input in enumerate(interface.input_components):
label = interface.config["input_components"][i]["label"]
filepath = input.save_flagged(
self.dataset_dir, label, input_data[i], None)
self.dataset_dir, label, input_data[i], None
)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
csv_data.append("{}/resolve/main/{}".format(
self.path_to_dataset_repo, filepath))
csv_data.append(
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
for i, output in enumerate(interface.output_components):
label = interface.config["output_components"][i]["label"]
filepath = (output.save_flagged(
self.dataset_dir, label, output_data[i], None) if
output_data[i] is not None else "")
filepath = (
output.save_flagged(self.dataset_dir, label, output_data[i], None)
if output_data[i] is not None
else ""
)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
csv_data.append("{}/resolve/main/{}".format(
self.path_to_dataset_repo, filepath))
csv_data.append(
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
if flag_option is not None:
csv_data.append(flag_option)
writer.writerow(csv_data)
if is_new:
json.dump(infos, open(self.infos_file, "w"))
with open(self.log_file, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
self.repo.push_to_hub(
commit_message="Flagged sample #{}".format(line_count))
return line_count
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
return line_count

File diff suppressed because it is too large Load Diff

View File

@ -4,40 +4,47 @@ including various methods for constructing an interface and then launching it.
"""
from __future__ import annotations
import copy
import getpass
from logging import warning
import markdown2 # type: ignore
import os
import random
import sys
import time
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
import warnings
import webbrowser
import weakref
import webbrowser
from logging import warning
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from gradio import encryptor, interpretation, networking, queueing, strings, utils # type: ignore
from gradio.external import load_interface, load_from_pipeline # type: ignore
from gradio.flagging import FlaggingCallback, CSVLogger # type: ignore
from gradio.inputs import get_input_instance, InputComponent, State as i_State # type: ignore
from gradio.outputs import get_output_instance, OutputComponent, State as o_State # type: ignore
import markdown2 # type: ignore
from gradio import (encryptor, interpretation, networking, # type: ignore
queueing, strings, utils)
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import InputComponent
from gradio.inputs import State as i_State # type: ignore
from gradio.inputs import get_input_instance
from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance
from gradio.process_examples import cache_interface_examples
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
import flask
import transformers
class Interface:
"""
Gradio interfaces are created by constructing a `Interface` object
with a locally-defined function, or with `Interface.load()` with the path
with a locally-defined function, or with `Interface.load()` with the path
to a repo or by `Interface.from_pipeline()` with a Transformers Pipeline.
"""
# stores references to all currently existing Interface instances
instances: weakref.WeakSet = weakref.WeakSet()
instances: weakref.WeakSet = weakref.WeakSet()
@classmethod
def get_instances(cls) -> List[Interface]:
@ -47,15 +54,17 @@ class Interface:
return list(Interface.instances)
@classmethod
def load(cls,
name: str,
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
**kwargs) -> Interface:
def load(
cls,
name: str,
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
**kwargs,
) -> Interface:
"""
Class method to construct an Interface from an external source repository, such as huggingface.
Parameters:
Parameters:
name (str): the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "huggingface/gpt2")
src (str): the source of the model: `huggingface` or `gradio` (or empty if source is provided as a prefix in `name`)
api_key (str): optional api key for use with Hugging Face Model Hub
@ -70,14 +79,11 @@ class Interface:
return interface
@classmethod
def from_pipeline(
cls,
pipeline: transformers.Pipeline,
**kwargs) -> Interface:
def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
"""
Construct an Interface from a Hugging Face transformers.Pipeline.
Parameters:
pipeline (transformers.Pipeline):
Parameters:
pipeline (transformers.Pipeline):
Returns:
(gradio.Interface): a Gradio Interface object from the given Pipeline
"""
@ -87,41 +93,42 @@ class Interface:
return interface
def __init__(
self,
fn: Callable | List[Callable],
inputs: str | InputComponent | List[str | InputComponent] = None,
outputs: str | OutputComponent | List[str | OutputComponent] = None,
verbose: bool = False,
self,
fn: Callable | List[Callable],
inputs: str | InputComponent | List[str | InputComponent] = None,
outputs: str | OutputComponent | List[str | OutputComponent] = None,
verbose: bool = False,
examples: Optional[List[Any] | List[List[Any]] | str] = None,
examples_per_page: int = 10,
live: bool = False,
layout: str = "unaligned",
show_input: bool = True,
examples_per_page: int = 10,
live: bool = False,
layout: str = "unaligned",
show_input: bool = True,
show_output: bool = True,
capture_session: Optional[bool] = None,
interpretation: Optional[Callable | str] = None,
num_shap: float = 2.0,
theme: Optional[str] = None,
capture_session: Optional[bool] = None,
interpretation: Optional[Callable | str] = None,
num_shap: float = 2.0,
theme: Optional[str] = None,
repeat_outputs_per_model: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
article: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
article: Optional[str] = None,
thumbnail: Optional[str] = None,
css: Optional[str] = None,
height=None,
width=None,
allow_screenshot: bool = True,
allow_flagging: Optional[str] = None,
flagging_options: List[str]=None,
encrypt=None,
show_tips=None,
flagging_dir: str = "flagged",
analytics_enabled: Optional[bool] = None,
css: Optional[str] = None,
height=None,
width=None,
allow_screenshot: bool = True,
allow_flagging: Optional[str] = None,
flagging_options: List[str] = None,
encrypt=None,
show_tips=None,
flagging_dir: str = "flagged",
analytics_enabled: Optional[bool] = None,
server_name=None,
server_port=None,
enable_queue=None,
enable_queue=None,
api_mode=None,
flagging_callback: FlaggingCallback = CSVLogger()):
flagging_callback: FlaggingCallback = CSVLogger(),
):
"""
Parameters:
fn (Union[Callable, List[Callable]]): the function to wrap an interface around.
@ -133,7 +140,7 @@ class Interface:
live (bool): whether the interface should automatically reload on change.
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
capture_session (bool): DEPRECATED. If True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function.
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function.
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
@ -147,7 +154,7 @@ class Interface:
encrypt (bool): DEPRECATED. If True, flagged data will be encrypted by key provided by creator at launch
flagging_dir (str): what to name the dir where flagged data is stored.
show_tips (bool): DEPRECATED. if True, will occasionally show tips about new Gradio features
enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
api_mode (bool): DEPRECATED. If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
server_name (str): DEPRECATED. Name of the server to use for serving the interface - pass in launch() instead.
server_port (int): DEPRECATED. Port of the server to use for serving the interface - pass in launch() instead.
@ -163,27 +170,33 @@ class Interface:
self.output_components = [get_output_instance(o) for o in outputs]
if repeat_outputs_per_model:
self.output_components *= len(fn)
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
raise ValueError("Only one input component can be State.")
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
raise ValueError("Only one output component can be State.")
if sum(isinstance(i, i_State) for i in self.input_components) == 1:
if len(fn) > 1:
raise ValueError(
"State cannot be used with multiple functions.")
state_param_index = [isinstance(i, i_State)
for i in self.input_components].index(True)
raise ValueError("State cannot be used with multiple functions.")
state_param_index = [
isinstance(i, i_State) for i in self.input_components
].index(True)
state: i_State = self.input_components[state_param_index]
if state.default is None:
default = utils.get_default_args(fn[0])[state_param_index]
state.default = default
if interpretation is None or isinstance(interpretation, list) or callable(interpretation):
if (
interpretation is None
or isinstance(interpretation, list)
or callable(interpretation)
):
self.interpretation = interpretation
elif isinstance(interpretation, str):
self.interpretation = [interpretation.lower() for _ in self.input_components]
self.interpretation = [
interpretation.lower() for _ in self.input_components
]
else:
raise ValueError("Invalid value for parameter: interpretation")
@ -193,8 +206,10 @@ class Interface:
self.__name__ = ", ".join(self.function_names)
if verbose:
warnings.warn("The `verbose` parameter in the `Interface`"
"is deprecated and has no effect.")
warnings.warn(
"The `verbose` parameter in the `Interface`"
"is deprecated and has no effect."
)
self.status = "OFF"
self.live = live
@ -205,12 +220,14 @@ class Interface:
self.capture_session = capture_session
if capture_session is not None:
warnings.warn("The `capture_session` parameter in the `Interface`"
" is deprecated and may be removed in the future.")
warnings.warn(
"The `capture_session` parameter in the `Interface`"
" is deprecated and may be removed in the future."
)
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
self.session = tf.get_default_graph(), tf.keras.backend.get_session()
except (ImportError, AttributeError):
# If they are using TF >= 2.0 or don't have TF,
# just ignore this parameter.
@ -219,80 +236,120 @@ class Interface:
if server_name is not None or server_port is not None:
raise DeprecationWarning(
"The `server_name` and `server_port` parameters in `Interface`"
"are deprecated. Please pass into launch() instead.")
"are deprecated. Please pass into launch() instead."
)
self.session = None
self.title = title
self.description = description
if article is not None:
article = utils.readme_to_html(article)
article = markdown2.markdown(
article, extras=["fenced-code-blocks"])
article = markdown2.markdown(article, extras=["fenced-code-blocks"])
self.article = article
self.thumbnail = thumbnail
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
DEPRECATED_THEME_MAP = {"darkdefault": "default", "darkhuggingface": "dark-huggingface", "darkpeach": "dark-peach", "darkgrass": "dark-grass"}
VALID_THEME_SET = ("default", "huggingface", "seafoam", "grass", "peach", "dark", "dark-huggingface", "dark-seafoam", "dark-grass", "dark-peach")
DEPRECATED_THEME_MAP = {
"darkdefault": "default",
"darkhuggingface": "dark-huggingface",
"darkpeach": "dark-peach",
"darkgrass": "dark-grass",
}
VALID_THEME_SET = (
"default",
"huggingface",
"seafoam",
"grass",
"peach",
"dark",
"dark-huggingface",
"dark-seafoam",
"dark-grass",
"dark-peach",
)
if theme in DEPRECATED_THEME_MAP:
warnings.warn(f"'{theme}' theme name is deprecated, using {DEPRECATED_THEME_MAP[theme]} instead.")
warnings.warn(
f"'{theme}' theme name is deprecated, using {DEPRECATED_THEME_MAP[theme]} instead."
)
theme = DEPRECATED_THEME_MAP[theme]
elif theme not in VALID_THEME_SET:
raise ValueError(f"Invalid theme name, theme must be one of: {', '.join(VALID_THEME_SET)}")
raise ValueError(
f"Invalid theme name, theme must be one of: {', '.join(VALID_THEME_SET)}"
)
self.theme = theme
self.height = height
self.width = width
if self.height is not None or self.width is not None:
warnings.warn("The `height` and `width` parameters in `Interface` "
"are deprecated and should be passed into launch().")
warnings.warn(
"The `height` and `width` parameters in `Interface` "
"are deprecated and should be passed into launch()."
)
if css is not None and os.path.exists(css):
with open(css) as css_file:
self.css = css_file.read()
else:
self.css = css
if examples is None or isinstance(examples, str) or (isinstance(
examples, list) and (len(examples) == 0 or isinstance(
examples[0], list))):
if (
examples is None
or isinstance(examples, str)
or (
isinstance(examples, list)
and (len(examples) == 0 or isinstance(examples[0], list))
)
):
self.examples = examples
elif isinstance(examples, list) and len(self.input_components) == 1: # If there is only one input component, examples can be provided as a regular list instead of a list of lists
elif (
isinstance(examples, list) and len(self.input_components) == 1
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
self.examples = [[e] for e in examples]
else:
raise ValueError(
"Examples argument must either be a directory or a nested "
"list, where each sublist represents a set of inputs.")
"list, where each sublist represents a set of inputs."
)
self.num_shap = num_shap
self.examples_per_page = examples_per_page
self.simple_server = None
self.allow_screenshot = allow_screenshot
# For analytics_enabled and allow_flagging: (1) first check for
# For analytics_enabled and allow_flagging: (1) first check for
# parameter, (2) check for env variable, (3) default to True/"manual"
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
self.analytics_enabled = (
analytics_enabled
if analytics_enabled is not None
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
)
if allow_flagging is None:
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
if allow_flagging==True:
warnings.warn("The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
", not a boolean. Setting parameter to: 'manual'.")
if allow_flagging == True:
warnings.warn(
"The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
", not a boolean. Setting parameter to: 'manual'."
)
self.allow_flagging = "manual"
elif allow_flagging=="manual":
elif allow_flagging == "manual":
self.allow_flagging = "manual"
elif allow_flagging==False:
warnings.warn("The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
", not a boolean. Setting parameter to: 'never'.")
elif allow_flagging == False:
warnings.warn(
"The `allow_flagging` parameter in `Interface` now"
"takes a string value ('auto', 'manual', or 'never')"
", not a boolean. Setting parameter to: 'never'."
)
self.allow_flagging = "never"
elif allow_flagging=="never":
elif allow_flagging == "never":
self.allow_flagging = "never"
elif allow_flagging=="auto":
elif allow_flagging == "auto":
self.allow_flagging = "auto"
else:
raise ValueError("Invalid value for `allow_flagging` parameter."
"Must be: 'auto', 'manual', or 'never'.")
raise ValueError(
"Invalid value for `allow_flagging` parameter."
"Must be: 'auto', 'manual', or 'never'."
)
self.flagging_options = flagging_options
self.flagging_callback = flagging_callback
@ -305,17 +362,22 @@ class Interface:
self.ip_address = utils.get_local_ip_address()
if show_tips is not None:
warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead")
warnings.warn(
"The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead"
)
self.requires_permissions = any(
[component.requires_permissions for component in self.input_components])
[component.requires_permissions for component in self.input_components]
)
self.enable_queue = enable_queue
if self.enable_queue is not None:
warnings.warn("The `enable_queue` parameter in the `Interface`"
"will be deprecated and may not work properly. "
"Please use the `enable_queue` parameter in "
"`launch()` instead")
warnings.warn(
"The `enable_queue` parameter in the `Interface`"
"will be deprecated and may not work properly. "
"Please use the `enable_queue` parameter in "
"`launch()` instead"
)
self.favicon_path = None
self.height = height
@ -324,31 +386,34 @@ class Interface:
warnings.warn(
"The `width` and `height` parameters in the `Interface` class"
"will be deprecated. Please provide these parameters"
"in `launch()` instead")
"in `launch()` instead"
)
self.encrypt = encrypt
if self.encrypt is not None:
warnings.warn(
"The `encrypt` parameter in the `Interface` class"
"will be deprecated. Please provide this parameter"
"in `launch()` instead")
"in `launch()` instead"
)
if api_mode is not None:
warnings.warn("The `api_mode` parameter in the `Interface` is deprecated.")
self.api_mode = False
data = {'fn': fn,
'inputs': inputs,
'outputs': outputs,
'live': live,
'capture_session': capture_session,
'ip_address': self.ip_address,
'interpretation': interpretation,
'allow_flagging': allow_flagging,
'allow_screenshot': allow_screenshot,
'custom_css': self.css is not None,
'theme': self.theme
}
data = {
"fn": fn,
"inputs": inputs,
"outputs": outputs,
"live": live,
"capture_session": capture_session,
"ip_address": self.ip_address,
"interpretation": interpretation,
"allow_flagging": allow_flagging,
"allow_screenshot": allow_screenshot,
"custom_css": self.css is not None,
"theme": self.theme,
}
if self.analytics_enabled:
utils.initiated_analytics(data)
@ -358,7 +423,9 @@ class Interface:
Interface.instances.add(self)
def __call__(self, *params):
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
if (
self.api_mode
): # skip the preprocessing/postprocessing if sending to a remote API
output = self.run_prediction(params, called_directly=True)
else:
output, _ = self.process(params)
@ -369,7 +436,8 @@ class Interface:
def __repr__(self):
repr = "Gradio Interface for: {}".format(
", ".join(fn.__name__ for fn in self.predict))
", ".join(fn.__name__ for fn in self.predict)
)
repr += "\n" + "-" * len(repr)
repr += "\ninputs:"
for component in self.input_components:
@ -381,28 +449,30 @@ class Interface:
def get_config_file(self):
return utils.get_config_file(self)
def run_prediction(
self,
processed_input: List[Any],
return_duration: bool = False,
called_directly: bool = False
self,
processed_input: List[Any],
return_duration: bool = False,
called_directly: bool = False,
) -> List[Any] | Tuple[List[Any], List[float]]:
"""
Runs the prediction function with the given (already processed) inputs.
Parameters:
processed_input (list): A list of processed inputs.
return_duration (bool): Whether to return the duration of the prediction.
called_directly (bool): Whether the prediction is being called
called_directly (bool): Whether the prediction is being called
directly (i.e. as a function, not through the GUI).
Returns:
predictions (list): A list of predictions (not post-processed).
durations (list): A list of durations for each prediction
durations (list): A list of durations for each prediction
(only returned if `return_duration` is True).
"""
if self.api_mode: # Serialize the input
processed_input = [input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)]
processed_input = [
input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)
]
predictions = []
durations = []
output_component_counter = 0
@ -423,8 +493,16 @@ class Interface:
if self.api_mode: # Serialize the input
prediction_ = copy.deepcopy(prediction)
prediction = []
for pred in prediction_: # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
prediction.append(self.output_components[output_component_counter].deserialize(pred))
for (
pred
) in (
prediction_
): # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
prediction.append(
self.output_components[output_component_counter].deserialize(
pred
)
)
output_component_counter += 1
durations.append(duration)
@ -435,12 +513,9 @@ class Interface:
else:
return predictions
def process(
self,
raw_input: List[Any]
) -> Tuple[List[Any], List[float]]:
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""
First preprocesses the input, then runs prediction using
First preprocesses the input, then runs prediction using
self.run_prediction(), then postprocesses the output.
Parameters:
raw_input: a list of raw inputs to process and apply the prediction(s) on.
@ -448,33 +523,37 @@ class Interface:
processed output: a list of processed outputs to return as the prediction(s).
duration: a list of time deltas measuring inference time for each prediction fn.
"""
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(
self.input_components)]
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)
]
predictions, durations = self.run_prediction(
processed_input, return_duration=True)
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
processed_input, return_duration=True
)
processed_output = [
output_component.postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_component in enumerate(self.output_components)
]
avg_durations = []
for i, duration in enumerate(durations):
self.predict_durations[i][0] += duration
self.predict_durations[i][1] += 1
avg_durations.append(self.predict_durations[i][0]
/ self.predict_durations[i][1])
avg_durations.append(
self.predict_durations[i][0] / self.predict_durations[i][1]
)
if hasattr(self, "config"):
self.config["avg_durations"] = avg_durations
return processed_output, durations
def interpret(
self,
raw_input: List[Any]
) -> List[Any]:
def interpret(self, raw_input: List[Any]) -> List[Any]:
return interpretation.run_interpret(self, raw_input)
def block_thread(
self,
self,
) -> None:
"""Block main thread until interrupted by user."""
try:
@ -488,10 +567,10 @@ class Interface:
def test_launch(self) -> None:
for predict_fn in self.predict:
print("Test launch: {}()...".format(predict_fn.__name__), end=' ')
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
raw_input = []
for input_component in self.input_components:
if input_component.test_input is None:
if input_component.test_input is None:
print("SKIPPED")
break
else:
@ -502,22 +581,22 @@ class Interface:
continue
def launch(
self,
inline: bool = None,
inbrowser: bool = None,
share: bool = False,
self,
inline: bool = None,
inbrowser: bool = None,
share: bool = False,
debug: bool = False,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
private_endpoint: Optional[str] = None,
prevent_thread_lock: bool = False,
show_error: bool = True,
prevent_thread_lock: bool = False,
show_error: bool = True,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
show_tips: bool = False,
server_port: Optional[int] = None,
show_tips: bool = False,
enable_queue: bool = False,
height: int = 500,
width: int = 900,
height: int = 500,
width: int = 900,
encrypt: bool = False,
cache_examples: bool = False,
favicon_path: Optional[str] = None,
@ -537,7 +616,7 @@ class Interface:
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
width (int): The width in pixels of the <iframe> element containing the interface (used if inline=True)
height (int): The height in pixels of the <iframe> element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
@ -548,10 +627,14 @@ class Interface:
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
self.config = self.get_config_file()
self.config = self.get_config_file()
self.cache_examples = cache_examples
if auth and not callable(auth) and not isinstance(
auth[0], tuple) and not isinstance(auth[0], list):
if (
auth
and not callable(auth)
and not isinstance(auth[0], tuple)
and not isinstance(auth[0], list)
):
auth = [auth]
self.auth = auth
self.auth_message = auth_message
@ -560,12 +643,13 @@ class Interface:
self.height = self.height or height
self.width = self.width or width
self.favicon_path = favicon_path
if self.encrypt is None:
self.encrypt = encrypt
self.encrypt = encrypt
if self.encrypt:
self.encryption_key = encryptor.get_key(
getpass.getpass("Enter key for encryption: "))
getpass.getpass("Enter key for encryption: ")
)
if self.enable_queue is None:
self.enable_queue = enable_queue
@ -579,8 +663,9 @@ class Interface:
cache_interface_examples(self)
server_port, path_to_local_server, app, server = networking.start_server(
self, server_name, server_port)
self, server_name, server_port
)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"
@ -589,7 +674,7 @@ class Interface:
utils.launch_counter()
# If running in a colab or not able to access localhost,
# If running in a colab or not able to access localhost,
# automatically create a shareable link.
is_colab = utils.colab_check()
if is_colab or not (networking.url_ok(path_to_local_server)):
@ -607,11 +692,10 @@ class Interface:
if private_endpoint is not None:
share = True
self.share = share
if share:
try:
share_url = networking.setup_tunnel(
server_port, private_endpoint)
share_url = networking.setup_tunnel(server_port, private_endpoint)
self.share_url = share_url
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
if private_endpoint:
@ -620,8 +704,7 @@ class Interface:
print(strings.en["SHARE_LINK_MESSAGE"])
except RuntimeError:
if self.analytics_enabled:
utils.error_analytics(self.ip_address,
"Not able to set up tunnel")
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
share_url = None
else:
print(strings.en["PUBLIC_SHARE_TRUE"])
@ -636,31 +719,37 @@ class Interface:
inline = utils.ipython_check() and (auth is None)
if inline:
if auth is not None:
print("Warning: authentication is not supported inline. Please"
"click the link to access the interface in a new tab.")
print(
"Warning: authentication is not supported inline. Please"
"click the link to access the interface in a new tab."
)
try:
from IPython.display import IFrame, display # type: ignore
if share:
while not networking.url_ok(share_url):
time.sleep(1)
display(IFrame(share_url, width=self.width, height=self.height))
else:
display(IFrame(path_to_local_server,
width=self.width, height=self.height))
display(
IFrame(
path_to_local_server, width=self.width, height=self.height
)
)
except ImportError:
pass
data = {
'launch_method': 'browser' if inbrowser else 'inline',
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': self.ip_address,
'enable_queue': self.enable_queue,
'show_tips': self.show_tips,
'api_mode': self.api_mode,
'server_name': server_name,
'server_port': server_port,
"launch_method": "browser" if inbrowser else "inline",
"is_google_colab": is_colab,
"is_sharing_on": share,
"share_url": share_url,
"ip_address": self.ip_address,
"enable_queue": self.enable_queue,
"show_tips": self.show_tips,
"api_mode": self.api_mode,
"server_name": server_name,
"server_port": server_port,
}
if self.analytics_enabled:
utils.launch_analytics(data)
@ -668,46 +757,36 @@ class Interface:
utils.show_tip(self)
# Block main thread if debug==True
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
self.block_thread()
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(
getattr(sys, 'ps1', sys.flags.interactive))
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
if not prevent_thread_lock and not is_in_interactive_mode:
self.block_thread()
return app, path_to_local_server, share_url
def close(
self,
verbose: bool = True
) -> None:
def close(self, verbose: bool = True) -> None:
"""
Closes the Interface that was launched and frees the port.
"""
try:
self.server.close()
if verbose:
print("Closing server running on port: {}".format(
self.server_port))
print("Closing server running on port: {}".format(self.server_port))
except (AttributeError, OSError): # can't close if not running
pass
def integrate(
self,
comet_ml=None,
wandb=None,
mlflow=None
) -> None:
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
"""
A catch-all method for integrating with other libraries.
A catch-all method for integrating with other libraries.
Should be run after launch()
Parameters:
comet_ml (Experiment): If a comet_ml Experiment object is provided,
comet_ml (Experiment): If a comet_ml Experiment object is provided,
will integrate with the experiment and appear on Comet dashboard
wandb (module): If the wandb module is provided, will integrate
wandb (module): If the wandb module is provided, will integrate
with it and appear on WandB dashboard
mlflow (module): If the mlflow module is provided, will integrate
mlflow (module): If the mlflow module is provided, will integrate
with the experiment and appear on ML Flow dashboard
"""
analytics_integration = ""
@ -723,24 +802,32 @@ class Interface:
if wandb is not None:
analytics_integration = "WandB"
if self.share_url is not None:
wandb.log({"Gradio panel": wandb.Html(
'<iframe src="' + self.share_url + '" width="' +
str(self.width) + '" height="' + str(self.height) +
'" frameBorder="0"></iframe>')})
wandb.log(
{
"Gradio panel": wandb.Html(
'<iframe src="'
+ self.share_url
+ '" width="'
+ str(self.width)
+ '" height="'
+ str(self.height)
+ '" frameBorder="0"></iframe>'
)
}
)
else:
print(
"The WandB integration requires you to "
"`launch(share=True)` first.")
"`launch(share=True)` first."
)
if mlflow is not None:
analytics_integration = "MLFlow"
if self.share_url is not None:
mlflow.log_param("Gradio Interface Share Link",
self.share_url)
mlflow.log_param("Gradio Interface Share Link", self.share_url)
else:
mlflow.log_param("Gradio Interface Local Link",
self.local_url)
mlflow.log_param("Gradio Interface Local Link", self.local_url)
if self.analytics_enabled and analytics_integration:
data = {'integration': analytics_integration}
data = {"integration": analytics_integration}
utils.integration_analytics(data)
@ -750,6 +837,8 @@ def close_all(verbose: bool = True) -> None:
def reset_all() -> None:
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
"and will be deprecated. Please use `close_all()` instead.")
warnings.warn(
"The `reset_all()` method has been renamed to `close_all()` "
"and will be deprecated. Please use `close_all()` instead."
)
close_all()

View File

@ -1,5 +1,6 @@
import copy
import math
import numpy as np
from gradio.outputs import Label, Textbox
@ -13,8 +14,10 @@ def run_interpret(interface, raw_input):
raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if isinstance(interface.interpretation, list): # Either "default" or "shap"
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)]
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
original_output = interface.run_prediction(processed_input)
scores, alternative_outputs = [], []
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
@ -22,104 +25,147 @@ def run_interpret(interface, raw_input):
input_component = interface.input_components[i]
neighbor_raw_input = list(raw_input)
if input_component.interpret_by_tokens:
tokens, neighbor_values, masks = input_component.tokenize(
x)
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(interface.input_components)]
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = interface.run_prediction(
processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(interface.output_components)]
processed_neighbor_input
)
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(
processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(
interface, original_output, neighbor_output))
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
raw_input[i],
neighbor_values,
interface_scores,
masks=masks,
tokens=tokens,
)
)
else:
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(
x)
(
neighbor_values,
interpret_kwargs,
) = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(interface.input_components)]
processed_neighbor_input = [
input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(
interface.input_components
)
]
neighbor_output = interface.run_prediction(
processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(interface.output_components)]
processed_neighbor_input
)
processed_neighbor_output = [
output_component.postprocess(neighbor_output[i])
for i, output_component in enumerate(
interface.output_components
)
]
alternative_output.append(
processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(
interface, original_output, neighbor_output))
alternative_output.append(processed_neighbor_output)
interface_scores.append(
quantify_difference_in_label(
interface, original_output, neighbor_output
)
)
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
raw_input[i],
neighbor_values,
interface_scores,
**interpret_kwargs
)
)
elif interp == "shap" or interp == "shapley":
try:
import shap # type: ignore
except (ImportError, ModuleNotFoundError):
raise ValueError(
"The package `shap` is required for this interpretation method. Try: `pip install shap`")
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
)
input_component = interface.input_components[i]
if not (input_component.interpret_by_tokens):
raise ValueError(
"Input component {} does not support `shap` interpretation".format(input_component))
"Input component {} does not support `shap` interpretation".format(
input_component
)
)
tokens, _, masks = input_component.tokenize(x)
# construct a masked version of the input
def get_masked_prediction(binary_mask):
masked_xs = input_component.get_masked_inputs(
tokens, binary_mask)
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(
processed_input)
processed_masked_input[i] = input_component.preprocess(
masked_x)
new_output = interface.run_prediction(
processed_masked_input)
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = interface.run_prediction(processed_masked_input)
pred = get_regression_or_classification_value(
interface, original_output, new_output)
interface, original_output, new_output
)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(
get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(
interface.num_shap * num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
get_masked_prediction, np.zeros((1, num_total_segments))
)
shap_values = explainer.shap_values(
np.ones((1, num_total_segments)),
nsamples=int(interface.num_shap * num_total_segments),
silent=True,
)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens
)
)
alternative_outputs.append([])
elif interp is None:
scores.append(None)
alternative_outputs.append([])
else:
raise ValueError(
"Uknown intepretation method: {}".format(interp))
raise ValueError("Uknown intepretation method: {}".format(interp))
return scores, alternative_outputs
else: # custom interpretation function
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)]
processed_input = [
input_component.preprocess(raw_input[i])
for i, input_component in enumerate(interface.input_components)
]
interpreter = interface.interpretation
if interface.capture_session and interface.session is not None:
graph, sess = interface.session
with graph.as_default(), sess.as_default():
interpretation = interpreter(*processed_input)
else:
else:
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
@ -130,7 +176,7 @@ def diff(original, perturbed):
try: # try computing numerical difference
score = float(original) - float(perturbed)
except ValueError: # otherwise, look at strict difference in label
score = int(not(original == perturbed))
score = int(not (original == perturbed))
return score
@ -155,12 +201,18 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
elif isinstance(output_component, Textbox):
score = diff(post_original_output, post_perturbed_output)
return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)
def get_regression_or_classification_value(interface, original_output, perturbed_output):
def get_regression_or_classification_value(
interface, original_output, perturbed_output
):
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
@ -176,9 +228,14 @@ def get_regression_or_classification_value(interface, original_output, perturbed
return 0
return perturbed_output[0][original_label]
else:
score = diff(perturbed_label, original_label) # Intentionally inverted order of arguments.
score = diff(
perturbed_label, original_label
) # Intentionally inverted order of arguments.
return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
raise ValueError(
"This interpretation method doesn't support the Output component: {}".format(
output_component
)
)

View File

@ -3,19 +3,21 @@ Ways to transform interfaces to produce new interfaces
"""
import gradio
class Parallel(gradio.Interface):
"""
Creates a new Interface consisting of multiple models in parallel
Parameters:
interfaces: any number of Interface objects that are to be compared in parallel
options: additional kwargs that are passed into the new Interface object to customize it
Parameters:
interfaces: any number of Interface objects that are to be compared in parallel
options: additional kwargs that are passed into the new Interface object to customize it
Returns:
(Interface): an Interface object comparing the given models
"""
def __init__(self, *interfaces, **options):
fns = []
outputs = []
for io in interfaces:
fns.extend(io.predict)
outputs.extend(io.output_components)
@ -27,27 +29,35 @@ class Parallel(gradio.Interface):
"repeat_outputs_per_model": False,
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
super().__init__(**kwargs)
self.api_mode = (
interfaces[0].api_mode,
) # TODO(abidlabs): make api_mode a per-function attribute
class Series(gradio.Interface):
"""
Creates a new Interface from multiple models in series (the output of one is fed as the input to the next)
Parameters:
interfaces: any number of Interface objects that are to be connected in series
options: additional kwargs that are passed into the new Interface object to customize it
Parameters:
interfaces: any number of Interface objects that are to be connected in series
options: additional kwargs that are passed into the new Interface object to customize it
Returns:
(Interface): an Interface object connecting the given models
"""
def __init__(self, *interfaces, **options):
fns = [io.predict for io in interfaces]
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
def connected_fn(
*data,
): # Run each function with the appropriate preprocessing and postprocessing
for idx, io in enumerate(interfaces):
# skip preprocessing for first interface since the Series interface will include it
if idx > 0 and not(io.api_mode):
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
if idx > 0 and not (io.api_mode):
data = [
input_component.preprocess(data[i])
for i, input_component in enumerate(io.input_components)
]
# run all of predictions sequentially
predictions = []
@ -56,11 +66,14 @@ class Series(gradio.Interface):
predictions.append(prediction)
data = predictions
# skip postprocessing for final interface since the Series interface will include it
if idx < len(interfaces) - 1 and not(io.api_mode):
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
if idx < len(interfaces) - 1 and not (io.api_mode):
data = [
output_component.postprocess(data[i])
for i, output_component in enumerate(io.output_components)
]
return data[0]
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
kwargs = {
@ -69,6 +82,7 @@ class Series(gradio.Interface):
"outputs": interfaces[-1].output_components,
}
kwargs.update(options)
super().__init__(**kwargs)
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
super().__init__(**kwargs)
self.api_mode = (
interfaces[0].api_mode,
) # TODO(abidlabs): make api_mode a per-function attribute

View File

@ -3,32 +3,34 @@ Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import fastapi
import http
import json
import os
import requests
import socket
import threading
import time
from typing import Optional, Tuple, TYPE_CHECKING
import urllib.parse
import urllib.request
from typing import TYPE_CHECKING, Optional, Tuple
import fastapi
import requests
import uvicorn
from gradio import queueing
from gradio.tunneling import create_tunnel
from gradio.app import app
from gradio.tunneling import create_tunnel
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
# By default, the local server will try to open on localhost, port 7860.
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv('GRADIO_SERVER_PORT', "7860"))
TRY_NUM_PORTS = int(os.getenv('GRADIO_NUM_PORTS', "100"))
LOCALHOST_NAME = os.getenv('GRADIO_SERVER_NAME', "127.0.0.1")
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
@ -47,9 +49,7 @@ class Server(uvicorn.Server):
self.thread.join()
def get_first_available_port(
initial: int,
final: int) -> int:
def get_first_available_port(initial: int, final: int) -> int:
"""
Gets the first open port in a specified range of port numbers
Parameters:
@ -81,8 +81,8 @@ def queue_thread(path_to_local_server, test_mode=False):
_, hash, input_data, task_type = next_job
queueing.start_job(hash)
response = requests.post(
path_to_local_server + "api/" + task_type + "/",
json=input_data)
path_to_local_server + "api/" + task_type + "/", json=input_data
)
if response.status_code == 200:
queueing.pass_job(hash, response.json())
else:
@ -97,9 +97,9 @@ def queue_thread(path_to_local_server, test_mode=False):
def start_server(
interface: Interface,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
interface: Interface,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]:
"""Launches a local server running the provided Interface
Parameters:
@ -107,20 +107,24 @@ def start_server(
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
"""
"""
server_name = server_name or LOCALHOST_NAME
# if port is not specified, search for first available port
if server_port is None:
if server_port is None:
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
else:
try:
s = socket.socket()
s.bind((LOCALHOST_NAME, server_port))
s.bind((LOCALHOST_NAME, server_port))
s.close()
except OSError:
raise OSError("Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(server_port))
raise OSError(
"Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(
server_port
)
)
port = server_port
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
@ -137,27 +141,28 @@ def start_server(
app.cwd = os.getcwd()
app.favicon_path = interface.favicon_path
app.tokens = {}
if app.interface.enable_queue:
if auth is not None or app.interface.encrypt:
raise ValueError("Cannot queue with encryption or authentication enabled.")
queueing.init()
app.queue_thread = threading.Thread(
target=queue_thread, args=(path_to_local_server,))
target=queue_thread, args=(path_to_local_server,)
)
app.queue_thread.start()
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
config = uvicorn.Config(app=app, port=port, host=server_name,
log_level="warning")
config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning")
server = Server(config=config)
server.run_in_thread()
return port, path_to_local_server, app, server
return port, path_to_local_server, app, server
def setup_tunnel(local_server_port: int, endpoint: str) -> str:
response = url_request(
endpoint + '/v1/tunnel-request' if endpoint is not None else GRADIO_API_SERVER)
endpoint + "/v1/tunnel-request" if endpoint is not None else GRADIO_API_SERVER
)
if response and response.code == 200:
try:
payload = json.loads(response.read().decode("utf-8"))[0]
@ -180,7 +185,7 @@ def url_request(url: str) -> Optional[http.client.HTTPResponse]:
def url_ok(url: str) -> bool:
try:
for _ in range(5):
time.sleep(.500)
time.sleep(0.500)
r = requests.head(url, timeout=3)
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
return True

View File

@ -5,18 +5,20 @@ automatically added to a registry, which allows them to be easily referenced in
"""
from __future__ import annotations
from ffmpy import FFmpeg
import json
from numbers import Number
import numpy as np
import operator
import os
import tempfile
import warnings
from numbers import Number
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import PIL
import tempfile
from types import ModuleType
from typing import Callable, Any, List, Optional, Tuple, Dict, TYPE_CHECKING
import warnings
from ffmpy import FFmpeg
from gradio import processing_utils
from gradio.component import Component
@ -38,34 +40,29 @@ class OutputComponent(Component):
def deserialize(self, x):
"""
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
"""
return x
class Textbox(OutputComponent):
'''
"""
Component creates a textbox to render output text or number.
Output type: Union[str, float, int]
Demos: hello_world, sentence_builder
'''
"""
def __init__(
self,
type: str = "auto",
label: Optional[str] = None):
'''
def __init__(self, type: str = "auto", label: Optional[str] = None):
"""
Parameters:
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value, "auto" detects return type.
label (str): component name in interface.
'''
"""
self.type = type
super().__init__(label)
def get_template_context(self):
return {
**super().get_template_context()
}
return {**super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
@ -87,30 +84,32 @@ class Textbox(OutputComponent):
elif self.type == "number":
return y
else:
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'str', 'number'")
raise ValueError(
"Unknown type: " + self.type + ". Please choose from: 'str', 'number'"
)
class Label(OutputComponent):
'''
"""
Component outputs a classification label, along with confidence scores of top categories if provided. Confidence scores are represented as a dictionary mapping labels to scores between 0 and 1.
Output type: Union[Dict[str, float], str, int, float]
Demos: image_classifier, main_note, titanic_survival
'''
"""
CONFIDENCES_KEY = "confidences"
def __init__(
self,
num_top_classes: Optional[int] = None,
type: str = "auto",
label: Optional[str] = None):
'''
self,
num_top_classes: Optional[int] = None,
type: str = "auto",
label: Optional[str] = None,
):
"""
Parameters:
num_top_classes (int): number of most confident classes to show.
type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
label (str): component name in interface.
'''
"""
self.num_top_classes = num_top_classes
self.type = type
super().__init__(label)
@ -118,43 +117,50 @@ class Label(OutputComponent):
def postprocess(self, y):
"""
Parameters:
y (Dict[str, float]): dictionary mapping label to confidence value
y (Dict[str, float]): dictionary mapping label to confidence value
Returns:
(Dict[label: str, confidences: List[Dict[label: str, confidence: number]]]): Object with key 'label' representing primary label, and key 'confidences' representing a list of label-confidence pairs
"""
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
if self.type == "label" or (
self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))
):
return {"label": str(y)}
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
sorted_pred = sorted(
y.items(),
key=operator.itemgetter(1),
reverse=True
)
elif self.type == "confidences" or (
self.type == "auto" and isinstance(y, dict)
):
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
sorted_pred = sorted_pred[: self.num_top_classes]
return {
"label": sorted_pred[0][0],
"confidences": [
{
"label": pred[0],
"confidence": pred[1]
} for pred in sorted_pred
]
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
],
}
else:
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences.")
raise ValueError(
"The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences."
)
def deserialize(self, y):
# 5 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion', '0.46'
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, int) or isinstance(y, float) or ('label' in y and not('confidences' in y.keys())))):
if self.type == "label" or (
self.type == "auto"
and (
isinstance(y, str)
or isinstance(y, int)
or isinstance(y, float)
or ("label" in y and not ("confidences" in y.keys()))
)
):
if isinstance(y, str) or isinstance(y, int) or isinstance(y, float):
return y
else:
return y['label']
return y["label"]
elif self.type == "confidences" or self.type == "auto":
if ('confidences' in y.keys()) and isinstance(y['confidences'], list):
return {k['label']: k['confidence'] for k in y['confidences']}
if ("confidences" in y.keys()) and isinstance(y["confidences"], list):
return {k["label"]: k["confidence"] for k in y["confidences"]}
else:
return y
raise ValueError("Unable to deserialize output: {}".format(y))
@ -170,7 +176,12 @@ class Label(OutputComponent):
Returns: (Union[str, Dict[str, number]]): Either a string representing the main category label, or a dictionary with category keys mapping to confidence levels.
"""
if "confidences" in data:
return json.dumps({example["label"]: example["confidence"] for example in data["confidences"]})
return json.dumps(
{
example["label"]: example["confidence"]
for example in data["confidences"]
}
)
else:
return data["label"]
@ -183,26 +194,26 @@ class Label(OutputComponent):
class Image(OutputComponent):
'''
Component displays an output image.
"""
Component displays an output image.
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
Demos: image_mod, webcam
'''
"""
def __init__(
self,
type: str = "auto",
plot: bool = False,
label: Optional[str] = None):
'''
self, type: str = "auto", plot: bool = False, label: Optional[str] = None
):
"""
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
label (str): component name in interface.
'''
"""
if plot:
warnings.warn(
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
DeprecationWarning,
)
self.type = "plot"
else:
self.type = type
@ -210,16 +221,12 @@ class Image(OutputComponent):
@classmethod
def get_shortcut_implementations(cls):
return {
"image": {},
"plot": {"type": "plot"},
"pil": {"type": "pil"}
}
return {"image": {}, "plot": {"type": "plot"}, "pil": {"type": "pil"}}
def postprocess(self, y):
"""
Parameters:
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
Returns:
(str): base64 url data
"""
@ -234,7 +241,8 @@ class Image(OutputComponent):
dtype = "plot"
else:
raise ValueError(
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'.")
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
else:
dtype = self.type
if dtype in ["numpy", "pil"]:
@ -246,8 +254,11 @@ class Image(OutputComponent):
elif dtype == "plot":
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError("Unknown type: " + dtype +
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
raise ValueError(
"Unknown type: "
+ dtype
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
return out_y
def deserialize(self, x):
@ -262,83 +273,71 @@ class Image(OutputComponent):
class Video(OutputComponent):
'''
Used for video output.
"""
Used for video output.
Output type: filepath
Demos: video_flip
'''
"""
def __init__(
self,
type: Optional[str] = None,
label: Optional[str] = None):
'''
def __init__(self, type: Optional[str] = None, label: Optional[str] = None):
"""
Parameters:
type (str): Type of video format to be passed to component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep returned format.
label (str): component name in interface.
'''
"""
self.type = type
super().__init__(label)
@classmethod
def get_shortcut_implementations(cls):
return {
"video": {},
"playable_video": {"type": "mp4"}
}
return {"video": {}, "playable_video": {"type": "mp4"}}
def postprocess(self, y):
"""
Parameters:
y (str): path to video
y (str): path to video
Returns:
(str): base64 url data
"""
returned_format = y.split(".")[-1].lower()
if self.type is not None and returned_format != self.type:
output_file_name = y[0: y.rindex(
".") + 1] + self.type
ff = FFmpeg(
inputs={y: None},
outputs={output_file_name: None}
)
output_file_name = y[0 : y.rindex(".") + 1] + self.type
ff = FFmpeg(inputs={y: None}, outputs={output_file_name: None})
ff.run()
y = output_file_name
return {
"name": os.path.basename(y),
"data": processing_utils.encode_file_to_base64(y)
"data": processing_utils.encode_file_to_base64(y),
}
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
return self.save_flagged_file(dir, label, data['data'], encryption_key)
return self.save_flagged_file(dir, label, data["data"], encryption_key)
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)
class KeyValues(OutputComponent):
'''
Component displays a table representing values for multiple fields.
"""
Component displays a table representing values for multiple fields.
Output type: Union[Dict, List[Tuple[str, Union[str, int, float]]]]
Demos: text_analysis
'''
"""
def __init__(
self,
label: Optional[str] = None):
'''
def __init__(self, label: Optional[str] = None):
"""
Parameters:
label (str): component name in interface.
'''
"""
super().__init__(label)
def postprocess(self, y):
"""
Parameters:
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
Returns:
(List[Tuple[str, Union[str, number]]]): list of key value pairs
"""
@ -347,8 +346,10 @@ class KeyValues(OutputComponent):
elif isinstance(y, list):
return y
else:
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
"labels and values are corresponding values.")
raise ValueError(
"The `KeyValues` output interface expects an output that is a dictionary whose keys are "
"labels and values are corresponding values."
)
@classmethod
def get_shortcut_implementations(cls):
@ -364,24 +365,25 @@ class KeyValues(OutputComponent):
class HighlightedText(OutputComponent):
'''
"""
Component creates text that contains spans that are highlighted by category or numerical value.
Output is represent as a list of Tuple pairs, where the first element represents the span of text represented by the tuple, and the second element represents the category or value of the text.
Output type: List[Tuple[str, Union[float, str]]]
Demos: diff_texts, text_analysis
'''
"""
def __init__(
self,
color_map: Dict[str, str] = None,
self,
color_map: Dict[str, str] = None,
label: Optional[str] = None,
show_legend: bool = False):
'''
show_legend: bool = False,
):
"""
Parameters:
color_map (Dict[str, str]): Map between category and respective colors
label (str): component name in interface.
show_legend (bool): whether to show span categories in a separate legend or inline.
'''
"""
self.color_map = color_map
self.show_legend = show_legend
super().__init__(label)
@ -390,7 +392,7 @@ class HighlightedText(OutputComponent):
return {
"color_map": self.color_map,
"show_legend": self.show_legend,
**super().get_template_context()
**super().get_template_context(),
}
@classmethod
@ -402,7 +404,7 @@ class HighlightedText(OutputComponent):
def postprocess(self, y):
"""
Parameters:
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
Returns:
(List[Tuple[str, Union[str, number]]]): list of key value pairs
@ -417,28 +419,23 @@ class HighlightedText(OutputComponent):
class Audio(OutputComponent):
'''
"""
Creates an audio player that plays the output audio.
Output type: Union[Tuple[int, numpy.array], str]
Demos: generate_tone, reverse_audio
'''
"""
def __init__(
self,
type: str = "auto",
label: Optional[str] = None):
'''
def __init__(self, type: str = "auto", label: Optional[str] = None):
"""
Parameters:
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
label (str): component name in interface.
'''
"""
self.type = type
super().__init__(label)
def get_template_context(self):
return {
**super().get_template_context()
}
return {**super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
@ -457,13 +454,15 @@ class Audio(OutputComponent):
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(
prefix="sample", suffix=".wav", delete=False)
prefix="sample", suffix=".wav", delete=False
)
processing_utils.audio_to_file(sample_rate, data, file.name)
y = file.name
return processing_utils.encode_url_or_file_to_base64(y)
else:
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'numpy', 'file'.")
raise ValueError(
"Unknown type: " + self.type + ". Please choose from: 'numpy', 'file'."
)
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
@ -476,19 +475,17 @@ class Audio(OutputComponent):
class JSON(OutputComponent):
'''
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
"""
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
Output type: Union[str, Any]
Demos: zip_to_json
'''
"""
def __init__(
self,
label: Optional[str] = None):
'''
def __init__(self, label: Optional[str] = None):
"""
Parameters:
label (str): component name in interface.
'''
"""
super().__init__(label)
def postprocess(self, y):
@ -517,19 +514,17 @@ class JSON(OutputComponent):
class HTML(OutputComponent):
'''
Used for HTML output. Expects an HTML valid string.
"""
Used for HTML output. Expects an HTML valid string.
Output type: str
Demos: text_analysis
'''
"""
def __init__(
self,
label: Optional[str] = None):
'''
def __init__(self, label: Optional[str] = None):
"""
Parameters:
label (str): component name in interface.
'''
"""
super().__init__(label)
def postprocess(self, x):
@ -549,19 +544,17 @@ class HTML(OutputComponent):
class File(OutputComponent):
'''
Used for file output.
"""
Used for file output.
Output type: Union[file-like, str]
Demos: zip_two_files
'''
"""
def __init__(
self,
label: Optional[str] = None):
'''
def __init__(self, label: Optional[str] = None):
"""
Parameters:
label (str): component name in interface.
'''
"""
super().__init__(label)
@classmethod
@ -575,12 +568,12 @@ class File(OutputComponent):
Parameters:
y (str): file path
Returns:
(Dict[name: str, size: number, data: str]): JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
(Dict[name: str, size: number, data: str]): JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
"""
return {
"name": os.path.basename(y),
"size": os.path.getsize(y),
"data": processing_utils.encode_file_to_base64(y)
"data": processing_utils.encode_file_to_base64(y),
}
def save_flagged(self, dir, label, data, encryption_key):
@ -598,22 +591,23 @@ class Dataframe(OutputComponent):
"""
def __init__(
self,
headers: Optional[List[str]] = None,
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
type: str = "auto",
label: Optional[str] = None):
'''
self,
headers: Optional[List[str]] = None,
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
type: str = "auto",
label: Optional[str] = None,
):
"""
Parameters:
headers (List[str]): Header names to dataframe. Only applicable if type is "numpy" or "array".
max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
max_cols (int): Maximum number of columns to display at once. Set to None for infinite.
overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
label (str): component name in interface.
'''
"""
self.headers = headers
self.max_rows = max_rows
self.max_cols = max_cols
@ -627,7 +621,7 @@ class Dataframe(OutputComponent):
"max_rows": self.max_rows,
"max_cols": self.max_cols,
"overflow_row_behaviour": self.overflow_row_behaviour,
**super().get_template_context()
**super().get_template_context(),
}
@classmethod
@ -664,8 +658,11 @@ class Dataframe(OutputComponent):
y = [y]
return {"data": y}
else:
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'pandas', 'numpy', 'array'.")
raise ValueError(
"Unknown type: "
+ self.type
+ ". Please choose from: 'pandas', 'numpy', 'array'."
)
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps(data["data"])
@ -682,24 +679,26 @@ class Carousel(OutputComponent):
"""
def __init__(
self,
components: OutputComponent | List[OutputComponent],
label: Optional[str] = None):
'''
self,
components: OutputComponent | List[OutputComponent],
label: Optional[str] = None,
):
"""
Parameters:
components (Union[List[OutputComponent], OutputComponent]): Classes of component(s) that will be scrolled through.
label (str): component name in interface.
'''
"""
if not isinstance(components, list):
components = [components]
self.components = [get_output_instance(
component) for component in components]
self.components = [get_output_instance(component) for component in components]
super().__init__(label)
def get_template_context(self):
return {
"components": [component.get_template_context() for component in self.components],
**super().get_template_context()
"components": [
component.get_template_context() for component in self.components
],
**super().get_template_context(),
}
def postprocess(self, y):
@ -720,23 +719,29 @@ class Carousel(OutputComponent):
output.append(output_row)
return output
else:
raise ValueError(
"Unknown type. Please provide a list for the Carousel.")
raise ValueError("Unknown type. Please provide a list for the Carousel.")
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps([
return json.dumps(
[
component.save_flagged(
dir, f"{label}_{j}", data[i][j], encryption_key)
for j, component in enumerate(self.components)
] for i, _ in enumerate(data)])
[
component.save_flagged(
dir, f"{label}_{j}", data[i][j], encryption_key
)
for j, component in enumerate(self.components)
]
for i, _ in enumerate(data)
]
)
def restore_flagged(self, dir, data, encryption_key):
return [
[
component.restore_flagged(dir, sample, encryption_key)
for component, sample in zip(self.components, sample_set)
] for sample_set in json.loads(data)]
]
for sample_set in json.loads(data)
]
class Timeseries(OutputComponent):
@ -747,10 +752,8 @@ class Timeseries(OutputComponent):
"""
def __init__(
self,
x: str = None,
y: str | List[str] = None,
label: Optional[str] = None):
self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
):
"""
Parameters:
x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
@ -764,11 +767,7 @@ class Timeseries(OutputComponent):
super().__init__(label)
def get_template_context(self):
return {
"x": self.x,
"y": self.y,
**super().get_template_context()
}
return {"x": self.x, "y": self.y, **super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
@ -783,11 +782,7 @@ class Timeseries(OutputComponent):
Returns:
(Dict[headers: List[str], data: List[List[Union[str, number]]]]): JSON object with key 'headers' for list of header names, 'data' for 2D array of string or numeric data
"""
return {
"headers": y.columns.values.tolist(),
"data": y.values.tolist()
}
return {"headers": y.columns.values.tolist(), "data": y.values.tolist()}
def save_flagged(self, dir, label, data, encryption_key):
"""
@ -803,13 +798,12 @@ class State(OutputComponent):
"""
Special hidden component that stores state across runs of the interface.
"""
def __init__(
self,
label: Optional[str] = None):
def __init__(self, label: Optional[str] = None):
"""
Parameters:
label (str): component name in interface (not used).
"""
"""
super().__init__(label)
@classmethod
@ -825,7 +819,7 @@ def get_output_instance(iface: Interface):
return shortcut[0](**shortcut[1])
# a dict with `name` as the output component type and other keys as parameters
elif isinstance(iface, dict):
name = iface.pop('name')
name = iface.pop("name")
for component in OutputComponent.__subclasses__():
if component.__name__.lower() == name:
break

View File

@ -1,22 +1,33 @@
import os, shutil
from gradio.flagging import CSVLogger
from typing import Any, List
import csv
import os
import shutil
from typing import Any, List
from gradio.flagging import CSVLogger
CACHED_FOLDER = "gradio_cached_examples"
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
def process_example(interface, example_id: int):
example_set = interface.examples[example_id]
raw_input = [interface.input_components[i].preprocess_example(example) for i, example in enumerate(example_set)]
raw_input = [
interface.input_components[i].preprocess_example(example)
for i, example in enumerate(example_set)
]
prediction, durations = interface.process(raw_input)
return prediction, durations
def cache_interface_examples(interface) -> None:
if os.path.exists(CACHE_FILE):
print(f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache.")
print(
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
)
else:
print(f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory.")
print(
f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory."
)
cache_logger = CSVLogger()
cache_logger.setup(CACHED_FOLDER)
for example_id, _ in enumerate(interface.examples):
@ -27,12 +38,18 @@ def cache_interface_examples(interface) -> None:
shutil.rmtree(CACHED_FOLDER)
raise e
def load_from_cache(interface, example_id: int) -> List[Any]:
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(interface.output_components, example):
output.append(component.restore_flagged(
CACHED_FOLDER, cell, interface.encryption_key if interface.encrypt else None))
output.append(
component.restore_flagged(
CACHED_FOLDER,
cell,
interface.encryption_key if interface.encrypt else None,
)
)
return output

View File

@ -1,24 +1,27 @@
from PIL import Image, ImageOps
from io import BytesIO
import base64
import requests
import tempfile
import shutil
import os
import numpy as np
from gradio import encryptor
import warnings
import mimetypes
import os
import shutil
import tempfile
import warnings
from io import BytesIO
import numpy as np
import requests
from PIL import Image, ImageOps
from gradio import encryptor
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
from pydub import AudioSegment
#########################
# IMAGE PRE-PROCESSING
#########################
def decode_base64_to_image(encoding):
content = encoding.split(';')[1]
image_encoded = content.split(',')[1]
content = encoding.split(";")[1]
image_encoded = content.split(",")[1]
return Image.open(BytesIO(base64.b64decode(image_encoded)))
@ -26,7 +29,7 @@ def get_url_or_file_as_bytes(path):
try:
return requests.get(path).content
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
with open(path, "rb") as f:
with open(path, "rb") as f:
return f.read()
@ -37,14 +40,14 @@ def encode_url_or_file_to_base64(path):
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return encode_file_to_base64(path)
def get_mimetype(filename):
mimetype = mimetypes.guess_type(filename)[0]
if mimetype is not None:
mimetype = mimetype.replace(
"x-wav", "wav").replace(
"x-flac", "flac")
mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
return mimetype
def get_extension(encoding):
encoding = encoding.replace("audio/wav", "audio/x-wav")
type = mimetypes.guess_type(encoding)[0]
@ -55,40 +58,49 @@ def get_extension(encoding):
extension = extension[1:]
return extension
def encode_file_to_base64(f, encryption_key=None):
with open(f, "rb") as file:
encoded_string = base64.b64encode(file.read())
if encryption_key:
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
base64_str = str(encoded_string, 'utf-8')
base64_str = str(encoded_string, "utf-8")
mimetype = get_mimetype(f)
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
return (
"data:"
+ (mimetype if mimetype is not None else "")
+ ";base64,"
+ base64_str
)
def encode_url_to_base64(url):
encoded_string = base64.b64encode(requests.get(url).content)
base64_str = str(encoded_string, 'utf-8')
base64_str = str(encoded_string, "utf-8")
mimetype = get_mimetype(url)
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
return (
"data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
)
def encode_plot_to_base64(plt):
with BytesIO() as output_bytes:
plt.savefig(output_bytes, format="png")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
def encode_array_to_base64(image_array):
with BytesIO() as output_bytes:
PIL_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
PIL_image.save(output_bytes, 'PNG')
PIL_image.save(output_bytes, "PNG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
def resize_and_crop(img, size, crop_type='center'):
def resize_and_crop(img, size, crop_type="center"):
"""
Resize and crop an image to fit the specified size.
args:
@ -105,32 +117,36 @@ def resize_and_crop(img, size, crop_type='center'):
center = (0.5, 0.5)
else:
raise ValueError
return ImageOps.fit(img, size, centering=center)
return ImageOps.fit(img, size, centering=center)
##################
# Audio
##################
def audio_from_file(filename, crop_min=0, crop_max=100):
audio = AudioSegment.from_file(filename)
if crop_min != 0 or crop_max != 100:
audio_start = len(audio) * crop_min / 100
audio_end = len(audio) * crop_max / 100
audio = audio[audio_start : audio_end]
audio = audio[audio_start:audio_end]
data = np.array(audio.get_array_of_samples())
if (audio.channels > 1):
if audio.channels > 1:
data = data.reshape(-1, audio.channels)
return audio.frame_rate, data
def audio_to_file(sample_rate, data, filename):
audio = AudioSegment(
data.tobytes(),
data.tobytes(),
frame_rate=sample_rate,
sample_width=data.dtype.itemsize,
channels=(1 if len(data.shape) == 1 else data.shape[1])
sample_width=data.dtype.itemsize,
channels=(1 if len(data.shape) == 1 else data.shape[1]),
)
audio.export(filename, format="wav")
##################
# OUTPUT
##################
@ -139,7 +155,8 @@ def audio_to_file(sample_rate, data, filename):
def decode_base64_to_binary(encoding):
extension = get_extension(encoding)
data = encoding.split(",")[1]
return base64.b64decode(data), extension
return base64.b64decode(data), extension
def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
data, extension = decode_base64_to_binary(encoding)
@ -148,35 +165,41 @@ def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
filename = os.path.basename(file_path)
prefix = filename
if "." in filename:
prefix = filename[0: filename.index(".")]
extension = filename[filename.index(".") + 1:]
prefix = filename[0 : filename.index(".")]
extension = filename[filename.index(".") + 1 :]
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
else:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, suffix="."+extension)
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, suffix="." + extension
)
if encryption_key is not None:
data = encryptor.encrypt(encryption_key, data)
file_obj.write(data)
file_obj.flush()
return file_obj
def create_tmp_copy_of_file(file_path):
file_name = os.path.basename(file_path)
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0: file_name.index(".")]
extension = file_name[file_name.index(".") + 1:]
prefix = file_name[0 : file_name.index(".")]
extension = file_name[file_name.index(".") + 1 :]
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
else:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, suffix="."+extension)
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, suffix="." + extension
)
shutil.copy2(file_path, file_obj.name)
return file_obj
def _convert(image, dtype, force_copy=False, uniform=False):
"""
Adapted from: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/dtype.py#L510-L531
Convert an image to the requested data-type.
Warnings are issued in case of precision loss, or when negative values
are clipped during conversion to unsigned integer types (sign loss).
@ -216,14 +239,16 @@ def _convert(image, dtype, force_copy=False, uniform=False):
.. [4] Dirty Pixels. J. Blinn. In "Jim Blinn's corner: Dirty Pixels",
pp 47-57. Morgan Kaufmann, 1998.
"""
dtype_range = {bool: (False, True),
np.bool_: (False, True),
np.bool8: (False, True),
float: (-1, 1),
np.float_: (-1, 1),
np.float16: (-1, 1),
np.float32: (-1, 1),
np.float64: (-1, 1)}
dtype_range = {
bool: (False, True),
np.bool_: (False, True),
np.bool8: (False, True),
float: (-1, 1),
np.float_: (-1, 1),
np.float16: (-1, 1),
np.float32: (-1, 1),
np.float64: (-1, 1),
}
def _dtype_itemsize(itemsize, *dtypes):
"""Return first of `dtypes` with itemsize greater than `itemsize`
@ -258,12 +283,14 @@ def _convert(image, dtype, force_copy=False, uniform=False):
Data type of `kind` that can store a `bits` wide unsigned int
"""
s = next(i for i in (itemsize, ) + (2, 4, 8) if
bits < (i * 8) or (bits == (i * 8) and kind == 'u'))
s = next(
i
for i in (itemsize,) + (2, 4, 8)
if bits < (i * 8) or (bits == (i * 8) and kind == "u")
)
return np.dtype(kind + str(s))
def _scale(a, n, m, copy=True):
"""Scale an array of unsigned/positive integers from `n` to `m` bits.
Numbers can be represented exactly only if `m` is a multiple of `n`.
@ -298,21 +325,20 @@ def _convert(image, dtype, force_copy=False, uniform=False):
# downscale with precision loss
if copy:
b = np.empty(a.shape, _dtype_bits(kind, m))
np.floor_divide(a, 2**(n - m), out=b, dtype=a.dtype,
casting='unsafe')
np.floor_divide(a, 2 ** (n - m), out=b, dtype=a.dtype, casting="unsafe")
return b
else:
a //= 2**(n - m)
a //= 2 ** (n - m)
return a
elif m % n == 0:
# exact upscale to a multiple of `n` bits
if copy:
b = np.empty(a.shape, _dtype_bits(kind, m))
np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
np.multiply(a, (2 ** m - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
return b
else:
a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
a *= (2**m - 1) // (2**n - 1)
a *= (2 ** m - 1) // (2 ** n - 1)
return a
else:
# upscale to a multiple of `n` bits,
@ -320,19 +346,19 @@ def _convert(image, dtype, force_copy=False, uniform=False):
o = (m // n + 1) * n
if copy:
b = np.empty(a.shape, _dtype_bits(kind, o))
np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
b //= 2**(o - m)
np.multiply(a, (2 ** o - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
b //= 2 ** (o - m)
return b
else:
a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
a *= (2**o - 1) // (2**n - 1)
a //= 2**(o - m)
a *= (2 ** o - 1) // (2 ** n - 1)
a //= 2 ** (o - m)
return a
image = np.asarray(image)
dtypeobj_in = image.dtype
if dtype is np.floating:
dtypeobj_out = np.dtype('float64')
dtypeobj_out = np.dtype("float64")
else:
dtypeobj_out = np.dtype(dtype)
dtype_in = dtypeobj_in.type
@ -356,28 +382,27 @@ def _convert(image, dtype, force_copy=False, uniform=False):
image = image.copy()
return image
if kind_in in 'ui':
if kind_in in "ui":
imin_in = np.iinfo(dtype_in).min
imax_in = np.iinfo(dtype_in).max
if kind_out in 'ui':
if kind_out in "ui":
imin_out = np.iinfo(dtype_out).min
imax_out = np.iinfo(dtype_out).max
# any -> binary
if kind_out == 'b':
if kind_out == "b":
return image > dtype_in(dtype_range[dtype_in][1] / 2)
# binary -> any
if kind_in == 'b':
if kind_in == "b":
result = image.astype(dtype_out)
if kind_out != 'f':
if kind_out != "f":
result *= dtype_out(dtype_range[dtype_out][1])
return result
# float -> any
if kind_in == 'f':
if kind_out == 'f':
if kind_in == "f":
if kind_out == "f":
# float -> float
return image.astype(dtype_out)
@ -385,41 +410,42 @@ def _convert(image, dtype, force_copy=False, uniform=False):
raise ValueError("Images of type float must be between -1 and 1.")
# floating point -> integer
# use float type that can represent output integer type
computation_type = _dtype_itemsize(itemsize_out, dtype_in,
np.float32, np.float64)
computation_type = _dtype_itemsize(
itemsize_out, dtype_in, np.float32, np.float64
)
if not uniform:
if kind_out == 'u':
image_out = np.multiply(image, imax_out,
dtype=computation_type)
if kind_out == "u":
image_out = np.multiply(image, imax_out, dtype=computation_type)
else:
image_out = np.multiply(image, (imax_out - imin_out) / 2,
dtype=computation_type)
image_out -= 1.0 / 2.
image_out = np.multiply(
image, (imax_out - imin_out) / 2, dtype=computation_type
)
image_out -= 1.0 / 2.0
np.rint(image_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out)
elif kind_out == 'u':
image_out = np.multiply(image, imax_out + 1,
dtype=computation_type)
elif kind_out == "u":
image_out = np.multiply(image, imax_out + 1, dtype=computation_type)
np.clip(image_out, 0, imax_out, out=image_out)
else:
image_out = np.multiply(image, (imax_out - imin_out + 1.0) / 2.0,
dtype=computation_type)
image_out = np.multiply(
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type
)
np.floor(image_out, out=image_out)
np.clip(image_out, imin_out, imax_out, out=image_out)
return image_out.astype(dtype_out)
# signed/unsigned int -> float
if kind_out == 'f':
if kind_out == "f":
# use float type that can exactly represent input integers
computation_type = _dtype_itemsize(itemsize_in, dtype_out,
np.float32, np.float64)
computation_type = _dtype_itemsize(
itemsize_in, dtype_out, np.float32, np.float64
)
if kind_in == 'u':
if kind_in == "u":
# using np.divide or np.multiply doesn't copy the data
# until the computation time
image = np.multiply(image, 1. / imax_in,
dtype=computation_type)
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type)
# DirectX uses this conversion also for signed ints
# if imin_in:
# np.maximum(image, -1.0, out=image)
@ -430,8 +456,8 @@ def _convert(image, dtype, force_copy=False, uniform=False):
return np.asarray(image, dtype_out)
# unsigned int -> signed/unsigned int
if kind_in == 'u':
if kind_out == 'i':
if kind_in == "u":
if kind_out == "i":
# unsigned int -> signed int
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out - 1)
return image.view(dtype_out)
@ -440,17 +466,17 @@ def _convert(image, dtype, force_copy=False, uniform=False):
return _scale(image, 8 * itemsize_in, 8 * itemsize_out)
# signed int -> unsigned int
if kind_out == 'u':
if kind_out == "u":
image = _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out)
result = np.empty(image.shape, dtype_out)
np.maximum(image, 0, out=result, dtype=image.dtype, casting='unsafe')
np.maximum(image, 0, out=result, dtype=image.dtype, casting="unsafe")
return result
# signed int -> signed int
if itemsize_in > itemsize_out:
return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
image = image.astype(_dtype_bits('i', itemsize_out * 8))
image = image.astype(_dtype_bits("i", itemsize_out * 8))
image -= imin_in
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
image += imin_out

View File

@ -1,66 +1,82 @@
import sqlite3
import os
import uuid
import json
import os
import sqlite3
import uuid
DB_FILE = "gradio_queue.db"
def generate_hash():
generate = True
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
while generate:
hash = uuid.uuid4().hex
c.execute("""
c.execute(
"""
SELECT hash FROM queue
WHERE hash = ?;
""", (hash,))
""",
(hash,),
)
generate = c.fetchone() is not None
conn.commit()
return hash
def init():
if os.path.exists(DB_FILE):
os.remove(DB_FILE)
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("""CREATE TABLE queue (
c.execute(
"""CREATE TABLE queue (
queue_index integer PRIMARY KEY,
hash text,
input_data text,
action text,
popped integer DEFAULT 0
);""")
c.execute("""
);"""
)
c.execute(
"""
CREATE TABLE jobs (
hash text PRIMARY KEY,
status text,
output_data text,
error_message text
);
""")
"""
)
conn.commit()
def close():
if os.path.exists(DB_FILE):
os.remove(DB_FILE)
def pop():
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
c.execute(
"""
SELECT queue_index, hash, input_data, action FROM queue
WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1;
""")
"""
)
result = c.fetchone()
if result is None:
conn.commit()
return None
queue_index = result[0]
c.execute("""
c.execute(
"""
UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?;
""", (queue_index,))
""",
(queue_index,),
)
conn.commit()
return result[0], result[1], json.loads(result[2]), result[3]
@ -70,42 +86,57 @@ def push(input_data, action):
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("""
c.execute(
"""
INSERT INTO queue (hash, input_data, action)
VALUES (?, ?, ?);
""", (hash, input_data, action))
""",
(hash, input_data, action),
)
queue_index = c.lastrowid
c.execute("""
c.execute(
"""
SELECT COUNT(*) FROM queue WHERE queue_index < ? and popped = 0;
""", (queue_index,))
""",
(queue_index,),
)
queue_position = c.fetchone()[0]
if queue_position is None:
conn.commit()
raise ValueError("Hash not found.")
elif queue_position == 0:
c.execute("""
c.execute(
"""
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
""")
"""
)
result = c.fetchone()
if result[0] == 0:
queue_position -= 1
conn.commit()
return hash, queue_position
def get_status(hash):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("""
c.execute(
"""
SELECT queue_index, popped FROM queue WHERE hash = ?;
""", (hash,))
""",
(hash,),
)
result = c.fetchone()
if result is None:
conn.commit()
raise ValueError("Hash not found.")
if result[1] == 1: # in jobs
c.execute("""
if result[1] == 1: # in jobs
c.execute(
"""
SELECT status, output_data, error_message FROM jobs WHERE hash = ?;
""", (hash,))
""",
(hash,),
)
result = c.fetchone()
if result is None:
conn.commit()
@ -119,55 +150,83 @@ def get_status(hash):
conn.commit()
return "FAILED", error_message
elif status == "COMPLETE":
c.execute("""
c.execute(
"""
UPDATE jobs SET output_data = '' WHERE hash = ?;
""", (hash,))
""",
(hash,),
)
conn.commit()
output_data = json.loads(output_data)
return "COMPLETE", output_data
else: # in queue
else: # in queue
queue_index = result[0]
c.execute("""
c.execute(
"""
SELECT COUNT(*) FROM queue WHERE queue_index < ? and popped = 0;
""", (queue_index,))
""",
(queue_index,),
)
result = c.fetchone()
queue_position = result[0]
if queue_position == 0:
c.execute("""
c.execute(
"""
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
""")
"""
)
result = c.fetchone()
if result[0] == 0:
queue_position -= 1
conn.commit()
return "QUEUED", queue_position
def start_job(hash):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
c.execute(
"""
UPDATE queue SET popped = 1 WHERE hash = ?;
""", (hash,))
c.execute("""
""",
(hash,),
)
c.execute(
"""
INSERT INTO jobs (hash, status) VALUES (?, 'PENDING');
""", (hash,))
""",
(hash,),
)
conn.commit()
def fail_job(hash, error_message):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("""
c.execute(
"""
UPDATE jobs SET status = 'FAILED', error_message = ? WHERE hash = ?;
""", (error_message, hash,))
""",
(
error_message,
hash,
),
)
conn.commit()
def pass_job(hash, output_data):
output_data = json.dumps(output_data)
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("""
c.execute(
"""
UPDATE jobs SET status = 'COMPLETE', output_data = ? WHERE hash = ?;
""", (output_data, hash,))
""",
(
output_data,
hash,
),
)
conn.commit()

View File

@ -1,10 +1,10 @@
import requests
import json
import requests
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
en = {
en = {
"RUNNING_LOCALLY": "Running on local URL: {}",
"SHARE_LINK_DISPLAY": "Running on public URL: {}",
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
@ -12,11 +12,11 @@ en = {
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
" avoid the 'Tensor is not an element of this graph.' error.",
" avoid the 'Tensor is not an element of this graph.' error.",
"BETA_INVITE": "\nWe want to invite you to become a beta user.\nYou'll get early access to new and premium "
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
"To turn off, set debug=False in launch().",
"To turn off, set debug=False in launch().",
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
"SHARE_LINK_MESSAGE": "\nThis share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)",
"PRIVATE_LINK_MESSAGE": "Since this is a private endpoint, this share link will never expire.",
@ -27,12 +27,16 @@ en = {
"Let users specify why they flagged input with the flagging_options= kwarg; for example: gr.Interface(..., flagging_options=['too slow', 'incorrect output', 'other'])",
"You can show or hide the buttons for flagging, screenshots, and interpretation with the allow_*= kwargs; for example: gr.Interface(..., allow_screenshot=True, allow_flagging=False)",
"The inputs and outputs flagged by the users are stored in the flagging directory, specified by the flagging_dir= kwarg. You can view this data through the interface by setting the examples= kwarg to the flagging directory; for example gr.Interface(..., examples='flagged')",
"You can add a title and description to your interface using the title= and description= kwargs. The article= kwarg can be used to add markdown or HTML under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum')"
]
"You can add a title and description to your interface using the title= and description= kwargs. The article= kwarg can be used to add markdown or HTML under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum')",
],
}
try:
updated_messaging = requests.get(MESSAGING_API_ENDPOINT, timeout=3).json()
en.update(updated_messaging)
except (requests.ConnectionError, requests.exceptions.ReadTimeout, json.decoder.JSONDecodeError): # Use default messaging
except (
requests.ConnectionError,
requests.exceptions.ReadTimeout,
json.decoder.JSONDecodeError,
): # Use default messaging
pass

File diff suppressed because one or more lines are too long

View File

@ -7,8 +7,9 @@ import select
import socket
import sys
import threading
from io import StringIO
import warnings
from io import StringIO
import paramiko
@ -21,9 +22,9 @@ def handler(chan, host, port):
return
verbose(
"Connected! Tunnel open {} -> {} -> {}".format(chan.origin_addr,
chan.getpeername(),
(host, port))
"Connected! Tunnel open {} -> {} -> {}".format(
chan.origin_addr, chan.getpeername(), (host, port)
)
)
while True:
r, w, x = select.select([sock, chan], [], [])
@ -39,7 +40,11 @@ def handler(chan, host, port):
sock.send(data)
chan.close()
sock.close()
verbose("Tunnel closed from {}".format(chan.origin_addr,))
verbose(
"Tunnel closed from {}".format(
chan.origin_addr,
)
)
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
@ -64,8 +69,7 @@ def create_tunnel(payload, local_server, local_server_port):
client.set_missing_host_key_policy(paramiko.WarningPolicy())
verbose(
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload[
"port"]))
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload["port"]))
)
try:
with warnings.catch_warnings():
@ -78,16 +82,16 @@ def create_tunnel(payload, local_server, local_server_port):
)
except Exception as e:
print(
"*** Failed to connect to {}:{}: {}}".format(payload["host"],
int(payload["port"]), e)
"*** Failed to connect to {}:{}: {}}".format(
payload["host"], int(payload["port"]), e
)
)
sys.exit(1)
verbose(
"Now forwarding remote port {} to {}:{} ...".format(int(payload[
"remote_port"]),
local_server,
local_server_port)
"Now forwarding remote port {} to {}:{} ...".format(
int(payload["remote_port"]), local_server, local_server_port
)
)
thread = threading.Thread(

View File

@ -1,19 +1,21 @@
""" Handy utility functions."""
from __future__ import annotations
import aiohttp
import analytics
import csv
from distutils.version import StrictVersion
import inspect
import json
import json.decoder
import os
import pkg_resources
import random
import requests
from typing import Callable, Any, Dict, TYPE_CHECKING
import warnings
from distutils.version import StrictVersion
from typing import TYPE_CHECKING, Any, Callable, Dict
import aiohttp
import analytics
import pkg_resources
import requests
import gradio
@ -21,7 +23,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Interface
analytics_url = 'https://api.gradio.app/'
analytics_url = "https://api.gradio.app/"
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
@ -32,13 +34,18 @@ def version_check():
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
print(
"IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version
)
)
print("--------")
except pkg_resources.DistributionNotFound:
warnings.warn("gradio is not setup or installed properly. Unable to get version info.")
warnings.warn(
"gradio is not setup or installed properly. Unable to get version info."
)
except json.decoder.JSONDecodeError:
warnings.warn("unable to parse version details from package URL.")
except KeyError:
@ -49,34 +56,36 @@ def version_check():
def get_local_ip_address() -> str:
try:
ip_address = requests.get('https://api.ipify.org', timeout=3).text
ip_address = requests.get("https://api.ipify.org", timeout=3).text
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
ip_address = "No internet connection"
return ip_address
def initiated_analytics(data: Dict[str: Any]) -> None:
def initiated_analytics(data: Dict[str:Any]) -> None:
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data, timeout=3)
requests.post(
analytics_url + "gradio-initiated-analytics/", data=data, timeout=3
)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def launch_analytics(data: Dict[str, Any]) -> None:
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data, timeout=3)
requests.post(
analytics_url + "gradio-launched-analytics/", data=data, timeout=3
)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def integration_analytics(data: Dict[str, Any]) -> None:
try:
requests.post(analytics_url + 'gradio-integration-analytics/',
data=data, timeout=3)
except (
requests.ConnectionError, requests.exceptions.ReadTimeout):
requests.post(
analytics_url + "gradio-integration-analytics/", data=data, timeout=3
)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
@ -85,23 +94,23 @@ def error_analytics(ip_address: str, message: str) -> None:
Send error analytics if there is network
:param type: RuntimeError or NameError
"""
data = {'ip_address': ip_address, 'error': message}
data = {"ip_address": ip_address, "error": message}
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data, timeout=3)
requests.post(analytics_url + "gradio-error-analytics/", data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
async def log_feature_analytics(ip_address: str, feature: str) -> None:
data={'ip_address': ip_address, 'feature': feature}
async with aiohttp.ClientSession() as session:
try:
async with session.post(
analytics_url + 'gradio-feature-analytics/', data=data):
pass
except (aiohttp.ClientError):
pass # do not push analytics if no network
data = {"ip_address": ip_address, "feature": feature}
async with aiohttp.ClientSession() as session:
try:
async with session.post(
analytics_url + "gradio-feature-analytics/", data=data
):
pass
except (aiohttp.ClientError):
pass # do not push analytics if no network
def colab_check() -> bool:
@ -112,6 +121,7 @@ def colab_check() -> bool:
is_colab = False
try: # Check if running interactively using ipython.
from IPython import get_ipython
from_ipynb = get_ipython()
if "google.colab" in str(from_ipynb):
is_colab = True
@ -128,6 +138,7 @@ def ipython_check() -> bool:
is_ipython = False
try: # Check if running interactively using ipython.
from IPython import get_ipython
if get_ipython() is not None:
is_ipython = True
except (ImportError, NameError):
@ -138,7 +149,7 @@ def ipython_check() -> bool:
def readme_to_html(article: str) -> str:
try:
response = requests.get(article, timeout=3)
if response.status_code == requests.codes.ok: #pylint: disable=no-member
if response.status_code == requests.codes.ok: # pylint: disable=no-member
article = response.text
except requests.exceptions.RequestException:
pass
@ -172,11 +183,11 @@ def launch_counter() -> None:
def get_config_file(interface: Interface) -> Dict[str, Any]:
config = {
"input_components": [
iface.get_template_context()
for iface in interface.input_components],
iface.get_template_context() for iface in interface.input_components
],
"output_components": [
iface.get_template_context()
for iface in interface.output_components],
iface.get_template_context() for iface in interface.output_components
],
"function_count": len(interface.predict),
"live": interface.live,
"examples_per_page": interface.examples_per_page,
@ -194,9 +205,11 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
"flagging_options": interface.flagging_options,
"allow_interpretation": interface.interpretation is not None,
"queue": interface.enable_queue,
"cached_examples": interface.cache_examples if hasattr(interface, "cache_examples") else False,
"cached_examples": interface.cache_examples
if hasattr(interface, "cache_examples")
else False,
"version": pkg_resources.require("gradio")[0].version,
"favicon_path": interface.favicon_path
"favicon_path": interface.favicon_path,
}
try:
param_names = inspect.getfullargspec(interface.predict[0])[0]
@ -205,16 +218,23 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
iface["label"] = param.replace("_", " ")
for i, iface in enumerate(config["output_components"]):
outputs_per_function = int(
len(interface.output_components) / len(interface.predict))
len(interface.output_components) / len(interface.predict)
)
function_index = i // outputs_per_function
component_index = i - function_index * outputs_per_function
ret_name = "Output " + \
str(component_index + 1) if outputs_per_function > 1 else "Output"
ret_name = (
"Output " + str(component_index + 1)
if outputs_per_function > 1
else "Output"
)
if iface["label"] is None:
iface["label"] = ret_name
if len(interface.predict) > 1:
iface["label"] = interface.function_names[function_index].replace(
"_", " ") + ": " + iface["label"]
iface["label"] = (
interface.function_names[function_index].replace("_", " ")
+ ": "
+ iface["label"]
)
except ValueError:
pass
@ -222,23 +242,36 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
if isinstance(interface.examples, str):
if not os.path.exists(interface.examples):
raise FileNotFoundError(
"Could not find examples directory: " + interface.examples)
"Could not find examples directory: " + interface.examples
)
log_file = os.path.join(interface.examples, "log.csv")
if not os.path.exists(log_file):
if len(interface.input_components) == 1:
examples = [[os.path.join(interface.examples, item)]
for item in os.listdir(interface.examples)]
examples = [
[os.path.join(interface.examples, item)]
for item in os.listdir(interface.examples)
]
else:
raise FileNotFoundError(
"Could not find log file (required for multiple inputs): " + log_file)
"Could not find log file (required for multiple inputs): "
+ log_file
)
else:
with open(log_file) as logs:
examples = list(csv.reader(logs))
examples = examples[1:] # remove header
for i, example in enumerate(examples):
for j, (component, cell) in enumerate(zip(interface.input_components + interface.output_components, example)):
for j, (component, cell) in enumerate(
zip(
interface.input_components + interface.output_components,
example,
)
):
examples[i][j] = component.restore_flagged(
interface.flagging_dir, cell, interface.encryption_key if interface.encrypt else None)
interface.flagging_dir,
cell,
interface.encryption_key if interface.encrypt else None,
)
config["examples"] = examples
config["examples_dir"] = interface.examples
else:
@ -249,8 +282,6 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
def get_default_args(func: Callable) -> Dict[str, Any]:
signature = inspect.signature(func)
return [
v.default
if v.default is not inspect.Parameter.empty
else None
v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values()
]
]

View File

@ -1,6 +1,7 @@
import re
from os.path import join, exists, getmtime
from jinja2 import Environment, BaseLoader, TemplateNotFound
from os.path import exists, getmtime, join
from jinja2 import BaseLoader, Environment, TemplateNotFound
README_TEMPLATE = "readme_template.md"
GETTING_STARTED_TEMPLATE = "getting_started.md"
@ -15,11 +16,16 @@ code, demos = {}, {}
for code_src in code_tags:
with open(join("demo", code_src, "run.py")) as code_file:
python_code = code_file.read()
python_code = python_code.replace('if __name__ == "__main__":\n iface.launch()', "iface.launch()")
python_code = python_code.replace(
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
)
code[code_src] = "```python\n" + python_code + "\n```"
for demo_src in demo_tags:
demos[demo_src] = "![" + demo_src + " interface](demo/" + demo_src + "/screenshot.gif)"
demos[demo_src] = (
"![" + demo_src + " interface](demo/" + demo_src + "/screenshot.gif)"
)
class GuidesLoader(BaseLoader):
def __init__(self, path):
@ -34,8 +40,11 @@ class GuidesLoader(BaseLoader):
source = f.read()
return source, path, lambda: mtime == getmtime(path)
readme_template = Environment(loader=GuidesLoader("guides")).get_template(README_TEMPLATE)
readme_template = Environment(loader=GuidesLoader("guides")).get_template(
README_TEMPLATE
)
output_readme = readme_template.render(code=code, demos=demos)
with open("README.md", "w") as readme_md:
readme_md.write(output_readme)
readme_md.write(output_readme)

View File

@ -4,31 +4,31 @@ except ImportError:
from distutils.core import setup
setup(
name='gradio',
version='2.7.1',
name="gradio",
version="2.7.1",
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq',
author_email='team@gradio.app',
url='https://github.com/gradio-app/gradio-UI',
packages=['gradio'],
license='Apache License 2.0',
keywords=['machine learning', 'visualization', 'reproducibility'],
description="Python library for easily interacting with trained machine learning models",
author="Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq",
author_email="team@gradio.app",
url="https://github.com/gradio-app/gradio-UI",
packages=["gradio"],
license="Apache License 2.0",
keywords=["machine learning", "visualization", "reproducibility"],
install_requires=[
'analytics-python',
'aiohttp',
'fastapi',
'ffmpy',
'markdown2',
'matplotlib',
'numpy',
'pandas',
'paramiko',
'pillow',
'pycryptodome',
'python-multipart',
'pydub',
'requests',
'uvicorn',
"analytics-python",
"aiohttp",
"fastapi",
"ffmpy",
"markdown2",
"matplotlib",
"numpy",
"pandas",
"paramiko",
"pillow",
"pycryptodome",
"python-multipart",
"pydub",
"requests",
"uvicorn",
],
)

View File

@ -1,14 +1,15 @@
import unittest
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import multiprocessing
import os
import random
import time
import unittest
import requests
from matplotlib.testing.compare import compare_images
import random
import os
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
current_dir = os.getcwd()
@ -36,12 +37,14 @@ def wait_for_url(url):
def diff_texts_thread(return_dict):
from demo.diff_texts.run import iface
iface.save_to = return_dict
iface.launch()
def image_mod_thread(return_dict):
from demo.image_mod.run import iface
iface.examples = None
iface.save_to = return_dict
iface.launch()
@ -49,12 +52,14 @@ def image_mod_thread(return_dict):
def longest_word_thread(return_dict):
from demo.longest_word.run import iface
iface.save_to = return_dict
iface.launch()
def sentence_builder_thread(return_dict):
from demo.sentence_builder.run import iface
iface.save_to = return_dict
iface.launch()
@ -63,8 +68,7 @@ class TestDemo(unittest.TestCase):
def start_test(self, target):
manager = multiprocessing.Manager()
return_dict = manager.dict()
self.i_thread = multiprocessing.Process(target=target,
args=(return_dict,))
self.i_thread = multiprocessing.Process(target=target, args=(return_dict,))
self.i_thread.start()
total_sleep = 0
while not return_dict and total_sleep < TIMEOUT:
@ -81,25 +85,36 @@ class TestDemo(unittest.TestCase):
def test_diff_texts(self):
driver = self.start_test(target=diff_texts_thread)
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(1) .input_text textarea"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(1) .input_text textarea",
)
)
)
elem.clear()
elem.send_keys("Want to see a magic trick?")
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(2) .input_text textarea"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(2) .input_text textarea",
)
)
)
elem.clear()
elem.send_keys("Let's go see a magic trick!")
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".submit"))
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
)
elem.click()
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".panel:nth-child(2) .component:nth-child(2) .textfield"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(2) .component:nth-child(2) .textfield",
)
)
)
total_sleep = 0
@ -108,10 +123,12 @@ class TestDemo(unittest.TestCase):
total_sleep += 0.2
self.assertEqual(elem.text, "L+e+W-a-n-t'+s+ t-g+o see a magic trick?-!+")
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
"diff_texts", "magic_trick"))
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
random.getrandbits(32)))
golden_img = os.path.join(
current_dir, GOLDEN_PATH.format("diff_texts", "magic_trick")
)
tmp = os.path.join(
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
)
time.sleep(GAP_TO_SCREENSHOT)
driver.save_screenshot(tmp)
driver.close()
@ -122,23 +139,32 @@ class TestDemo(unittest.TestCase):
driver = self.start_test(target=image_mod_thread)
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located(
(By.CSS_SELECTOR, ".panel:nth-child(1) .component:nth-child(1) .hidden_upload"))
(
By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(1) .hidden_upload",
)
)
)
cwd = os.getcwd()
rel = "test/test_files/cheetah1.jpg"
elem.send_keys(os.path.join(cwd, rel))
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
"image_mod", "cheetah1"))
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
random.getrandbits(32)))
golden_img = os.path.join(
current_dir, GOLDEN_PATH.format("image_mod", "cheetah1")
)
tmp = os.path.join(
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
)
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".submit"))
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
)
elem.click()
WebDriverWait(driver, TIMEOUT).until(
EC.visibility_of_element_located(
(By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_image"))
(
By.CSS_SELECTOR,
".panel:nth-child(2) .component:nth-child(2) .output_image",
)
)
)
time.sleep(GAP_TO_SCREENSHOT)
@ -150,18 +176,25 @@ class TestDemo(unittest.TestCase):
def test_longest_word(self):
driver = self.start_test(target=longest_word_thread)
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(1) .input_text textarea"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(1) .component:nth-child(1) .input_text textarea",
)
)
)
elem.send_keys("This is the most wonderful machine learning "
"library.")
elem.send_keys("This is the most wonderful machine learning " "library.")
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".submit"))
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
)
elem.click()
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences",
)
)
)
total_sleep = 0
@ -169,10 +202,12 @@ class TestDemo(unittest.TestCase):
time.sleep(0.2)
total_sleep += 0.2
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
"longest_word", "wonderful"))
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
random.getrandbits(32)))
golden_img = os.path.join(
current_dir, GOLDEN_PATH.format("longest_word", "wonderful")
)
tmp = os.path.join(
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
)
time.sleep(GAP_TO_SCREENSHOT)
driver.save_screenshot(tmp)
driver.close()
@ -182,12 +217,16 @@ class TestDemo(unittest.TestCase):
def test_sentence_builder(self):
driver = self.start_test(target=sentence_builder_thread)
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR,
".submit"))
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
)
elem.click()
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_text"))
EC.presence_of_element_located(
(
By.CSS_SELECTOR,
".panel:nth-child(2) .component:nth-child(2) .output_text",
)
)
)
total_sleep = 0
@ -196,11 +235,14 @@ class TestDemo(unittest.TestCase):
total_sleep += 0.2
self.assertEqual(
elem.text, "The 2 cats went to the park where they until the night")
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
"sentence_builder", "two_cats"))
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
random.getrandbits(32)))
elem.text, "The 2 cats went to the park where they until the night"
)
golden_img = os.path.join(
current_dir, GOLDEN_PATH.format("sentence_builder", "two_cats")
)
tmp = os.path.join(
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
)
time.sleep(GAP_TO_SCREENSHOT)
driver.save_screenshot(tmp)
self.assertIsNone(compare_images(tmp, golden_img, TOLERANCE))
@ -212,5 +254,5 @@ class TestDemo(unittest.TestCase):
self.i_thread.join()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,9 @@
import os
import pathlib
import transformers
import unittest
import transformers
import gradio as gr
"""
@ -18,7 +19,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
"deepset/roberta-base-squad2", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
@ -28,7 +30,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_text_generation(self):
model_type = "text_generation"
interface_info = gr.external.get_huggingface_interface(
"gpt2", api_key=None, alias=model_type)
"gpt2", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
@ -36,7 +39,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_summarization(self):
model_type = "summarization"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type)
"facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
@ -44,7 +48,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_translation(self):
model_type = "translation"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type)
"facebook/bart-large-cnn", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
@ -52,7 +57,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_text2text_generation(self):
model_type = "text2text-generation"
interface_info = gr.external.get_huggingface_interface(
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
"sshleifer/tiny-mbart", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
@ -61,7 +67,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
model_type = "text-classification"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english",
api_key=None, alias=model_type)
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
@ -69,8 +77,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_fill_mask(self):
model_type = "fill-mask"
interface_info = gr.external.get_huggingface_interface(
"bert-base-uncased",
api_key=None, alias=model_type)
"bert-base-uncased", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
@ -78,8 +86,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-mnli",
api_key=None, alias=model_type)
"facebook/bart-large-mnli", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
@ -89,8 +97,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface_info = gr.external.get_huggingface_interface(
"facebook/wav2vec2-base-960h",
api_key=None, alias=model_type)
"facebook/wav2vec2-base-960h", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
@ -98,8 +106,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_image_classification(self):
model_type = "image-classification"
interface_info = gr.external.get_huggingface_interface(
"google/vit-base-patch16-224",
api_key=None, alias=model_type)
"google/vit-base-patch16-224", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
@ -108,7 +116,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
model_type = "feature-extraction"
interface_info = gr.external.get_huggingface_interface(
"sentence-transformers/distilbert-base-nli-mean-tokens",
api_key=None, alias=model_type)
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
@ -117,7 +127,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
api_key=None, alias=model_type)
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
@ -126,7 +138,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
api_key=None, alias=model_type)
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
@ -134,60 +148,74 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
def test_text_to_image(self):
model_type = "text-to-image"
interface_info = gr.external.get_huggingface_interface(
"osanseviero/BigGAN-deep-128",
api_key=None, alias=model_type)
"osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
def test_english_to_spanish(self):
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
interface_info = gr.external.get_spaces_interface(
"abidlabs/english_to_spanish", api_key=None, alias=None
)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
class TestLoadInterface(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.load_interface("spaces/abidlabs/english_to_spanish")
interface_info = gr.external.load_interface(
"spaces/abidlabs/english_to_spanish"
)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
interface_info = gr.external.load_interface(
"models/distilbert-base-uncased-finetuned-sst-2-english",
alias="sentiment_classifier",
)
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("I am happy, I love you.")
self.assertGreater(output['POSITIVE'], 0.5)
self.assertGreater(output["POSITIVE"], 0.5)
def test_image_classification_model(self):
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
interface_info = gr.external.load_interface(
"models/google/vit-base-patch16-224"
)
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
self.assertGreater(output["lion"], 0.5)
def test_translation_model(self):
interface_info = gr.external.load_interface("models/t5-base")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("My name is Sarah and I live in London")
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
self.assertEquals(output, "Mein Name ist Sarah und ich lebe in London")
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("male", 77, 10)
self.assertLess(output['Survives'], 0.5)
self.assertLess(output["Survives"], 0.5)
def test_speech_recognition_model(self):
interface_info = gr.external.load_interface("models/jonatasgrosman/wav2vec2-large-xlsr-53-english")
interface_info = gr.external.load_interface(
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english"
)
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("test/test_data/test_audio.wav")
self.assertIsNotNone(output)
def test_text_to_image_model(self):
interface_info = gr.external.load_interface("models/osanseviero/BigGAN-deep-128")
interface_info = gr.external.load_interface(
"models/osanseviero/BigGAN-deep-128"
)
io = gr.Interface(**interface_info)
io.api_mode = True
filename = io("chest")
@ -207,11 +235,14 @@ class TestLoadInterface(unittest.TestCase):
class TestLoadFromPipeline(unittest.TestCase):
def test_question_answering(self):
p = transformers.pipeline("question-answering")
p = transformers.pipeline("question-answering")
io = gr.Interface.from_pipeline(p)
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn", "Where do I work?")
output = io(
"My name is Sylvain and I work at Hugging Face in Brooklyn",
"Where do I work?",
)
self.assertIsNotNone(output)
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

View File

@ -1,9 +1,10 @@
import gradio as gr
from gradio import flagging
import tempfile
import unittest
import unittest.mock as mock
import gradio as gr
from gradio import flagging
class TestFlagging(unittest.TestCase):
def test_default_flagging_handler(self):
@ -18,7 +19,13 @@ class TestFlagging(unittest.TestCase):
def test_simple_csv_flagging_handler(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname, flagging_callback=flagging.SimpleCSVLogger())
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 0) # no header
@ -27,5 +34,5 @@ class TestFlagging(unittest.TestCase):
io.close()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,14 +1,15 @@
from re import sub
import unittest
import gradio as gr
import PIL
import numpy as np
import pandas
from pydub import AudioSegment
import json
import os
import tempfile
import json
import unittest
from re import sub
import numpy as np
import pandas
import PIL
from pydub import AudioSegment
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -32,7 +33,9 @@ class TestTextbox(unittest.TestCase):
self.assertEqual(text_input.preprocess_example("Hello World!"), "Hello World!")
self.assertEqual(text_input.serialize("Hello World!", True), "Hello World!")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = text_input.save_flagged(tmpdirname, "text_input", "Hello World!", None)
to_save = text_input.save_flagged(
tmpdirname, "text_input", "Hello World!", None
)
self.assertEqual(to_save, "Hello World!")
restored = text_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "Hello World!")
@ -44,29 +47,79 @@ class TestTextbox(unittest.TestCase):
wrong_type = gr.inputs.Textbox(type="unknown")
wrong_type.preprocess(0)
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
['Hello', 'World!', 'Gradio', 'speaking.'],
['World! Gradio speaking.', 'Hello Gradio speaking.', 'Hello World! speaking.', 'Hello World! Gradio'],
None))
self.assertEqual(
text_input.tokenize("Hello World! Gradio speaking."),
(
["Hello", "World!", "Gradio", "speaking."],
[
"World! Gradio speaking.",
"Hello Gradio speaking.",
"Hello World! speaking.",
"Hello World! Gradio",
],
None,
),
)
text_input.interpretation_replacement = "unknown"
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
['Hello', 'World!', 'Gradio', 'speaking.'],
['unknown World! Gradio speaking.', 'Hello unknown Gradio speaking.', 'Hello World! unknown speaking.',
'Hello World! Gradio unknown'], None))
self.assertEqual(
text_input.tokenize("Hello World! Gradio speaking."),
(
["Hello", "World!", "Gradio", "speaking."],
[
"unknown World! Gradio speaking.",
"Hello unknown Gradio speaking.",
"Hello World! unknown speaking.",
"Hello World! Gradio unknown",
],
None,
),
)
self.assertIsInstance(text_input.generate_sample(), str)
def test_in_interface(self):
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
iface = gr.Interface(lambda sentence: max([len(word) for word in sentence.split()]), gr.inputs.Textbox(),
gr.outputs.Textbox(), interpretation="default")
scores, alternative_outputs = iface.interpret(["Return the length of the longest word in this sentence"])
self.assertEqual(scores, [[('Return', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('length', 0.0), (' ', 0),
('of', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('longest', 0.0), (' ', 0),
('word', 0.0), (' ', 0), ('in', 0.0), (' ', 0), ('this', 0.0), (' ', 0),
('sentence', 1.0), (' ', 0)]])
self.assertEqual(alternative_outputs, [[['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['7']]])
iface = gr.Interface(
lambda sentence: max([len(word) for word in sentence.split()]),
gr.inputs.Textbox(),
gr.outputs.Textbox(),
interpretation="default",
)
scores, alternative_outputs = iface.interpret(
["Return the length of the longest word in this sentence"]
)
self.assertEqual(
scores,
[
[
("Return", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("length", 0.0),
(" ", 0),
("of", 0.0),
(" ", 0),
("the", 0.0),
(" ", 0),
("longest", 0.0),
(" ", 0),
("word", 0.0),
(" ", 0),
("in", 0.0),
(" ", 0),
("this", 0.0),
(" ", 0),
("sentence", 1.0),
(" ", 0),
]
],
)
self.assertEqual(
alternative_outputs,
[[["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["7"]]],
)
class TestNumber(unittest.TestCase):
@ -82,20 +135,50 @@ class TestNumber(unittest.TestCase):
self.assertEqual(restored, 3)
self.assertIsInstance(numeric_input.generate_sample(), float)
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}))
self.assertEqual(
numeric_input.get_interpretation_neighbors(1),
([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}),
)
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent")
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
self.assertEqual(
numeric_input.get_interpretation_neighbors(1),
([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}),
)
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ['4.0'])
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default")
iface = gr.Interface(lambda x: x ** 2, "number", "textbox")
self.assertEqual(iface.process([2])[0], ["4.0"])
iface = gr.Interface(
lambda x: x ** 2, "number", "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(scores, [[(1.94, -0.23640000000000017), (1.96, -0.15840000000000032),
(1.98, -0.07960000000000012), [2, None], (2.02, 0.08040000000000003),
(2.04, 0.16159999999999997), (2.06, 0.24359999999999982)]])
self.assertEqual(alternative_outputs, [[['3.7636'], ['3.8415999999999997'], ['3.9204'], ['4.0804'], ['4.1616'],
['4.2436']]])
self.assertEqual(
scores,
[
[
(1.94, -0.23640000000000017),
(1.96, -0.15840000000000032),
(1.98, -0.07960000000000012),
[2, None],
(2.02, 0.08040000000000003),
(2.04, 0.16159999999999997),
(2.06, 0.24359999999999982),
]
],
)
self.assertEqual(
alternative_outputs,
[
[
["3.7636"],
["3.8415999999999997"],
["3.9204"],
["4.0804"],
["4.1616"],
["4.2436"],
]
],
)
class TestSlider(unittest.TestCase):
@ -111,26 +194,58 @@ class TestSlider(unittest.TestCase):
self.assertEqual(restored, 3)
self.assertIsInstance(slider_input.generate_sample(), int)
slider_input = gr.inputs.Slider(minimum=10, maximum=20, step=1, default=15, label="Slide Your Input")
self.assertEqual(slider_input.get_template_context(), {
'minimum': 10,
'maximum': 20,
'step': 1,
'default': 15,
'name': 'slider',
'label': 'Slide Your Input'
})
slider_input = gr.inputs.Slider(
minimum=10, maximum=20, step=1, default=15, label="Slide Your Input"
)
self.assertEqual(
slider_input.get_template_context(),
{
"minimum": 10,
"maximum": 20,
"step": 1,
"default": 15,
"name": "slider",
"label": "Slide Your Input",
},
)
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ['4'])
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default")
iface = gr.Interface(lambda x: x ** 2, "slider", "textbox")
self.assertEqual(iface.process([2])[0], ["4"])
iface = gr.Interface(
lambda x: x ** 2, "slider", "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([2])
self.assertEqual(scores, [[-4.0, 200.08163265306123, 812.3265306122449, 1832.7346938775513, 3261.3061224489797,
5098.040816326531, 7342.938775510205, 9996.0]])
self.assertEqual(alternative_outputs, [[['0.0'], ['204.08163265306123'], ['816.3265306122449'],
['1836.7346938775513'], ['3265.3061224489797'], ['5102.040816326531'],
['7346.938775510205'], ['10000.0']]])
self.assertEqual(
scores,
[
[
-4.0,
200.08163265306123,
812.3265306122449,
1832.7346938775513,
3261.3061224489797,
5098.040816326531,
7342.938775510205,
9996.0,
]
],
)
self.assertEqual(
alternative_outputs,
[
[
["0.0"],
["204.08163265306123"],
["816.3265306122449"],
["1836.7346938775513"],
["3265.3061224489797"],
["5102.040816326531"],
["7346.938775510205"],
["10000.0"],
]
],
)
class TestCheckbox(unittest.TestCase):
@ -146,22 +261,23 @@ class TestCheckbox(unittest.TestCase):
self.assertEqual(restored, True)
self.assertIsInstance(bool_input.generate_sample(), bool)
bool_input = gr.inputs.Checkbox(default=True, label="Check Your Input")
self.assertEqual(bool_input.get_template_context(), {
'default': True,
'name': 'checkbox',
'label': 'Check Your Input'
})
self.assertEqual(
bool_input.get_template_context(),
{"default": True, "name": "checkbox", "label": "Check Your Input"},
)
def test_in_interface(self):
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox")
self.assertEqual(iface.process([True])[0], ['1'])
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox", interpretation="default")
self.assertEqual(iface.process([True])[0], ["1"])
iface = gr.Interface(
lambda x: 1 if x else 0, "checkbox", "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([False])
self.assertEqual(scores, [(None, 1.0)])
self.assertEqual(alternative_outputs, [[['1']]])
self.assertEqual(alternative_outputs, [[["1"]]])
scores, alternative_outputs = iface.interpret([True])
self.assertEqual(scores, [(-1.0, None)])
self.assertEqual(alternative_outputs, [[['0']]])
self.assertEqual(alternative_outputs, [[["0"]]])
class TestCheckboxGroup(unittest.TestCase):
@ -171,19 +287,25 @@ class TestCheckboxGroup(unittest.TestCase):
self.assertEqual(checkboxes_input.preprocess_example(["a", "c"]), ["a", "c"])
self.assertEqual(checkboxes_input.serialize(["a", "c"], True), ["a", "c"])
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = checkboxes_input.save_flagged(tmpdirname, "checkboxes_input", ["a", "c"], None)
to_save = checkboxes_input.save_flagged(
tmpdirname, "checkboxes_input", ["a", "c"], None
)
self.assertEqual(to_save, '["a", "c"]')
restored = checkboxes_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, ["a", "c"])
self.assertIsInstance(checkboxes_input.generate_sample(), list)
checkboxes_input = gr.inputs.CheckboxGroup(choices=["a", "b", "c"], default=["a", "c"],
label="Check Your Inputs")
self.assertEqual(checkboxes_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': ['a', 'c'],
'name': 'checkboxgroup',
'label': 'Check Your Inputs'
})
checkboxes_input = gr.inputs.CheckboxGroup(
choices=["a", "b", "c"], default=["a", "c"], label="Check Your Inputs"
)
self.assertEqual(
checkboxes_input.get_template_context(),
{
"choices": ["a", "b", "c"],
"default": ["a", "c"],
"name": "checkboxgroup",
"label": "Check Your Inputs",
},
)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.CheckboxGroup(["a"], type="unknown")
wrong_type.preprocess(0)
@ -194,11 +316,16 @@ class TestCheckboxGroup(unittest.TestCase):
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
self.assertEqual(iface.process([[]])[0], [""])
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes_input, "textbox", interpretation="default")
iface = gr.Interface(
lambda x: "|".join(map(str, x)),
checkboxes_input,
"textbox",
interpretation="default",
)
self.assertEqual(iface.process([["a", "c"]])[0], ["0|2"])
scores, alternative_outputs = iface.interpret([["a", "c"]])
self.assertEqual(scores, [[[-1, None], [None, -1], [-1, None]]])
self.assertEqual(alternative_outputs, [[['2'], ['0|2|1'], ['0']]])
self.assertEqual(alternative_outputs, [[["2"], ["0|2|1"], ["0"]]])
class TestRadio(unittest.TestCase):
@ -209,20 +336,24 @@ class TestRadio(unittest.TestCase):
self.assertEqual(radio_input.serialize("a", True), "a")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = radio_input.save_flagged(tmpdirname, "radio_input", "a", None)
self.assertEqual(to_save, 'a')
self.assertEqual(to_save, "a")
restored = radio_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "a")
self.assertIsInstance(radio_input.generate_sample(), str)
radio_input = gr.inputs.Radio(choices=["a", "b", "c"], default="a",
label="Pick Your One Input")
self.assertEqual(radio_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': 'a',
'name': 'radio',
'label': 'Pick Your One Input'
})
radio_input = gr.inputs.Radio(
choices=["a", "b", "c"], default="a", label="Pick Your One Input"
)
self.assertEqual(
radio_input.get_template_context(),
{
"choices": ["a", "b", "c"],
"default": "a",
"name": "radio",
"label": "Pick Your One Input",
},
)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Radio(["a","b"], type="unknown")
wrong_type = gr.inputs.Radio(["a", "b"], type="unknown")
wrong_type.preprocess(0)
def test_in_interface(self):
@ -230,7 +361,9 @@ class TestRadio(unittest.TestCase):
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
self.assertEqual(iface.process(["c"])[0], ["cc"])
radio_input = gr.inputs.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, radio_input, "number", interpretation="default")
iface = gr.Interface(
lambda x: 2 * x, radio_input, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
scores, alternative_outputs = iface.interpret(["b"])
self.assertEqual(scores, [[-2.0, None, 2.0]])
@ -244,19 +377,25 @@ class TestDropdown(unittest.TestCase):
self.assertEqual(dropdown_input.preprocess_example("a"), "a")
self.assertEqual(dropdown_input.serialize("a", True), "a")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dropdown_input.save_flagged(tmpdirname, "dropdown_input", "a", None)
self.assertEqual(to_save, 'a')
to_save = dropdown_input.save_flagged(
tmpdirname, "dropdown_input", "a", None
)
self.assertEqual(to_save, "a")
restored = dropdown_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "a")
self.assertIsInstance(dropdown_input.generate_sample(), str)
dropdown_input = gr.inputs.Dropdown(choices=["a", "b", "c"], default="a",
label="Drop Your Input")
self.assertEqual(dropdown_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': 'a',
'name': 'dropdown',
'label': 'Drop Your Input'
})
dropdown_input = gr.inputs.Dropdown(
choices=["a", "b", "c"], default="a", label="Drop Your Input"
)
self.assertEqual(
dropdown_input.get_template_context(),
{
"choices": ["a", "b", "c"],
"default": "a",
"name": "dropdown",
"label": "Drop Your Input",
},
)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Dropdown(["a"], type="unknown")
wrong_type.preprocess(0)
@ -266,7 +405,9 @@ class TestDropdown(unittest.TestCase):
iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox")
self.assertEqual(iface.process(["c"])[0], ["cc"])
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, dropdown, "number", interpretation="default")
iface = gr.Interface(
lambda x: 2 * x, dropdown, "number", interpretation="default"
)
self.assertEqual(iface.process(["c"])[0], [4])
scores, alternative_outputs = iface.interpret(["b"])
self.assertEqual(scores, [[-2.0, None, 2.0]])
@ -293,16 +434,21 @@ class TestImage(unittest.TestCase):
self.assertEqual(restored, "image_input/1.png")
self.assertIsInstance(image_input.generate_sample(), str)
image_input = gr.inputs.Image(source="upload", tool="editor", type="pil", label="Upload Your Image")
self.assertEqual(image_input.get_template_context(), {
'image_mode': 'RGB',
'shape': None,
'source': 'upload',
'tool': 'editor',
'optional': False,
'name': 'image',
'label': 'Upload Your Image'
})
image_input = gr.inputs.Image(
source="upload", tool="editor", type="pil", label="Upload Your Image"
)
self.assertEqual(
image_input.get_template_context(),
{
"image_mode": "RGB",
"shape": None,
"source": "upload",
"tool": "editor",
"optional": False,
"name": "image",
"label": "Upload Your Image",
},
)
self.assertIsNone(image_input.preprocess(None))
image_input = gr.inputs.Image(invert_colors=True)
self.assertIsNotNone(image_input.preprocess(img))
@ -318,7 +464,7 @@ class TestImage(unittest.TestCase):
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Image(type="unknown")
wrong_type.serialize("test/test_files/bus.png", False)
img_pil = PIL.Image.open('test/test_files/bus.png')
img_pil = PIL.Image.open("test/test_files/bus.png")
image_input = gr.inputs.Image(type="numpy")
self.assertIsInstance(image_input.serialize(img_pil, False), str)
image_input = gr.inputs.Image(type="pil")
@ -332,21 +478,40 @@ class TestImage(unittest.TestCase):
def test_in_interface(self):
img = gr.test_data.BASE64_IMAGE
image_input = gr.inputs.Image()
iface = gr.Interface(lambda x: PIL.Image.open(x).rotate(90, expand=True),
gr.inputs.Image(shape=(30, 10), type="file"), "image")
iface = gr.Interface(
lambda x: PIL.Image.open(x).rotate(90, expand=True),
gr.inputs.Image(shape=(30, 10), type="file"),
"image",
)
output = iface.process([img])[0][0]
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
self.assertEqual(
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
)
iface = gr.Interface(
lambda x: np.sum(x), image_input, "textbox", interpretation="default"
)
scores, alternative_outputs = iface.interpret([img])
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"])
self.assertEqual(alternative_outputs, gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"])
iface = gr.Interface(lambda x: np.sum(x), image_input, "label", interpretation="shap")
self.assertEqual(
alternative_outputs,
gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"],
)
iface = gr.Interface(
lambda x: np.sum(x), image_input, "label", interpretation="shap"
)
scores, alternative_outputs = iface.interpret([img])
self.assertEqual(len(scores[0]), len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]))
self.assertEqual(len(alternative_outputs[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]))
self.assertEqual(
len(scores[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]),
)
self.assertEqual(
len(alternative_outputs[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]),
)
image_input = gr.inputs.Image(shape=(30, 10))
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
iface = gr.Interface(
lambda x: np.sum(x), image_input, "textbox", interpretation="default"
)
self.assertIsNotNone(iface.interpret([img]))
@ -357,8 +522,11 @@ class TestAudio(unittest.TestCase):
output = audio_input.preprocess(x_wav)
self.assertEqual(output[0], 8000)
self.assertEqual(output[1].shape, (8046,))
self.assertEqual(audio_input.serialize("test/test_files/audio_sample.wav", True)["data"], x_wav["data"])
self.assertEqual(
audio_input.serialize("test/test_files/audio_sample.wav", True)["data"],
x_wav["data"],
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None)
self.assertEqual("audio_input/0.wav", to_save)
@ -369,12 +537,15 @@ class TestAudio(unittest.TestCase):
self.assertIsInstance(audio_input.generate_sample(), dict)
audio_input = gr.inputs.Audio(label="Upload Your Audio")
self.assertEqual(audio_input.get_template_context(), {
'source': 'upload',
'optional': False,
'name': 'audio',
'label': 'Upload Your Audio'
})
self.assertEqual(
audio_input.get_template_context(),
{
"source": "upload",
"optional": False,
"name": "audio",
"label": "Upload Your Audio",
},
)
self.assertIsNone(audio_input.preprocess(None))
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
@ -394,7 +565,6 @@ class TestAudio(unittest.TestCase):
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
# def test_in_interface(self):
# x_wav = gr.test_data.BASE64_AUDIO
# def max_amplitude_from_wav_file(wav_file):
@ -405,7 +575,7 @@ class TestAudio(unittest.TestCase):
# max_amplitude_from_wav_file,
# gr.inputs.Audio(type="file"),
# "number", interpretation="default")
# # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576)
# # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576)
# self.assertEqual(iface.process([x_wav])[0], [576])
# scores, alternative_outputs = iface.interpret([x_wav])
# self.assertEqual(scores, ... )
@ -418,8 +588,11 @@ class TestFile(unittest.TestCase):
file_input = gr.inputs.File()
output = file_input.preprocess(x_file)
self.assertIsInstance(output, tempfile._TemporaryFileWrapper)
self.assertEqual(file_input.serialize("test/test_files/sample_file.pdf", True), 'test/test_files/sample_file.pdf')
self.assertEqual(
file_input.serialize("test/test_files/sample_file.pdf", True),
"test/test_files/sample_file.pdf",
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None)
self.assertEqual("file_input/0", to_save)
@ -430,12 +603,15 @@ class TestFile(unittest.TestCase):
self.assertIsInstance(file_input.generate_sample(), dict)
file_input = gr.inputs.File(label="Upload Your File")
self.assertEqual(file_input.get_template_context(), {
'file_count': 'single',
'optional': False,
'name': 'file',
'label': 'Upload Your File'
})
self.assertEqual(
file_input.get_template_context(),
{
"file_count": "single",
"optional": False,
"name": "file",
"label": "Upload Your File",
},
)
self.assertIsNone(file_input.preprocess(None))
x_file["is_example"] = True
self.assertIsNotNone(file_input.preprocess(x_file))
@ -445,39 +621,46 @@ class TestFile(unittest.TestCase):
def get_size_of_file(file_obj):
return os.path.getsize(file_obj.name)
iface = gr.Interface(
get_size_of_file, "file", "number")
iface = gr.Interface(get_size_of_file, "file", "number")
self.assertEqual(iface.process([[x_file]])[0], [10558])
class TestDataframe(unittest.TestCase):
def test_as_component(self):
x_data = [["Tim", 12, False], ["Jan", 24, True]]
dataframe_input = gr.inputs.Dataframe(headers=["Name","Age","Member"])
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"])
output = dataframe_input.preprocess(x_data)
self.assertEqual(output["Age"][1], 24)
self.assertEqual(output["Member"][0], False)
self.assertEqual(dataframe_input.preprocess_example(x_data), x_data)
self.assertEqual(dataframe_input.serialize(x_data, True), x_data)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dataframe_input.save_flagged(tmpdirname, "dataframe_input", x_data, None)
to_save = dataframe_input.save_flagged(
tmpdirname, "dataframe_input", x_data, None
)
self.assertEqual(json.dumps(x_data), to_save)
restored = dataframe_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(x_data, restored)
self.assertIsInstance(dataframe_input.generate_sample(), list)
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"], label="Dataframe Input")
self.assertEqual(dataframe_input.get_template_context(), {
'headers': ['Name', 'Age', 'Member'],
'datatype': 'str',
'row_count': 3,
'col_count': 3,
'col_width': None,
'default': [[None, None, None], [None, None, None], [None, None, None]],
'name': 'dataframe',
'label': 'Dataframe Input'
})
dataframe_input = gr.inputs.Dataframe(
headers=["Name", "Age", "Member"], label="Dataframe Input"
)
self.assertEqual(
dataframe_input.get_template_context(),
{
"headers": ["Name", "Age", "Member"],
"datatype": "str",
"row_count": 3,
"col_count": 3,
"col_width": None,
"default": [[None, None, None], [None, None, None], [None, None, None]],
"name": "dataframe",
"label": "Dataframe Input",
},
)
dataframe_input = gr.inputs.Dataframe()
output = dataframe_input.preprocess(x_data)
self.assertEqual(output[1][1], 24)
@ -493,6 +676,7 @@ class TestDataframe(unittest.TestCase):
def get_last(l):
return l[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])
@ -503,7 +687,7 @@ class TestVideo(unittest.TestCase):
video_input = gr.inputs.Video()
output = video_input.preprocess(x_video)
self.assertIsInstance(output, str)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None)
self.assertEqual("video_input/0.mp4", to_save)
@ -511,15 +695,18 @@ class TestVideo(unittest.TestCase):
self.assertEqual("video_input/1.mp4", to_save)
restored = video_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(restored, "video_input/1.mp4")
self.assertIsInstance(video_input.generate_sample(), dict)
video_input = gr.inputs.Video(label="Upload Your Video")
self.assertEqual(video_input.get_template_context(), {
'source': 'upload',
'optional': False,
'name': 'video',
'label': 'Upload Your Video'
})
self.assertEqual(
video_input.get_template_context(),
{
"source": "upload",
"optional": False,
"name": "video",
"label": "Upload Your Video",
},
)
self.assertIsNone(video_input.preprocess(None))
x_video["is_example"] = True
self.assertIsNotNone(video_input.preprocess(x_video))
@ -528,72 +715,73 @@ class TestVideo(unittest.TestCase):
with self.assertRaises(NotImplementedError):
video_input.serialize(x_video, True)
def test_in_interface(self):
x_video = gr.test_data.BASE64_VIDEO
iface = gr.Interface(
lambda x:x,
"video",
"playable_video")
iface = gr.Interface(lambda x: x, "video", "playable_video")
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
class TestTimeseries(unittest.TestCase):
def test_as_component(self):
timeseries_input = gr.inputs.Timeseries(
x="time",
y=["retail", "food", "other"]
)
x_timeseries = {"data": [[1] + [2] * len(timeseries_input.y)] * 4, "headers": [timeseries_input.x] +
timeseries_input.y}
timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"])
x_timeseries = {
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
"headers": [timeseries_input.x] + timeseries_input.y,
}
output = timeseries_input.preprocess(x_timeseries)
self.assertIsInstance(output, pandas.core.frame.DataFrame)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = timeseries_input.save_flagged(tmpdirname, "video_input", x_timeseries, None)
to_save = timeseries_input.save_flagged(
tmpdirname, "video_input", x_timeseries, None
)
self.assertEqual(json.dumps(x_timeseries), to_save)
restored = timeseries_input.restore_flagged(tmpdirname, to_save, None)
self.assertEqual(x_timeseries, restored)
self.assertIsInstance(timeseries_input.generate_sample(), dict)
timeseries_input = gr.inputs.Timeseries(
x="time",
y="retail", label="Upload Your Timeseries"
x="time", y="retail", label="Upload Your Timeseries"
)
self.assertEqual(
timeseries_input.get_template_context(),
{
"x": "time",
"y": ["retail"],
"optional": False,
"name": "timeseries",
"label": "Upload Your Timeseries",
},
)
self.assertEqual(timeseries_input.get_template_context(), {
'x': 'time',
'y': ['retail'],
'optional': False,
'name': 'timeseries',
'label': 'Upload Your Timeseries'
})
self.assertIsNone(timeseries_input.preprocess(None))
x_timeseries["range"] = (0, 1)
self.assertIsNotNone(timeseries_input.preprocess(x_timeseries))
def test_in_interface(self):
timeseries_input = gr.inputs.Timeseries(
x="time",
y=["retail", "food", "other"]
timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"])
x_timeseries = {
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
"headers": [timeseries_input.x] + timeseries_input.y,
}
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
self.assertEqual(
iface.process([x_timeseries])[0],
[
{
"headers": ["time", "retail", "food", "other"],
"data": [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]],
}
],
)
x_timeseries = {"data": [[1] + [2] * len(timeseries_input.y)] * 4, "headers": [timeseries_input.x] +
timeseries_input.y}
iface = gr.Interface(
lambda x: x,
timeseries_input,
"dataframe")
self.assertEqual(iface.process([x_timeseries])[0], [{'headers': ['time', 'retail', 'food', 'other'],
'data': [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
[1, 2, 2, 2]]}])
class TestNames(unittest.TestCase):
# this ensures that `inputs.get_input_instance()` works correctly when instantiating from components
def test_no_duplicate_uncased_names(self):
def test_no_duplicate_uncased_names(self):
subclasses = gr.inputs.InputComponent.__subclasses__()
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,19 +1,22 @@
import io
import socket
import sys
import tempfile
from gradio.interface import *
import threading
import unittest
import unittest.mock as mock
import requests
import sys
from contextlib import contextmanager
import io
import threading
from comet_ml import Experiment
import mlflow
import requests
import wandb
import socket
from comet_ml import Experiment
from gradio.interface import *
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@contextmanager
def captured_output():
new_out, new_err = io.StringIO(), io.StringIO()
@ -24,6 +27,7 @@ def captured_output():
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestInterface(unittest.TestCase):
def test_close(self):
io = Interface(lambda input: None, "textbox", "label")
@ -33,32 +37,34 @@ class TestInterface(unittest.TestCase):
io.close()
with self.assertRaises(Exception):
response = requests.get(local_url)
def test_close_all(self):
interface = Interface(lambda input: None, "textbox", "label")
interface.close = mock.MagicMock()
close_all()
interface.close.assert_called()
def test_examples_invalid_input(self):
with self.assertRaises(ValueError):
Interface(lambda x: x, examples=1234)
def test_examples_valid_path(self):
path = os.path.join(os.path.dirname(__file__), 'test_data/flagged_with_log')
path = os.path.join(os.path.dirname(__file__), "test_data/flagged_with_log")
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
self.assertEqual(len(interface.get_config_file()['examples']), 2)
path = os.path.join(os.path.dirname(__file__), 'test_data/flagged_no_log')
self.assertEqual(len(interface.get_config_file()["examples"]), 2)
path = os.path.join(os.path.dirname(__file__), "test_data/flagged_no_log")
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
self.assertEqual(len(interface.get_config_file()['examples']), 3)
self.assertEqual(len(interface.get_config_file()["examples"]), 3)
def test_examples_not_valid_path(self):
with self.assertRaises(FileNotFoundError):
interface = Interface(lambda x: x, "textbox", "label", examples='invalid-path')
interface = Interface(
lambda x: x, "textbox", "label", examples="invalid-path"
)
interface.launch(prevent_thread_lock=True)
interface.close()
def test_test_launch(self):
with captured_output() as (out, err):
prediction_fn = lambda x: x
@ -66,8 +72,8 @@ class TestInterface(unittest.TestCase):
interface = Interface(prediction_fn, "textbox", "label")
interface.test_launch()
output = out.getvalue().strip()
self.assertEqual(output, 'Test launch: prediction_fn()... PASSED')
self.assertEqual(output, "Test launch: prediction_fn()... PASSED")
@mock.patch("time.sleep")
def test_block_thread(self, mock_sleep):
with self.assertRaises(KeyboardInterrupt):
@ -76,19 +82,20 @@ class TestInterface(unittest.TestCase):
interface = Interface(lambda x: x, "textbox", "label")
interface.launch(prevent_thread_lock=False)
output = out.getvalue().strip()
self.assertEqual(output, 'Keyboard interruption in main thread... closing server.')
self.assertEqual(
output, "Keyboard interruption in main thread... closing server."
)
@mock.patch('gradio.utils.colab_check')
@mock.patch("gradio.utils.colab_check")
def test_launch_colab_share(self, mock_colab_check):
mock_colab_check.return_value = True
interface = Interface(lambda x: x, "textbox", "label")
_, _, share_url = interface.launch(prevent_thread_lock=True)
self.assertIsNotNone(share_url)
interface.close()
@mock.patch('gradio.utils.colab_check')
@mock.patch('gradio.networking.setup_tunnel')
@mock.patch("gradio.utils.colab_check")
@mock.patch("gradio.networking.setup_tunnel")
def test_launch_colab_share_error(self, mock_setup_tunnel, mock_colab_check):
mock_setup_tunnel.side_effect = RuntimeError()
mock_colab_check.return_value = True
@ -96,40 +103,43 @@ class TestInterface(unittest.TestCase):
_, _, share_url = interface.launch(prevent_thread_lock=True)
self.assertIsNone(share_url)
interface.close()
def test_interface_representation(self):
prediction_fn = lambda x: x
prediction_fn.__name__ = "prediction_fn"
repr = str(Interface(prediction_fn, "textbox", "label")).split('\n')
repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
self.assertTrue(prediction_fn.__name__ in repr[0])
self.assertEqual(len(repr[0]), len(repr[1]))
def test_interface_load(self):
io = Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = Interface.load(
"models/distilbert-base-uncased-finetuned-sst-2-english",
alias="sentiment_classifier",
)
output = io("I am happy, I love you.")
self.assertGreater(output['POSITIVE'], 0.5)
self.assertGreater(output["POSITIVE"], 0.5)
def test_interface_none_interp(self):
interface = Interface(lambda x: x, "textbox", "label", interpretation=[None])
scores, alternative_outputs = interface.interpret(["quickest brown fox"])
self.assertIsNone(scores[0])
@mock.patch('webbrowser.open')
@mock.patch("webbrowser.open")
def test_interface_browser(self, mock_browser):
interface = Interface(lambda x: x, "textbox", "label")
interface.launch(inbrowser=True, prevent_thread_lock=True)
mock_browser.assert_called_once()
interface.close()
def test_examples_list(self):
examples = ['test1', 'test2']
examples = ["test1", "test2"]
interface = Interface(lambda x: x, "textbox", "label", examples=examples)
interface.launch(prevent_thread_lock=True)
self.assertEqual(len(interface.examples), 2)
self.assertEqual(len(interface.examples[0]), 1)
interface.close()
@mock.patch('IPython.display.display')
@mock.patch("IPython.display.display")
def test_inline_display(self, mock_display):
interface = Interface(lambda x: x, "textbox", "label")
interface.launch(inline=True, prevent_thread_lock=True)
@ -137,8 +147,8 @@ class TestInterface(unittest.TestCase):
interface.launch(inline=True, share=True, prevent_thread_lock=True)
self.assertEqual(mock_display.call_count, 2)
interface.close()
@mock.patch('comet_ml.Experiment')
@mock.patch("comet_ml.Experiment")
def test_integration_comet(self, mock_experiment):
experiment = mock_experiment()
experiment.log_text = mock.MagicMock()
@ -146,44 +156,52 @@ class TestInterface(unittest.TestCase):
interface = Interface(lambda x: x, "textbox", "label")
interface.launch(prevent_thread_lock=True)
interface.integrate(comet_ml=experiment)
experiment.log_text.assert_called_with('gradio: ' + interface.local_url)
interface.share_url = 'tmp' # used to avoid creating real share links.
experiment.log_text.assert_called_with("gradio: " + interface.local_url)
interface.share_url = "tmp" # used to avoid creating real share links.
interface.integrate(comet_ml=experiment)
experiment.log_text.assert_called_with('gradio: ' + interface.share_url)
experiment.log_text.assert_called_with("gradio: " + interface.share_url)
self.assertEqual(experiment.log_other.call_count, 2)
interface.share_url = None
interface.close()
def test_integration_mlflow(self):
mlflow.log_param = mock.MagicMock()
interface = Interface(lambda x: x, "textbox", "label")
interface.launch(prevent_thread_lock=True)
interface.integrate(mlflow=mlflow)
mlflow.log_param.assert_called_with("Gradio Interface Local Link", interface.local_url)
interface.share_url = 'tmp' # used to avoid creating real share links.
mlflow.log_param.assert_called_with(
"Gradio Interface Local Link", interface.local_url
)
interface.share_url = "tmp" # used to avoid creating real share links.
interface.integrate(mlflow=mlflow)
mlflow.log_param.assert_called_with("Gradio Interface Share Link", interface.share_url)
mlflow.log_param.assert_called_with(
"Gradio Interface Share Link", interface.share_url
)
interface.share_url = None
interface.close()
def test_integration_wandb(self):
with captured_output() as (out, err):
wandb.log = mock.MagicMock()
wandb.Html = mock.MagicMock()
interface = Interface(lambda x: x, "textbox", "label")
interface.integrate(wandb=wandb)
self.assertEqual(out.getvalue().strip(), "The WandB integration requires you to `launch(share=True)` first.")
interface.share_url = 'tmp'
self.assertEqual(
out.getvalue().strip(),
"The WandB integration requires you to `launch(share=True)` first.",
)
interface.share_url = "tmp"
interface.integrate(wandb=wandb)
wandb.log.assert_called_once()
@mock.patch('requests.post')
@mock.patch("requests.post")
def test_integration_analytics(self, mock_post):
mlflow.log_param = mock.MagicMock()
interface = Interface(lambda x: x, "textbox", "label")
interface.analytics_enabled = True
interface.integrate(mlflow=mlflow)
mock_post.assert_called_once()
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

View File

@ -1,45 +1,71 @@
import os
import unittest
import numpy as np
import gradio.interpretation
import gradio.test_data
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
from gradio import Interface
import numpy as np
import os
from gradio.processing_utils import (decode_base64_to_image,
encode_array_to_base64)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestDefault(unittest.TestCase):
def test_default_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="default"
)
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
self.assertGreater(
interpretation[0][1], 0
) # Checks to see if the first word has >0 score.
self.assertEqual(
interpretation[-1][1], 0
) # Checks to see if the last word has 0 score.
class TestShapley(unittest.TestCase):
def test_shapley_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley")
text_interface = Interface(
max_word_len, "textbox", "label", interpretation="shapley"
)
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
self.assertGreater(
interpretation[0][1], 0
) # Checks to see if the first word has >0 score.
self.assertEqual(
interpretation[-1][1], 0
) # Checks to see if the last word has 0 score.
class TestCustom(unittest.TestCase):
def test_custom_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
custom = lambda text: [(char, 1) for char in text]
text_interface = Interface(max_word_len, "textbox", "label", interpretation=custom)
text_interface = Interface(
max_word_len, "textbox", "label", interpretation=custom
)
result = text_interface.interpret(["quickest brown fox"])[0][0]
self.assertEqual(result[0][1], 1) # Checks to see if the first letter has score of 1.
self.assertEqual(
result[0][1], 1
) # Checks to see if the first letter has score of 1.
def test_custom_img(self):
max_pixel_value = lambda img: img.max()
custom = lambda img: img.tolist()
img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom)
img_interface = Interface(
max_pixel_value, "image", "label", interpretation=custom
)
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][0]
expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist()
expected_result = np.asarray(
decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert("RGB")
).tolist()
self.assertEqual(result, expected_result)
class TestHelperMethods(unittest.TestCase):
def test_diff(self):
@ -52,9 +78,13 @@ class TestHelperMethods(unittest.TestCase):
def test_quantify_difference_with_textbox(self):
iface = Interface(lambda text: text, ["textbox"], ["textbox"])
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"])
diff = gradio.interpretation.quantify_difference_in_label(
iface, ["test"], ["test"]
)
self.assertEquals(diff, 0)
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"])
diff = gradio.interpretation.quantify_difference_in_label(
iface, ["test"], ["test_diff"]
)
self.assertEquals(diff, 1)
def test_quantify_difference_with_label(self):
@ -66,48 +96,43 @@ class TestHelperMethods(unittest.TestCase):
def test_quantify_difference_with_confidences(self):
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
output_1 = {
"cat": 0.9,
"dog": 0.1
}
output_2 = {
"cat": 0.6,
"dog": 0.4
}
output_3 = {
"cat": 0.1,
"dog": 0.6
}
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2])
output_1 = {"cat": 0.9, "dog": 0.1}
output_2 = {"cat": 0.6, "dog": 0.4}
output_3 = {"cat": 0.1, "dog": 0.6}
diff = gradio.interpretation.quantify_difference_in_label(
iface, [output_1], [output_2]
)
self.assertAlmostEquals(diff, 0.3)
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3])
diff = gradio.interpretation.quantify_difference_in_label(
iface, [output_1], [output_3]
)
self.assertAlmostEquals(diff, 0.8)
def test_get_regression_value(self):
iface = Interface(lambda text: text, ["textbox"], ["label"])
output_1 = {
"cat": 0.9,
"dog": 0.1
}
output_2 = {
"cat": float("nan"),
"dog": 0.4
}
output_3 = {
"cat": 0.1,
"dog": 0.6
}
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2])
output_1 = {"cat": 0.9, "dog": 0.1}
output_2 = {"cat": float("nan"), "dog": 0.4}
output_3 = {"cat": 0.1, "dog": 0.6}
diff = gradio.interpretation.get_regression_or_classification_value(
iface, [output_1], [output_2]
)
self.assertEquals(diff, 0)
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3])
diff = gradio.interpretation.get_regression_or_classification_value(
iface, [output_1], [output_3]
)
self.assertAlmostEquals(diff, 0.1)
def test_get_classification_value(self):
iface = Interface(lambda text: text, ["textbox"], ["label"])
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"])
diff = gradio.interpretation.get_regression_or_classification_value(
iface, ["cat"], ["test"]
)
self.assertEquals(diff, 1)
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"])
diff = gradio.interpretation.get_regression_or_classification_value(
iface, ["test"], ["test"]
)
self.assertEquals(diff, 0)
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,8 @@
import os
import unittest
import gradio as gr
from gradio import mix
import os
"""
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
@ -14,8 +14,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestSeries(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World", "textbox",
gr.outputs.Textbox())
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
series = mix.Series(io1, io2)
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
@ -25,18 +24,18 @@ class TestSeries(unittest.TestCase):
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
series = mix.Series(io1, io2)
output = series("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
self.assertGreater(output["lion"], 0.5)
class TestParallel(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + " World 2!", "textbox",
gr.outputs.Textbox())
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
parallel = mix.Parallel(io1, io2)
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
"Hello World 2!"])
self.assertEqual(
parallel.process(["Hello"])[0], ["Hello World 1!", "Hello World 2!"]
)
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
io2 = gr.Interface.load("spaces/abidlabs/english2german")
@ -46,5 +45,5 @@ class TestParallel(unittest.TestCase):
self.assertIn("hallo", hello_de.lower())
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,15 +1,15 @@
"""Contains tests for networking.py and app.py"""
import aiohttp
from fastapi.testclient import TestClient
import os
import unittest
import unittest.mock as mock
import urllib.request
import warnings
from gradio import flagging, Interface, networking, reset_all, utils
import aiohttp
from fastapi.testclient import TestClient
from gradio import Interface, flagging, networking, reset_all, utils
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -37,46 +37,46 @@ class TestPort(unittest.TestCase):
class TestRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = TestClient(self.app)
def test_get_main_route(self):
response = self.client.get('/')
response = self.client.get("/")
self.assertEqual(response.status_code, 200)
def test_get_api_route(self):
response = self.client.get('/api/')
response = self.client.get("/api/")
self.assertEqual(response.status_code, 200)
def test_static_files_served_safely(self):
# Make sure things outside the static folder are not accessible
response = self.client.get(r'/static/..%2findex.html')
response = self.client.get(r"/static/..%2findex.html")
self.assertEqual(response.status_code, 404)
response = self.client.get(r'/static/..%2f..%2fapi_docs.html')
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
self.assertEqual(response.status_code, 404)
def test_get_config_route(self):
response = self.client.get('/config/')
response = self.client.get("/config/")
self.assertEqual(response.status_code, 200)
def test_predict_route(self):
response = self.client.post('/api/predict/', json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
response = self.client.post("/api/predict/", json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["test"])
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
self.assertEqual(output["data"], ["test"])
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
# def test_queue_push_route(self):
# networking.queue.push = mock.MagicMock(return_value=(None, None))
# response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
# self.assertEqual(response.status_code, 200)
# self.assertEqual(response.status_code, 200)
# def test_queue_push_route(self):
# networking.queue.get_status = mock.MagicMock(return_value=(None, None))
# response = self.client.post('/api/queue/status/', json={"hash": "test"})
# self.assertEqual(response.status_code, 200)
# self.assertEqual(response.status_code, 200)
def tearDown(self) -> None:
self.io.close()
@ -85,15 +85,21 @@ class TestRoutes(unittest.TestCase):
class TestAuthenticatedRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(
auth=("test", "correct_password"), prevent_thread_lock=True
)
self.client = TestClient(self.app)
def test_post_login(self):
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
response = self.client.post(
"/login", data=dict(username="test", password="correct_password")
)
self.assertEqual(response.status_code, 302)
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
self.assertEqual(response.status_code, 400)
response = self.client.post(
"/login", data=dict(username="test", password="incorrect_password")
)
self.assertEqual(response.status_code, 400)
def tearDown(self) -> None:
self.io.close()
@ -102,10 +108,10 @@ class TestAuthenticatedRoutes(unittest.TestCase):
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
io = Interface(lambda x: 1/x, "number", "number")
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post('/api/predict/', json={"data": [0]})
response = client.post("/api/predict/", json={"data": [0]})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
io.close()
@ -119,13 +125,18 @@ class TestFlagging(unittest.TestCase):
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
io = Interface(
lambda x: x, "text", "text",
analytics_enabled=True, flagging_callback=callback)
lambda x: x,
"text",
"text",
analytics_enabled=True,
flagging_callback=callback,
)
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post(
'/api/flag/',
json={"data": {"input_data": ["test"], "output_data": ["test"]}})
"/api/flag/",
json={"data": {"input_data": ["test"], "output_data": ["test"]}},
)
aiohttp.ClientSession.post.assert_called()
callback.flag.assert_called_once()
self.assertEqual(response.status_code, 200)
@ -135,16 +146,19 @@ class TestFlagging(unittest.TestCase):
class TestInterpretation(unittest.TestCase):
def test_interpretation(self):
io = Interface(
lambda x: len(x), "text", "label",
interpretation="default", analytics_enabled=True)
lambda x: len(x),
"text",
"label",
interpretation="default",
analytics_enabled=True,
)
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
aiohttp.ClientSession.post = mock.MagicMock()
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
aiohttp.ClientSession.post.__aexit__ = None
io.interpret = mock.MagicMock(return_value=(None, None))
response = client.post(
'/api/interpret/', json={"data": ["test test"]})
response = client.post("/api/interpret/", json={"data": ["test test"]})
aiohttp.ClientSession.post.assert_called()
self.assertEqual(response.status_code, 200)
io.close()
@ -187,5 +201,5 @@ class TestURLs(unittest.TestCase):
# networking.queue.fail_job.assert_called_once()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,11 +1,12 @@
import os
import tempfile
import unittest
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
import pandas as pd
import tempfile
import os
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -26,13 +27,15 @@ class TestTextbox(unittest.TestCase):
def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox(type="number"))
iface = gr.Interface(
lambda x: x / 2, "number", gr.outputs.Textbox(type="number")
)
self.assertEqual(iface.process([10])[0], [5])
class TestLabel(unittest.TestCase):
def test_as_component(self):
y = 'happy'
y = "happy"
label_output = gr.outputs.Label()
label = label_output.postprocess(y)
self.assertDictEqual(label, {"label": "happy"})
@ -41,45 +44,55 @@ class TestLabel(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, y)
y = {
3: 0.7,
1: 0.2,
0: 0.1
}
y = {3: 0.7, 1: 0.2, 0: 0.1}
label_output = gr.outputs.Label()
label = label_output.postprocess(y)
self.assertDictEqual(label, {
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
{"label": 0, "confidence": 0.1},
]
})
self.assertDictEqual(
label,
{
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
{"label": 0, "confidence": 0.1},
],
},
)
label_output = gr.outputs.Label(num_top_classes=2)
label = label_output.postprocess(y)
self.assertDictEqual(label, {
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
]
})
self.assertDictEqual(
label,
{
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
],
},
)
with self.assertRaises(ValueError):
label_output.postprocess([1, 2, 3])
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
self.assertEqual(label_output.restore_flagged(tmpdir, to_save, None),
{'label': '3', 'confidences': [{"label": "3", "confidence": 0.7}, {"label": "1", "confidence": 0.2}]})
self.assertEqual(
label_output.restore_flagged(tmpdir, to_save, None),
{
"label": "3",
"confidences": [
{"label": "3", "confidence": 0.7},
{"label": "1", "confidence": 0.2},
],
},
)
with self.assertRaises(ValueError):
label_output = gr.outputs.Label(type="unknown")
label_output.deserialize([1, 2, 3])
def test_in_interface(self):
x_img = gr.test_data.BASE64_IMAGE
def rgb_distribution(img):
rgb_dist = np.mean(img, axis=(0, 1))
rgb_dist /= np.sum(rgb_dist)
@ -92,22 +105,33 @@ class TestLabel(unittest.TestCase):
iface = gr.Interface(rgb_distribution, "image", "label")
output = iface.process([x_img])[0][0]
self.assertDictEqual(output, {
'label': 'red',
'confidences': [
{'label': 'red', 'confidence': 0.44},
{'label': 'green', 'confidence': 0.28},
{'label': 'blue', 'confidence': 0.28}
]
})
self.assertDictEqual(
output,
{
"label": "red",
"confidences": [
{"label": "red", "confidence": 0.44},
{"label": "green", "confidence": 0.28},
{"label": "blue", "confidence": 0.28},
],
},
)
class TestImage(unittest.TestCase):
def test_as_component(self):
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
image_output = gr.outputs.Image()
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
self.assertTrue(
image_output.postprocess(y_img).startswith(
"data:image/png;base64,iVBORw0KGgoAAA"
)
)
self.assertTrue(
image_output.postprocess(np.array(y_img)).startswith(
"data:image/png;base64,iVBORw0KGgoAAA"
)
)
with self.assertWarns(DeprecationWarning):
plot_output = gr.outputs.Image(plot=True)
@ -115,15 +139,23 @@ class TestImage(unittest.TestCase):
ypoints = np.array([0, 250])
fig = plt.figure()
p = plt.plot(xpoints, ypoints)
self.assertTrue(plot_output.postprocess(fig).startswith("data:image/png;base64,"))
self.assertTrue(
plot_output.postprocess(fig).startswith("data:image/png;base64,")
)
with self.assertRaises(ValueError):
image_output.postprocess([1, 2, 3])
image_output = gr.outputs.Image(type="numpy")
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,"))
self.assertTrue(
image_output.postprocess(y_img).startswith("data:image/png;base64,")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = image_output.save_flagged(tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None)
to_save = image_output.save_flagged(
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
)
self.assertEqual("image_output/0.png", to_save)
to_save = image_output.save_flagged(tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None)
to_save = image_output.save_flagged(
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
)
self.assertEqual("image_output/1.png", to_save)
def test_in_interface(self):
@ -131,19 +163,29 @@ class TestImage(unittest.TestCase):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
self.assertTrue(
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
)
class TestVideo(unittest.TestCase):
def test_as_component(self):
y_vid = "test/test_files/video_sample.mp4"
video_output = gr.outputs.Video()
self.assertTrue(video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,"))
self.assertTrue(video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4"))
self.assertTrue(
video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,")
)
self.assertTrue(
video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = video_output.save_flagged(tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None)
to_save = video_output.save_flagged(
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
)
self.assertEqual("video_output/0.mp4", to_save)
to_save = video_output.save_flagged(tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None)
to_save = video_output.save_flagged(
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
)
self.assertEqual("video_output/1.mp4", to_save)
@ -159,7 +201,10 @@ class TestKeyValues(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = kv_output.save_flagged(tmpdirname, "kv_output", kv_list, None)
self.assertEqual(to_save, '[["a", 1], ["b", 2]]')
self.assertEqual(kv_output.restore_flagged(tmpdirname, to_save, None), [["a", 1], ["b", 2]])
self.assertEqual(
kv_output.restore_flagged(tmpdirname, to_save, None),
[["a", 1], ["b", 2]],
)
def test_in_interface(self):
def letter_distribution(word):
@ -169,27 +214,31 @@ class TestKeyValues(unittest.TestCase):
return dist
iface = gr.Interface(letter_distribution, "text", "key_values")
self.assertListEqual(iface.process(["alpaca"])[0][0], [
("a", 3), ("l", 1), ("p", 1), ("c", 1)])
self.assertListEqual(
iface.process(["alpaca"])[0][0], [("a", 3), ("l", 1), ("p", 1), ("c", 1)]
)
class TestHighlightedText(unittest.TestCase):
def test_as_component(self):
ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"})
self.assertEqual(ht_output.get_template_context(), {
'color_map': {'pos': 'green', 'neg': 'red'},
'name': 'highlightedtext',
'label': None,
'show_legend': False
})
ht = {
"pos": "Hello ",
"neg": "World"
}
self.assertEqual(
ht_output.get_template_context(),
{
"color_map": {"pos": "green", "neg": "red"},
"name": "highlightedtext",
"label": None,
"show_legend": False,
},
)
ht = {"pos": "Hello ", "neg": "World"}
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = ht_output.save_flagged(tmpdirname, "ht_output", ht, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(ht_output.restore_flagged(tmpdirname, to_save, None), {"pos": "Hello ", "neg": "World"})
self.assertEqual(
ht_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
def test_in_interface(self):
def highlight_vowels(sentence):
@ -206,29 +255,42 @@ class TestHighlightedText(unittest.TestCase):
cur_phrase += letter
phrases.append((cur_phrase, mode))
return phrases
iface = gr.Interface(highlight_vowels, "text", "highlight")
self.assertListEqual(iface.process(["Helloooo"])[0][0], [
("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")])
self.assertListEqual(
iface.process(["Helloooo"])[0][0],
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
)
class TestAudio(unittest.TestCase):
def test_as_component(self):
y_audio = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_AUDIO["data"])
y_audio = gr.processing_utils.decode_base64_to_file(
gr.test_data.BASE64_AUDIO["data"]
)
audio_output = gr.outputs.Audio(type="file")
self.assertTrue(audio_output.postprocess(y_audio.name).startswith("data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"))
self.assertEqual(audio_output.get_template_context(), {
'name': 'audio',
'label': None
})
self.assertTrue(
audio_output.postprocess(y_audio.name).startswith(
"data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"
)
)
self.assertEqual(
audio_output.get_template_context(), {"name": "audio", "label": None}
)
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Audio(type="unknown")
wrong_type.postprocess(y_audio.name)
self.assertTrue(audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav"))
self.assertTrue(
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_output.save_flagged(tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
)
self.assertEqual("audio_output/0.wav", to_save)
to_save = audio_output.save_flagged(tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
)
self.assertEqual("audio_output/1.wav", to_save)
def test_in_interface(self):
@ -242,16 +304,18 @@ class TestAudio(unittest.TestCase):
class TestJSON(unittest.TestCase):
def test_as_component(self):
js_output = gr.outputs.JSON()
self.assertTrue(js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"')
js = {
"pos": "Hello ",
"neg": "World"
}
self.assertTrue(
js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
)
js = {"pos": "Hello ", "neg": "World"}
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = js_output.save_flagged(tmpdirname, "js_output", js, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(js_output.restore_flagged(tmpdirname, to_save, None), {"pos": "Hello ", "neg": "World"})
self.assertEqual(
js_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
def test_in_interface(self):
def get_avg_age_per_gender(data):
return {
@ -263,7 +327,8 @@ class TestJSON(unittest.TestCase):
iface = gr.Interface(
get_avg_age_per_gender,
gr.inputs.Dataframe(headers=["gender", "age"]),
"json")
"json",
)
y_data = [
["M", 30],
["F", 20],
@ -271,9 +336,7 @@ class TestJSON(unittest.TestCase):
["O", 20],
["F", 30],
]
self.assertDictEqual(iface.process([y_data])[0][0], {
"M": 35, "F": 25, "O": 20
})
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
class TestHTML(unittest.TestCase):
@ -293,80 +356,116 @@ class TestFile(unittest.TestCase):
return "test.txt"
iface = gr.Interface(write_file, "text", "file")
self.assertDictEqual(iface.process(["hello world"])[0][0], {
'name': 'test.txt', 'size': 11, 'data': 'data:text/plain;base64,aGVsbG8gd29ybGQ='
})
self.assertDictEqual(
iface.process(["hello world"])[0][0],
{
"name": "test.txt",
"size": 11,
"data": "data:text/plain;base64,aGVsbG8gd29ybGQ=",
},
)
file_output = gr.outputs.File()
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_output.save_flagged(tmpdirname, "file_output", gr.test_data.BASE64_FILE, None)
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
)
self.assertEqual("file_output/0", to_save)
to_save = file_output.save_flagged(tmpdirname, "file_output", gr.test_data.BASE64_FILE, None)
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
)
self.assertEqual("file_output/1", to_save)
class TestDataframe(unittest.TestCase):
def test_as_component(self):
dataframe_output = gr.outputs.Dataframe()
output = dataframe_output.postprocess(np.zeros((2,2)))
self.assertDictEqual(output, {"data": [[0,0],[0,0]]})
output = dataframe_output.postprocess([[1,3,5]])
output = dataframe_output.postprocess(np.zeros((2, 2)))
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
output = dataframe_output.postprocess([[1, 3, 5]])
self.assertDictEqual(output, {"data": [[1, 3, 5]]})
output = dataframe_output.postprocess(pd.DataFrame(
[[2, True], [3, True], [4, False]], columns=["num", "prime"]))
self.assertDictEqual(output,
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]})
self.assertEqual(dataframe_output.get_template_context(), {
'headers': None,
'max_rows': 20,
'max_cols': None,
'overflow_row_behaviour': 'paginate',
'name': 'dataframe',
'label': None
})
output = dataframe_output.postprocess(
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
)
self.assertDictEqual(
output,
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]},
)
self.assertEqual(
dataframe_output.get_template_context(),
{
"headers": None,
"max_rows": 20,
"max_cols": None,
"overflow_row_behaviour": "paginate",
"name": "dataframe",
"label": None,
},
)
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Dataframe(type="unknown")
wrong_type.postprocess(0)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dataframe_output.save_flagged(tmpdirname, "dataframe_output", output, None)
self.assertEqual(to_save, '[[2, true], [3, true], [4, false]]')
self.assertEqual(dataframe_output.restore_flagged(tmpdirname, to_save, None), {"data": [[2, True], [3, True], [4, False]]})
to_save = dataframe_output.save_flagged(
tmpdirname, "dataframe_output", output, None
)
self.assertEqual(to_save, "[[2, true], [3, true], [4, false]]")
self.assertEqual(
dataframe_output.restore_flagged(tmpdirname, to_save, None),
{"data": [[2, True], [3, True], [4, False]]},
)
def test_in_interface(self):
def check_odd(array):
return array % 2 == 0
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
iface.process([[2, 3, 4]])[0][0],
{"data": [[True, False, True]]})
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
)
class TestCarousel(unittest.TestCase):
def test_as_component(self):
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
output = carousel_output.postprocess([["Hello World", "test/test_files/bus.png"],
["Bye World", "test/test_files/bus.png"]])
self.assertEqual(output, [['Hello World', gr.test_data.BASE64_IMAGE],
['Bye World', gr.test_data.BASE64_IMAGE]])
output = carousel_output.postprocess(
[
["Hello World", "test/test_files/bus.png"],
["Bye World", "test/test_files/bus.png"],
]
)
self.assertEqual(
output,
[
["Hello World", gr.test_data.BASE64_IMAGE],
["Bye World", gr.test_data.BASE64_IMAGE],
],
)
carousel_output = gr.outputs.Carousel("text", label="Disease")
output = carousel_output.postprocess([["Hello World"], ["Bye World"]])
self.assertEqual(output, [['Hello World'], ['Bye World']])
self.assertEqual(carousel_output.get_template_context(), {
'components': [{'name': 'textbox', 'label': None}],
'name': 'carousel',
'label': 'Disease'
})
self.assertEqual(output, [["Hello World"], ["Bye World"]])
self.assertEqual(
carousel_output.get_template_context(),
{
"components": [{"name": "textbox", "label": None}],
"name": "carousel",
"label": "Disease",
},
)
output = carousel_output.postprocess(["Hello World", "Bye World"])
self.assertEqual(output, [['Hello World'], ['Bye World']])
self.assertEqual(output, [["Hello World"], ["Bye World"]])
with self.assertRaises(ValueError):
carousel_output.postprocess('Hello World!')
carousel_output.postprocess("Hello World!")
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = carousel_output.save_flagged(tmpdirname, "carousel_output", output, None)
to_save = carousel_output.save_flagged(
tmpdirname, "carousel_output", output, None
)
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
def test_in_interface(self):
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
def report(img):
results = []
for i, mode in enumerate(["Red", "Green", "Blue"]):
@ -374,48 +473,82 @@ class TestCarousel(unittest.TestCase):
color_filter[i] = 1
results.append([mode, img * color_filter])
return results
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
self.assertEqual(
iface.process([gr.test_data.BASE64_IMAGE])[0], [[['Red',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFFElEQVR4nO3aT2gcVRzA8U+MTZq6xZBKdYvFFYyCtFq0UO3FehEUe1E8+AeaUw+C2pPiyS14UQ9tvXlroZ5EqVgrKmq8aAQjVXvQNuIWYxtLIlsSog0tehgnndmdmZ3ZXdMU8j2Et+/NvPfN2/f7vTeT9PzjquSaKy3QJiveS8uK99Ky4r20rHgvLSveS8uK99JylXlf5CKuLu8pvmUOXHuFXfJRZyI0Dlju3nNMUG+qX77ef1NjKqV1OXpfZJLJMAQTWXbeU0xkGgcso3xSZ4yfkqTnOcaLkZplMd9pwRdwjFH+ildeYe/s4MMkHyXVx9bJKLUuSmVykRpjKdKnOMw8p1Juvzzfx3kQ7KJKpauWDUxSSwm+Gd7lR7CtaXkscnm+62HhELcy8v/M/TRj6RljntdC6WxS80nX7esc5wR/J7V+wTy/p09wAy3i8hBH2MMeBvM7xskOvjE+4k9uLtJn6/x9nr1UqKanqjSygw8HeJs/C3Yr/77Thv0kYynLbCb8OZFzeDAQKRfbL3PaT6UH3zyHqTJWcJqHeCbysZ19vqX9TynBN0aVb5BbepgBHmMvd0Xq2z+ftLFy3sudLgKGGOb1cGOJctl7C9cX6TSgpf0pDvADCkrvYF1662XvQfa3pS5ifyRSOcMB3mSCySK93cbzPJ55TWydjFDjlQ7s90Q+Hi6YLjDMS7zAcKsrG9f3INUO7E9HyoWkh0LXnLtPo3eNWsf2hRjgYV4qeFej9yd8whnE7bvOAMOh8SOsKXh7o3cZnI3UDFLlV3a1L5lAkIwfyUwaGTR63085qa8KB7tkP8TzuXVLbOKmpvpG7xvYmf7QUOnMfjNPszdHuggo8T5P8FbTabSd/bJS3H4I7Oa+IgMd5VVG2d90okz2rjHdqtNKbvttBXUXORApfxYWgieGZO+v+DJf15V0+yFuoxo/x+Xnc+rsYh8oMchWSqAn8f8hxhnnoYJPxzXqbGG0LdEGXuH78MzTQzWejpPnexMlvuJCjgEO8gGosKV9z0am4r0txFuTvfvZzhxf5xhggbP83K5fIr2cDMvHwSp+DB+UZOSTCrdzkvFWY2xC03x0SC+oMUoVbGWBGr8h+jz/Pfvib3x2MMM4F9iePsZ2Ku1ue4nG/fSGsxY8MdxDmT4qrEV0vu9OemfyKGVO8DGzScNcYJoN9HdsfA1rWBNO9r2RpmepsDmUjnkvhEf1QzxHjQv0s5NNnOZdxuP2ZzjKe62EekKVjAtWc138st2UGeQtRpq+z//y4BnOMstRSuwMm9dRpp8zjIfnrRJrmWWOPu7njnino5HyKj5ljsdTslMfffQkNa1jY8rv/J/3Jf7gHJdS7g/spznNNAv0sYHbk1bIoncPb/AheJLd8ctW0Z9ivJYKfUlNMW9F7Fuy6D3Gy2G5xLGw515Wp+SyATZG1nEasfeDvWzgxhT7GWaK2OMd8ADHOU8v/7A65asPvsCceSnhdw7sN1NOGmCGE2HUZvMX37GLUUbAqqbgWxyxzJ1Fkmnq+9iWc19nPevTu/gFofEgUhZGRvBl0OI9cob9Jc5yLt0++jxfD89xUVoGXwa5/i7Vnv1saFznIFvjxuUcwdepd0B++2Cv3ghGGOQ8D6Bg8GWQfP5uSXbGDDjJU2G5zDHWs6Gt4Zpp0zugpf1uvqPEEXYUD74MOvIOyLCf5RzbuKXjURrogndAs33nwZdB17wDLvEbs10Kvgy67L1k/Asi+GhgiYdDNAAAAABJRU5ErkJggg=='],
['Green',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFvElEQVR4nNXaT2gexx3G8U9iIyNbQS+oKDhE2G6DQkDCAhWKcrGwe4xooKU15BAFQ6CX4NxNqxLRXlqi5pZCiHJo4ubSgnwKldHJIgeBgnTJe6mEWpWKCCIUR0jEqIfVvt7V/nl3913J7XOwxjO7M1+tfs/M7G/2KYf+H/X0iY+wyEds19zrUyf4vJvMRohf5hX66un7ZLibzNFMa6qJvm7ube7xoN1lHdPXx73HPHNlbumAvibuee7xbaV7K9F3zP0Ff6ljuihJ3wF3jvkqqzB9Je6C5qusAvQluSuYr7Jy6ctw32O+qvkqK4O+GPcic/Wv1SWUoG/HfRLmq6wIfTb3Np+yfHpUhbTPr/i+sylte8wxf+pMbfUlS7yRyv1EzNdWmyzGDBbhfuLmS9VXLPLv49VnYZ1ZNk+dKV+7LGVOCWdhjSkGGeWZ0wPL1D6rrHCQeUkkTpo0/wfov2QxjzhQwpdPkH6TBb5Ja7rABBO8dlSRNg86dfoM8x3pJhNciNVlcAc6Bfpc88EVbqZUx7mvscZ6/JITog/Mt5TROsR1PmAovT3CfZUFMMvUCdOvsJRhvn5u8SNw/3h4tBTJ+zTCwiRrfMilxOVNPmGB3arEa3ycPWNc4N0QOlfZ+arJuuk3meOzjBkjcN6VzAd8TLm+xCSvMsMMO/GmIHJGGeJcbif55rvOTfr5RyHiQAXygw2mWOPX9CZal/iEJfbT7t0PL8iCnuYt+gvzhiqc12xk0x9k0K+ElUn1h/9mTBfpevi42C5OjqnBFLfTIiegX2GYHpayV75bXOc9tsoMvcUHlbkDNdrRZ+k6t0Ln9RfjXuUhd48nETrIfzdy4z5Vt4pOF0faYpXXUjIfEe5lvi7TaaBGO/ohpsMpuRT0XN4fJMK9w+1K6CL0P4lU9jPNNENcKdPbKndi0ZxUPE4+4jJTHdDPRP77VsnpAqu8zR1W21yYiO8dftMB/eVIuRT0VshabPVJcPfS2zF9KQXTxdvlbkpwv8AL9CBOX7seshoS342tKUWU4A62StGdaot+tjpkij4Hd0uuPqES3BvspuV91nmjJvot7hTGHeFvTB6vTnDv0Uxs/VrqkP5z3uPN9tPFkUYYZ4YG47GWSutlBfrg6f6O+2UGuswMC8wwEmvJ4O6lu12nxenvl8QNNBK+NwZaCwuN4MdhvDnQEC+VGeMSH3IolpX+E9NV9taPGfA674IpxllujZLKfZHRwrulli5xtSplUgP0Rp7FVFg+5DArTrbYZ4AzBQa4yiBY54t6mI+idCRSE3+XzeB+xAbneL7AGGd5prYPHY40wEZYXgbbfPz4fe9puMZsfGuBHb7ie1xsN8Z/UOwvU1wvgjUWwtV6kG9YaJ2btoJmOe3+lxgt8NR76uMe5A7vMxCvfJ8/868j2ni+KqkmP+BZzrGRlqw5Q1fGq2RZDfDz0CoBaytUfsp4pCl2nrbLz/gMjPFi5NjkeZ7lO7bYjtD3MMD53HdK9NGXm7zsY4KxeOUe7/CId3iTrlhjyL3EEns84Hyki9ahTw/PhfutfQ7o4hzf8c/cU6HgEQQbnk8Trd38mBsZy9wgLx8njnMfsJJ9NNGi76bBec7wiF22eZQN3cdA6JwbPGAv/iv9IoP4IuN5CdT4uWtB+uIaidi9m4EwWgaZzJg6+xjjuTYdp50X10XfAr3GMjtMsBw3X1Q9/DCjqRB3XfSDNHmdWW7zR36b8Yy7GGY4PZRTlZ2v6mKU4Qz6MUaY5+/xkE1qEuEmLhU623w5apdny6Hv5hVuZNOfj5S/TvNfO/PlqMx3MzmRs5dGH8TJXxnnMlf4ZYR4tL35auIOVIR+gAm2maWXButMMlbOfLVyB8p3baANpsNyH79nmNFKwyXU2feDben/QJNuZnm1tPlyVMf3mvmRc4aJtMOtzlTf97FJ+o7Nl6O6v0c+4AGb9ZgvRyf53fpJ6r8Fs9GodiVMlAAAAABJRU5ErkJggg=='],
['Blue',
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFGklEQVR4nNXaMWwbVRzH8U8RUjMU1RIDaZeeGKADKJbaIWJozRYJWrKB6NBUKoIxS8VWUlaQCFurIpEOgMoCIgUhpIqwkA61SBUGmslZAkhUSlVEG5YyHOee7bvnu/MlgZ+s5Pm987uvn/+//7t79/bw0P9Qj23/KZa5wp16O92zneO9xkKK+AVe5slaut4m7jUWWctqqoe+du47XOPHYYeNSl8j932us1jmI9Xp6+K+zjX+qvTZKvSjc9/iah3pohz9KNwB81VWUfpq3AXNV1nD6ctyVzBfZYXoS3Ff43pV81VWNn1B7mUWa5+ry6iffij3dpivsh7RB7jv8DkrO4hVRFuc5+nHs9rus8j1nUYartu0OZPJvSvmG6oNltMGS3Pvuvky9QfL/NpXG3Ovs8DGjjOFdY92XkqIuTvM8QxHeGLnwHK1xc+s8nfeEek4WWPtP0B/m+UAcaxBX+4i/QZL/JnVtI9TnOJY/D4zD9px+mzzJXqTU30YedyxdoA+ZD7wDG8N1vZxH6fDem/lNtHH5mvntB7hJO9xNLM5zT3BElhgbpvpV2nnmO8A53gRfJV3rvS6TyMpzNDhYw4NHL/GZyxxrypxh0/zM8Y+ribQIQXWq2bqpt9gke9yMsbr7OPZgj9m2JeYYZp55rnb2xRHzhGeY2+wk7D5TvAWB7ldhDjWUG40mGM2h77NKs/n0IfNh8t5zgur+Lpmgzk6vMP+3qa/afMZbbZS9atJ5aAOJH9LQT8Ky7LrsY1i9LfzzbePC3zDCQ6WOfUG5ytzx2oMo/8hx3wn+IaTKAx9k3u8x0tJmma09e9GPn2ezpXM/Ru0OcanfQ1p7hU2y3QaqzGM/giXaaEk9Cf5Vyw93HeZrYQuRf9KqvIAl/mIozxbprebnOX9wBF9cXKFiLkR6OdTb98tn+PavMobwdRJVnzf5cII9FGqXAp6I2EttFaT58sR6UvpHhd5tdRnwvkkTV+74sk/Jr6UkzdzVSQPdukXquDl6ntwKZA0Aiqev9c5UxP9BmcL4zb5kpm+2rLzzoj033Oel4ami0RNWszTSGaAR3qYnj/L6BAf83Dg1dVPqdeJSqeYTpVnk8ISD0eZ54uP/VeVHlE0ewe0kxQa8b/K451Weuy7+prLySVrBR0Gp/kAzNFipXuWWrhjHWKipq4wzv7UWMylo7HI/U5xrQ+sAlTWGGimajbTzTuwj6OaxvktKa+ADvPd5x8x93EWei8tdl0R6LCUzNYRm3zJt3qf79zq/V12SxFTPMWl1JBHnKbBdPyV+tardlfjTKWWa6IU9xTT6WFNc2/SSnJLk4ilmi4GGzRSCTjzgNbAwLX4BczxZuLUf9WNkyWW2GKFsVQXt0ambxLxAHFo9mqMSSZzVo6aTPUR93E/4AY3khP0qTJ9g/Fk2CZZ6e0/xsokjphOLVn2q++5a+30hxNojDGeREuMlXkfHUd5FO4383lxXfRd0OOscDcJ2amstVJJlDcL9Bx6zj06fUSH0ywwy4fM5oxxN8ozQjlTgXl+jBaTOfQTHA5+sa5mkERqJnQzz3wBDb0+CdDv5Xj+F9OLsplFFoXNF1CpfTOByNnKaoro8AUtIg6kbtqjpLKiKuyvKkI/Tiu5nNhPg3WmmShlvoAq72cLuzbW71xMyg3eZnLwNrGaRtw/OJT+Ch3GknWScuYLqJb9muHIafBaTsKurhr3xw7SRyOaL6Da9yM/4Fs6tZgvoG3dt76N+gfaDbBaHMV3YgAAAABJRU5ErkJggg==']]])
iface.process([gr.test_data.BASE64_IMAGE])[0],
[
[
[
"Red",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFFElEQVR4nO3aT2gcVRzA8U+MTZq6xZBKdYvFFYyCtFq0UO3FehEUe1E8+AeaUw+C2pPiyS14UQ9tvXlroZ5EqVgrKmq8aAQjVXvQNuIWYxtLIlsSog0tehgnndmdmZ3ZXdMU8j2Et+/NvPfN2/f7vTeT9PzjquSaKy3QJiveS8uK99Ky4r20rHgvLSveS8uK99JylXlf5CKuLu8pvmUOXHuFXfJRZyI0Dlju3nNMUG+qX77ef1NjKqV1OXpfZJLJMAQTWXbeU0xkGgcso3xSZ4yfkqTnOcaLkZplMd9pwRdwjFH+ildeYe/s4MMkHyXVx9bJKLUuSmVykRpjKdKnOMw8p1Juvzzfx3kQ7KJKpauWDUxSSwm+Gd7lR7CtaXkscnm+62HhELcy8v/M/TRj6RljntdC6WxS80nX7esc5wR/J7V+wTy/p09wAy3i8hBH2MMeBvM7xskOvjE+4k9uLtJn6/x9nr1UqKanqjSygw8HeJs/C3Yr/77Thv0kYynLbCb8OZFzeDAQKRfbL3PaT6UH3zyHqTJWcJqHeCbysZ19vqX9TynBN0aVb5BbepgBHmMvd0Xq2z+ftLFy3sudLgKGGOb1cGOJctl7C9cX6TSgpf0pDvADCkrvYF1662XvQfa3pS5ifyRSOcMB3mSCySK93cbzPJ55TWydjFDjlQ7s90Q+Hi6YLjDMS7zAcKsrG9f3INUO7E9HyoWkh0LXnLtPo3eNWsf2hRjgYV4qeFej9yd8whnE7bvOAMOh8SOsKXh7o3cZnI3UDFLlV3a1L5lAkIwfyUwaGTR63085qa8KB7tkP8TzuXVLbOKmpvpG7xvYmf7QUOnMfjNPszdHuggo8T5P8FbTabSd/bJS3H4I7Oa+IgMd5VVG2d90okz2rjHdqtNKbvttBXUXORApfxYWgieGZO+v+DJf15V0+yFuoxo/x+Xnc+rsYh8oMchWSqAn8f8hxhnnoYJPxzXqbGG0LdEGXuH78MzTQzWejpPnexMlvuJCjgEO8gGosKV9z0am4r0txFuTvfvZzhxf5xhggbP83K5fIr2cDMvHwSp+DB+UZOSTCrdzkvFWY2xC03x0SC+oMUoVbGWBGr8h+jz/Pfvib3x2MMM4F9iePsZ2Ku1ue4nG/fSGsxY8MdxDmT4qrEV0vu9OemfyKGVO8DGzScNcYJoN9HdsfA1rWBNO9r2RpmepsDmUjnkvhEf1QzxHjQv0s5NNnOZdxuP2ZzjKe62EekKVjAtWc138st2UGeQtRpq+z//y4BnOMstRSuwMm9dRpp8zjIfnrRJrmWWOPu7njnino5HyKj5ljsdTslMfffQkNa1jY8rv/J/3Jf7gHJdS7g/spznNNAv0sYHbk1bIoncPb/AheJLd8ctW0Z9ivJYKfUlNMW9F7Fuy6D3Gy2G5xLGw515Wp+SyATZG1nEasfeDvWzgxhT7GWaK2OMd8ADHOU8v/7A65asPvsCceSnhdw7sN1NOGmCGE2HUZvMX37GLUUbAqqbgWxyxzJ1Fkmnq+9iWc19nPevTu/gFofEgUhZGRvBl0OI9cob9Jc5yLt0++jxfD89xUVoGXwa5/i7Vnv1saFznIFvjxuUcwdepd0B++2Cv3ghGGOQ8D6Bg8GWQfP5uSXbGDDjJU2G5zDHWs6Gt4Zpp0zugpf1uvqPEEXYUD74MOvIOyLCf5RzbuKXjURrogndAs33nwZdB17wDLvEbs10Kvgy67L1k/Asi+GhgiYdDNAAAAABJRU5ErkJggg==",
],
[
"Green",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFvElEQVR4nNXaT2gexx3G8U9iIyNbQS+oKDhE2G6DQkDCAhWKcrGwe4xooKU15BAFQ6CX4NxNqxLRXlqi5pZCiHJo4ubSgnwKldHJIgeBgnTJe6mEWpWKCCIUR0jEqIfVvt7V/nl3913J7XOwxjO7M1+tfs/M7G/2KYf+H/X0iY+wyEds19zrUyf4vJvMRohf5hX66un7ZLibzNFMa6qJvm7ube7xoN1lHdPXx73HPHNlbumAvibuee7xbaV7K9F3zP0Ff6ljuihJ3wF3jvkqqzB9Je6C5qusAvQluSuYr7Jy6ctw32O+qvkqK4O+GPcic/Wv1SWUoG/HfRLmq6wIfTb3Np+yfHpUhbTPr/i+sylte8wxf+pMbfUlS7yRyv1EzNdWmyzGDBbhfuLmS9VXLPLv49VnYZ1ZNk+dKV+7LGVOCWdhjSkGGeWZ0wPL1D6rrHCQeUkkTpo0/wfov2QxjzhQwpdPkH6TBb5Ja7rABBO8dlSRNg86dfoM8x3pJhNciNVlcAc6Bfpc88EVbqZUx7mvscZ6/JITog/Mt5TROsR1PmAovT3CfZUFMMvUCdOvsJRhvn5u8SNw/3h4tBTJ+zTCwiRrfMilxOVNPmGB3arEa3ycPWNc4N0QOlfZ+arJuuk3meOzjBkjcN6VzAd8TLm+xCSvMsMMO/GmIHJGGeJcbif55rvOTfr5RyHiQAXygw2mWOPX9CZal/iEJfbT7t0PL8iCnuYt+gvzhiqc12xk0x9k0K+ElUn1h/9mTBfpevi42C5OjqnBFLfTIiegX2GYHpayV75bXOc9tsoMvcUHlbkDNdrRZ+k6t0Ln9RfjXuUhd48nETrIfzdy4z5Vt4pOF0faYpXXUjIfEe5lvi7TaaBGO/ohpsMpuRT0XN4fJMK9w+1K6CL0P4lU9jPNNENcKdPbKndi0ZxUPE4+4jJTHdDPRP77VsnpAqu8zR1W21yYiO8dftMB/eVIuRT0VshabPVJcPfS2zF9KQXTxdvlbkpwv8AL9CBOX7seshoS342tKUWU4A62StGdaot+tjpkij4Hd0uuPqES3BvspuV91nmjJvot7hTGHeFvTB6vTnDv0Uxs/VrqkP5z3uPN9tPFkUYYZ4YG47GWSutlBfrg6f6O+2UGuswMC8wwEmvJ4O6lu12nxenvl8QNNBK+NwZaCwuN4MdhvDnQEC+VGeMSH3IolpX+E9NV9taPGfA674IpxllujZLKfZHRwrulli5xtSplUgP0Rp7FVFg+5DArTrbYZ4AzBQa4yiBY54t6mI+idCRSE3+XzeB+xAbneL7AGGd5prYPHY40wEZYXgbbfPz4fe9puMZsfGuBHb7ie1xsN8Z/UOwvU1wvgjUWwtV6kG9YaJ2btoJmOe3+lxgt8NR76uMe5A7vMxCvfJ8/868j2ni+KqkmP+BZzrGRlqw5Q1fGq2RZDfDz0CoBaytUfsp4pCl2nrbLz/gMjPFi5NjkeZ7lO7bYjtD3MMD53HdK9NGXm7zsY4KxeOUe7/CId3iTrlhjyL3EEns84Hyki9ahTw/PhfutfQ7o4hzf8c/cU6HgEQQbnk8Trd38mBsZy9wgLx8njnMfsJJ9NNGi76bBec7wiF22eZQN3cdA6JwbPGAv/iv9IoP4IuN5CdT4uWtB+uIaidi9m4EwWgaZzJg6+xjjuTYdp50X10XfAr3GMjtMsBw3X1Q9/DCjqRB3XfSDNHmdWW7zR36b8Yy7GGY4PZRTlZ2v6mKU4Qz6MUaY5+/xkE1qEuEmLhU623w5apdny6Hv5hVuZNOfj5S/TvNfO/PlqMx3MzmRs5dGH8TJXxnnMlf4ZYR4tL35auIOVIR+gAm2maWXButMMlbOfLVyB8p3baANpsNyH79nmNFKwyXU2feDben/QJNuZnm1tPlyVMf3mvmRc4aJtMOtzlTf97FJ+o7Nl6O6v0c+4AGb9ZgvRyf53fpJ6r8Fs9GodiVMlAAAAABJRU5ErkJggg==",
],
[
"Blue",
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFGklEQVR4nNXaMWwbVRzH8U8RUjMU1RIDaZeeGKADKJbaIWJozRYJWrKB6NBUKoIxS8VWUlaQCFurIpEOgMoCIgUhpIqwkA61SBUGmslZAkhUSlVEG5YyHOee7bvnu/MlgZ+s5Pm987uvn/+//7t79/bw0P9Qj23/KZa5wp16O92zneO9xkKK+AVe5slaut4m7jUWWctqqoe+du47XOPHYYeNSl8j932us1jmI9Xp6+K+zjX+qvTZKvSjc9/iah3pohz9KNwB81VWUfpq3AXNV1nD6ctyVzBfZYXoS3Ff43pV81VWNn1B7mUWa5+ry6iffij3dpivsh7RB7jv8DkrO4hVRFuc5+nHs9rus8j1nUYartu0OZPJvSvmG6oNltMGS3Pvuvky9QfL/NpXG3Ovs8DGjjOFdY92XkqIuTvM8QxHeGLnwHK1xc+s8nfeEek4WWPtP0B/m+UAcaxBX+4i/QZL/JnVtI9TnOJY/D4zD9px+mzzJXqTU30YedyxdoA+ZD7wDG8N1vZxH6fDem/lNtHH5mvntB7hJO9xNLM5zT3BElhgbpvpV2nnmO8A53gRfJV3rvS6TyMpzNDhYw4NHL/GZyxxrypxh0/zM8Y+ribQIQXWq2bqpt9gke9yMsbr7OPZgj9m2JeYYZp55rnb2xRHzhGeY2+wk7D5TvAWB7ldhDjWUG40mGM2h77NKs/n0IfNh8t5zgur+Lpmgzk6vMP+3qa/afMZbbZS9atJ5aAOJH9LQT8Ky7LrsY1i9LfzzbePC3zDCQ6WOfUG5ytzx2oMo/8hx3wn+IaTKAx9k3u8x0tJmma09e9GPn2ezpXM/Ru0OcanfQ1p7hU2y3QaqzGM/giXaaEk9Cf5Vyw93HeZrYQuRf9KqvIAl/mIozxbprebnOX9wBF9cXKFiLkR6OdTb98tn+PavMobwdRJVnzf5cII9FGqXAp6I2EttFaT58sR6UvpHhd5tdRnwvkkTV+74sk/Jr6UkzdzVSQPdukXquDl6ntwKZA0Aiqev9c5UxP9BmcL4zb5kpm+2rLzzoj033Oel4ami0RNWszTSGaAR3qYnj/L6BAf83Dg1dVPqdeJSqeYTpVnk8ISD0eZ54uP/VeVHlE0ewe0kxQa8b/K451Weuy7+prLySVrBR0Gp/kAzNFipXuWWrhjHWKipq4wzv7UWMylo7HI/U5xrQ+sAlTWGGimajbTzTuwj6OaxvktKa+ADvPd5x8x93EWei8tdl0R6LCUzNYRm3zJt3qf79zq/V12SxFTPMWl1JBHnKbBdPyV+tardlfjTKWWa6IU9xTT6WFNc2/SSnJLk4ilmi4GGzRSCTjzgNbAwLX4BczxZuLUf9WNkyWW2GKFsVQXt0ambxLxAHFo9mqMSSZzVo6aTPUR93E/4AY3khP0qTJ9g/Fk2CZZ6e0/xsokjphOLVn2q++5a+30hxNojDGeREuMlXkfHUd5FO4383lxXfRd0OOscDcJ2amstVJJlDcL9Bx6zj06fUSH0ywwy4fM5oxxN8ozQjlTgXl+jBaTOfQTHA5+sa5mkERqJnQzz3wBDb0+CdDv5Xj+F9OLsplFFoXNF1CpfTOByNnKaoro8AUtIg6kbtqjpLKiKuyvKkI/Tiu5nNhPg3WmmShlvoAq72cLuzbW71xMyg3eZnLwNrGaRtw/OJT+Ch3GknWScuYLqJb9muHIafBaTsKurhr3xw7SRyOaL6Da9yM/4Fs6tZgvoG3dt76N+gfaDbBaHMV3YgAAAABJRU5ErkJggg==",
],
]
],
)
class TestTimeseries(unittest.TestCase):
def test_as_component(self):
timeseries_output = gr.outputs.Timeseries(label="Disease")
self.assertEqual(timeseries_output.get_template_context(), {
'x': None, 'y': None, 'name': 'timeseries', 'label': 'Disease'
})
data = {'Name': ['Tom', 'nick', 'krish', 'jack'], 'Age': [20, 21, 19, 18]}
self.assertEqual(
timeseries_output.get_template_context(),
{"x": None, "y": None, "name": "timeseries", "label": "Disease"},
)
data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]}
df = pd.DataFrame(data)
self.assertEqual(timeseries_output.postprocess(df),{'headers': ['Name', 'Age'],
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
['jack', 18]]})
self.assertEqual(
timeseries_output.postprocess(df),
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease")
output = timeseries_output.postprocess(df)
self.assertEqual(output, {'headers': ['Name', 'Age'],
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
['jack', 18]]})
self.assertEqual(
output,
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = timeseries_output.save_flagged(tmpdirname, "timeseries_output", output, None)
self.assertEqual(to_save, '{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
'["jack", 18]]}')
self.assertEqual(timeseries_output.restore_flagged(tmpdirname, to_save, None), {"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21],
["krish", 19], ["jack", 18]]})
to_save = timeseries_output.save_flagged(
tmpdirname, "timeseries_output", output, None
)
self.assertEqual(
to_save,
'{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
'["jack", 18]]}',
)
self.assertEqual(
timeseries_output.restore_flagged(tmpdirname, to_save, None),
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
class TestNames(unittest.TestCase):
def test_no_duplicate_uncased_names(self): # this ensures that get_input_instance() works correctly when instantiating from components
def test_no_duplicate_uncased_names(
self,
): # this ensures that get_input_instance() works correctly when instantiating from components
subclasses = gr.outputs.OutputComponent.__subclasses__()
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,12 +1,13 @@
import unittest
import os
import pathlib
import gradio as gr
from PIL import Image
import tempfile
import unittest
import matplotlib.pyplot as plt
import numpy as np
import os
import tempfile
from PIL import Image
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -14,23 +15,27 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class ImagePreprocessing(unittest.TestCase):
def test_decode_base64_to_image(self):
output_image = gr.processing_utils.decode_base64_to_image(
gr.test_data.BASE64_IMAGE)
gr.test_data.BASE64_IMAGE
)
self.assertIsInstance(output_image, Image.Image)
def test_encode_url_or_file_to_base64(self):
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
"test/test_data/test_image.png")
"test/test_data/test_image.png"
)
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
def test_encode_file_to_base64(self):
output_base64 = gr.processing_utils.encode_file_to_base64(
"test/test_data/test_image.png")
"test/test_data/test_image.png"
)
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
def test_encode_url_to_base64(self):
output_base64 = gr.processing_utils.encode_url_to_base64(
"https://raw.githubusercontent.com/gradio-app/gradio/master/test"
"/test_data/test_image.png")
"/test_data/test_image.png"
)
self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
# def test_encode_plot_to_base64(self): # Commented out because this is throwing errors on Windows. Possibly due to different matplotlib behavior on Windows?
@ -49,19 +54,21 @@ class ImagePreprocessing(unittest.TestCase):
img = Image.open("test/test_data/test_image.png")
new_img = gr.processing_utils.resize_and_crop(img, (20, 20))
self.assertEqual(new_img.size, (20, 20))
self.assertRaises(ValueError, gr.processing_utils.resize_and_crop,
**{'img': img, 'size': (20, 20), 'crop_type': 'test'})
self.assertRaises(
ValueError,
gr.processing_utils.resize_and_crop,
**{"img": img, "size": (20, 20), "crop_type": "test"}
)
class AudioPreprocessing(unittest.TestCase):
def test_audio_from_file(self):
audio = gr.processing_utils.audio_from_file(
"test/test_data/test_audio.wav")
audio = gr.processing_utils.audio_from_file("test/test_data/test_audio.wav")
self.assertEqual(audio[0], 22050)
self.assertIsInstance(audio[1], np.ndarray)
def test_audio_to_file(self):
audio = gr.processing_utils.audio_from_file(
"test/test_data/test_audio.wav")
audio = gr.processing_utils.audio_from_file("test/test_data/test_audio.wav")
gr.processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
self.assertTrue(os.path.exists("test_audio_to_file"))
os.remove("test_audio_to_file")
@ -69,31 +76,39 @@ class AudioPreprocessing(unittest.TestCase):
class OutputPreprocessing(unittest.TestCase):
def test_decode_base64_to_binary(self):
binary = gr.processing_utils.decode_base64_to_binary(
gr.test_data.BASE64_IMAGE)
binary = gr.processing_utils.decode_base64_to_binary(gr.test_data.BASE64_IMAGE)
self.assertEqual(gr.test_data.BINARY_IMAGE, binary)
def test_decode_base64_to_file(self):
temp_file = gr.processing_utils.decode_base64_to_file(
gr.test_data.BASE64_IMAGE)
temp_file = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_IMAGE)
self.assertIsInstance(temp_file, tempfile._TemporaryFileWrapper)
def test_create_tmp_copy_of_file(self):
temp_file = gr.processing_utils.create_tmp_copy_of_file(
"test.txt")
temp_file = gr.processing_utils.create_tmp_copy_of_file("test.txt")
self.assertIsInstance(temp_file, tempfile._TemporaryFileWrapper)
float_dtype_list = [float, float, np.double, np.single, np.float32,
np.float64, 'float32', 'float64']
float_dtype_list = [
float,
float,
np.double,
np.single,
np.float32,
np.float64,
"float32",
"float64",
]
def test_float_conversion_dtype(self):
"""Test any convertion from a float dtype to an other."""
x = np.array([-1, 1])
# Test all combinations of dtypes conversions
dtype_combin = np.array(np.meshgrid(
OutputPreprocessing.float_dtype_list,
OutputPreprocessing.float_dtype_list)).T.reshape(-1, 2)
dtype_combin = np.array(
np.meshgrid(
OutputPreprocessing.float_dtype_list,
OutputPreprocessing.float_dtype_list,
)
).T.reshape(-1, 2)
for dtype_in, dtype_out in dtype_combin:
x = x.astype(dtype_in)
@ -108,5 +123,6 @@ class OutputPreprocessing(unittest.TestCase):
y = gr.processing_utils._convert(x, np.floating)
assert y.dtype == x.dtype
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,13 +1,14 @@
import io
import sys
import json
import os
import sys
import threading
import unittest
import unittest.mock as mock
from gradio import tunneling, networking, Interface
import threading
import paramiko
import os
import paramiko
from gradio import Interface, networking, tunneling
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -27,13 +28,13 @@ class TestTunneling(unittest.TestCase):
io.close()
class TestVerbose(unittest.TestCase):
"""Not absolutely needed but just including them for the sake of completion."""
class TestVerbose(unittest.TestCase):
"""Not absolutely needed but just including them for the sake of completion."""
def setUp(self):
self.message = "print test"
self.capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = self.capturedOutput # and redirect stdout.
self.capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = self.capturedOutput # and redirect stdout.
def test_verbose_debug_true(self):
tunneling.verbose(self.message, debug_mode=True)
@ -41,10 +42,11 @@ class TestVerbose(unittest.TestCase):
def test_verbose_debug_false(self):
tunneling.verbose(self.message, debug_mode=False)
self.assertEqual(self.capturedOutput.getvalue().strip(), '')
self.assertEqual(self.capturedOutput.getvalue().strip(), "")
def tearDown(self):
sys.stdout = sys.__stdout__
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,14 +1,15 @@
import ipaddress
import os
import pkg_resources
import requests
import unittest
import unittest.mock as mock
import warnings
import pkg_resources
import requests
import gradio
from gradio.utils import *
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -21,7 +22,10 @@ class TestUtils(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
version_check()
self.assertEqual(str(w[-1].message), "gradio is not setup or installed properly. Unable to get version info.")
self.assertEqual(
str(w[-1].message),
"gradio is not setup or installed properly. Unable to get version info.",
)
@mock.patch("requests.get")
def test_should_warn_with_unable_to_parse(self, mock_get):
@ -31,7 +35,9 @@ class TestUtils(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
version_check()
self.assertEqual(str(w[-1].message), "unable to parse version details from package URL.")
self.assertEqual(
str(w[-1].message), "unable to parse version details from package URL."
)
@mock.patch("requests.Response.json")
def test_should_warn_url_not_having_version(self, mock_json):
@ -41,17 +47,18 @@ class TestUtils(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
version_check()
self.assertEqual(str(w[-1].message), "package URL does not contain version info.")
self.assertEqual(
str(w[-1].message), "package URL does not contain version info."
)
@mock.patch("requests.post")
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
mock_post.side_effect = requests.ConnectionError()
error_analytics("placeholder", "placeholder")
mock_post.assert_called()
@mock.patch("requests.post")
@mock.patch("requests.post")
def test_error_analytics_successful(self, mock_post):
error_analytics("placeholder", "placeholder")
mock_post.assert_called()
@ -61,29 +68,29 @@ class TestUtils(unittest.TestCase):
mock_post.side_effect = requests.ConnectionError()
launch_analytics(data={})
mock_post.assert_called()
@mock.patch("IPython.get_ipython")
def test_colab_check_no_ipython(self, mock_get_ipython):
mock_get_ipython.return_value = None
assert colab_check() is False
@mock.patch("IPython.get_ipython")
def test_ipython_check_import_fail(self, mock_get_ipython):
mock_get_ipython.side_effect = ImportError()
assert ipython_check() is False
@mock.patch("IPython.get_ipython")
def test_ipython_check_no_ipython(self, mock_get_ipython):
mock_get_ipython.return_value = None
assert ipython_check() is False
@mock.patch("requests.get")
def test_readme_to_html_doesnt_crash_on_connection_error(self, mock_get):
mock_get.side_effect = requests.ConnectionError()
readme_to_html("placeholder")
def test_readme_to_html_correct_parse(self):
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
class TestIPAddress(unittest.TestCase):
@ -101,6 +108,5 @@ class TestIPAddress(unittest.TestCase):
self.assertEqual(ip, "No internet connection")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,10 @@
import os, sys
import subprocess
from jinja2 import Template
import json
import os
import re
import subprocess
import sys
from jinja2 import Template
GRADIO_DIR = os.path.join(os.getcwd(), os.pardir, os.pardir)
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
@ -22,17 +24,19 @@ for demo_name in demos_to_run:
demo_folder = os.path.join(GRADIO_DEMO_DIR, demo_name)
requirements_file = os.path.join(demo_folder, "requirements.txt")
if os.path.exists(requirements_file):
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", requirements_file])
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_file]
)
setup_file = os.path.join(demo_folder, "setup.sh")
if os.path.exists(setup_file):
subprocess.check_call(["sh", setup_file])
subprocess.check_call(["sh", setup_file])
demo_port_sets.append((demo_name, port))
port += 1
with open("nginx_template.conf") as nginx_template_conf:
with open("nginx_template.conf") as nginx_template_conf:
template = Template(nginx_template_conf.read())
output_nginx_conf = template.render(demo_port_sets=demo_port_sets)
with open("nginx.conf", "w") as nginx_conf:
nginx_conf.write(output_nginx_conf)
with open("demos.json", "w") as demos_file:
json.dump(demo_port_sets, demos_file)
json.dump(demo_port_sets, demos_file)

View File

@ -1,11 +1,12 @@
import time
import os
import requests
import json
import importlib
import json
import os
import subprocess
import sys
import threading
import subprocess
import time
import requests
LAUNCH_PERIOD = 60
GRADIO_DEMO_DIR = "../../demo"
@ -14,18 +15,18 @@ sys.path.insert(0, GRADIO_DEMO_DIR)
with open("demos.json") as demos_file:
demo_port_sets = json.load(demos_file)
def launch_demo(demo_folder):
subprocess.call(f"cd {demo_folder} && python run.py", shell=True)
for demo_name, port in demo_port_sets:
demo_folder = os.path.join(GRADIO_DEMO_DIR, demo_name)
demo_file = os.path.join(demo_folder, "run.py")
with open(demo_file, 'r') as file:
with open(demo_file, "r") as file:
filedata = file.read()
filedata = filedata.replace(
f'iface.launch()',
f'iface.launch(server_port={port})')
with open(demo_file, 'w') as file:
filedata = filedata.replace(f"iface.launch()", f"iface.launch(server_port={port})")
with open(demo_file, "w") as file:
file.write(filedata)
demo_thread = threading.Thread(target=launch_demo, args=(demo_folder,))
demo_thread.start()

View File

@ -1,81 +1,108 @@
import os
import json
from jinja2 import Template
import requests
import markdown2
import re
from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent
from gradio.interface import Interface
import inspect
import json
import os
import re
import markdown2
import requests
from jinja2 import Template
from gradio.inputs import InputComponent
from gradio.interface import Interface
from gradio.outputs import OutputComponent
GRADIO_DIR = "../../"
GRADIO_GUIDES_DIR = os.path.join(GRADIO_DIR, "guides")
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
guide_names = [] # used for dropdown in navbar
guide_names = [] # used for dropdown in navbar
for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
if "template" in guide or "getting_started" in guide:
continue
guide_name = guide[:-3]
pretty_guide_name = " ".join([word.capitalize().replace("Ml", "ML")
for word in guide_name.split("_")])
pretty_guide_name = " ".join(
[word.capitalize().replace("Ml", "ML") for word in guide_name.split("_")]
)
guide_names.append((guide_name, pretty_guide_name))
def render_index():
os.makedirs("generated", exist_ok=True)
with open("src/tweets.json", encoding='utf-8') as tweets_file:
with open("src/tweets.json", encoding="utf-8") as tweets_file:
tweets = json.load(tweets_file)
star_count = "{:,}".format(requests.get("https://api.github.com/repos/gradio-app/gradio"
).json()["stargazers_count"])
with open("src/index_template.html", encoding='utf-8') as template_file:
star_count = "{:,}".format(
requests.get("https://api.github.com/repos/gradio-app/gradio").json()[
"stargazers_count"
]
)
with open("src/index_template.html", encoding="utf-8") as template_file:
template = Template(template_file.read())
output_html = template.render(tweets=tweets, star_count=star_count, guide_names=guide_names)
with open(os.path.join("generated", "index.html"), "w", encoding='utf-8') as generated_template:
output_html = template.render(
tweets=tweets, star_count=star_count, guide_names=guide_names
)
with open(
os.path.join("generated", "index.html"), "w", encoding="utf-8"
) as generated_template:
generated_template.write(output_html)
def render_guides():
guides = []
for guide in os.listdir(GRADIO_GUIDES_DIR):
if "template" in guide:
continue
with open(os.path.join(GRADIO_GUIDES_DIR, guide), encoding='utf-8') as guide_file:
with open(
os.path.join(GRADIO_GUIDES_DIR, guide), encoding="utf-8"
) as guide_file:
guide_text = guide_file.read()
code_tags = re.findall(r'\{\{ code\["([^\s]*)"\] \}\}', guide_text)
demo_names = re.findall(r'\{\{ demos\["([^\s]*)"\] \}\}', guide_text)
code, demos = {}, {}
guide_text = guide_text.replace(
"website/src/assets", "/assets").replace(
"```python\n", "<pre><code class='lang-python'>").replace(
"```bash\n", "<pre><code class='lang-bash'>").replace(
"```directory\n", "<pre><code class='lang-bash'>").replace(
"```csv\n", "<pre><code class='lang-bash'>").replace(
"```", "</code></pre>")
guide_text = (
guide_text.replace("website/src/assets", "/assets")
.replace("```python\n", "<pre><code class='lang-python'>")
.replace("```bash\n", "<pre><code class='lang-bash'>")
.replace("```directory\n", "<pre><code class='lang-bash'>")
.replace("```csv\n", "<pre><code class='lang-bash'>")
.replace("```", "</code></pre>")
)
for code_src in code_tags:
with open(os.path.join(GRADIO_DEMO_DIR, code_src, "run.py")) as code_file:
python_code = code_file.read().replace(
'if __name__ == "__main__":\n iface.launch()', "iface.launch()")
code[code_src] = "<pre><code class='lang-python'>" + \
python_code + "</code></pre>"
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
)
code[code_src] = (
"<pre><code class='lang-python'>" + python_code + "</code></pre>"
)
for demo_name in demo_names:
demos[demo_name] = "<div id='interface_" + demo_name + "'></div>"
guide_template = Template(guide_text)
guide_output = guide_template.render(code=code, demos=demos)
output_html = markdown2.markdown(guide_output)
output_html = output_html.replace("<a ", "<a target='blank' ")
for match in re.findall(r'<h3>([A-Za-z0-9 ]*)<\/h3>', output_html):
for match in re.findall(r"<h3>([A-Za-z0-9 ]*)<\/h3>", output_html):
output_html = output_html.replace(
f"<h3>{match}</h3>", f"<h3 id={match.lower().replace(' ', '_')}>{match}</h3>")
f"<h3>{match}</h3>",
f"<h3 id={match.lower().replace(' ', '_')}>{match}</h3>",
)
os.makedirs("generated", exist_ok=True)
guide = guide[:-3]
os.makedirs(os.path.join(
"generated", guide), exist_ok=True)
with open("src/guides_template.html", encoding='utf-8') as general_template_file:
os.makedirs(os.path.join("generated", guide), exist_ok=True)
with open(
"src/guides_template.html", encoding="utf-8"
) as general_template_file:
general_template = Template(general_template_file.read())
with open(os.path.join("generated", guide, "index.html"), "w", encoding='utf-8') as generated_template:
output_html = general_template.render(template_html=output_html, demo_names=demo_names, guide_names=guide_names)
with open(
os.path.join("generated", guide, "index.html"), "w", encoding="utf-8"
) as generated_template:
output_html = general_template.render(
template_html=output_html,
demo_names=demo_names,
guide_names=guide_names,
)
generated_template.write(output_html)
def render_docs():
if os.path.exists("generated/colab_links.json"):
with open("generated/colab_links.json") as demo_links_file:
@ -110,10 +137,15 @@ def render_docs():
name = line[:space_index]
documented_params.add(name)
params_doc.append(
(name, line[space_index+2:colon_index-1], line[colon_index+2:]))
(
name,
line[space_index + 2 : colon_index - 1],
line[colon_index + 2 :],
)
)
elif mode == "out":
colon_index = line.index(":")
return_doc.append((line[1:colon_index-1], line[colon_index+2:]))
return_doc.append((line[1 : colon_index - 1], line[colon_index + 2 :]))
params = inspect.getfullargspec(func)
param_set = []
for i in range(len(params.args)):
@ -139,13 +171,20 @@ def render_docs():
inp["doc"] = "\n".join(doc_lines[:-2])
inp["type"] = doc_lines[-2].split("type: ")[-1]
inp["demos"] = doc_lines[-1][7:].split(", ")
_, inp["params"], inp["params_doc"], _ = get_function_documentation(cls.__init__)
_, inp["params"], inp["params_doc"], _ = get_function_documentation(
cls.__init__
)
inp["shortcuts"] = list(cls.get_shortcut_implementations().items())
if "interpret" in cls.__dict__:
inp["interpret"], inp["interpret_params"], inp["interpret_params_doc"], _ = get_function_documentation(
cls.interpret)
(
inp["interpret"],
inp["interpret_params"],
inp["interpret_params_doc"],
_,
) = get_function_documentation(cls.interpret)
_, _, _, inp["interpret_returns_doc"] = get_function_documentation(
cls.get_interpretation_scores)
cls.get_interpretation_scores
)
return inp
@ -178,24 +217,32 @@ def render_docs():
os.makedirs("generated", exist_ok=True)
with open("src/docs_template.html") as template_file:
template = Template(template_file.read())
output_html = template.render(docs=docs, demo_links=demo_links, guide_names=guide_names)
output_html = template.render(
docs=docs, demo_links=demo_links, guide_names=guide_names
)
os.makedirs(os.path.join("generated", "docs"), exist_ok=True)
with open(os.path.join("generated", "docs", "index.html"), "w") as generated_template:
with open(
os.path.join("generated", "docs", "index.html"), "w"
) as generated_template:
generated_template.write(output_html)
def render_other():
os.makedirs("generated", exist_ok=True)
for template_filename in os.listdir("src/other_templates"):
with open(os.path.join("src/other_templates", template_filename)) as template_file:
with open(
os.path.join("src/other_templates", template_filename)
) as template_file:
template = Template(template_file.read())
output_html = template.render(guide_names=guide_names)
folder_name = template_filename[:-14]
os.makedirs(os.path.join("generated", folder_name), exist_ok=True)
with open(os.path.join("generated", folder_name, "index.html"), "w", encoding='utf-8') as generated_template:
with open(
os.path.join("generated", folder_name, "index.html"), "w", encoding="utf-8"
) as generated_template:
generated_template.write(output_html)
if __name__ == "__main__":
render_index()
render_guides()

View File

@ -8,11 +8,15 @@ for folder in ("generated", "dist"):
for root, _, files in os.walk(folder):
for file in files:
if file.endswith(".html"):
with open(os.path.join(root, file), encoding='utf-8') as old_file:
with open(os.path.join(root, file), encoding="utf-8") as old_file:
content = old_file.read()
for old_name, new_name in style_map.items():
content = content.replace(old_name, new_name)
with open(os.path.join(root, file), "w", encoding='utf-8') as new_file:
with open(os.path.join(root, file), "w", encoding="utf-8") as new_file:
new_file.write(content)
elif file.startswith("style.") and file.endswith(".css") and file not in list(style_map.values()):
elif (
file.startswith("style.")
and file.endswith(".css")
and file not in list(style_map.values())
):
os.remove(os.path.join(root, file))

View File

@ -1,5 +1,6 @@
import os
import json
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
@ -23,17 +24,24 @@ GRADIO_DEMO_DIR = "../../demo/"
NOTEBOOK_TYPE = "application/vnd.google.colaboratory"
GOOGLE_FOLDER_TYPE = "application/vnd.google-apps.folder"
def run():
drive = GoogleDrive(gauth)
demo_links = {}
with open("colab_template.ipynb") as notebook_template_file:
notebook_template = notebook_template_file.read()
file_list = drive.ListFile({'q': f"title='Demos' and mimeType='{GOOGLE_FOLDER_TYPE}' and 'root' in parents and trashed=false"}).GetList()
file_list = drive.ListFile(
{
"q": f"title='Demos' and mimeType='{GOOGLE_FOLDER_TYPE}' and 'root' in parents and trashed=false"
}
).GetList()
if len(file_list) > 0:
demo_folder = file_list[0].metadata["id"]
else:
demo_folder_file = drive.CreateFile({'title' : "Demos", 'mimeType' : GOOGLE_FOLDER_TYPE})
demo_folder_file = drive.CreateFile(
{"title": "Demos", "mimeType": GOOGLE_FOLDER_TYPE}
)
demo_folder_file.Upload()
demo_folder = demo_folder_file.metadata["id"]
for demo_name in os.listdir(GRADIO_DEMO_DIR):
@ -41,29 +49,40 @@ def run():
print("--- " + demo_name + " ---")
with open(os.path.join(GRADIO_DEMO_DIR, demo_name, "run.py")) as demo_file:
demo_content = demo_file.read()
demo_content = demo_content.replace('if __name__ == "__main__":\n iface.launch()', "iface.launch()")
demo_content = demo_content.replace(
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
)
lines = demo_content.split("/n")
demo_content = [line + "\n" if i != len(lines) - 1 else line for i, line in enumerate(lines)]
demo_content = [
line + "\n" if i != len(lines) - 1 else line
for i, line in enumerate(lines)
]
notebook = json.loads(notebook_template)
notebook["cells"][1]["source"] = demo_content
file_list = drive.ListFile({'q': f"title='{notebook_title}' and mimeType='{NOTEBOOK_TYPE}' and 'root' in parents and trashed=false"}).GetList()
file_list = drive.ListFile(
{
"q": f"title='{notebook_title}' and mimeType='{NOTEBOOK_TYPE}' and 'root' in parents and trashed=false"
}
).GetList()
if len(file_list) > 0:
drive_file = file_list[0]
else:
drive_file = drive.CreateFile({
'title': notebook_title,
'mimeType': NOTEBOOK_TYPE,
'parents': [{"id": demo_folder}]
})
drive_file = drive.CreateFile(
{
"title": notebook_title,
"mimeType": NOTEBOOK_TYPE,
"parents": [{"id": demo_folder}],
}
)
drive_file.SetContentString(json.dumps(notebook))
drive_file.Upload()
drive_file.InsertPermission({
'type': 'anyone',
'value': 'anyone',
'role': 'reader'})
demo_links[demo_name] = drive_file['alternateLink']
drive_file.InsertPermission(
{"type": "anyone", "value": "anyone", "role": "reader"}
)
demo_links[demo_name] = drive_file["alternateLink"]
with open("../.env", "w") as env_file:
env_file.write(f"COLAB_NOTEBOOK_LINKS='{json.dumps(demo_links)}'")
if __name__ == "__main__":
run()
run()