mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Format The Codebase
- black formatting - isort formatting
This commit is contained in:
parent
7fc0c83beb
commit
cc0cff893f
@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def calculator(num1, operation, num2):
|
||||
if operation == "add":
|
||||
return num1 + num2
|
||||
@ -10,7 +11,9 @@ def calculator(num1, operation, num2):
|
||||
elif operation == "divide":
|
||||
return num1 / num2
|
||||
|
||||
iface = gr.Interface(calculator,
|
||||
|
||||
iface = gr.Interface(
|
||||
calculator,
|
||||
["number", gr.inputs.Radio(["add", "subtract", "multiply", "divide"]), "number"],
|
||||
"number",
|
||||
examples=[
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def calculator(num1, operation, num2):
|
||||
if operation == "add":
|
||||
return num1 + num2
|
||||
@ -10,10 +11,12 @@ def calculator(num1, operation, num2):
|
||||
elif operation == "divide":
|
||||
return num1 / num2
|
||||
|
||||
iface = gr.Interface(calculator,
|
||||
|
||||
iface = gr.Interface(
|
||||
calculator,
|
||||
["number", gr.inputs.Radio(["add", "subtract", "multiply", "divide"]), "number"],
|
||||
"number",
|
||||
live=True
|
||||
live=True,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,12 @@
|
||||
import gradio as gr
|
||||
import random
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def chat(message, history):
|
||||
history = history or []
|
||||
if message.startswith("How many"):
|
||||
response = random.randint(1,10)
|
||||
response = random.randint(1, 10)
|
||||
elif message.startswith("How"):
|
||||
response = random.choice(["Great", "Good", "Okay", "Bad"])
|
||||
elif message.startswith("Where"):
|
||||
@ -19,11 +21,19 @@ def chat(message, history):
|
||||
html += "</div>"
|
||||
return html, history
|
||||
|
||||
iface = gr.Interface(chat, ["text", "state"], ["html", "state"], css="""
|
||||
|
||||
iface = gr.Interface(
|
||||
chat,
|
||||
["text", "state"],
|
||||
["html", "state"],
|
||||
css="""
|
||||
.chatbox {display:flex;flex-direction:column}
|
||||
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
||||
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
||||
.resp_msg {background-color:lightgray;align-self:self-end}
|
||||
""", allow_screenshot=False, allow_flagging="never")
|
||||
""",
|
||||
allow_screenshot=False,
|
||||
allow_flagging="never",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch()
|
||||
|
@ -1,20 +1,25 @@
|
||||
import gradio as gr
|
||||
from difflib import Differ
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def diff_texts(text1, text2):
|
||||
d = Differ()
|
||||
return [
|
||||
(token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2)
|
||||
(token[2:], token[0] if token[0] != " " else None)
|
||||
for token in d.compare(text1, text2)
|
||||
]
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
diff_texts,
|
||||
[
|
||||
gr.inputs.Textbox(
|
||||
lines=3, default="The quick brown fox jumped over the lazy dogs."),
|
||||
gr.inputs.Textbox(
|
||||
lines=3, default="The fast brown fox jumps over lazy dogs."),
|
||||
lines=3, default="The quick brown fox jumped over the lazy dogs."
|
||||
),
|
||||
gr.inputs.Textbox(lines=3, default="The fast brown fox jumps over lazy dogs."),
|
||||
],
|
||||
gr.outputs.HighlightedText())
|
||||
gr.outputs.HighlightedText(),
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch()
|
||||
|
@ -1,10 +1,14 @@
|
||||
import os
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import gradio
|
||||
import gradio as gr
|
||||
from urllib.request import urlretrieve
|
||||
import os
|
||||
|
||||
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
|
||||
urlretrieve(
|
||||
"https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5"
|
||||
)
|
||||
model = tf.keras.models.load_model("mnist-model.h5")
|
||||
|
||||
|
||||
@ -13,11 +17,14 @@ def recognize_digit(image):
|
||||
prediction = model.predict(image).tolist()[0]
|
||||
return {str(i): prediction[i] for i in range(10)}
|
||||
|
||||
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False, source="canvas")
|
||||
|
||||
im = gradio.inputs.Image(
|
||||
shape=(28, 28), image_mode="L", invert_colors=False, source="canvas"
|
||||
)
|
||||
|
||||
iface = gr.Interface(
|
||||
recognize_digit,
|
||||
im,
|
||||
recognize_digit,
|
||||
im,
|
||||
gradio.outputs.Label(num_top_classes=3),
|
||||
live=True,
|
||||
interpretation="default",
|
||||
@ -28,4 +35,3 @@ iface.test_launch()
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from fpdf import FPDF
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from fpdf import FPDF
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def disease_report(img, scan_for, generate_report):
|
||||
results = []
|
||||
for i, mode in enumerate(["Red", "Green", "Blue"]):
|
||||
@ -16,28 +19,30 @@ def disease_report(img, scan_for, generate_report):
|
||||
pdf = FPDF()
|
||||
pdf.add_page()
|
||||
pdf.set_font("Arial", size=15)
|
||||
pdf.cell(200, 10, txt="Disease Report",
|
||||
ln=1, align='C')
|
||||
pdf.cell(200, 10, txt="A Gradio Demo.",
|
||||
ln=2, align='C')
|
||||
pdf.cell(200, 10, txt="Disease Report", ln=1, align="C")
|
||||
pdf.cell(200, 10, txt="A Gradio Demo.", ln=2, align="C")
|
||||
pdf.output(report)
|
||||
return results, report if generate_report else None
|
||||
|
||||
iface = gr.Interface(disease_report,
|
||||
|
||||
iface = gr.Interface(
|
||||
disease_report,
|
||||
[
|
||||
"image",
|
||||
gr.inputs.CheckboxGroup(["Cancer", "Rash", "Heart Failure", "Stroke", "Diabetes", "Pneumonia"]),
|
||||
"checkbox"
|
||||
"image",
|
||||
gr.inputs.CheckboxGroup(
|
||||
["Cancer", "Rash", "Heart Failure", "Stroke", "Diabetes", "Pneumonia"]
|
||||
),
|
||||
"checkbox",
|
||||
],
|
||||
[
|
||||
gr.outputs.Carousel(["text", "image"], label="Disease"),
|
||||
gr.outputs.File(label="Report")
|
||||
gr.outputs.File(label="Report"),
|
||||
],
|
||||
title="Disease Report",
|
||||
description="Upload an Xray and select the diseases to scan for.",
|
||||
theme="grass",
|
||||
flagging_options=["good", "bad", "etc"],
|
||||
allow_flagging="auto"
|
||||
allow_flagging="auto",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,19 +1,25 @@
|
||||
import gradio as gr
|
||||
|
||||
def filter_records(records, gender):
|
||||
return records[records['gender'] == gender]
|
||||
|
||||
iface = gr.Interface(filter_records,
|
||||
[
|
||||
gr.inputs.Dataframe(headers=["name", "age", "gender"], datatype=["str", "number", "str"], row_count=5),
|
||||
gr.inputs.Dropdown(["M", "F", "O"])
|
||||
],
|
||||
"dataframe",
|
||||
description="Enter gender as 'M', 'F', or 'O' for other."
|
||||
def filter_records(records, gender):
|
||||
return records[records["gender"] == gender]
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
filter_records,
|
||||
[
|
||||
gr.inputs.Dataframe(
|
||||
headers=["name", "age", "gender"],
|
||||
datatype=["str", "number", "str"],
|
||||
row_count=5,
|
||||
),
|
||||
gr.inputs.Dropdown(["M", "F", "O"]),
|
||||
],
|
||||
"dataframe",
|
||||
description="Enter gender as 'M', 'F', or 'O' for other.",
|
||||
)
|
||||
|
||||
iface.test_launch()
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
import random
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def plot_forecast(final_year, companies, noise, show_legend, point_style):
|
||||
start_year = 2020
|
||||
@ -22,18 +24,19 @@ def plot_forecast(final_year, companies, noise, show_legend, point_style):
|
||||
return fig
|
||||
|
||||
|
||||
iface = gr.Interface(plot_forecast,
|
||||
[
|
||||
gr.inputs.Radio([2025, 2030, 2035, 2040],
|
||||
label="Project to:"),
|
||||
gr.inputs.CheckboxGroup(
|
||||
["Google", "Microsoft", "Gradio"], label="Company Selection"),
|
||||
gr.inputs.Slider(1, 100, label="Noise Level"),
|
||||
gr.inputs.Checkbox(label="Show Legend"),
|
||||
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
|
||||
],
|
||||
gr.outputs.Image(plot=True, label="forecast")
|
||||
)
|
||||
iface = gr.Interface(
|
||||
plot_forecast,
|
||||
[
|
||||
gr.inputs.Radio([2025, 2030, 2035, 2040], label="Project to:"),
|
||||
gr.inputs.CheckboxGroup(
|
||||
["Google", "Microsoft", "Gradio"], label="Company Selection"
|
||||
),
|
||||
gr.inputs.Slider(1, 100, label="Noise Level"),
|
||||
gr.inputs.Checkbox(label="Show Legend"),
|
||||
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
|
||||
],
|
||||
gr.outputs.Image(plot=True, label="forecast"),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,34 +1,38 @@
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def fraud_detector(card_activity, categories, sensitivity):
|
||||
activity_range = random.randint(0, 100)
|
||||
drop_columns = [column for column in ["retail", "food", "other"] if column not in categories]
|
||||
drop_columns = [
|
||||
column for column in ["retail", "food", "other"] if column not in categories
|
||||
]
|
||||
if len(drop_columns):
|
||||
card_activity.drop(columns=drop_columns, inplace=True)
|
||||
return card_activity, card_activity, {"fraud": activity_range / 100., "not fraud": 1 - activity_range / 100.}
|
||||
return (
|
||||
card_activity,
|
||||
card_activity,
|
||||
{"fraud": activity_range / 100.0, "not fraud": 1 - activity_range / 100.0},
|
||||
)
|
||||
|
||||
|
||||
iface = gr.Interface(fraud_detector,
|
||||
[
|
||||
gr.inputs.Timeseries(
|
||||
x="time",
|
||||
y=["retail", "food", "other"]
|
||||
),
|
||||
gr.inputs.CheckboxGroup(["retail", "food", "other"], default=[
|
||||
"retail", "food", "other"]),
|
||||
gr.inputs.Slider(1, 3)
|
||||
],
|
||||
[
|
||||
"dataframe",
|
||||
gr.outputs.Timeseries(
|
||||
x="time",
|
||||
y=["retail", "food", "other"]
|
||||
),
|
||||
gr.outputs.Label(label="Fraud Level"),
|
||||
]
|
||||
)
|
||||
iface = gr.Interface(
|
||||
fraud_detector,
|
||||
[
|
||||
gr.inputs.Timeseries(x="time", y=["retail", "food", "other"]),
|
||||
gr.inputs.CheckboxGroup(
|
||||
["retail", "food", "other"], default=["retail", "food", "other"]
|
||||
),
|
||||
gr.inputs.Slider(1, 3),
|
||||
],
|
||||
[
|
||||
"dataframe",
|
||||
gr.outputs.Timeseries(x="time", y=["retail", "food", "other"]),
|
||||
gr.outputs.Label(label="Fraud Level"),
|
||||
],
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,29 +1,42 @@
|
||||
import gradio as gr
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
|
||||
male_words, female_words = ["he", "his", "him"], ["she", "hers", "her"]
|
||||
|
||||
|
||||
def gender_of_sentence(sentence):
|
||||
male_count = len([word for word in sentence.split() if word.lower() in male_words])
|
||||
female_count = len([word for word in sentence.split() if word.lower() in female_words])
|
||||
total = max(male_count + female_count, 1)
|
||||
return {"male": male_count / total, "female": female_count / total}
|
||||
male_count = len([word for word in sentence.split() if word.lower() in male_words])
|
||||
female_count = len(
|
||||
[word for word in sentence.split() if word.lower() in female_words]
|
||||
)
|
||||
total = max(male_count + female_count, 1)
|
||||
return {"male": male_count / total, "female": female_count / total}
|
||||
|
||||
|
||||
def interpret_gender(sentence):
|
||||
result = gender_of_sentence(sentence)
|
||||
is_male = result["male"] > result["female"]
|
||||
interpretation = []
|
||||
for word in re.split('( )', sentence):
|
||||
score = 0
|
||||
token = word.lower()
|
||||
if (is_male and token in male_words) or (not is_male and token in female_words):
|
||||
score = 1
|
||||
elif (is_male and token in female_words) or (not is_male and token in male_words):
|
||||
score = -1
|
||||
interpretation.append((word, score))
|
||||
return interpretation
|
||||
result = gender_of_sentence(sentence)
|
||||
is_male = result["male"] > result["female"]
|
||||
interpretation = []
|
||||
for word in re.split("( )", sentence):
|
||||
score = 0
|
||||
token = word.lower()
|
||||
if (is_male and token in male_words) or (not is_male and token in female_words):
|
||||
score = 1
|
||||
elif (is_male and token in female_words) or (
|
||||
not is_male and token in male_words
|
||||
):
|
||||
score = -1
|
||||
interpretation.append((word, score))
|
||||
return interpretation
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=gender_of_sentence, inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
|
||||
outputs="label", interpretation=interpret_gender, enable_queue=True)
|
||||
fn=gender_of_sentence,
|
||||
inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
|
||||
outputs="label",
|
||||
interpretation=interpret_gender,
|
||||
enable_queue=True,
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,15 +1,24 @@
|
||||
import gradio as gr
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
|
||||
male_words, female_words = ["he", "his", "him"], ["she", "hers", "her"]
|
||||
|
||||
|
||||
def gender_of_sentence(sentence):
|
||||
male_count = len([word for word in sentence.split() if word.lower() in male_words])
|
||||
female_count = len([word for word in sentence.split() if word.lower() in female_words])
|
||||
total = max(male_count + female_count, 1)
|
||||
return {"male": male_count / total, "female": female_count / total}
|
||||
male_count = len([word for word in sentence.split() if word.lower() in male_words])
|
||||
female_count = len(
|
||||
[word for word in sentence.split() if word.lower() in female_words]
|
||||
)
|
||||
total = max(male_count + female_count, 1)
|
||||
return {"male": male_count / total, "female": female_count / total}
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=gender_of_sentence, inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
|
||||
outputs="label", interpretation="default")
|
||||
fn=gender_of_sentence,
|
||||
inputs=gr.inputs.Textbox(default="She went to his house to get her keys."),
|
||||
outputs="label",
|
||||
interpretation="default",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
||||
|
||||
|
||||
def generate_tone(note, octave, duration):
|
||||
sr = 48000
|
||||
a4_freq, tones_from_a4 = 440, 12 * (octave - 4) + (note - 9)
|
||||
@ -14,12 +16,14 @@ def generate_tone(note, octave, duration):
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
generate_tone,
|
||||
generate_tone,
|
||||
[
|
||||
gr.inputs.Dropdown(notes, type="index"),
|
||||
gr.inputs.Slider(4, 6, step=1),
|
||||
gr.inputs.Textbox(type="number", default=1, label="Duration in seconds")
|
||||
], "audio")
|
||||
gr.inputs.Textbox(type="number", default=1, label="Duration in seconds"),
|
||||
],
|
||||
"audio",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -3,11 +3,14 @@ import gradio as gr
|
||||
title = "GPT-J-6B"
|
||||
|
||||
examples = [
|
||||
['The tower is 324 metres (1,063 ft) tall,'],
|
||||
["The tower is 324 metres (1,063 ft) tall,"],
|
||||
["The Moon's orbit around Earth has"],
|
||||
["The smooth Borealis basin in the Northern Hemisphere covers 40%"]
|
||||
["The smooth Borealis basin in the Northern Hemisphere covers 40%"],
|
||||
]
|
||||
|
||||
gr.Interface.load("huggingface/EleutherAI/gpt-j-6B",
|
||||
gr.Interface.load(
|
||||
"huggingface/EleutherAI/gpt-j-6B",
|
||||
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
|
||||
title=title, examples=examples).launch();
|
||||
title=title,
|
||||
examples=examples,
|
||||
).launch()
|
||||
|
@ -1,7 +1,9 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!!"
|
||||
return "Hello " + name + "!!"
|
||||
|
||||
|
||||
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,11 +1,14 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
return "Hello " + name + "!"
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=greet,
|
||||
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),
|
||||
outputs="text")
|
||||
fn=greet,
|
||||
inputs=gr.inputs.Textbox(lines=2, placeholder="Name Here..."),
|
||||
outputs="text",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch()
|
||||
|
@ -1,15 +1,17 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def greet(name, is_morning, temperature):
|
||||
salutation = "Good morning" if is_morning else "Good evening"
|
||||
greeting = "%s %s. It is %s degrees today" % (
|
||||
salutation, name, temperature)
|
||||
celsius = (temperature - 32) * 5 / 9
|
||||
return greeting, round(celsius, 2)
|
||||
salutation = "Good morning" if is_morning else "Good evening"
|
||||
greeting = "%s %s. It is %s degrees today" % (salutation, name, temperature)
|
||||
celsius = (temperature - 32) * 5 / 9
|
||||
return greeting, round(celsius, 2)
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=greet,
|
||||
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],
|
||||
outputs=["text", "number"])
|
||||
fn=greet,
|
||||
inputs=["text", "checkbox", gr.inputs.Slider(0, 100)],
|
||||
outputs=["text", "number"],
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch()
|
||||
|
@ -1,21 +1,28 @@
|
||||
import gradio as gr
|
||||
import tensorflow as tf
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
|
||||
inception_net = tf.keras.applications.MobileNetV2() # load the model
|
||||
import gradio as gr
|
||||
|
||||
inception_net = tf.keras.applications.MobileNetV2() # load the model
|
||||
|
||||
# Download human-readable labels for ImageNet.
|
||||
response = requests.get("https://git.io/JJkYN")
|
||||
labels = response.text.split("\n")
|
||||
|
||||
|
||||
def classify_image(inp):
|
||||
inp = inp.reshape((-1, 224, 224, 3))
|
||||
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
|
||||
prediction = inception_net.predict(inp).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
inp = inp.reshape((-1, 224, 224, 3))
|
||||
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
|
||||
prediction = inception_net.predict(inp).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
|
||||
image = gr.inputs.Image(shape=(224, 224))
|
||||
label = gr.outputs.Label(num_top_classes=3)
|
||||
|
||||
gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=[
|
||||
["images/cheetah1.jpg"], ["images/lion.jpg"]]).launch()
|
||||
gr.Interface(
|
||||
fn=classify_image,
|
||||
inputs=image,
|
||||
outputs=label,
|
||||
examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
|
||||
).launch()
|
||||
|
@ -1,21 +1,24 @@
|
||||
import torch
|
||||
import requests
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
|
||||
import gradio as gr
|
||||
|
||||
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
|
||||
|
||||
# Download human-readable labels for ImageNet.
|
||||
response = requests.get("https://git.io/JJkYN")
|
||||
labels = response.text.split("\n")
|
||||
|
||||
|
||||
def predict(inp):
|
||||
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
||||
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
inp = Image.fromarray(inp.astype("uint8"), "RGB")
|
||||
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
|
||||
inputs = gr.inputs.Image()
|
||||
outputs = gr.outputs.Label(num_top_classes=3)
|
||||
|
@ -1,20 +1,25 @@
|
||||
import gradio as gr
|
||||
import tensorflow as tf
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
|
||||
inception_net = tf.keras.applications.MobileNetV2() # load the model
|
||||
import gradio as gr
|
||||
|
||||
inception_net = tf.keras.applications.MobileNetV2() # load the model
|
||||
|
||||
# Download human-readable labels for ImageNet.
|
||||
response = requests.get("https://git.io/JJkYN")
|
||||
labels = response.text.split("\n")
|
||||
|
||||
|
||||
def classify_image(inp):
|
||||
inp = inp.reshape((-1, 224, 224, 3))
|
||||
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
|
||||
prediction = inception_net.predict(inp).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
inp = inp.reshape((-1, 224, 224, 3))
|
||||
inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp)
|
||||
prediction = inception_net.predict(inp).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
|
||||
image = gr.inputs.Image(shape=(224, 224))
|
||||
label = gr.outputs.Label(num_top_classes=3)
|
||||
|
||||
gr.Interface(fn=classify_image, inputs=image, outputs=label, interpretation="default").launch()
|
||||
gr.Interface(
|
||||
fn=classify_image, inputs=image, outputs=label, interpretation="default"
|
||||
).launch()
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def image_mod(image):
|
||||
return image.rotate(45)
|
||||
|
||||
|
||||
iface = gr.Interface(image_mod, gr.inputs.Image(type="pil"), "image")
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,34 +1,84 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
CHOICES = ["foo", "bar", "baz"]
|
||||
JSONOBJ = """{"items":{"item":[{"id": "0001","type": null,"is_good": false,"ppu": 0.55,"batters":{"batter":[{ "id": "1001", "type": "Regular" },{ "id": "1002", "type": "Chocolate" },{ "id": "1003", "type": "Blueberry" },{ "id": "1004", "type": "Devil's Food" }]},"topping":[{ "id": "5001", "type": "None" },{ "id": "5002", "type": "Glazed" },{ "id": "5005", "type": "Sugar" },{ "id": "5007", "type": "Powdered Sugar" },{ "id": "5006", "type": "Chocolate with Sprinkles" },{ "id": "5003", "type": "Chocolate" },{ "id": "5004", "type": "Maple" }]}]}}"""
|
||||
|
||||
def fn(text1, text2, num, slider1, slider2, single_checkbox,
|
||||
checkboxes, radio, dropdown, im1, im2, im3, im4, video, audio1,
|
||||
audio2, file, df1, df2):
|
||||
|
||||
def fn(
|
||||
text1,
|
||||
text2,
|
||||
num,
|
||||
slider1,
|
||||
slider2,
|
||||
single_checkbox,
|
||||
checkboxes,
|
||||
radio,
|
||||
dropdown,
|
||||
im1,
|
||||
im2,
|
||||
im3,
|
||||
im4,
|
||||
video,
|
||||
audio1,
|
||||
audio2,
|
||||
file,
|
||||
df1,
|
||||
df2,
|
||||
):
|
||||
return (
|
||||
(text1 if single_checkbox else text2) +
|
||||
", selected:" + ", ".join(checkboxes), # Text
|
||||
(text1 if single_checkbox else text2)
|
||||
+ ", selected:"
|
||||
+ ", ".join(checkboxes), # Text
|
||||
{
|
||||
"positive": num / (num + slider1 + slider2),
|
||||
"negative": slider1 / (num + slider1 + slider2),
|
||||
"neutral": slider2 / (num + slider1 + slider2),
|
||||
}, # Label
|
||||
(audio1[0], np.flipud(audio1[1])) if audio1 is not None else "files/cantina.wav", # Audio
|
||||
(audio1[0], np.flipud(audio1[1]))
|
||||
if audio1 is not None
|
||||
else "files/cantina.wav", # Audio
|
||||
np.flipud(im1) if im1 is not None else "files/cheetah1.jpg", # Image
|
||||
video if video is not None else "files/world.mp4", # Video
|
||||
[("The", "art"), ("quick brown", "adj"), ("fox", "nn"), ("jumped", "vrb"), ("testing testing testing", None), ("over", "prp"), ("the", "art"), ("testing", None), ("lazy", "adj"), ("dogs", "nn"), (".", "punc")] + [(f"test {x}", f"test {x}") for x in range(10)], # HighlightedText
|
||||
[
|
||||
("The", "art"),
|
||||
("quick brown", "adj"),
|
||||
("fox", "nn"),
|
||||
("jumped", "vrb"),
|
||||
("testing testing testing", None),
|
||||
("over", "prp"),
|
||||
("the", "art"),
|
||||
("testing", None),
|
||||
("lazy", "adj"),
|
||||
("dogs", "nn"),
|
||||
(".", "punc"),
|
||||
]
|
||||
+ [(f"test {x}", f"test {x}") for x in range(10)], # HighlightedText
|
||||
# [("The testing testing testing", None), ("quick brown", 0.2), ("fox", 1), ("jumped", -1), ("testing testing testing", 0), ("over", 0), ("the", 0), ("testing", 0), ("lazy", 1), ("dogs", 0), (".", 1)] + [(f"test {x}", x/10) for x in range(-10, 10)], # HighlightedText
|
||||
[("The testing testing testing", None), ("over", 0.6), ("the", 0.2), ("testing", None), ("lazy", -.1), ("dogs", 0.4), (".", 0)] + [(f"test", x/10) for x in range(-10, 10)], # HighlightedText
|
||||
[
|
||||
("The testing testing testing", None),
|
||||
("over", 0.6),
|
||||
("the", 0.2),
|
||||
("testing", None),
|
||||
("lazy", -0.1),
|
||||
("dogs", 0.4),
|
||||
(".", 0),
|
||||
]
|
||||
+ [(f"test", x / 10) for x in range(-10, 10)], # HighlightedText
|
||||
json.loads(JSONOBJ), # JSON
|
||||
"<button style='background-color: red'>Click Me: " + radio + "</button>", # HTML
|
||||
"<button style='background-color: red'>Click Me: "
|
||||
+ radio
|
||||
+ "</button>", # HTML
|
||||
"files/titanic.csv",
|
||||
df1, # Dataframe
|
||||
np.random.randint(0, 10, (4,4)), # Dataframe
|
||||
[im for im in [im1, im2, im3, im4, "files/cheetah1.jpg"] if im is not None], # Carousel
|
||||
df2 # Timeseries
|
||||
np.random.randint(0, 10, (4, 4)), # Dataframe
|
||||
[
|
||||
im for im in [im1, im2, im3, im4, "files/cheetah1.jpg"] if im is not None
|
||||
], # Carousel
|
||||
df2, # Timeseries
|
||||
)
|
||||
|
||||
|
||||
@ -41,7 +91,9 @@ iface = gr.Interface(
|
||||
gr.inputs.Slider(minimum=10, maximum=20, default=15, label="Slider: 10 - 20"),
|
||||
gr.inputs.Slider(maximum=20, step=0.04, label="Slider: step @ 0.04"),
|
||||
gr.inputs.Checkbox(label="Checkbox"),
|
||||
gr.inputs.CheckboxGroup(label="CheckboxGroup", choices=CHOICES, default=CHOICES[0:2]),
|
||||
gr.inputs.CheckboxGroup(
|
||||
label="CheckboxGroup", choices=CHOICES, default=CHOICES[0:2]
|
||||
),
|
||||
gr.inputs.Radio(label="Radio", choices=CHOICES, default=CHOICES[2]),
|
||||
gr.inputs.Dropdown(label="Dropdown", choices=CHOICES),
|
||||
gr.inputs.Image(label="Image", optional=True),
|
||||
@ -69,15 +121,36 @@ iface = gr.Interface(
|
||||
gr.outputs.Dataframe(label="Dataframe"),
|
||||
gr.outputs.Dataframe(label="Numpy", type="numpy"),
|
||||
gr.outputs.Carousel("image", label="Carousel"),
|
||||
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries")
|
||||
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries"),
|
||||
],
|
||||
examples=[
|
||||
["the quick brown fox", "jumps over the lazy dog", 10, 12, 4, True, ["foo", "baz"], "baz", "bar", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/world.mp4", "files/cantina.wav", "files/cantina.wav","files/titanic.csv", [[1,2,3],[3,4,5]], "files/time.csv"]
|
||||
] * 3,
|
||||
[
|
||||
"the quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
10,
|
||||
12,
|
||||
4,
|
||||
True,
|
||||
["foo", "baz"],
|
||||
"baz",
|
||||
"bar",
|
||||
"files/cheetah1.jpg",
|
||||
"files/cheetah1.jpg",
|
||||
"files/cheetah1.jpg",
|
||||
"files/cheetah1.jpg",
|
||||
"files/world.mp4",
|
||||
"files/cantina.wav",
|
||||
"files/cantina.wav",
|
||||
"files/titanic.csv",
|
||||
[[1, 2, 3], [3, 4, 5]],
|
||||
"files/time.csv",
|
||||
]
|
||||
]
|
||||
* 3,
|
||||
theme="huggingface",
|
||||
title="Kitchen Sink",
|
||||
description="Try out all the components!",
|
||||
article="Learn more about [Gradio](http://gradio.app)"
|
||||
article="Learn more about [Gradio](http://gradio.app)",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,14 +1,17 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def longest_word(text):
|
||||
words = text.split(" ")
|
||||
lengths = [len(word) for word in words]
|
||||
return max(lengths)
|
||||
|
||||
|
||||
ex = "The quick brown fox jumped over the lazy dog."
|
||||
|
||||
iface = gr.Interface(longest_word, "textbox", "label",
|
||||
interpretation="default", examples=[[ex]])
|
||||
iface = gr.Interface(
|
||||
longest_word, "textbox", "label", interpretation="default", examples=[[ex]]
|
||||
)
|
||||
|
||||
iface.test_launch()
|
||||
|
||||
|
@ -1,28 +1,32 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from scipy.fftpack import fft
|
||||
import matplotlib.pyplot as plt
|
||||
from math import log2, pow
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy.fftpack import fft
|
||||
|
||||
import gradio as gr
|
||||
|
||||
A4 = 440
|
||||
C0 = A4*pow(2, -4.75)
|
||||
C0 = A4 * pow(2, -4.75)
|
||||
name = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
||||
|
||||
|
||||
|
||||
def get_pitch(freq):
|
||||
h = round(12*log2(freq/C0))
|
||||
h = round(12 * log2(freq / C0))
|
||||
n = h % 12
|
||||
return name[n]
|
||||
|
||||
|
||||
def main_note(audio):
|
||||
rate, y = audio
|
||||
if len(y.shape) == 2:
|
||||
y = y.T[0]
|
||||
N = len(y)
|
||||
T = 1.0 / rate
|
||||
x = np.linspace(0.0, N*T, N)
|
||||
x = np.linspace(0.0, N * T, N)
|
||||
yf = fft(y)
|
||||
yf2 = 2.0/N * np.abs(yf[0:N//2])
|
||||
xf = np.linspace(0.0, 1.0/(2.0*T), N//2)
|
||||
yf2 = 2.0 / N * np.abs(yf[0 : N // 2])
|
||||
xf = np.linspace(0.0, 1.0 / (2.0 * T), N // 2)
|
||||
|
||||
volume_per_pitch = {}
|
||||
total_volume = np.sum(yf2)
|
||||
@ -35,15 +39,17 @@ def main_note(audio):
|
||||
volume_per_pitch[pitch] += 1.0 * volume / total_volume
|
||||
return volume_per_pitch
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
main_note,
|
||||
"audio",
|
||||
main_note,
|
||||
"audio",
|
||||
gr.outputs.Label(num_top_classes=4),
|
||||
examples=[
|
||||
["audio/recording1.wav"],
|
||||
["audio/cantina.wav"],
|
||||
],
|
||||
interpretation="default")
|
||||
interpretation="default",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,6 +1,8 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def transpose(matrix):
|
||||
return matrix.T
|
||||
|
||||
@ -10,12 +12,12 @@ iface = gr.Interface(
|
||||
gr.inputs.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
|
||||
"numpy",
|
||||
examples=[
|
||||
[np.zeros((3,3)).tolist()],
|
||||
[np.ones((2,2)).tolist()],
|
||||
[np.random.randint(0, 10, (3,10)).tolist()],
|
||||
[np.random.randint(0, 10, (10,3)).tolist()],
|
||||
[np.random.randint(0, 10, (10,10)).tolist()],
|
||||
]
|
||||
[np.zeros((3, 3)).tolist()],
|
||||
[np.ones((2, 2)).tolist()],
|
||||
[np.random.randint(0, 10, (3, 10)).tolist()],
|
||||
[np.random.randint(0, 10, (10, 3)).tolist()],
|
||||
[np.random.randint(0, 10, (10, 10)).tolist()],
|
||||
],
|
||||
)
|
||||
|
||||
iface.test_launch()
|
||||
|
@ -1,14 +1,17 @@
|
||||
import gradio as gr
|
||||
from math import sqrt
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from math import sqrt
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def outbreak(r, month, countries, social_distancing):
|
||||
months = ["January", "February", "March", "April", "May"]
|
||||
m = months.index(month)
|
||||
start_day = 30 * m
|
||||
final_day = 30 * (m + 1)
|
||||
x = np.arange(start_day, final_day+1)
|
||||
x = np.arange(start_day, final_day + 1)
|
||||
day_count = x.shape[0]
|
||||
pop_count = {"USA": 350, "Canada": 40, "Mexico": 300, "UK": 120}
|
||||
r = sqrt(r)
|
||||
@ -23,13 +26,18 @@ def outbreak(r, month, countries, social_distancing):
|
||||
plt.legend(countries)
|
||||
return plt
|
||||
|
||||
iface = gr.Interface(outbreak,
|
||||
|
||||
iface = gr.Interface(
|
||||
outbreak,
|
||||
[
|
||||
gr.inputs.Slider(1, 4, default=3.2, label="R"),
|
||||
gr.inputs.Dropdown(["January", "February", "March", "April", "May"], label="Month"),
|
||||
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
|
||||
gr.inputs.Checkbox(label="Social Distancing?"),
|
||||
],
|
||||
"plot")
|
||||
gr.inputs.Slider(1, 4, default=3.2, label="R"),
|
||||
gr.inputs.Dropdown(
|
||||
["January", "February", "March", "April", "May"], label="Month"
|
||||
),
|
||||
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
|
||||
gr.inputs.Checkbox(label="Social Distancing?"),
|
||||
],
|
||||
"plot",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -9,18 +9,16 @@ import torch
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForQuestionAnswering, BertTokenizer)
|
||||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
|
||||
|
||||
from utils import (get_answer, input_to_squad_example,
|
||||
squad_examples_to_features, to_list)
|
||||
|
||||
RawResult = collections.namedtuple("RawResult",
|
||||
["unique_id", "start_logits", "end_logits"])
|
||||
|
||||
RawResult = collections.namedtuple(
|
||||
"RawResult", ["unique_id", "start_logits", "end_logits"]
|
||||
)
|
||||
|
||||
|
||||
class QA:
|
||||
|
||||
def __init__(self,model_path: str):
|
||||
def __init__(self, model_path: str):
|
||||
self.max_seq_length = 384
|
||||
self.doc_stride = 128
|
||||
self.do_lower_case = True
|
||||
@ -29,47 +27,71 @@ class QA:
|
||||
self.max_answer_length = 30
|
||||
self.model, self.tokenizer = self.load_model(model_path)
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
self.device = "cuda"
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
self.device = "cpu"
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
|
||||
def load_model(self,model_path: str,do_lower_case=False):
|
||||
def load_model(self, model_path: str, do_lower_case=False):
|
||||
config = BertConfig.from_pretrained(model_path + "/bert_config.json")
|
||||
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=do_lower_case)
|
||||
model = BertForQuestionAnswering.from_pretrained(model_path, from_tf=False, config=config)
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
model_path, do_lower_case=do_lower_case
|
||||
)
|
||||
model = BertForQuestionAnswering.from_pretrained(
|
||||
model_path, from_tf=False, config=config
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
def predict(self,passage :str,question :str):
|
||||
example = input_to_squad_example(passage,question)
|
||||
features = squad_examples_to_features(example,self.tokenizer,self.max_seq_length,self.doc_stride,self.max_query_length)
|
||||
|
||||
def predict(self, passage: str, question: str):
|
||||
example = input_to_squad_example(passage, question)
|
||||
features = squad_examples_to_features(
|
||||
example,
|
||||
self.tokenizer,
|
||||
self.max_seq_length,
|
||||
self.doc_stride,
|
||||
self.max_query_length,
|
||||
)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor(
|
||||
[f.input_mask for f in features], dtype=torch.long
|
||||
)
|
||||
all_segment_ids = torch.tensor(
|
||||
[f.segment_ids for f in features], dtype=torch.long
|
||||
)
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_example_index
|
||||
)
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
|
||||
all_results = []
|
||||
for batch in eval_dataloader:
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2]
|
||||
}
|
||||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": batch[2],
|
||||
}
|
||||
example_indices = batch[3]
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
result = RawResult(
|
||||
unique_id=unique_id,
|
||||
start_logits=to_list(outputs[0][i]),
|
||||
end_logits=to_list(outputs[1][i]),
|
||||
)
|
||||
all_results.append(result)
|
||||
answer = get_answer(example,features,all_results,self.n_best_size,self.max_answer_length,self.do_lower_case)
|
||||
answer = get_answer(
|
||||
example,
|
||||
features,
|
||||
all_results,
|
||||
self.n_best_size,
|
||||
self.max_answer_length,
|
||||
self.do_lower_case,
|
||||
)
|
||||
return answer
|
||||
|
@ -16,13 +16,15 @@ class SquadExample(object):
|
||||
For examples without an answer, the start and end position are -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None):
|
||||
def __init__(
|
||||
self,
|
||||
qas_id,
|
||||
question_text,
|
||||
doc_tokens,
|
||||
orig_answer_text=None,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.doc_tokens = doc_tokens
|
||||
@ -36,8 +38,7 @@ class SquadExample(object):
|
||||
def __repr__(self):
|
||||
s = ""
|
||||
s += "qas_id: %s" % (self.qas_id)
|
||||
s += ", question_text: %s" % (
|
||||
self.question_text)
|
||||
s += ", question_text: %s" % (self.question_text)
|
||||
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
||||
if self.start_position:
|
||||
s += ", start_position: %d" % (self.start_position)
|
||||
@ -45,22 +46,25 @@ class SquadExample(object):
|
||||
s += ", end_position: %d" % (self.end_position)
|
||||
return s
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,):
|
||||
def __init__(
|
||||
self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
paragraph_len,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
@ -74,6 +78,7 @@ class InputFeatures(object):
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
|
||||
|
||||
def input_to_squad_example(passage, question):
|
||||
"""Convert input passage and question into a SquadExample."""
|
||||
|
||||
@ -109,10 +114,12 @@ def input_to_squad_example(passage, question):
|
||||
doc_tokens=doc_tokens,
|
||||
orig_answer_text=orig_answer_text,
|
||||
start_position=start_position,
|
||||
end_position=end_position)
|
||||
|
||||
end_position=end_position,
|
||||
)
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
|
||||
@ -149,12 +156,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
def squad_examples_to_features(example, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length,cls_token_at_end=False,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
|
||||
def squad_examples_to_features(
|
||||
example,
|
||||
tokenizer,
|
||||
max_seq_length,
|
||||
doc_stride,
|
||||
max_query_length,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
sep_token="[SEP]",
|
||||
pad_token=0,
|
||||
sequence_a_segment_id=0,
|
||||
sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0,
|
||||
pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
@ -188,7 +206,8 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
|
||||
# To deal with this we do a sliding window approach, where we take chunks
|
||||
# of the up to our max length with a stride of `doc_stride`.
|
||||
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"DocSpan", ["start", "length"])
|
||||
"DocSpan", ["start", "length"]
|
||||
)
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
@ -225,8 +244,9 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
||||
split_token_index)
|
||||
is_max_context = _check_is_max_context(
|
||||
doc_spans, doc_span_index, split_token_index
|
||||
)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
@ -273,14 +293,18 @@ def squad_examples_to_features(example, tokenizer, max_seq_length,
|
||||
segment_ids=segment_ids,
|
||||
paragraph_len=paragraph_len,
|
||||
start_position=start_position,
|
||||
end_position=end_position))
|
||||
end_position=end_position,
|
||||
)
|
||||
)
|
||||
unique_id += 1
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
|
||||
def _get_best_indexes(logits, n_best_size):
|
||||
"""Get the n-best logits from a list."""
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
@ -292,7 +316,11 @@ def _get_best_indexes(logits, n_best_size):
|
||||
best_indexes.append(index_and_score[i][0])
|
||||
return best_indexes
|
||||
|
||||
RawResult = collections.namedtuple("RawResult",["unique_id", "start_logits", "end_logits"])
|
||||
|
||||
RawResult = collections.namedtuple(
|
||||
"RawResult", ["unique_id", "start_logits", "end_logits"]
|
||||
)
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
@ -376,9 +404,10 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
if orig_end_position is None:
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def _compute_softmax(scores):
|
||||
"""Compute softmax probability over raw logits."""
|
||||
if not scores:
|
||||
@ -401,17 +430,22 @@ def _compute_softmax(scores):
|
||||
probs.append(score / total_sum)
|
||||
return probs
|
||||
|
||||
def get_answer(example, features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case):
|
||||
|
||||
def get_answer(
|
||||
example, features, all_results, n_best_size, max_answer_length, do_lower_case
|
||||
):
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
_PrelimPrediction = collections.namedtuple( "PrelimPrediction",["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
|
||||
|
||||
_PrelimPrediction = collections.namedtuple(
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
|
||||
)
|
||||
|
||||
example_index = 0
|
||||
features = example_index_to_features[example_index]
|
||||
@ -448,10 +482,16 @@ def get_answer(example, features, all_results, n_best_size,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_logit=result.start_logits[start_index],
|
||||
end_logit=result.end_logits[end_index]))
|
||||
prelim_predictions = sorted(prelim_predictions,key=lambda x: (x.start_logit + x.end_logit),reverse=True)
|
||||
_NbestPrediction = collections.namedtuple("NbestPrediction",
|
||||
["text", "start_logit", "end_logit","start_index","end_index"])
|
||||
end_logit=result.end_logits[end_index],
|
||||
)
|
||||
)
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True
|
||||
)
|
||||
_NbestPrediction = collections.namedtuple(
|
||||
"NbestPrediction",
|
||||
["text", "start_logit", "end_logit", "start_index", "end_index"],
|
||||
)
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
@ -461,10 +501,10 @@ def get_answer(example, features, all_results, n_best_size,
|
||||
orig_doc_start = -1
|
||||
orig_doc_end = -1
|
||||
if pred.start_index > 0: # this is a non-null prediction
|
||||
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
||||
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
@ -476,7 +516,7 @@ def get_answer(example, features, all_results, n_best_size,
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text,do_lower_case)
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case)
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
@ -491,11 +531,20 @@ def get_answer(example, features, all_results, n_best_size,
|
||||
start_logit=pred.start_logit,
|
||||
end_logit=pred.end_logit,
|
||||
start_index=orig_doc_start,
|
||||
end_index=orig_doc_end))
|
||||
end_index=orig_doc_end,
|
||||
)
|
||||
)
|
||||
|
||||
if not nbest:
|
||||
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0,start_index=-1,
|
||||
end_index=-1))
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text="empty",
|
||||
start_logit=0.0,
|
||||
end_logit=0.0,
|
||||
start_index=-1,
|
||||
end_index=-1,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(nbest) >= 1
|
||||
|
||||
@ -504,11 +553,12 @@ def get_answer(example, features, all_results, n_best_size,
|
||||
total_scores.append(entry.start_logit + entry.end_logit)
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
answer = {"answer" : nbest[0].text,
|
||||
"start" : nbest[0].start_index,
|
||||
"end" : nbest[0].end_index,
|
||||
"confidence" : probs[0],
|
||||
"document" : example.doc_tokens
|
||||
}
|
||||
return answer
|
||||
|
||||
answer = {
|
||||
"answer": nbest[0].text,
|
||||
"start": nbest[0].start_index,
|
||||
"end": nbest[0].end_index,
|
||||
"confidence": probs[0],
|
||||
"document": example.doc_tokens,
|
||||
}
|
||||
return answer
|
||||
|
@ -1,13 +1,24 @@
|
||||
import gradio as gr
|
||||
|
||||
examples = [
|
||||
["The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America",
|
||||
"Which continent is the Amazon rainforest in?"]
|
||||
[
|
||||
"The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America",
|
||||
"Which continent is the Amazon rainforest in?",
|
||||
]
|
||||
]
|
||||
|
||||
gr.Interface.load("huggingface/deepset/roberta-base-squad2",
|
||||
inputs=[gr.inputs.Textbox(lines=5, label="Context", placeholder="Type a sentence or paragraph here."),
|
||||
gr.inputs.Textbox(lines=2, label="Question", placeholder="Ask a question based on the context.")],
|
||||
outputs=[gr.outputs.Textbox(label="Answer"),
|
||||
gr.outputs.Label(label="Probability")],
|
||||
examples=examples).launch()
|
||||
gr.Interface.load(
|
||||
"huggingface/deepset/roberta-base-squad2",
|
||||
inputs=[
|
||||
gr.inputs.Textbox(
|
||||
lines=5, label="Context", placeholder="Type a sentence or paragraph here."
|
||||
),
|
||||
gr.inputs.Textbox(
|
||||
lines=2,
|
||||
label="Question",
|
||||
placeholder="Ask a question based on the context.",
|
||||
),
|
||||
],
|
||||
outputs=[gr.outputs.Textbox(label="Answer"), gr.outputs.Label(label="Probability")],
|
||||
examples=examples,
|
||||
).launch()
|
||||
|
@ -1,10 +1,13 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def reverse_audio(audio):
|
||||
sr, data = audio
|
||||
return (sr, np.flipud(data))
|
||||
|
||||
|
||||
iface = gr.Interface(reverse_audio, "microphone", "audio", examples="audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,32 +1,37 @@
|
||||
import gradio as gr
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def sales_projections(employee_data):
|
||||
sales_data = employee_data.iloc[:, 1:4].astype("int").to_numpy()
|
||||
regression_values = np.apply_along_axis(lambda row:
|
||||
np.array(np.poly1d(np.polyfit([0,1,2], row, 2))), 0, sales_data)
|
||||
projected_months = np.repeat(np.expand_dims(
|
||||
np.arange(3,12), 0), len(sales_data), axis=0)
|
||||
projected_values = np.array([
|
||||
month * month * regression[0] + month * regression[1] + regression[2]
|
||||
for month, regression in zip(projected_months, regression_values)])
|
||||
regression_values = np.apply_along_axis(
|
||||
lambda row: np.array(np.poly1d(np.polyfit([0, 1, 2], row, 2))), 0, sales_data
|
||||
)
|
||||
projected_months = np.repeat(
|
||||
np.expand_dims(np.arange(3, 12), 0), len(sales_data), axis=0
|
||||
)
|
||||
projected_values = np.array(
|
||||
[
|
||||
month * month * regression[0] + month * regression[1] + regression[2]
|
||||
for month, regression in zip(projected_months, regression_values)
|
||||
]
|
||||
)
|
||||
plt.plot(projected_values.T)
|
||||
plt.legend(employee_data["Name"])
|
||||
return employee_data, plt.gcf(), regression_values
|
||||
|
||||
iface = gr.Interface(sales_projections,
|
||||
|
||||
iface = gr.Interface(
|
||||
sales_projections,
|
||||
gr.inputs.Dataframe(
|
||||
headers=["Name", "Jan Sales", "Feb Sales", "Mar Sales"],
|
||||
default=[["Jon", 12, 14, 18], ["Alice", 14, 17, 2], ["Sana", 8, 9.5, 12]]
|
||||
default=[["Jon", 12, 14, 18], ["Alice", 14, 17, 2], ["Sana", 8, 9.5, 12]],
|
||||
),
|
||||
[
|
||||
"dataframe",
|
||||
"plot",
|
||||
"numpy"
|
||||
],
|
||||
description="Enter sales figures for employees to predict sales trajectory over year."
|
||||
["dataframe", "plot", "numpy"],
|
||||
description="Enter sales figures for employees to predict sales trajectory over year.",
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def sentence_builder(quantity, animal, place, activity_list, morning):
|
||||
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
|
||||
|
||||
|
@ -1,14 +1,18 @@
|
||||
import gradio as gr
|
||||
import nltk
|
||||
from nltk.sentiment.vader import SentimentIntensityAnalyzer
|
||||
nltk.download('vader_lexicon')
|
||||
|
||||
import gradio as gr
|
||||
|
||||
nltk.download("vader_lexicon")
|
||||
sid = SentimentIntensityAnalyzer()
|
||||
|
||||
|
||||
def sentiment_analysis(text):
|
||||
scores = sid.polarity_scores(text)
|
||||
del scores["compound"]
|
||||
return scores
|
||||
|
||||
|
||||
iface = gr.Interface(sentiment_analysis, "textbox", "label", interpretation="default")
|
||||
|
||||
iface.test_launch()
|
||||
|
@ -1,15 +1,18 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def sepia(input_img):
|
||||
sepia_filter = np.array([[.393, .769, .189],
|
||||
[.349, .686, .168],
|
||||
[.272, .534, .131]])
|
||||
sepia_img = input_img.dot(sepia_filter.T)
|
||||
sepia_img /= sepia_img.max()
|
||||
return sepia_img
|
||||
sepia_filter = np.array(
|
||||
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
|
||||
)
|
||||
sepia_img = input_img.dot(sepia_filter.T)
|
||||
sepia_img /= sepia_img.max()
|
||||
return sepia_img
|
||||
|
||||
|
||||
iface = gr.Interface(sepia, gr.inputs.Image(shape=(200, 200)), "image")
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch()
|
||||
|
@ -1,14 +1,17 @@
|
||||
import gradio as gr
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def spectrogram(audio):
|
||||
sr, data = audio
|
||||
if len(data.shape) == 2:
|
||||
data = np.mean(data, axis=0)
|
||||
frequencies, times, spectrogram_data = signal.spectrogram(data, sr, window="hamming")
|
||||
frequencies, times, spectrogram_data = signal.spectrogram(
|
||||
data, sr, window="hamming"
|
||||
)
|
||||
plt.pcolormesh(times, frequencies, np.log10(spectrogram_data))
|
||||
return plt
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import gradio as gr
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def stock_forecast(final_year, companies, noise, show_legend, point_style):
|
||||
start_year = 2020
|
||||
@ -28,8 +29,10 @@ iface = gr.Interface(
|
||||
gr.inputs.CheckboxGroup(["Google", "Microsoft", "Gradio"]),
|
||||
gr.inputs.Slider(1, 100),
|
||||
"checkbox",
|
||||
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style")],
|
||||
gr.outputs.Image(plot=True, label="forecast"))
|
||||
gr.inputs.Dropdown(["cross", "line", "circle"], label="Style"),
|
||||
],
|
||||
gr.outputs.Image(plot=True, label="forecast"),
|
||||
)
|
||||
|
||||
iface.test_launch()
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,13 +1,8 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def tax_calculator(income, marital_status, assets):
|
||||
tax_brackets = [
|
||||
(10, 0),
|
||||
(25, 8),
|
||||
(60, 12),
|
||||
(120, 20),
|
||||
(250, 30)
|
||||
]
|
||||
tax_brackets = [(10, 0), (25, 8), (60, 12), (120, 20), (250, 30)]
|
||||
total_deductible = sum(assets[assets["Deduct"]]["Cost"])
|
||||
taxable_income = income - total_deductible
|
||||
|
||||
@ -15,31 +10,32 @@ def tax_calculator(income, marital_status, assets):
|
||||
for bracket, rate in tax_brackets:
|
||||
if taxable_income > bracket:
|
||||
total_tax += (taxable_income - bracket) * rate / 100
|
||||
|
||||
|
||||
if marital_status == "Married":
|
||||
total_tax *= 0.75
|
||||
elif marital_status == "Divorced":
|
||||
total_tax *= 0.8
|
||||
|
||||
|
||||
return round(total_tax)
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
tax_calculator,
|
||||
tax_calculator,
|
||||
[
|
||||
"number",
|
||||
gr.inputs.Radio(["Single", "Married", "Divorced"]),
|
||||
gr.inputs.Dataframe(
|
||||
headers=["Item", "Cost", "Deduct"],
|
||||
headers=["Item", "Cost", "Deduct"],
|
||||
datatype=["str", "number", "bool"],
|
||||
label="Assets Purchased this Year"
|
||||
)
|
||||
label="Assets Purchased this Year",
|
||||
),
|
||||
],
|
||||
"number",
|
||||
# interpretation="default", # Removed interpretation for dataframes
|
||||
examples=[
|
||||
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
|
||||
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,5 +1,6 @@
|
||||
import spacy
|
||||
from spacy import displacy
|
||||
|
||||
import gradio as gr
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
@ -8,7 +9,11 @@ nlp = spacy.load("en_core_web_sm")
|
||||
def text_analysis(text):
|
||||
doc = nlp(text)
|
||||
html = displacy.render(doc, style="dep", page=True)
|
||||
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
||||
html = (
|
||||
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
|
||||
+ html
|
||||
+ "</div>"
|
||||
)
|
||||
pos_count = {
|
||||
"char_count": len(text),
|
||||
"token_count": 0,
|
||||
@ -17,20 +22,18 @@ def text_analysis(text):
|
||||
|
||||
for token in doc:
|
||||
pos_tokens.extend([(token.text, token.pos_), (" ", None)])
|
||||
|
||||
|
||||
return pos_tokens, pos_count, html
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
text_analysis,
|
||||
gr.inputs.Textbox(placeholder="Enter sentence here..."),
|
||||
[
|
||||
"highlight", "key_values", "html"
|
||||
],
|
||||
["highlight", "key_values", "html"],
|
||||
examples=[
|
||||
["What a beautiful morning for a walk!"],
|
||||
["It was the best of times, it was the worst of times."],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
iface.test_launch()
|
||||
|
@ -1,15 +1,18 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import sklearn
|
||||
import gradio as gr
|
||||
from sklearn import preprocessing
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sklearn
|
||||
from sklearn import preprocessing
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import gradio as gr
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
data = pd.read_csv(os.path.join(current_dir, 'files/titanic.csv'))
|
||||
data = pd.read_csv(os.path.join(current_dir, "files/titanic.csv"))
|
||||
|
||||
|
||||
def encode_age(df):
|
||||
df.Age = df.Age.fillna(-0.5)
|
||||
@ -18,6 +21,7 @@ def encode_age(df):
|
||||
df.Age = categories
|
||||
return df
|
||||
|
||||
|
||||
def encode_fare(df):
|
||||
df.Fare = df.Fare.fillna(-0.5)
|
||||
bins = (-1, 0, 8, 15, 31, 1000)
|
||||
@ -25,46 +29,67 @@ def encode_fare(df):
|
||||
df.Fare = categories
|
||||
return df
|
||||
|
||||
|
||||
def encode_df(df):
|
||||
df = encode_age(df)
|
||||
df = encode_fare(df)
|
||||
sex_mapping = {"male": 0, "female": 1}
|
||||
df = df.replace({'Sex': sex_mapping})
|
||||
df = df.replace({"Sex": sex_mapping})
|
||||
embark_mapping = {"S": 1, "C": 2, "Q": 3}
|
||||
df = df.replace({'Embarked': embark_mapping})
|
||||
df = df.replace({"Embarked": embark_mapping})
|
||||
df.Embarked = df.Embarked.fillna(0)
|
||||
df["Company"] = 0
|
||||
df.loc[(df["SibSp"] > 0), "Company"] = 1
|
||||
df.loc[(df["Parch"] > 0), "Company"] = 2
|
||||
df.loc[(df["SibSp"] > 0) & (df["Parch"] > 0), "Company"] = 3
|
||||
df = df[["PassengerId", "Pclass", "Sex", "Age", "Fare", "Embarked", "Company", "Survived"]]
|
||||
df = df[
|
||||
[
|
||||
"PassengerId",
|
||||
"Pclass",
|
||||
"Sex",
|
||||
"Age",
|
||||
"Fare",
|
||||
"Embarked",
|
||||
"Company",
|
||||
"Survived",
|
||||
]
|
||||
]
|
||||
return df
|
||||
|
||||
|
||||
train = encode_df(data)
|
||||
|
||||
X_all = train.drop(['Survived', 'PassengerId'], axis=1)
|
||||
y_all = train['Survived']
|
||||
X_all = train.drop(["Survived", "PassengerId"], axis=1)
|
||||
y_all = train["Survived"]
|
||||
|
||||
num_test = 0.20
|
||||
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X_all, y_all, test_size=num_test, random_state=23
|
||||
)
|
||||
|
||||
clf = RandomForestClassifier()
|
||||
clf.fit(X_train, y_train)
|
||||
predictions = clf.predict(X_test)
|
||||
|
||||
|
||||
def predict_survival(passenger_class, is_male, age, company, fare, embark_point):
|
||||
df = pd.DataFrame.from_dict({
|
||||
'Pclass': [passenger_class + 1],
|
||||
'Sex': [0 if is_male else 1],
|
||||
'Age': [age],
|
||||
'Company': [(1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)],
|
||||
'Fare': [fare],
|
||||
'Embarked': [embark_point + 1]
|
||||
})
|
||||
df = pd.DataFrame.from_dict(
|
||||
{
|
||||
"Pclass": [passenger_class + 1],
|
||||
"Sex": [0 if is_male else 1],
|
||||
"Age": [age],
|
||||
"Company": [
|
||||
(1 if "Sibling" in company else 0) + (2 if "Child" in company else 0)
|
||||
],
|
||||
"Fare": [fare],
|
||||
"Embarked": [embark_point + 1],
|
||||
}
|
||||
)
|
||||
df = encode_age(df)
|
||||
df = encode_fare(df)
|
||||
pred = clf.predict_proba(df)[0]
|
||||
return {'Perishes': pred[0], 'Survives': pred[1]}
|
||||
return {"Perishes": pred[0], "Survives": pred[1]}
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
predict_survival,
|
||||
@ -72,7 +97,9 @@ iface = gr.Interface(
|
||||
gr.inputs.Dropdown(["first", "second", "third"], type="index"),
|
||||
"checkbox",
|
||||
gr.inputs.Slider(0, 80),
|
||||
gr.inputs.CheckboxGroup(["Sibling", "Child"], label="Travelling with (select all)"),
|
||||
gr.inputs.CheckboxGroup(
|
||||
["Sibling", "Child"], label="Travelling with (select all)"
|
||||
),
|
||||
gr.inputs.Number(),
|
||||
gr.inputs.Radio(["S", "C", "Q"], type="index"),
|
||||
],
|
||||
|
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def video_flip(video):
|
||||
return video
|
||||
|
||||
|
||||
iface = gr.Interface(video_flip, gr.inputs.Video(source="webcam"), "playable_video")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,12 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def snap(image):
|
||||
return np.flipud(image)
|
||||
|
||||
|
||||
iface = gr.Interface(snap, gr.inputs.Image(source="webcam", tool=None), "image")
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -1,15 +1,19 @@
|
||||
import gradio as gr
|
||||
from zipfile import ZipFile
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def zip_to_json(file_obj):
|
||||
files = []
|
||||
with ZipFile(file_obj.name) as zfile:
|
||||
for zinfo in zfile.infolist():
|
||||
files.append({
|
||||
"name": zinfo.filename,
|
||||
"file_size": zinfo.file_size,
|
||||
"compressed_size": zinfo.compress_size,
|
||||
})
|
||||
files.append(
|
||||
{
|
||||
"name": zinfo.filename,
|
||||
"file_size": zinfo.file_size,
|
||||
"compressed_size": zinfo.compress_size,
|
||||
}
|
||||
)
|
||||
return files
|
||||
|
||||
|
||||
|
@ -1,19 +1,22 @@
|
||||
import gradio as gr
|
||||
from zipfile import ZipFile
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def zip_two_files(file1, file2):
|
||||
with ZipFile('tmp.zip', 'w') as zipObj:
|
||||
with ZipFile("tmp.zip", "w") as zipObj:
|
||||
zipObj.write(file1.name, "file1")
|
||||
zipObj.write(file2.name, "file2")
|
||||
return "tmp.zip"
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
zip_two_files,
|
||||
["file", "file"],
|
||||
zip_two_files,
|
||||
["file", "file"],
|
||||
"file",
|
||||
examples=[
|
||||
["files/titanic.csv", "files/titanic.csv"],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,4 @@
|
||||
Metadata-Version: 1.0
|
||||
Metadata-Version: 2.1
|
||||
Name: gradio
|
||||
Version: 2.7.1
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq
|
||||
Author-email: team@gradio.app
|
||||
License: Apache License 2.0
|
||||
Description: UNKNOWN
|
||||
Keywords: machine learning,visualization,reproducibility
|
||||
Platform: UNKNOWN
|
||||
License-File: LICENSE
|
||||
|
||||
UNKNOWN
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
aiohttp
|
||||
analytics-python
|
||||
aiohttp
|
||||
fastapi
|
||||
ffmpy
|
||||
markdown2
|
||||
@ -9,7 +9,7 @@ pandas
|
||||
paramiko
|
||||
pillow
|
||||
pycryptodome
|
||||
pydub
|
||||
python-multipart
|
||||
pydub
|
||||
requests
|
||||
uvicorn
|
||||
|
@ -1,8 +1,9 @@
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.app import get_state, set_state
|
||||
from gradio.mix import *
|
||||
from gradio.flagging import *
|
||||
import pkg_resources
|
||||
|
||||
from gradio.app import get_state, set_state
|
||||
from gradio.flagging import *
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.mix import *
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
__version__ = current_pkg_version
|
||||
__version__ = current_pkg_version
|
||||
|
192
gradio/app.py
192
gradio/app.py
@ -1,34 +1,36 @@
|
||||
"""Implements a FastAPI server to run the gradio interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import posixpath
|
||||
import pkg_resources
|
||||
import secrets
|
||||
from starlette.responses import RedirectResponse
|
||||
import traceback
|
||||
from typing import List, Optional, Type
|
||||
import urllib
|
||||
import uvicorn
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from gradio import utils, queueing
|
||||
import pkg_resources
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import queueing, utils
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename(
|
||||
"gradio", "templates/frontend/static")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
|
||||
VERSION_FILE = pkg_resources.resource_filename("gradio", "version.txt")
|
||||
with open(VERSION_FILE) as version_file:
|
||||
VERSION = version_file.read()
|
||||
GRADIO_STATIC_ROOT = "https://gradio.s3-us-west-2.amazonaws.com/{}/static/".format(VERSION)
|
||||
GRADIO_STATIC_ROOT = "https://gradio.s3-us-west-2.amazonaws.com/{}/static/".format(
|
||||
VERSION
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
@ -43,40 +45,43 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
# Auth
|
||||
###########
|
||||
|
||||
|
||||
@app.get('/user')
|
||||
@app.get('/user/')
|
||||
@app.get("/user")
|
||||
@app.get("/user/")
|
||||
def get_current_user(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get('access-token')
|
||||
token = request.cookies.get("access-token")
|
||||
return app.tokens.get(token)
|
||||
|
||||
|
||||
@app.get('/login_check')
|
||||
@app.get('/login_check/')
|
||||
@app.get("/login_check")
|
||||
@app.get("/login_check/")
|
||||
def login_check(user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not(user is None):
|
||||
if app.auth is None or not (user is None):
|
||||
return
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
|
||||
@app.get('/token')
|
||||
@app.get('/token/')
|
||||
@app.get("/token")
|
||||
@app.get("/token/")
|
||||
def get_token(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get('access-token')
|
||||
token = request.cookies.get("access-token")
|
||||
return {"token": token, "user": app.tokens.get(token)}
|
||||
|
||||
|
||||
@app.post('/login')
|
||||
@app.post('/login/')
|
||||
@app.post("/login")
|
||||
@app.post("/login/")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if ((not callable(app.auth) and username in app.auth
|
||||
and app.auth[username] == password)
|
||||
or (callable(app.auth) and app.auth.__call__(username, password))):
|
||||
if (
|
||||
not callable(app.auth)
|
||||
and username in app.auth
|
||||
and app.auth[username] == password
|
||||
) or (callable(app.auth) and app.auth.__call__(username, password)):
|
||||
token = secrets.token_urlsafe(16)
|
||||
app.tokens[token] = username
|
||||
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
||||
@ -91,21 +96,19 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
###############
|
||||
|
||||
|
||||
@app.head('/', response_class=HTMLResponse)
|
||||
@app.get('/', response_class=HTMLResponse)
|
||||
@app.head("/", response_class=HTMLResponse)
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: Request, user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not(user is None):
|
||||
if app.auth is None or not (user is None):
|
||||
config = app.interface.config
|
||||
else:
|
||||
config = {"auth_required": True,
|
||||
"auth_message": app.interface.auth_message}
|
||||
|
||||
config = {"auth_required": True, "auth_message": app.interface.auth_message}
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"frontend/index.html",
|
||||
{"request": request, "config": config}
|
||||
"frontend/index.html", {"request": request, "config": config}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config():
|
||||
@ -135,8 +138,9 @@ def api_docs(request: Request):
|
||||
if app.interface.examples is not None:
|
||||
sample_inputs = app.interface.examples[0]
|
||||
else:
|
||||
sample_inputs = [inp.generate_sample()
|
||||
for inp in app.interface.input_components]
|
||||
sample_inputs = [
|
||||
inp.generate_sample() for inp in app.interface.input_components
|
||||
]
|
||||
docs = {
|
||||
"inputs": input_names,
|
||||
"outputs": output_names,
|
||||
@ -150,90 +154,90 @@ def api_docs(request: Request):
|
||||
"output_types_doc": output_types_doc,
|
||||
"sample_inputs": sample_inputs,
|
||||
"auth": app.interface.auth,
|
||||
"local_login_url": urllib.parse.urljoin(
|
||||
app.interface.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(
|
||||
app.interface.local_url, "api/predict")
|
||||
"local_login_url": urllib.parse.urljoin(app.interface.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(app.interface.local_url, "api/predict"),
|
||||
}
|
||||
return templates.TemplateResponse(
|
||||
"api_docs.html",
|
||||
{"request": request, **docs}
|
||||
)
|
||||
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
|
||||
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(
|
||||
request: Request,
|
||||
username: str = Depends(get_current_user)
|
||||
):
|
||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
flag_index = None
|
||||
|
||||
|
||||
if body.get("example_id") != None:
|
||||
example_id = body["example_id"]
|
||||
if app.interface.cache_examples:
|
||||
prediction = await run_in_threadpool(
|
||||
load_from_cache, app.interface, example_id)
|
||||
load_from_cache, app.interface, example_id
|
||||
)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
process_example, app.interface, example_id)
|
||||
process_example, app.interface, example_id
|
||||
)
|
||||
else:
|
||||
raw_input = body["data"]
|
||||
if app.interface.show_error:
|
||||
try:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input)
|
||||
app.interface.process, raw_input
|
||||
)
|
||||
except BaseException as error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)},
|
||||
status_code=500)
|
||||
return JSONResponse(content={"error": str(error)}, status_code=500)
|
||||
else:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input)
|
||||
app.interface.process, raw_input
|
||||
)
|
||||
if app.interface.allow_flagging == "auto":
|
||||
flag_index = await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface, raw_input, prediction,
|
||||
flag_option="" if app.interface.flagging_options else None,
|
||||
username=username)
|
||||
app.interface,
|
||||
raw_input,
|
||||
prediction,
|
||||
flag_option="" if app.interface.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
output = {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": app.interface.config.get("avg_durations"),
|
||||
"flag_index": flag_index
|
||||
}
|
||||
"flag_index": flag_index,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||
async def flag(
|
||||
request: Request,
|
||||
username: str = Depends(get_current_user)
|
||||
):
|
||||
async def flag(request: Request, username: str = Depends(get_current_user)):
|
||||
if app.interface.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
|
||||
await utils.log_feature_analytics(app.interface.ip_address, "flag")
|
||||
body = await request.json()
|
||||
data = body['data']
|
||||
data = body["data"]
|
||||
await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface, data['input_data'], data['output_data'],
|
||||
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
|
||||
username=username)
|
||||
return {'success': True}
|
||||
app.interface,
|
||||
data["input_data"],
|
||||
data["output_data"],
|
||||
flag_option=data.get("flag_option"),
|
||||
flag_index=data.get("flag_index"),
|
||||
username=username,
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||
async def interpret(request: Request):
|
||||
if app.interface.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
|
||||
await utils.log_feature_analytics(app.interface.ip_address, "interpret")
|
||||
body = await request.json()
|
||||
raw_input = body["data"]
|
||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||
app.interface.interpret, raw_input)
|
||||
app.interface.interpret, raw_input
|
||||
)
|
||||
return {
|
||||
"interpretation_scores": interpretation_scores,
|
||||
"alternative_outputs": alternative_outputs
|
||||
"alternative_outputs": alternative_outputs,
|
||||
}
|
||||
|
||||
|
||||
@ -249,7 +253,7 @@ async def queue_push(request: Request):
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(request: Request):
|
||||
body = await request.json()
|
||||
hash = body['hash']
|
||||
hash = body["hash"]
|
||||
status, data = queueing.get_status(hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
@ -277,7 +281,7 @@ def safe_join(directory: str, path: str) -> Optional[str]:
|
||||
):
|
||||
return None
|
||||
|
||||
return posixpath.join(directory, filename)
|
||||
return posixpath.join(directory, filename)
|
||||
|
||||
|
||||
def get_types(cls_set: List[Type], component: str):
|
||||
@ -302,26 +306,30 @@ def get_state():
|
||||
raise DeprecationWarning(
|
||||
"This function is deprecated. To create stateful demos, pass 'state'"
|
||||
"as both an input and output component. Please see the getting started"
|
||||
"guide for more information.")
|
||||
"guide for more information."
|
||||
)
|
||||
|
||||
|
||||
def set_state(*args):
|
||||
raise DeprecationWarning(
|
||||
"This function is deprecated. To create stateful demos, pass 'state'"
|
||||
"as both an input and output component. Please see the getting started"
|
||||
"guide for more information.")
|
||||
"guide for more information."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__': # Run directly for debugging: python app.py
|
||||
from gradio import Interface
|
||||
app.interface = Interface(lambda x: "Hello, " + x, "text", "text",
|
||||
analytics_enabled=False)
|
||||
|
||||
if __name__ == "__main__": # Run directly for debugging: python app.py
|
||||
from gradio import Interface
|
||||
|
||||
app.interface = Interface(
|
||||
lambda x: "Hello, " + x, "text", "text", analytics_enabled=False
|
||||
)
|
||||
app.interface.favicon_path = None
|
||||
app.interface.config = app.interface.get_config_file()
|
||||
app.interface.show_error = True
|
||||
app.interface.flagging_callback.setup(app.interface.flagging_dir)
|
||||
app.tokens = {}
|
||||
|
||||
|
||||
auth = True
|
||||
if auth:
|
||||
app.interface.auth = ("a", "b")
|
||||
|
@ -1,8 +1,10 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from gradio import processing_utils
|
||||
|
||||
class Component():
|
||||
|
||||
class Component:
|
||||
"""
|
||||
A class for defining the methods that all gradio input and output components should have.
|
||||
"""
|
||||
@ -15,16 +17,13 @@ class Component():
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "{}(label=\"{}\")".format(type(self).__name__, self.label)
|
||||
return '{}(label="{}")'.format(type(self).__name__, self.label)
|
||||
|
||||
def get_template_context(self):
|
||||
"""
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
"""
|
||||
return {
|
||||
"name": self.__class__.__name__.lower(),
|
||||
"label": self.label
|
||||
}
|
||||
return {"name": self.__class__.__name__.lower(), "label": self.label}
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -60,13 +59,15 @@ class Component():
|
||||
new_file_name = str(file_index)
|
||||
if "." in old_file_name:
|
||||
uploaded_format = old_file_name.split(".")[-1].lower()
|
||||
new_file_name += "." + uploaded_format
|
||||
new_file_name += "." + uploaded_format
|
||||
file.close()
|
||||
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
|
||||
return label + "/" + new_file_name
|
||||
|
||||
def restore_flagged_file(self, dir, file, encryption_key):
|
||||
data = processing_utils.encode_file_to_base64(os.path.join(dir, file), encryption_key=encryption_key)
|
||||
data = processing_utils.encode_file_to_base64(
|
||||
os.path.join(dir, file), encryption_key=encryption_key
|
||||
)
|
||||
return {"name": file, "data": data}
|
||||
|
||||
@classmethod
|
||||
|
@ -1,11 +1,13 @@
|
||||
from Crypto import Random
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto import Random
|
||||
|
||||
|
||||
def get_key(password):
|
||||
key = SHA256.new(password.encode()).digest()
|
||||
return key
|
||||
|
||||
|
||||
def encrypt(key, source):
|
||||
IV = Random.new().read(AES.block_size) # generate IV
|
||||
encryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
@ -14,11 +16,14 @@ def encrypt(key, source):
|
||||
data = IV + encryptor.encrypt(source) # store the IV at the beginning and encrypt
|
||||
return data
|
||||
|
||||
|
||||
def decrypt(key, source):
|
||||
IV = source[:AES.block_size] # extract the IV from the beginning
|
||||
IV = source[: AES.block_size] # extract the IV from the beginning
|
||||
decryptor = AES.new(key, AES.MODE_CBC, IV)
|
||||
data = decryptor.decrypt(source[AES.block_size:]) # decrypt
|
||||
data = decryptor.decrypt(source[AES.block_size :]) # decrypt
|
||||
padding = data[-1] # pick the padding value from the end; Python 2.x: ord(data[-1])
|
||||
if data[-padding:] != bytes([padding]) * padding: # Python 2.x: chr(padding) * padding
|
||||
if (
|
||||
data[-padding:] != bytes([padding]) * padding
|
||||
): # Python 2.x: chr(padding) * padding
|
||||
raise ValueError("Invalid padding...")
|
||||
return data[:-padding] # remove the padding
|
||||
return data[:-padding] # remove the padding
|
||||
|
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import tempfile
|
||||
import requests
|
||||
from gradio import inputs, outputs
|
||||
import re
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import requests
|
||||
|
||||
from gradio import inputs, outputs
|
||||
|
||||
|
||||
def get_huggingface_interface(model_name, api_key, alias):
|
||||
@ -17,56 +19,63 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
headers = {}
|
||||
|
||||
# Checking if model exists, and if so, it gets the pipeline
|
||||
response = requests.request("GET", api_url, headers=headers)
|
||||
response = requests.request("GET", api_url, headers=headers)
|
||||
assert response.status_code == 200, "Invalid model name or src"
|
||||
p = response.json().get('pipeline_tag')
|
||||
p = response.json().get("pipeline_tag")
|
||||
|
||||
def encode_to_base64(r: requests.Response) -> str:
|
||||
base64_repr = base64.b64encode(r.content).decode('utf-8')
|
||||
base64_repr = base64.b64encode(r.content).decode("utf-8")
|
||||
data_prefix = ";base64,"
|
||||
if data_prefix in base64_repr:
|
||||
return base64_repr
|
||||
else:
|
||||
content_type = r.headers.get('content-type')
|
||||
content_type = r.headers.get("content-type")
|
||||
return "data:{};base64,".format(content_type) + base64_repr
|
||||
|
||||
|
||||
pipelines = {
|
||||
'audio-classification': {
|
||||
"audio-classification": {
|
||||
# example model: https://hf.co/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition
|
||||
'inputs': inputs.Audio(label="Input", source="upload",
|
||||
type="filepath"),
|
||||
'outputs': outputs.Label(label="Class", type="confidences"),
|
||||
'preprocess': lambda i: base64.b64decode(i['data'].split(",")[1]), # convert the base64 representation to binary
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
|
||||
"outputs": outputs.Label(label="Class", type="confidences"),
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
i["data"].split(",")[1]
|
||||
), # convert the base64 representation to binary
|
||||
"postprocess": lambda r: {
|
||||
i["label"].split(", ")[0]: i["score"] for i in r.json()
|
||||
},
|
||||
},
|
||||
'automatic-speech-recognition': {
|
||||
"automatic-speech-recognition": {
|
||||
# example model: https://hf.co/jonatasgrosman/wav2vec2-large-xlsr-53-english
|
||||
'inputs': inputs.Audio(label="Input", source="upload",
|
||||
type="filepath"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda i: base64.b64decode(i['data'].split(",")[1]), # convert the base64 representation to binary
|
||||
'postprocess': lambda r: r.json()["text"]
|
||||
"inputs": inputs.Audio(label="Input", source="upload", type="filepath"),
|
||||
"outputs": outputs.Textbox(label="Output"),
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
i["data"].split(",")[1]
|
||||
), # convert the base64 representation to binary
|
||||
"postprocess": lambda r: r.json()["text"],
|
||||
},
|
||||
'feature-extraction': {
|
||||
"feature-extraction": {
|
||||
# example model: hf.co/julien-c/distilbert-feature-extraction
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Dataframe(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0],
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Dataframe(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0],
|
||||
},
|
||||
'fill-mask': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
|
||||
"fill-mask": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r.json()},
|
||||
},
|
||||
'image-classification': {
|
||||
"image-classification": {
|
||||
# Example: https://huggingface.co/google/vit-base-patch16-224
|
||||
'inputs': inputs.Image(label="Input Image", type="filepath"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
"inputs": inputs.Image(label="Input Image", type="filepath"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda i: base64.b64decode(
|
||||
i.split(",")[1]
|
||||
), # convert the base64 representation to binary
|
||||
"postprocess": lambda r: {
|
||||
i["label"].split(", ")[0]: i["score"] for i in r.json()
|
||||
},
|
||||
},
|
||||
# TODO: support image segmentation pipeline -- should we add a new output component type?
|
||||
# 'image-segmentation': {
|
||||
@ -75,165 +84,220 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
# 'outputs': outputs.Image(label="Segmentation"),
|
||||
# 'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
|
||||
# 'postprocess': lambda x: base64.b64encode(x.json()[0]["mask"]).decode('utf-8'),
|
||||
# },
|
||||
# },
|
||||
# TODO: also: support NER pipeline, object detection, table question answering
|
||||
'question-answering': {
|
||||
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
|
||||
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
|
||||
'preprocess': lambda c, q: {"inputs": {"context": c, "question": q}},
|
||||
'postprocess': lambda r: (r.json()["answer"], r.json()["score"]),
|
||||
},
|
||||
'summarization': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Summary"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["summary_text"]
|
||||
},
|
||||
'text-classification': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
|
||||
},
|
||||
'text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
'text2text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"]
|
||||
},
|
||||
'translation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Translation"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["translation_text"]
|
||||
},
|
||||
'zero-shot-classification': {
|
||||
'inputs': [inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names ("
|
||||
"comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes")],
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
|
||||
{"candidate_labels": c, "multi_class": m}},
|
||||
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
|
||||
range(len(r.json()["labels"]))}
|
||||
},
|
||||
'sentence-similarity': {
|
||||
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
'inputs': [
|
||||
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
|
||||
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
|
||||
"question-answering": {
|
||||
"inputs": [
|
||||
inputs.Textbox(label="Context", lines=7),
|
||||
inputs.Textbox(label="Question"),
|
||||
],
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda src, sentences: {"inputs": {
|
||||
"source_sentence": src,
|
||||
"sentences": [s for s in sentences.splitlines() if s != ""],
|
||||
}},
|
||||
'postprocess': lambda r: { f"sentence {i}": v for i, v in enumerate(r.json()) },
|
||||
"outputs": [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
|
||||
"preprocess": lambda c, q: {"inputs": {"context": c, "question": q}},
|
||||
"postprocess": lambda r: (r.json()["answer"], r.json()["score"]),
|
||||
},
|
||||
'text-to-speech': {
|
||||
"summarization": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Summary"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["summary_text"],
|
||||
},
|
||||
"text-classification": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {
|
||||
i["label"].split(", ")[0]: i["score"] for i in r.json()[0]
|
||||
},
|
||||
},
|
||||
"text-generation": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
"text2text-generation": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Generated Text"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["generated_text"],
|
||||
},
|
||||
"translation": {
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Translation"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r.json()[0]["translation_text"],
|
||||
},
|
||||
"zero-shot-classification": {
|
||||
"inputs": [
|
||||
inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes"),
|
||||
],
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"inputs": i,
|
||||
"parameters": {"candidate_labels": c, "multi_class": m},
|
||||
},
|
||||
"postprocess": lambda r: {
|
||||
r.json()["labels"][i]: r.json()["scores"][i]
|
||||
for i in range(len(r.json()["labels"]))
|
||||
},
|
||||
},
|
||||
"sentence-similarity": {
|
||||
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
"inputs": [
|
||||
inputs.Textbox(
|
||||
label="Source Sentence", default="That is a happy person"
|
||||
),
|
||||
inputs.Textbox(
|
||||
lines=7,
|
||||
label="Sentences to compare to",
|
||||
placeholder="Separate each sentence by a newline",
|
||||
),
|
||||
],
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda src, sentences: {
|
||||
"inputs": {
|
||||
"source_sentence": src,
|
||||
"sentences": [s for s in sentences.splitlines() if s != ""],
|
||||
}
|
||||
},
|
||||
"postprocess": lambda r: {
|
||||
f"sentence {i}": v for i, v in enumerate(r.json())
|
||||
},
|
||||
},
|
||||
"text-to-speech": {
|
||||
# example model: hf.co/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Audio(label="Audio"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': encode_to_base64,
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Audio(label="Audio"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
'text-to-image': {
|
||||
"text-to-image": {
|
||||
# example model: hf.co/osanseviero/BigGAN-deep-128
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Image(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': encode_to_base64,
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Image(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": encode_to_base64,
|
||||
},
|
||||
}
|
||||
|
||||
if p is None or not(p in pipelines):
|
||||
if p is None or not (p in pipelines):
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(p)))
|
||||
|
||||
|
||||
pipeline = pipelines[p]
|
||||
|
||||
def query_huggingface_api(*params):
|
||||
# Convert to a list of input components
|
||||
data = pipeline['preprocess'](*params)
|
||||
if isinstance(data, dict): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
|
||||
data.update({'options': {'wait_for_model': True}})
|
||||
data = pipeline["preprocess"](*params)
|
||||
if isinstance(
|
||||
data, dict
|
||||
): # HF doesn't allow additional parameters for binary files (e.g. images or audio files)
|
||||
data.update({"options": {"wait_for_model": True}})
|
||||
data = json.dumps(data)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not(response.status_code == 200):
|
||||
raise ValueError("Could not complete request to HuggingFace API, Error {}".format(response.status_code))
|
||||
output = pipeline['postprocess'](response)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not (response.status_code == 200):
|
||||
raise ValueError(
|
||||
"Could not complete request to HuggingFace API, Error {}".format(
|
||||
response.status_code
|
||||
)
|
||||
)
|
||||
output = pipeline["postprocess"](response)
|
||||
return output
|
||||
|
||||
|
||||
if alias is None:
|
||||
query_huggingface_api.__name__ = model_name
|
||||
else:
|
||||
query_huggingface_api.__name__ = alias
|
||||
|
||||
interface_info = {
|
||||
'fn': query_huggingface_api,
|
||||
'inputs': pipeline['inputs'],
|
||||
'outputs': pipeline['outputs'],
|
||||
'title': model_name,
|
||||
"fn": query_huggingface_api,
|
||||
"inputs": pipeline["inputs"],
|
||||
"outputs": pipeline["outputs"],
|
||||
"title": model_name,
|
||||
}
|
||||
|
||||
return interface_info
|
||||
|
||||
|
||||
def load_interface(name, src=None, api_key=None, alias=None):
|
||||
if src is None:
|
||||
tokens = name.split("/") # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
|
||||
assert len(tokens) > 1, "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
|
||||
tokens = name.split(
|
||||
"/"
|
||||
) # Separate the source (e.g. "huggingface") from the repo name (e.g. "google/vit-base-patch16-224")
|
||||
assert (
|
||||
len(tokens) > 1
|
||||
), "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
assert src.lower() in repos, "parameter: src must be one of {}".format(repos.keys())
|
||||
interface_info = repos[src](name, api_key, alias)
|
||||
return interface_info
|
||||
|
||||
|
||||
def interface_params_from_config(config_dict):
|
||||
## instantiate input component and output component
|
||||
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
|
||||
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
|
||||
config_dict["inputs"] = [
|
||||
inputs.get_input_instance(component)
|
||||
for component in config_dict["input_components"]
|
||||
]
|
||||
config_dict["outputs"] = [
|
||||
outputs.get_output_instance(component)
|
||||
for component in config_dict["output_components"]
|
||||
]
|
||||
parameters = {
|
||||
"allow_flagging", "allow_screenshot", "article", "description", "flagging_options", "inputs", "outputs",
|
||||
"show_input", "show_output", "theme", "title"
|
||||
"allow_flagging",
|
||||
"allow_screenshot",
|
||||
"article",
|
||||
"description",
|
||||
"flagging_options",
|
||||
"inputs",
|
||||
"outputs",
|
||||
"show_input",
|
||||
"show_output",
|
||||
"theme",
|
||||
"title",
|
||||
}
|
||||
config_dict = {k: config_dict[k] for k in parameters}
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_spaces_interface(model_name, api_key, alias):
|
||||
space_url = "https://huggingface.co/spaces/{}".format(model_name)
|
||||
print("Fetching interface from: {}".format(space_url))
|
||||
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
|
||||
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
r = requests.get(iframe_url)
|
||||
result = re.search('window.gradio_config = (.*?);</script>', r.text) # some basic regex to extract the config
|
||||
result = re.search(
|
||||
"window.gradio_config = (.*?);</script>", r.text
|
||||
) # some basic regex to extract the config
|
||||
config = json.loads(result.group(1))
|
||||
interface_info = interface_params_from_config(config)
|
||||
|
||||
|
||||
# The function should call the API with preprocessed data
|
||||
def fn(*data):
|
||||
data = json.dumps({"data": data})
|
||||
response = requests.post(api_url, headers=headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
output = result["data"]
|
||||
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
|
||||
if (
|
||||
len(interface_info["outputs"]) == 1
|
||||
): # if the fn is supposed to return a single value, pop it
|
||||
output = output[0]
|
||||
if len(interface_info["outputs"])==1 and isinstance(output, list): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
|
||||
if len(interface_info["outputs"]) == 1 and isinstance(
|
||||
output, list
|
||||
): # Needed to support Output.Image() returning bounding boxes as well (TODO: handle different versions of gradio since they have slightly different APIs)
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
|
||||
fn.__name__ = alias if (alias is not None) else model_name
|
||||
interface_info["fn"] = fn
|
||||
|
||||
|
||||
return interface_info
|
||||
|
||||
|
||||
repos = {
|
||||
# for each repo, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
"huggingface": get_huggingface_interface,
|
||||
@ -241,6 +305,7 @@ repos = {
|
||||
"spaces": get_spaces_interface,
|
||||
}
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline):
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
@ -251,128 +316,163 @@ def load_from_pipeline(pipeline):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ImportError("transformers not installed. Please try `pip install transformers`")
|
||||
raise ImportError(
|
||||
"transformers not installed. Please try `pip install transformers`"
|
||||
)
|
||||
if not isinstance(pipeline, transformers.Pipeline):
|
||||
raise ValueError("pipeline must be a transformers.Pipeline")
|
||||
|
||||
|
||||
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
|
||||
# version of the transformers library that the user has installed.
|
||||
if hasattr(transformers, 'AudioClassificationPipeline') and isinstance(pipeline, transformers.AudioClassificationPipeline):
|
||||
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.AudioClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Audio(label="Input", source="microphone",
|
||||
type="filepath"),
|
||||
'outputs': outputs.Label(label="Class", type="confidences"),
|
||||
'preprocess': lambda i: {"inputs": i},
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
|
||||
"inputs": inputs.Audio(label="Input", source="microphone", type="filepath"),
|
||||
"outputs": outputs.Label(label="Class", type="confidences"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, 'AutomaticSpeechRecognitionPipeline') and isinstance(pipeline, transformers.AutomaticSpeechRecognitionPipeline):
|
||||
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
|
||||
pipeline, transformers.AutomaticSpeechRecognitionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Audio(label="Input", source="microphone",
|
||||
type="filepath"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda i: {"inputs": i},
|
||||
'postprocess': lambda r: r["text"]
|
||||
"inputs": inputs.Audio(label="Input", source="microphone", type="filepath"),
|
||||
"outputs": outputs.Textbox(label="Output"),
|
||||
"preprocess": lambda i: {"inputs": i},
|
||||
"postprocess": lambda r: r["text"],
|
||||
}
|
||||
elif hasattr(transformers, 'FeatureExtractionPipeline') and isinstance(pipeline, transformers.FeatureExtractionPipeline):
|
||||
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
|
||||
pipeline, transformers.FeatureExtractionPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Dataframe(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0],
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Dataframe(label="Output"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0],
|
||||
}
|
||||
elif hasattr(transformers, 'FillMaskPipeline') and isinstance(pipeline, transformers.FillMaskPipeline):
|
||||
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
|
||||
pipeline, transformers.FillMaskPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r}
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, 'ImageClassificationPipeline') and isinstance(pipeline, transformers.ImageClassificationPipeline):
|
||||
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ImageClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Image(label="Input Image", type="filepath"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i: {"images": i},
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
|
||||
"inputs": inputs.Image(label="Input Image", type="filepath"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda i: {"images": i},
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, 'QuestionAnsweringPipeline') and isinstance(pipeline, transformers.QuestionAnsweringPipeline):
|
||||
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
|
||||
pipeline, transformers.QuestionAnsweringPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
|
||||
'outputs': [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
|
||||
'preprocess': lambda c, q: {"context": c, "question": q},
|
||||
'postprocess': lambda r: (r["answer"], r["score"]),
|
||||
"inputs": [
|
||||
inputs.Textbox(label="Context", lines=7),
|
||||
inputs.Textbox(label="Question"),
|
||||
],
|
||||
"outputs": [outputs.Textbox(label="Answer"), outputs.Label(label="Score")],
|
||||
"preprocess": lambda c, q: {"context": c, "question": q},
|
||||
"postprocess": lambda r: (r["answer"], r["score"]),
|
||||
}
|
||||
elif hasattr(transformers, 'SummarizationPipeline') and isinstance(pipeline, transformers.SummarizationPipeline):
|
||||
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
|
||||
pipeline, transformers.SummarizationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input", lines=7),
|
||||
'outputs': outputs.Textbox(label="Summary"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["summary_text"]
|
||||
"inputs": inputs.Textbox(label="Input", lines=7),
|
||||
"outputs": outputs.Textbox(label="Summary"),
|
||||
"preprocess": lambda x: {"inputs": x},
|
||||
"postprocess": lambda r: r[0]["summary_text"],
|
||||
}
|
||||
elif hasattr(transformers, 'TextClassificationPipeline') and isinstance(pipeline, transformers.TextClassificationPipeline):
|
||||
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.TextClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
|
||||
}
|
||||
elif hasattr(transformers, 'TextGenerationPipeline') and isinstance(pipeline, transformers.TextGenerationPipeline):
|
||||
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"text_inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"],
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Output"),
|
||||
"preprocess": lambda x: {"text_inputs": x},
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, 'TranslationPipeline') and isinstance(pipeline, transformers.TranslationPipeline):
|
||||
elif hasattr(transformers, "TranslationPipeline") and isinstance(
|
||||
pipeline, transformers.TranslationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Translation"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: r[0]["translation_text"]
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Translation"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["translation_text"],
|
||||
}
|
||||
elif hasattr(transformers, 'Text2TextGenerationPipeline') and isinstance(pipeline, transformers.Text2TextGenerationPipeline):
|
||||
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
|
||||
pipeline, transformers.Text2TextGenerationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: r[0]["generated_text"]
|
||||
"inputs": inputs.Textbox(label="Input"),
|
||||
"outputs": outputs.Textbox(label="Generated Text"),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, 'ZeroShotClassificationPipeline') and isinstance(pipeline, transformers.ZeroShotClassificationPipeline):
|
||||
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
|
||||
pipeline, transformers.ZeroShotClassificationPipeline
|
||||
):
|
||||
pipeline_info = {
|
||||
'inputs': [inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names ("
|
||||
"comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes")],
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i, c, m: {"sequences": i,
|
||||
"candidate_labels": c, "multi_label": m},
|
||||
'postprocess': lambda r: {r["labels"][i]: r["scores"][i] for i in
|
||||
range(len(r["labels"]))}
|
||||
"inputs": [
|
||||
inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names (" "comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes"),
|
||||
],
|
||||
"outputs": outputs.Label(label="Classification", type="confidences"),
|
||||
"preprocess": lambda i, c, m: {
|
||||
"sequences": i,
|
||||
"candidate_labels": c,
|
||||
"multi_label": m,
|
||||
},
|
||||
"postprocess": lambda r: {
|
||||
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
|
||||
|
||||
|
||||
# define the function that will be called by the Interface
|
||||
def fn(*params):
|
||||
data = pipeline_info["preprocess"](*params)
|
||||
# special cases that needs to be handled differently
|
||||
if isinstance(pipeline, (transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline)):
|
||||
if isinstance(
|
||||
pipeline,
|
||||
(
|
||||
transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline,
|
||||
),
|
||||
):
|
||||
data = pipeline(*data)
|
||||
else:
|
||||
data = pipeline(**data)
|
||||
# print("Before postprocessing", data)
|
||||
output = pipeline_info["postprocess"](data)
|
||||
return output
|
||||
|
||||
|
||||
interface_info = pipeline_info.copy()
|
||||
interface_info["fn"] = fn
|
||||
del interface_info["preprocess"]
|
||||
del interface_info["postprocess"]
|
||||
|
||||
|
||||
# define the title/description of the Interface
|
||||
interface_info["title"] = pipeline.model.__class__.__name__
|
||||
|
||||
return interface_info
|
||||
|
||||
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import csv
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import gradio as gr
|
||||
@ -17,9 +18,7 @@ class FlaggingCallback(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def setup(
|
||||
self,
|
||||
flagging_dir: str):
|
||||
def setup(self, flagging_dir: str):
|
||||
"""
|
||||
This method should be overridden and ensure that everything is set up correctly for flag().
|
||||
This method gets called once at the beginning of the Interface.launch() method.
|
||||
@ -30,20 +29,21 @@ class FlaggingCallback(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None) -> int:
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
|
||||
This gets called every time the <flag> button is pressed.
|
||||
Parameters:
|
||||
interface: The Interface object that is being used to launch the flagging interface.
|
||||
input_data: The input data to be flagged.
|
||||
output_data: The output data to be flagged.
|
||||
output_data: The output data to be flagged.
|
||||
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
|
||||
flag_index (optional): The index of the sample that is being flagged.
|
||||
username (optional): The username of the user that is flagging the data, if logged in.
|
||||
@ -55,41 +55,52 @@ class FlaggingCallback(ABC):
|
||||
|
||||
class SimpleCSVLogger(FlaggingCallback):
|
||||
"""
|
||||
A simple example implementation of the FlaggingCallback abstract class
|
||||
A simple example implementation of the FlaggingCallback abstract class
|
||||
provided for illustrative purposes.
|
||||
"""
|
||||
def setup(
|
||||
self,
|
||||
flagging_dir: str
|
||||
):
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = "{}/log.csv".format(flagging_dir)
|
||||
|
||||
csv_data = []
|
||||
for i, input in enumerate(interface.input_components):
|
||||
csv_data.append(input.save_flagged(
|
||||
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], None))
|
||||
csv_data.append(
|
||||
input.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["input_components"][i]["label"],
|
||||
input_data[i],
|
||||
None,
|
||||
)
|
||||
)
|
||||
for i, output in enumerate(interface.output_components):
|
||||
csv_data.append(output.save_flagged(
|
||||
flagging_dir, interface.config["output_components"][i]["label"], output_data[i], None) if
|
||||
output_data[i] is not None else "")
|
||||
|
||||
csv_data.append(
|
||||
output.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["output_components"][i]["label"],
|
||||
output_data[i],
|
||||
None,
|
||||
)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
with open(log_filepath, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(csv_data)
|
||||
|
||||
|
||||
with open(log_filepath, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
@ -97,24 +108,22 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
|
||||
class CSVLogger(FlaggingCallback):
|
||||
"""
|
||||
The default implementation of the FlaggingCallback abstract class.
|
||||
The default implementation of the FlaggingCallback abstract class.
|
||||
Logs the input and output data to a CSV file.
|
||||
"""
|
||||
def setup(
|
||||
self,
|
||||
flagging_dir: str
|
||||
):
|
||||
|
||||
def setup(self, flagging_dir: str):
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_fp = "{}/log.csv".format(flagging_dir)
|
||||
@ -126,12 +135,25 @@ class CSVLogger(FlaggingCallback):
|
||||
csv_data = []
|
||||
if not output_only_mode:
|
||||
for i, input in enumerate(interface.input_components):
|
||||
csv_data.append(input.save_flagged(
|
||||
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], encryption_key))
|
||||
csv_data.append(
|
||||
input.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["input_components"][i]["label"],
|
||||
input_data[i],
|
||||
encryption_key,
|
||||
)
|
||||
)
|
||||
for i, output in enumerate(interface.output_components):
|
||||
csv_data.append(output.save_flagged(
|
||||
flagging_dir, interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
|
||||
output_data[i] is not None else "")
|
||||
csv_data.append(
|
||||
output.save_flagged(
|
||||
flagging_dir,
|
||||
interface.config["output_components"][i]["label"],
|
||||
output_data[i],
|
||||
encryption_key,
|
||||
)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
if not output_only_mode:
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
@ -141,10 +163,14 @@ class CSVLogger(FlaggingCallback):
|
||||
if is_new:
|
||||
headers = []
|
||||
if not output_only_mode:
|
||||
headers += [interface["label"]
|
||||
for interface in interface.config["input_components"]]
|
||||
headers += [interface["label"]
|
||||
for interface in interface.config["output_components"]]
|
||||
headers += [
|
||||
interface["label"]
|
||||
for interface in interface.config["input_components"]
|
||||
]
|
||||
headers += [
|
||||
interface["label"]
|
||||
for interface in interface.config["output_components"]
|
||||
]
|
||||
if not output_only_mode:
|
||||
if interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
@ -169,7 +195,8 @@ class CSVLogger(FlaggingCallback):
|
||||
with open(log_fp, "rb") as csvfile:
|
||||
encrypted_csv = csvfile.read()
|
||||
decrypted_csv = encryptor.decrypt(
|
||||
interface.encryption_key, encrypted_csv)
|
||||
interface.encryption_key, encrypted_csv
|
||||
)
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
@ -180,8 +207,11 @@ class CSVLogger(FlaggingCallback):
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()))
|
||||
csvfile.write(
|
||||
encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()
|
||||
)
|
||||
)
|
||||
else:
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a", newline="") as csvfile:
|
||||
@ -193,7 +223,9 @@ class CSVLogger(FlaggingCallback):
|
||||
with open(log_fp) as csvfile:
|
||||
file_content = csvfile.read()
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
with open(log_fp, "w", newline="") as csvfile: # newline parameter needed for Windows
|
||||
with open(
|
||||
log_fp, "w", newline=""
|
||||
) as csvfile: # newline parameter needed for Windows
|
||||
csvfile.write(file_content)
|
||||
with open(log_fp, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
@ -204,24 +236,26 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"""
|
||||
A FlaggingCallback that saves flagged data to a HuggingFace dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hf_foken: str,
|
||||
dataset_name: str,
|
||||
organization: Optional[str] = None,
|
||||
private: bool = False,
|
||||
verbose: bool = True):
|
||||
self,
|
||||
hf_foken: str,
|
||||
dataset_name: str,
|
||||
organization: Optional[str] = None,
|
||||
private: bool = False,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Params:
|
||||
hf_token (str): The token to use to access the huggingface API.
|
||||
dataset_name (str): The name of the dataset to save the data to, e.g.
|
||||
dataset_name (str): The name of the dataset to save the data to, e.g.
|
||||
"image-classifier-1"
|
||||
organization (str): The name of the organization to which to attach
|
||||
organization (str): The name of the organization to which to attach
|
||||
the datasets. If None, the dataset attaches to the user only.
|
||||
private (bool): If the dataset does not already exist, whether it
|
||||
should be created as a private dataset or public. Private datasets
|
||||
private (bool): If the dataset does not already exist, whether it
|
||||
should be created as a private dataset or public. Private datasets
|
||||
may require paid huggingface.co accounts
|
||||
verbose (bool): Whether to print out the status of the dataset
|
||||
verbose (bool): Whether to print out the status of the dataset
|
||||
creation.
|
||||
"""
|
||||
self.hf_foken = hf_foken
|
||||
@ -230,118 +264,129 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
self.dataset_private = private
|
||||
self.verbose = verbose
|
||||
|
||||
def setup(
|
||||
self,
|
||||
flagging_dir: str):
|
||||
def setup(self, flagging_dir: str):
|
||||
"""
|
||||
Params:
|
||||
flagging_dir (str): local directory where the dataset is cloned,
|
||||
flagging_dir (str): local directory where the dataset is cloned,
|
||||
updated, and pushed from.
|
||||
"""
|
||||
try:
|
||||
import huggingface_hub
|
||||
import huggingface_hub
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError("Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'.")
|
||||
raise ImportError(
|
||||
"Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
|
||||
)
|
||||
path_to_dataset_repo = huggingface_hub.create_repo(
|
||||
name=self.dataset_name, token=self.hf_foken,
|
||||
private=self.dataset_private, repo_type="dataset", exist_ok=True)
|
||||
name=self.dataset_name,
|
||||
token=self.hf_foken,
|
||||
private=self.dataset_private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
||||
self.flagging_dir = flagging_dir
|
||||
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
|
||||
self.repo = huggingface_hub.Repository(
|
||||
local_dir=self.dataset_dir, clone_from=path_to_dataset_repo,
|
||||
use_auth_token=self.hf_foken)
|
||||
local_dir=self.dataset_dir,
|
||||
clone_from=path_to_dataset_repo,
|
||||
use_auth_token=self.hf_foken,
|
||||
)
|
||||
self.repo.git_pull()
|
||||
|
||||
#Should filename be user-specified?
|
||||
self.log_file = os.path.join(self.dataset_dir, "data.csv")
|
||||
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
|
||||
|
||||
# Should filename be user-specified?
|
||||
self.log_file = os.path.join(self.dataset_dir, "data.csv")
|
||||
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
|
||||
|
||||
def flag(
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None
|
||||
) -> int:
|
||||
self,
|
||||
interface: gr.Interface,
|
||||
input_data: List[Any],
|
||||
output_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
is_new = not os.path.exists(self.log_file)
|
||||
infos = {"flagged": {"features": {}}}
|
||||
|
||||
|
||||
with open(self.log_file, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
|
||||
|
||||
# File previews for certain input and output types
|
||||
file_preview_types = {
|
||||
gr.inputs.Audio: "Audio",
|
||||
gr.inputs.Audio: "Audio",
|
||||
gr.outputs.Audio: "Audio",
|
||||
gr.inputs.Image: "Image",
|
||||
gr.inputs.Image: "Image",
|
||||
gr.outputs.Image: "Image",
|
||||
}
|
||||
|
||||
|
||||
# Generate the headers and dataset_infos
|
||||
if is_new:
|
||||
headers = []
|
||||
for i, component in enumerate(interface.input_components +
|
||||
interface.output_components):
|
||||
component_label = (interface.config["input_components"] +
|
||||
interface.config["output_components"])[i]["label"]
|
||||
for i, component in enumerate(
|
||||
interface.input_components + interface.output_components
|
||||
):
|
||||
component_label = (
|
||||
interface.config["input_components"]
|
||||
+ interface.config["output_components"]
|
||||
)[i]["label"]
|
||||
headers.append(component_label)
|
||||
infos["flagged"]["features"][component_label] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
headers.append(component_label + " file")
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
infos["flagged"]["features"][component_label +
|
||||
" file"] = {
|
||||
"_type": _type
|
||||
}
|
||||
infos["flagged"]["features"][
|
||||
component_label + " file"
|
||||
] = {"_type": _type}
|
||||
break
|
||||
if interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
infos["flagged"]["features"]["flag"] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
}
|
||||
"dtype": "string",
|
||||
"_type": "Value",
|
||||
}
|
||||
writer.writerow(headers)
|
||||
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
for i, input in enumerate(interface.input_components):
|
||||
label = interface.config["input_components"][i]["label"]
|
||||
filepath = input.save_flagged(
|
||||
self.dataset_dir, label, input_data[i], None)
|
||||
self.dataset_dir, label, input_data[i], None
|
||||
)
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
csv_data.append("{}/resolve/main/{}".format(
|
||||
self.path_to_dataset_repo, filepath))
|
||||
csv_data.append(
|
||||
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
||||
)
|
||||
for i, output in enumerate(interface.output_components):
|
||||
label = interface.config["output_components"][i]["label"]
|
||||
filepath = (output.save_flagged(
|
||||
self.dataset_dir, label, output_data[i], None) if
|
||||
output_data[i] is not None else "")
|
||||
filepath = (
|
||||
output.save_flagged(self.dataset_dir, label, output_data[i], None)
|
||||
if output_data[i] is not None
|
||||
else ""
|
||||
)
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
csv_data.append("{}/resolve/main/{}".format(
|
||||
self.path_to_dataset_repo, filepath))
|
||||
csv_data.append(
|
||||
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
|
||||
)
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
|
||||
writer.writerow(csv_data)
|
||||
|
||||
if is_new:
|
||||
json.dump(infos, open(self.infos_file, "w"))
|
||||
|
||||
|
||||
with open(self.log_file, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
|
||||
self.repo.push_to_hub(
|
||||
commit_message="Flagged sample #{}".format(line_count))
|
||||
|
||||
return line_count
|
||||
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
|
||||
|
||||
return line_count
|
||||
|
625
gradio/inputs.py
625
gradio/inputs.py
File diff suppressed because it is too large
Load Diff
@ -4,40 +4,47 @@ including various methods for constructing an interface and then launching it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
from logging import warning
|
||||
import markdown2 # type: ignore
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
import warnings
|
||||
import webbrowser
|
||||
import weakref
|
||||
import webbrowser
|
||||
from logging import warning
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
from gradio import encryptor, interpretation, networking, queueing, strings, utils # type: ignore
|
||||
from gradio.external import load_interface, load_from_pipeline # type: ignore
|
||||
from gradio.flagging import FlaggingCallback, CSVLogger # type: ignore
|
||||
from gradio.inputs import get_input_instance, InputComponent, State as i_State # type: ignore
|
||||
from gradio.outputs import get_output_instance, OutputComponent, State as o_State # type: ignore
|
||||
import markdown2 # type: ignore
|
||||
|
||||
from gradio import (encryptor, interpretation, networking, # type: ignore
|
||||
queueing, strings, utils)
|
||||
from gradio.external import load_from_pipeline, load_interface # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import State as i_State # type: ignore
|
||||
from gradio.inputs import get_input_instance
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.process_examples import cache_interface_examples
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
import flask
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class Interface:
|
||||
"""
|
||||
Gradio interfaces are created by constructing a `Interface` object
|
||||
with a locally-defined function, or with `Interface.load()` with the path
|
||||
with a locally-defined function, or with `Interface.load()` with the path
|
||||
to a repo or by `Interface.from_pipeline()` with a Transformers Pipeline.
|
||||
"""
|
||||
|
||||
# stores references to all currently existing Interface instances
|
||||
instances: weakref.WeakSet = weakref.WeakSet()
|
||||
instances: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> List[Interface]:
|
||||
@ -47,15 +54,17 @@ class Interface:
|
||||
return list(Interface.instances)
|
||||
|
||||
@classmethod
|
||||
def load(cls,
|
||||
name: str,
|
||||
src: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
**kwargs) -> Interface:
|
||||
def load(
|
||||
cls,
|
||||
name: str,
|
||||
src: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Interface:
|
||||
"""
|
||||
Class method to construct an Interface from an external source repository, such as huggingface.
|
||||
Parameters:
|
||||
Parameters:
|
||||
name (str): the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "huggingface/gpt2")
|
||||
src (str): the source of the model: `huggingface` or `gradio` (or empty if source is provided as a prefix in `name`)
|
||||
api_key (str): optional api key for use with Hugging Face Model Hub
|
||||
@ -70,14 +79,11 @@ class Interface:
|
||||
return interface
|
||||
|
||||
@classmethod
|
||||
def from_pipeline(
|
||||
cls,
|
||||
pipeline: transformers.Pipeline,
|
||||
**kwargs) -> Interface:
|
||||
def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
|
||||
"""
|
||||
Construct an Interface from a Hugging Face transformers.Pipeline.
|
||||
Parameters:
|
||||
pipeline (transformers.Pipeline):
|
||||
Parameters:
|
||||
pipeline (transformers.Pipeline):
|
||||
Returns:
|
||||
(gradio.Interface): a Gradio Interface object from the given Pipeline
|
||||
"""
|
||||
@ -87,41 +93,42 @@ class Interface:
|
||||
return interface
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable | List[Callable],
|
||||
inputs: str | InputComponent | List[str | InputComponent] = None,
|
||||
outputs: str | OutputComponent | List[str | OutputComponent] = None,
|
||||
verbose: bool = False,
|
||||
self,
|
||||
fn: Callable | List[Callable],
|
||||
inputs: str | InputComponent | List[str | InputComponent] = None,
|
||||
outputs: str | OutputComponent | List[str | OutputComponent] = None,
|
||||
verbose: bool = False,
|
||||
examples: Optional[List[Any] | List[List[Any]] | str] = None,
|
||||
examples_per_page: int = 10,
|
||||
live: bool = False,
|
||||
layout: str = "unaligned",
|
||||
show_input: bool = True,
|
||||
examples_per_page: int = 10,
|
||||
live: bool = False,
|
||||
layout: str = "unaligned",
|
||||
show_input: bool = True,
|
||||
show_output: bool = True,
|
||||
capture_session: Optional[bool] = None,
|
||||
interpretation: Optional[Callable | str] = None,
|
||||
num_shap: float = 2.0,
|
||||
theme: Optional[str] = None,
|
||||
capture_session: Optional[bool] = None,
|
||||
interpretation: Optional[Callable | str] = None,
|
||||
num_shap: float = 2.0,
|
||||
theme: Optional[str] = None,
|
||||
repeat_outputs_per_model: bool = True,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
article: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
article: Optional[str] = None,
|
||||
thumbnail: Optional[str] = None,
|
||||
css: Optional[str] = None,
|
||||
height=None,
|
||||
width=None,
|
||||
allow_screenshot: bool = True,
|
||||
allow_flagging: Optional[str] = None,
|
||||
flagging_options: List[str]=None,
|
||||
encrypt=None,
|
||||
show_tips=None,
|
||||
flagging_dir: str = "flagged",
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
css: Optional[str] = None,
|
||||
height=None,
|
||||
width=None,
|
||||
allow_screenshot: bool = True,
|
||||
allow_flagging: Optional[str] = None,
|
||||
flagging_options: List[str] = None,
|
||||
encrypt=None,
|
||||
show_tips=None,
|
||||
flagging_dir: str = "flagged",
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
server_name=None,
|
||||
server_port=None,
|
||||
enable_queue=None,
|
||||
enable_queue=None,
|
||||
api_mode=None,
|
||||
flagging_callback: FlaggingCallback = CSVLogger()):
|
||||
flagging_callback: FlaggingCallback = CSVLogger(),
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Union[Callable, List[Callable]]): the function to wrap an interface around.
|
||||
@ -133,7 +140,7 @@ class Interface:
|
||||
live (bool): whether the interface should automatically reload on change.
|
||||
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
|
||||
capture_session (bool): DEPRECATED. If True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function.
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function.
|
||||
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
@ -147,7 +154,7 @@ class Interface:
|
||||
encrypt (bool): DEPRECATED. If True, flagged data will be encrypted by key provided by creator at launch
|
||||
flagging_dir (str): what to name the dir where flagged data is stored.
|
||||
show_tips (bool): DEPRECATED. if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
api_mode (bool): DEPRECATED. If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
|
||||
server_name (str): DEPRECATED. Name of the server to use for serving the interface - pass in launch() instead.
|
||||
server_port (int): DEPRECATED. Port of the server to use for serving the interface - pass in launch() instead.
|
||||
@ -163,27 +170,33 @@ class Interface:
|
||||
self.output_components = [get_output_instance(o) for o in outputs]
|
||||
if repeat_outputs_per_model:
|
||||
self.output_components *= len(fn)
|
||||
|
||||
|
||||
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
|
||||
raise ValueError("Only one input component can be State.")
|
||||
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
|
||||
raise ValueError("Only one output component can be State.")
|
||||
|
||||
|
||||
if sum(isinstance(i, i_State) for i in self.input_components) == 1:
|
||||
if len(fn) > 1:
|
||||
raise ValueError(
|
||||
"State cannot be used with multiple functions.")
|
||||
state_param_index = [isinstance(i, i_State)
|
||||
for i in self.input_components].index(True)
|
||||
raise ValueError("State cannot be used with multiple functions.")
|
||||
state_param_index = [
|
||||
isinstance(i, i_State) for i in self.input_components
|
||||
].index(True)
|
||||
state: i_State = self.input_components[state_param_index]
|
||||
if state.default is None:
|
||||
default = utils.get_default_args(fn[0])[state_param_index]
|
||||
state.default = default
|
||||
|
||||
if interpretation is None or isinstance(interpretation, list) or callable(interpretation):
|
||||
if (
|
||||
interpretation is None
|
||||
or isinstance(interpretation, list)
|
||||
or callable(interpretation)
|
||||
):
|
||||
self.interpretation = interpretation
|
||||
elif isinstance(interpretation, str):
|
||||
self.interpretation = [interpretation.lower() for _ in self.input_components]
|
||||
self.interpretation = [
|
||||
interpretation.lower() for _ in self.input_components
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid value for parameter: interpretation")
|
||||
|
||||
@ -193,8 +206,10 @@ class Interface:
|
||||
self.__name__ = ", ".join(self.function_names)
|
||||
|
||||
if verbose:
|
||||
warnings.warn("The `verbose` parameter in the `Interface`"
|
||||
"is deprecated and has no effect.")
|
||||
warnings.warn(
|
||||
"The `verbose` parameter in the `Interface`"
|
||||
"is deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.status = "OFF"
|
||||
self.live = live
|
||||
@ -205,12 +220,14 @@ class Interface:
|
||||
self.capture_session = capture_session
|
||||
|
||||
if capture_session is not None:
|
||||
warnings.warn("The `capture_session` parameter in the `Interface`"
|
||||
" is deprecated and may be removed in the future.")
|
||||
warnings.warn(
|
||||
"The `capture_session` parameter in the `Interface`"
|
||||
" is deprecated and may be removed in the future."
|
||||
)
|
||||
try:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
|
||||
self.session = tf.get_default_graph(), tf.keras.backend.get_session()
|
||||
except (ImportError, AttributeError):
|
||||
# If they are using TF >= 2.0 or don't have TF,
|
||||
# just ignore this parameter.
|
||||
@ -219,80 +236,120 @@ class Interface:
|
||||
if server_name is not None or server_port is not None:
|
||||
raise DeprecationWarning(
|
||||
"The `server_name` and `server_port` parameters in `Interface`"
|
||||
"are deprecated. Please pass into launch() instead.")
|
||||
"are deprecated. Please pass into launch() instead."
|
||||
)
|
||||
|
||||
self.session = None
|
||||
self.title = title
|
||||
self.description = description
|
||||
if article is not None:
|
||||
article = utils.readme_to_html(article)
|
||||
article = markdown2.markdown(
|
||||
article, extras=["fenced-code-blocks"])
|
||||
article = markdown2.markdown(article, extras=["fenced-code-blocks"])
|
||||
|
||||
self.article = article
|
||||
self.thumbnail = thumbnail
|
||||
|
||||
|
||||
theme = theme if theme is not None else os.getenv("GRADIO_THEME", "default")
|
||||
DEPRECATED_THEME_MAP = {"darkdefault": "default", "darkhuggingface": "dark-huggingface", "darkpeach": "dark-peach", "darkgrass": "dark-grass"}
|
||||
VALID_THEME_SET = ("default", "huggingface", "seafoam", "grass", "peach", "dark", "dark-huggingface", "dark-seafoam", "dark-grass", "dark-peach")
|
||||
DEPRECATED_THEME_MAP = {
|
||||
"darkdefault": "default",
|
||||
"darkhuggingface": "dark-huggingface",
|
||||
"darkpeach": "dark-peach",
|
||||
"darkgrass": "dark-grass",
|
||||
}
|
||||
VALID_THEME_SET = (
|
||||
"default",
|
||||
"huggingface",
|
||||
"seafoam",
|
||||
"grass",
|
||||
"peach",
|
||||
"dark",
|
||||
"dark-huggingface",
|
||||
"dark-seafoam",
|
||||
"dark-grass",
|
||||
"dark-peach",
|
||||
)
|
||||
if theme in DEPRECATED_THEME_MAP:
|
||||
warnings.warn(f"'{theme}' theme name is deprecated, using {DEPRECATED_THEME_MAP[theme]} instead.")
|
||||
warnings.warn(
|
||||
f"'{theme}' theme name is deprecated, using {DEPRECATED_THEME_MAP[theme]} instead."
|
||||
)
|
||||
theme = DEPRECATED_THEME_MAP[theme]
|
||||
elif theme not in VALID_THEME_SET:
|
||||
raise ValueError(f"Invalid theme name, theme must be one of: {', '.join(VALID_THEME_SET)}")
|
||||
raise ValueError(
|
||||
f"Invalid theme name, theme must be one of: {', '.join(VALID_THEME_SET)}"
|
||||
)
|
||||
self.theme = theme
|
||||
|
||||
self.height = height
|
||||
self.width = width
|
||||
if self.height is not None or self.width is not None:
|
||||
warnings.warn("The `height` and `width` parameters in `Interface` "
|
||||
"are deprecated and should be passed into launch().")
|
||||
warnings.warn(
|
||||
"The `height` and `width` parameters in `Interface` "
|
||||
"are deprecated and should be passed into launch()."
|
||||
)
|
||||
|
||||
if css is not None and os.path.exists(css):
|
||||
with open(css) as css_file:
|
||||
self.css = css_file.read()
|
||||
else:
|
||||
self.css = css
|
||||
if examples is None or isinstance(examples, str) or (isinstance(
|
||||
examples, list) and (len(examples) == 0 or isinstance(
|
||||
examples[0], list))):
|
||||
if (
|
||||
examples is None
|
||||
or isinstance(examples, str)
|
||||
or (
|
||||
isinstance(examples, list)
|
||||
and (len(examples) == 0 or isinstance(examples[0], list))
|
||||
)
|
||||
):
|
||||
self.examples = examples
|
||||
elif isinstance(examples, list) and len(self.input_components) == 1: # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
||||
elif (
|
||||
isinstance(examples, list) and len(self.input_components) == 1
|
||||
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
||||
self.examples = [[e] for e in examples]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Examples argument must either be a directory or a nested "
|
||||
"list, where each sublist represents a set of inputs.")
|
||||
"list, where each sublist represents a set of inputs."
|
||||
)
|
||||
self.num_shap = num_shap
|
||||
self.examples_per_page = examples_per_page
|
||||
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
|
||||
# For analytics_enabled and allow_flagging: (1) first check for
|
||||
|
||||
# For analytics_enabled and allow_flagging: (1) first check for
|
||||
# parameter, (2) check for env variable, (3) default to True/"manual"
|
||||
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
|
||||
self.analytics_enabled = (
|
||||
analytics_enabled
|
||||
if analytics_enabled is not None
|
||||
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
||||
)
|
||||
if allow_flagging is None:
|
||||
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
|
||||
if allow_flagging==True:
|
||||
warnings.warn("The `allow_flagging` parameter in `Interface` now"
|
||||
"takes a string value ('auto', 'manual', or 'never')"
|
||||
", not a boolean. Setting parameter to: 'manual'.")
|
||||
if allow_flagging == True:
|
||||
warnings.warn(
|
||||
"The `allow_flagging` parameter in `Interface` now"
|
||||
"takes a string value ('auto', 'manual', or 'never')"
|
||||
", not a boolean. Setting parameter to: 'manual'."
|
||||
)
|
||||
self.allow_flagging = "manual"
|
||||
elif allow_flagging=="manual":
|
||||
elif allow_flagging == "manual":
|
||||
self.allow_flagging = "manual"
|
||||
elif allow_flagging==False:
|
||||
warnings.warn("The `allow_flagging` parameter in `Interface` now"
|
||||
"takes a string value ('auto', 'manual', or 'never')"
|
||||
", not a boolean. Setting parameter to: 'never'.")
|
||||
elif allow_flagging == False:
|
||||
warnings.warn(
|
||||
"The `allow_flagging` parameter in `Interface` now"
|
||||
"takes a string value ('auto', 'manual', or 'never')"
|
||||
", not a boolean. Setting parameter to: 'never'."
|
||||
)
|
||||
self.allow_flagging = "never"
|
||||
elif allow_flagging=="never":
|
||||
elif allow_flagging == "never":
|
||||
self.allow_flagging = "never"
|
||||
elif allow_flagging=="auto":
|
||||
elif allow_flagging == "auto":
|
||||
self.allow_flagging = "auto"
|
||||
else:
|
||||
raise ValueError("Invalid value for `allow_flagging` parameter."
|
||||
"Must be: 'auto', 'manual', or 'never'.")
|
||||
raise ValueError(
|
||||
"Invalid value for `allow_flagging` parameter."
|
||||
"Must be: 'auto', 'manual', or 'never'."
|
||||
)
|
||||
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_callback = flagging_callback
|
||||
@ -305,17 +362,22 @@ class Interface:
|
||||
self.ip_address = utils.get_local_ip_address()
|
||||
|
||||
if show_tips is not None:
|
||||
warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead")
|
||||
warnings.warn(
|
||||
"The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead"
|
||||
)
|
||||
|
||||
self.requires_permissions = any(
|
||||
[component.requires_permissions for component in self.input_components])
|
||||
[component.requires_permissions for component in self.input_components]
|
||||
)
|
||||
|
||||
self.enable_queue = enable_queue
|
||||
if self.enable_queue is not None:
|
||||
warnings.warn("The `enable_queue` parameter in the `Interface`"
|
||||
"will be deprecated and may not work properly. "
|
||||
"Please use the `enable_queue` parameter in "
|
||||
"`launch()` instead")
|
||||
warnings.warn(
|
||||
"The `enable_queue` parameter in the `Interface`"
|
||||
"will be deprecated and may not work properly. "
|
||||
"Please use the `enable_queue` parameter in "
|
||||
"`launch()` instead"
|
||||
)
|
||||
|
||||
self.favicon_path = None
|
||||
self.height = height
|
||||
@ -324,31 +386,34 @@ class Interface:
|
||||
warnings.warn(
|
||||
"The `width` and `height` parameters in the `Interface` class"
|
||||
"will be deprecated. Please provide these parameters"
|
||||
"in `launch()` instead")
|
||||
"in `launch()` instead"
|
||||
)
|
||||
|
||||
self.encrypt = encrypt
|
||||
if self.encrypt is not None:
|
||||
warnings.warn(
|
||||
"The `encrypt` parameter in the `Interface` class"
|
||||
"will be deprecated. Please provide this parameter"
|
||||
"in `launch()` instead")
|
||||
|
||||
"in `launch()` instead"
|
||||
)
|
||||
|
||||
if api_mode is not None:
|
||||
warnings.warn("The `api_mode` parameter in the `Interface` is deprecated.")
|
||||
self.api_mode = False
|
||||
|
||||
data = {'fn': fn,
|
||||
'inputs': inputs,
|
||||
'outputs': outputs,
|
||||
'live': live,
|
||||
'capture_session': capture_session,
|
||||
'ip_address': self.ip_address,
|
||||
'interpretation': interpretation,
|
||||
'allow_flagging': allow_flagging,
|
||||
'allow_screenshot': allow_screenshot,
|
||||
'custom_css': self.css is not None,
|
||||
'theme': self.theme
|
||||
}
|
||||
data = {
|
||||
"fn": fn,
|
||||
"inputs": inputs,
|
||||
"outputs": outputs,
|
||||
"live": live,
|
||||
"capture_session": capture_session,
|
||||
"ip_address": self.ip_address,
|
||||
"interpretation": interpretation,
|
||||
"allow_flagging": allow_flagging,
|
||||
"allow_screenshot": allow_screenshot,
|
||||
"custom_css": self.css is not None,
|
||||
"theme": self.theme,
|
||||
}
|
||||
|
||||
if self.analytics_enabled:
|
||||
utils.initiated_analytics(data)
|
||||
@ -358,7 +423,9 @@ class Interface:
|
||||
Interface.instances.add(self)
|
||||
|
||||
def __call__(self, *params):
|
||||
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
|
||||
if (
|
||||
self.api_mode
|
||||
): # skip the preprocessing/postprocessing if sending to a remote API
|
||||
output = self.run_prediction(params, called_directly=True)
|
||||
else:
|
||||
output, _ = self.process(params)
|
||||
@ -369,7 +436,8 @@ class Interface:
|
||||
|
||||
def __repr__(self):
|
||||
repr = "Gradio Interface for: {}".format(
|
||||
", ".join(fn.__name__ for fn in self.predict))
|
||||
", ".join(fn.__name__ for fn in self.predict)
|
||||
)
|
||||
repr += "\n" + "-" * len(repr)
|
||||
repr += "\ninputs:"
|
||||
for component in self.input_components:
|
||||
@ -381,28 +449,30 @@ class Interface:
|
||||
|
||||
def get_config_file(self):
|
||||
return utils.get_config_file(self)
|
||||
|
||||
|
||||
def run_prediction(
|
||||
self,
|
||||
processed_input: List[Any],
|
||||
return_duration: bool = False,
|
||||
called_directly: bool = False
|
||||
self,
|
||||
processed_input: List[Any],
|
||||
return_duration: bool = False,
|
||||
called_directly: bool = False,
|
||||
) -> List[Any] | Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
Runs the prediction function with the given (already processed) inputs.
|
||||
Parameters:
|
||||
processed_input (list): A list of processed inputs.
|
||||
return_duration (bool): Whether to return the duration of the prediction.
|
||||
called_directly (bool): Whether the prediction is being called
|
||||
called_directly (bool): Whether the prediction is being called
|
||||
directly (i.e. as a function, not through the GUI).
|
||||
Returns:
|
||||
predictions (list): A list of predictions (not post-processed).
|
||||
durations (list): A list of durations for each prediction
|
||||
durations (list): A list of durations for each prediction
|
||||
(only returned if `return_duration` is True).
|
||||
"""
|
||||
if self.api_mode: # Serialize the input
|
||||
processed_input = [input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
processed_input = [
|
||||
input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions = []
|
||||
durations = []
|
||||
output_component_counter = 0
|
||||
@ -423,8 +493,16 @@ class Interface:
|
||||
if self.api_mode: # Serialize the input
|
||||
prediction_ = copy.deepcopy(prediction)
|
||||
prediction = []
|
||||
for pred in prediction_: # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
|
||||
prediction.append(self.output_components[output_component_counter].deserialize(pred))
|
||||
for (
|
||||
pred
|
||||
) in (
|
||||
prediction_
|
||||
): # Done this way to handle both single interfaces with multiple outputs and Parallel() interfaces
|
||||
prediction.append(
|
||||
self.output_components[output_component_counter].deserialize(
|
||||
pred
|
||||
)
|
||||
)
|
||||
output_component_counter += 1
|
||||
|
||||
durations.append(duration)
|
||||
@ -435,12 +513,9 @@ class Interface:
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process(
|
||||
self,
|
||||
raw_input: List[Any]
|
||||
) -> Tuple[List[Any], List[float]]:
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
First preprocesses the input, then runs prediction using
|
||||
First preprocesses the input, then runs prediction using
|
||||
self.run_prediction(), then postprocesses the output.
|
||||
Parameters:
|
||||
raw_input: a list of raw inputs to process and apply the prediction(s) on.
|
||||
@ -448,33 +523,37 @@ class Interface:
|
||||
processed output: a list of processed outputs to return as the prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each prediction fn.
|
||||
"""
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(
|
||||
self.input_components)]
|
||||
processed_input = [
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)
|
||||
]
|
||||
predictions, durations = self.run_prediction(
|
||||
processed_input, return_duration=True)
|
||||
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
|
||||
for i, output_component in enumerate(self.output_components)]
|
||||
processed_input, return_duration=True
|
||||
)
|
||||
processed_output = [
|
||||
output_component.postprocess(predictions[i])
|
||||
if predictions[i] is not None
|
||||
else None
|
||||
for i, output_component in enumerate(self.output_components)
|
||||
]
|
||||
|
||||
avg_durations = []
|
||||
for i, duration in enumerate(durations):
|
||||
self.predict_durations[i][0] += duration
|
||||
self.predict_durations[i][1] += 1
|
||||
avg_durations.append(self.predict_durations[i][0]
|
||||
/ self.predict_durations[i][1])
|
||||
avg_durations.append(
|
||||
self.predict_durations[i][0] / self.predict_durations[i][1]
|
||||
)
|
||||
if hasattr(self, "config"):
|
||||
self.config["avg_durations"] = avg_durations
|
||||
|
||||
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(
|
||||
self,
|
||||
raw_input: List[Any]
|
||||
) -> List[Any]:
|
||||
|
||||
def interpret(self, raw_input: List[Any]) -> List[Any]:
|
||||
return interpretation.run_interpret(self, raw_input)
|
||||
|
||||
def block_thread(
|
||||
self,
|
||||
self,
|
||||
) -> None:
|
||||
"""Block main thread until interrupted by user."""
|
||||
try:
|
||||
@ -488,10 +567,10 @@ class Interface:
|
||||
|
||||
def test_launch(self) -> None:
|
||||
for predict_fn in self.predict:
|
||||
print("Test launch: {}()...".format(predict_fn.__name__), end=' ')
|
||||
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
|
||||
raw_input = []
|
||||
for input_component in self.input_components:
|
||||
if input_component.test_input is None:
|
||||
if input_component.test_input is None:
|
||||
print("SKIPPED")
|
||||
break
|
||||
else:
|
||||
@ -502,22 +581,22 @@ class Interface:
|
||||
continue
|
||||
|
||||
def launch(
|
||||
self,
|
||||
inline: bool = None,
|
||||
inbrowser: bool = None,
|
||||
share: bool = False,
|
||||
self,
|
||||
inline: bool = None,
|
||||
inbrowser: bool = None,
|
||||
share: bool = False,
|
||||
debug: bool = False,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
private_endpoint: Optional[str] = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = True,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = True,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
show_tips: bool = False,
|
||||
server_port: Optional[int] = None,
|
||||
show_tips: bool = False,
|
||||
enable_queue: bool = False,
|
||||
height: int = 500,
|
||||
width: int = 900,
|
||||
height: int = 500,
|
||||
width: int = 900,
|
||||
encrypt: bool = False,
|
||||
cache_examples: bool = False,
|
||||
favicon_path: Optional[str] = None,
|
||||
@ -537,7 +616,7 @@ class Interface:
|
||||
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
||||
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
width (int): The width in pixels of the <iframe> element containing the interface (used if inline=True)
|
||||
height (int): The height in pixels of the <iframe> element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
@ -548,10 +627,14 @@ class Interface:
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
self.config = self.get_config_file()
|
||||
self.config = self.get_config_file()
|
||||
self.cache_examples = cache_examples
|
||||
if auth and not callable(auth) and not isinstance(
|
||||
auth[0], tuple) and not isinstance(auth[0], list):
|
||||
if (
|
||||
auth
|
||||
and not callable(auth)
|
||||
and not isinstance(auth[0], tuple)
|
||||
and not isinstance(auth[0], list)
|
||||
):
|
||||
auth = [auth]
|
||||
self.auth = auth
|
||||
self.auth_message = auth_message
|
||||
@ -560,12 +643,13 @@ class Interface:
|
||||
self.height = self.height or height
|
||||
self.width = self.width or width
|
||||
self.favicon_path = favicon_path
|
||||
|
||||
|
||||
if self.encrypt is None:
|
||||
self.encrypt = encrypt
|
||||
self.encrypt = encrypt
|
||||
if self.encrypt:
|
||||
self.encryption_key = encryptor.get_key(
|
||||
getpass.getpass("Enter key for encryption: "))
|
||||
getpass.getpass("Enter key for encryption: ")
|
||||
)
|
||||
|
||||
if self.enable_queue is None:
|
||||
self.enable_queue = enable_queue
|
||||
@ -579,8 +663,9 @@ class Interface:
|
||||
cache_interface_examples(self)
|
||||
|
||||
server_port, path_to_local_server, app, server = networking.start_server(
|
||||
self, server_name, server_port)
|
||||
|
||||
self, server_name, server_port
|
||||
)
|
||||
|
||||
self.local_url = path_to_local_server
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
@ -589,7 +674,7 @@ class Interface:
|
||||
|
||||
utils.launch_counter()
|
||||
|
||||
# If running in a colab or not able to access localhost,
|
||||
# If running in a colab or not able to access localhost,
|
||||
# automatically create a shareable link.
|
||||
is_colab = utils.colab_check()
|
||||
if is_colab or not (networking.url_ok(path_to_local_server)):
|
||||
@ -607,11 +692,10 @@ class Interface:
|
||||
if private_endpoint is not None:
|
||||
share = True
|
||||
self.share = share
|
||||
|
||||
|
||||
if share:
|
||||
try:
|
||||
share_url = networking.setup_tunnel(
|
||||
server_port, private_endpoint)
|
||||
share_url = networking.setup_tunnel(server_port, private_endpoint)
|
||||
self.share_url = share_url
|
||||
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
|
||||
if private_endpoint:
|
||||
@ -620,8 +704,7 @@ class Interface:
|
||||
print(strings.en["SHARE_LINK_MESSAGE"])
|
||||
except RuntimeError:
|
||||
if self.analytics_enabled:
|
||||
utils.error_analytics(self.ip_address,
|
||||
"Not able to set up tunnel")
|
||||
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
|
||||
share_url = None
|
||||
else:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
@ -636,31 +719,37 @@ class Interface:
|
||||
inline = utils.ipython_check() and (auth is None)
|
||||
if inline:
|
||||
if auth is not None:
|
||||
print("Warning: authentication is not supported inline. Please"
|
||||
"click the link to access the interface in a new tab.")
|
||||
print(
|
||||
"Warning: authentication is not supported inline. Please"
|
||||
"click the link to access the interface in a new tab."
|
||||
)
|
||||
try:
|
||||
from IPython.display import IFrame, display # type: ignore
|
||||
|
||||
if share:
|
||||
while not networking.url_ok(share_url):
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=self.width, height=self.height))
|
||||
else:
|
||||
display(IFrame(path_to_local_server,
|
||||
width=self.width, height=self.height))
|
||||
display(
|
||||
IFrame(
|
||||
path_to_local_server, width=self.width, height=self.height
|
||||
)
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
data = {
|
||||
'launch_method': 'browser' if inbrowser else 'inline',
|
||||
'is_google_colab': is_colab,
|
||||
'is_sharing_on': share,
|
||||
'share_url': share_url,
|
||||
'ip_address': self.ip_address,
|
||||
'enable_queue': self.enable_queue,
|
||||
'show_tips': self.show_tips,
|
||||
'api_mode': self.api_mode,
|
||||
'server_name': server_name,
|
||||
'server_port': server_port,
|
||||
"launch_method": "browser" if inbrowser else "inline",
|
||||
"is_google_colab": is_colab,
|
||||
"is_sharing_on": share,
|
||||
"share_url": share_url,
|
||||
"ip_address": self.ip_address,
|
||||
"enable_queue": self.enable_queue,
|
||||
"show_tips": self.show_tips,
|
||||
"api_mode": self.api_mode,
|
||||
"server_name": server_name,
|
||||
"server_port": server_port,
|
||||
}
|
||||
if self.analytics_enabled:
|
||||
utils.launch_analytics(data)
|
||||
@ -668,46 +757,36 @@ class Interface:
|
||||
utils.show_tip(self)
|
||||
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(
|
||||
getattr(sys, 'ps1', sys.flags.interactive))
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
self.block_thread()
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def close(
|
||||
self,
|
||||
verbose: bool = True
|
||||
) -> None:
|
||||
def close(self, verbose: bool = True) -> None:
|
||||
"""
|
||||
Closes the Interface that was launched and frees the port.
|
||||
"""
|
||||
try:
|
||||
self.server.close()
|
||||
if verbose:
|
||||
print("Closing server running on port: {}".format(
|
||||
self.server_port))
|
||||
print("Closing server running on port: {}".format(self.server_port))
|
||||
except (AttributeError, OSError): # can't close if not running
|
||||
pass
|
||||
|
||||
def integrate(
|
||||
self,
|
||||
comet_ml=None,
|
||||
wandb=None,
|
||||
mlflow=None
|
||||
) -> None:
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
|
||||
"""
|
||||
A catch-all method for integrating with other libraries.
|
||||
A catch-all method for integrating with other libraries.
|
||||
Should be run after launch()
|
||||
Parameters:
|
||||
comet_ml (Experiment): If a comet_ml Experiment object is provided,
|
||||
comet_ml (Experiment): If a comet_ml Experiment object is provided,
|
||||
will integrate with the experiment and appear on Comet dashboard
|
||||
wandb (module): If the wandb module is provided, will integrate
|
||||
wandb (module): If the wandb module is provided, will integrate
|
||||
with it and appear on WandB dashboard
|
||||
mlflow (module): If the mlflow module is provided, will integrate
|
||||
mlflow (module): If the mlflow module is provided, will integrate
|
||||
with the experiment and appear on ML Flow dashboard
|
||||
"""
|
||||
analytics_integration = ""
|
||||
@ -723,24 +802,32 @@ class Interface:
|
||||
if wandb is not None:
|
||||
analytics_integration = "WandB"
|
||||
if self.share_url is not None:
|
||||
wandb.log({"Gradio panel": wandb.Html(
|
||||
'<iframe src="' + self.share_url + '" width="' +
|
||||
str(self.width) + '" height="' + str(self.height) +
|
||||
'" frameBorder="0"></iframe>')})
|
||||
wandb.log(
|
||||
{
|
||||
"Gradio panel": wandb.Html(
|
||||
'<iframe src="'
|
||||
+ self.share_url
|
||||
+ '" width="'
|
||||
+ str(self.width)
|
||||
+ '" height="'
|
||||
+ str(self.height)
|
||||
+ '" frameBorder="0"></iframe>'
|
||||
)
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"The WandB integration requires you to "
|
||||
"`launch(share=True)` first.")
|
||||
"`launch(share=True)` first."
|
||||
)
|
||||
if mlflow is not None:
|
||||
analytics_integration = "MLFlow"
|
||||
if self.share_url is not None:
|
||||
mlflow.log_param("Gradio Interface Share Link",
|
||||
self.share_url)
|
||||
mlflow.log_param("Gradio Interface Share Link", self.share_url)
|
||||
else:
|
||||
mlflow.log_param("Gradio Interface Local Link",
|
||||
self.local_url)
|
||||
mlflow.log_param("Gradio Interface Local Link", self.local_url)
|
||||
if self.analytics_enabled and analytics_integration:
|
||||
data = {'integration': analytics_integration}
|
||||
data = {"integration": analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
|
||||
|
||||
@ -750,6 +837,8 @@ def close_all(verbose: bool = True) -> None:
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
warnings.warn("The `reset_all()` method has been renamed to `close_all()` "
|
||||
"and will be deprecated. Please use `close_all()` instead.")
|
||||
warnings.warn(
|
||||
"The `reset_all()` method has been renamed to `close_all()` "
|
||||
"and will be deprecated. Please use `close_all()` instead."
|
||||
)
|
||||
close_all()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gradio.outputs import Label, Textbox
|
||||
@ -13,8 +14,10 @@ def run_interpret(interface, raw_input):
|
||||
raw_input: a list of raw inputs to apply the interpretation(s) on.
|
||||
"""
|
||||
if isinstance(interface.interpretation, list): # Either "default" or "shap"
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)]
|
||||
processed_input = [
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)
|
||||
]
|
||||
original_output = interface.run_prediction(processed_input)
|
||||
scores, alternative_outputs = [], []
|
||||
for i, (x, interp) in enumerate(zip(raw_input, interface.interpretation)):
|
||||
@ -22,104 +25,147 @@ def run_interpret(interface, raw_input):
|
||||
input_component = interface.input_components[i]
|
||||
neighbor_raw_input = list(raw_input)
|
||||
if input_component.interpret_by_tokens:
|
||||
tokens, neighbor_values, masks = input_component.tokenize(
|
||||
x)
|
||||
tokens, neighbor_values, masks = input_component.tokenize(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)]
|
||||
processed_neighbor_input = [
|
||||
input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(
|
||||
interface.input_components
|
||||
)
|
||||
]
|
||||
neighbor_output = interface.run_prediction(
|
||||
processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(interface.output_components)]
|
||||
processed_neighbor_input
|
||||
)
|
||||
processed_neighbor_output = [
|
||||
output_component.postprocess(neighbor_output[i])
|
||||
for i, output_component in enumerate(
|
||||
interface.output_components
|
||||
)
|
||||
]
|
||||
|
||||
alternative_output.append(
|
||||
processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(
|
||||
interface, original_output, neighbor_output))
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(
|
||||
quantify_difference_in_label(
|
||||
interface, original_output, neighbor_output
|
||||
)
|
||||
)
|
||||
alternative_outputs.append(alternative_output)
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
|
||||
raw_input[i],
|
||||
neighbor_values,
|
||||
interface_scores,
|
||||
masks=masks,
|
||||
tokens=tokens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(
|
||||
x)
|
||||
(
|
||||
neighbor_values,
|
||||
interpret_kwargs,
|
||||
) = input_component.get_interpretation_neighbors(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)]
|
||||
processed_neighbor_input = [
|
||||
input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(
|
||||
interface.input_components
|
||||
)
|
||||
]
|
||||
neighbor_output = interface.run_prediction(
|
||||
processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(interface.output_components)]
|
||||
processed_neighbor_input
|
||||
)
|
||||
processed_neighbor_output = [
|
||||
output_component.postprocess(neighbor_output[i])
|
||||
for i, output_component in enumerate(
|
||||
interface.output_components
|
||||
)
|
||||
]
|
||||
|
||||
alternative_output.append(
|
||||
processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(
|
||||
interface, original_output, neighbor_output))
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(
|
||||
quantify_difference_in_label(
|
||||
interface, original_output, neighbor_output
|
||||
)
|
||||
)
|
||||
alternative_outputs.append(alternative_output)
|
||||
interface_scores = [-score for score in interface_scores]
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||
raw_input[i],
|
||||
neighbor_values,
|
||||
interface_scores,
|
||||
**interpret_kwargs
|
||||
)
|
||||
)
|
||||
elif interp == "shap" or interp == "shapley":
|
||||
try:
|
||||
import shap # type: ignore
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError(
|
||||
"The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||
"The package `shap` is required for this interpretation method. Try: `pip install shap`"
|
||||
)
|
||||
input_component = interface.input_components[i]
|
||||
if not (input_component.interpret_by_tokens):
|
||||
raise ValueError(
|
||||
"Input component {} does not support `shap` interpretation".format(input_component))
|
||||
"Input component {} does not support `shap` interpretation".format(
|
||||
input_component
|
||||
)
|
||||
)
|
||||
|
||||
tokens, _, masks = input_component.tokenize(x)
|
||||
|
||||
# construct a masked version of the input
|
||||
def get_masked_prediction(binary_mask):
|
||||
masked_xs = input_component.get_masked_inputs(
|
||||
tokens, binary_mask)
|
||||
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
||||
preds = []
|
||||
for masked_x in masked_xs:
|
||||
processed_masked_input = copy.deepcopy(
|
||||
processed_input)
|
||||
processed_masked_input[i] = input_component.preprocess(
|
||||
masked_x)
|
||||
new_output = interface.run_prediction(
|
||||
processed_masked_input)
|
||||
processed_masked_input = copy.deepcopy(processed_input)
|
||||
processed_masked_input[i] = input_component.preprocess(masked_x)
|
||||
new_output = interface.run_prediction(processed_masked_input)
|
||||
pred = get_regression_or_classification_value(
|
||||
interface, original_output, new_output)
|
||||
interface, original_output, new_output
|
||||
)
|
||||
preds.append(pred)
|
||||
return np.array(preds)
|
||||
|
||||
num_total_segments = len(tokens)
|
||||
explainer = shap.KernelExplainer(
|
||||
get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(
|
||||
interface.num_shap * num_total_segments), silent=True)
|
||||
scores.append(input_component.get_interpretation_scores(
|
||||
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||
get_masked_prediction, np.zeros((1, num_total_segments))
|
||||
)
|
||||
shap_values = explainer.shap_values(
|
||||
np.ones((1, num_total_segments)),
|
||||
nsamples=int(interface.num_shap * num_total_segments),
|
||||
silent=True,
|
||||
)
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], None, shap_values[0], masks=masks, tokens=tokens
|
||||
)
|
||||
)
|
||||
alternative_outputs.append([])
|
||||
elif interp is None:
|
||||
scores.append(None)
|
||||
alternative_outputs.append([])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Uknown intepretation method: {}".format(interp))
|
||||
raise ValueError("Uknown intepretation method: {}".format(interp))
|
||||
return scores, alternative_outputs
|
||||
else: # custom interpretation function
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)]
|
||||
processed_input = [
|
||||
input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(interface.input_components)
|
||||
]
|
||||
interpreter = interface.interpretation
|
||||
if interface.capture_session and interface.session is not None:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default(), sess.as_default():
|
||||
interpretation = interpreter(*processed_input)
|
||||
else:
|
||||
else:
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
@ -130,7 +176,7 @@ def diff(original, perturbed):
|
||||
try: # try computing numerical difference
|
||||
score = float(original) - float(perturbed)
|
||||
except ValueError: # otherwise, look at strict difference in label
|
||||
score = int(not(original == perturbed))
|
||||
score = int(not (original == perturbed))
|
||||
return score
|
||||
|
||||
|
||||
@ -155,12 +201,18 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
elif isinstance(output_component, Textbox):
|
||||
score = diff(post_original_output, post_perturbed_output)
|
||||
return score
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||
raise ValueError(
|
||||
"This interpretation method doesn't support the Output component: {}".format(
|
||||
output_component
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_regression_or_classification_value(interface, original_output, perturbed_output):
|
||||
def get_regression_or_classification_value(
|
||||
interface, original_output, perturbed_output
|
||||
):
|
||||
"""Used to combine regression/classification for Shap interpretation method."""
|
||||
output_component = interface.output_components[0]
|
||||
post_original_output = output_component.postprocess(original_output[0])
|
||||
@ -176,9 +228,14 @@ def get_regression_or_classification_value(interface, original_output, perturbed
|
||||
return 0
|
||||
return perturbed_output[0][original_label]
|
||||
else:
|
||||
score = diff(perturbed_label, original_label) # Intentionally inverted order of arguments.
|
||||
score = diff(
|
||||
perturbed_label, original_label
|
||||
) # Intentionally inverted order of arguments.
|
||||
return score
|
||||
|
||||
else:
|
||||
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||
|
||||
raise ValueError(
|
||||
"This interpretation method doesn't support the Output component: {}".format(
|
||||
output_component
|
||||
)
|
||||
)
|
||||
|
@ -3,19 +3,21 @@ Ways to transform interfaces to produce new interfaces
|
||||
"""
|
||||
import gradio
|
||||
|
||||
|
||||
class Parallel(gradio.Interface):
|
||||
"""
|
||||
Creates a new Interface consisting of multiple models in parallel
|
||||
Parameters:
|
||||
interfaces: any number of Interface objects that are to be compared in parallel
|
||||
options: additional kwargs that are passed into the new Interface object to customize it
|
||||
Parameters:
|
||||
interfaces: any number of Interface objects that are to be compared in parallel
|
||||
options: additional kwargs that are passed into the new Interface object to customize it
|
||||
Returns:
|
||||
(Interface): an Interface object comparing the given models
|
||||
"""
|
||||
|
||||
def __init__(self, *interfaces, **options):
|
||||
fns = []
|
||||
outputs = []
|
||||
|
||||
|
||||
for io in interfaces:
|
||||
fns.extend(io.predict)
|
||||
outputs.extend(io.output_components)
|
||||
@ -27,27 +29,35 @@ class Parallel(gradio.Interface):
|
||||
"repeat_outputs_per_model": False,
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = (
|
||||
interfaces[0].api_mode,
|
||||
) # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
||||
|
||||
class Series(gradio.Interface):
|
||||
"""
|
||||
Creates a new Interface from multiple models in series (the output of one is fed as the input to the next)
|
||||
Parameters:
|
||||
interfaces: any number of Interface objects that are to be connected in series
|
||||
options: additional kwargs that are passed into the new Interface object to customize it
|
||||
Parameters:
|
||||
interfaces: any number of Interface objects that are to be connected in series
|
||||
options: additional kwargs that are passed into the new Interface object to customize it
|
||||
Returns:
|
||||
(Interface): an Interface object connecting the given models
|
||||
"""
|
||||
|
||||
def __init__(self, *interfaces, **options):
|
||||
fns = [io.predict for io in interfaces]
|
||||
|
||||
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
|
||||
|
||||
def connected_fn(
|
||||
*data,
|
||||
): # Run each function with the appropriate preprocessing and postprocessing
|
||||
for idx, io in enumerate(interfaces):
|
||||
# skip preprocessing for first interface since the Series interface will include it
|
||||
if idx > 0 and not(io.api_mode):
|
||||
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
|
||||
if idx > 0 and not (io.api_mode):
|
||||
data = [
|
||||
input_component.preprocess(data[i])
|
||||
for i, input_component in enumerate(io.input_components)
|
||||
]
|
||||
|
||||
# run all of predictions sequentially
|
||||
predictions = []
|
||||
@ -56,11 +66,14 @@ class Series(gradio.Interface):
|
||||
predictions.append(prediction)
|
||||
data = predictions
|
||||
# skip postprocessing for final interface since the Series interface will include it
|
||||
if idx < len(interfaces) - 1 and not(io.api_mode):
|
||||
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
|
||||
if idx < len(interfaces) - 1 and not (io.api_mode):
|
||||
data = [
|
||||
output_component.postprocess(data[i])
|
||||
for i, output_component in enumerate(io.output_components)
|
||||
]
|
||||
|
||||
return data[0]
|
||||
|
||||
|
||||
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
|
||||
|
||||
kwargs = {
|
||||
@ -69,6 +82,7 @@ class Series(gradio.Interface):
|
||||
"outputs": interfaces[-1].output_components,
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = interfaces[0].api_mode, # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.api_mode = (
|
||||
interfaces[0].api_mode,
|
||||
) # TODO(abidlabs): make api_mode a per-function attribute
|
||||
|
@ -3,32 +3,34 @@ Defines helper methods useful for setting up ports, launching servers, and
|
||||
creating tunnels.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import fastapi
|
||||
|
||||
import http
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import fastapi
|
||||
import requests
|
||||
import uvicorn
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.tunneling import create_tunnel
|
||||
from gradio.app import app
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio import Interface
|
||||
|
||||
|
||||
# By default, the local server will try to open on localhost, port 7860.
|
||||
# By default, the local server will try to open on localhost, port 7860.
|
||||
# If that is not available, then it will try 7861, 7862, ... 7959.
|
||||
INITIAL_PORT_VALUE = int(os.getenv('GRADIO_SERVER_PORT', "7860"))
|
||||
TRY_NUM_PORTS = int(os.getenv('GRADIO_NUM_PORTS', "100"))
|
||||
LOCALHOST_NAME = os.getenv('GRADIO_SERVER_NAME', "127.0.0.1")
|
||||
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
||||
TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100"))
|
||||
LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1")
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
|
||||
@ -47,9 +49,7 @@ class Server(uvicorn.Server):
|
||||
self.thread.join()
|
||||
|
||||
|
||||
def get_first_available_port(
|
||||
initial: int,
|
||||
final: int) -> int:
|
||||
def get_first_available_port(initial: int, final: int) -> int:
|
||||
"""
|
||||
Gets the first open port in a specified range of port numbers
|
||||
Parameters:
|
||||
@ -81,8 +81,8 @@ def queue_thread(path_to_local_server, test_mode=False):
|
||||
_, hash, input_data, task_type = next_job
|
||||
queueing.start_job(hash)
|
||||
response = requests.post(
|
||||
path_to_local_server + "api/" + task_type + "/",
|
||||
json=input_data)
|
||||
path_to_local_server + "api/" + task_type + "/", json=input_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
queueing.pass_job(hash, response.json())
|
||||
else:
|
||||
@ -97,9 +97,9 @@ def queue_thread(path_to_local_server, test_mode=False):
|
||||
|
||||
|
||||
def start_server(
|
||||
interface: Interface,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
interface: Interface,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
) -> Tuple[int, str, fastapi.FastAPI, threading.Thread, None]:
|
||||
"""Launches a local server running the provided Interface
|
||||
Parameters:
|
||||
@ -107,20 +107,24 @@ def start_server(
|
||||
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
||||
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
||||
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
||||
"""
|
||||
"""
|
||||
server_name = server_name or LOCALHOST_NAME
|
||||
# if port is not specified, search for first available port
|
||||
if server_port is None:
|
||||
if server_port is None:
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
else:
|
||||
try:
|
||||
s = socket.socket()
|
||||
s.bind((LOCALHOST_NAME, server_port))
|
||||
s.bind((LOCALHOST_NAME, server_port))
|
||||
s.close()
|
||||
except OSError:
|
||||
raise OSError("Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(server_port))
|
||||
raise OSError(
|
||||
"Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(
|
||||
server_port
|
||||
)
|
||||
)
|
||||
port = server_port
|
||||
|
||||
url_host_name = "localhost" if server_name == "0.0.0.0" else server_name
|
||||
@ -137,27 +141,28 @@ def start_server(
|
||||
app.cwd = os.getcwd()
|
||||
app.favicon_path = interface.favicon_path
|
||||
app.tokens = {}
|
||||
|
||||
|
||||
if app.interface.enable_queue:
|
||||
if auth is not None or app.interface.encrypt:
|
||||
raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
queueing.init()
|
||||
app.queue_thread = threading.Thread(
|
||||
target=queue_thread, args=(path_to_local_server,))
|
||||
target=queue_thread, args=(path_to_local_server,)
|
||||
)
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
interface.save_to["port"] = port
|
||||
|
||||
config = uvicorn.Config(app=app, port=port, host=server_name,
|
||||
log_level="warning")
|
||||
|
||||
config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning")
|
||||
server = Server(config=config)
|
||||
server.run_in_thread()
|
||||
return port, path_to_local_server, app, server
|
||||
return port, path_to_local_server, app, server
|
||||
|
||||
|
||||
def setup_tunnel(local_server_port: int, endpoint: str) -> str:
|
||||
response = url_request(
|
||||
endpoint + '/v1/tunnel-request' if endpoint is not None else GRADIO_API_SERVER)
|
||||
endpoint + "/v1/tunnel-request" if endpoint is not None else GRADIO_API_SERVER
|
||||
)
|
||||
if response and response.code == 200:
|
||||
try:
|
||||
payload = json.loads(response.read().decode("utf-8"))[0]
|
||||
@ -180,7 +185,7 @@ def url_request(url: str) -> Optional[http.client.HTTPResponse]:
|
||||
def url_ok(url: str) -> bool:
|
||||
try:
|
||||
for _ in range(5):
|
||||
time.sleep(.500)
|
||||
time.sleep(0.500)
|
||||
r = requests.head(url, timeout=3)
|
||||
if r.status_code in (200, 401, 302): # 401 or 302 if auth is set
|
||||
return True
|
||||
|
@ -5,18 +5,20 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from ffmpy import FFmpeg
|
||||
|
||||
import json
|
||||
from numbers import Number
|
||||
import numpy as np
|
||||
import operator
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from numbers import Number
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import tempfile
|
||||
from types import ModuleType
|
||||
from typing import Callable, Any, List, Optional, Tuple, Dict, TYPE_CHECKING
|
||||
import warnings
|
||||
from ffmpy import FFmpeg
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio.component import Component
|
||||
@ -38,34 +40,29 @@ class OutputComponent(Component):
|
||||
|
||||
def deserialize(self, x):
|
||||
"""
|
||||
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
|
||||
Convert from serialized output (e.g. base64 representation) from a call() to the interface to a human-readable version of the output (path of an image, etc.)
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class Textbox(OutputComponent):
|
||||
'''
|
||||
"""
|
||||
Component creates a textbox to render output text or number.
|
||||
Output type: Union[str, float, int]
|
||||
Demos: hello_world, sentence_builder
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, type: str = "auto", label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "str" expects a string, "number" expects a float value, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
**super().get_template_context()
|
||||
}
|
||||
return {**super().get_template_context()}
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -87,30 +84,32 @@ class Textbox(OutputComponent):
|
||||
elif self.type == "number":
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type +
|
||||
". Please choose from: 'str', 'number'")
|
||||
raise ValueError(
|
||||
"Unknown type: " + self.type + ". Please choose from: 'str', 'number'"
|
||||
)
|
||||
|
||||
|
||||
class Label(OutputComponent):
|
||||
'''
|
||||
"""
|
||||
Component outputs a classification label, along with confidence scores of top categories if provided. Confidence scores are represented as a dictionary mapping labels to scores between 0 and 1.
|
||||
Output type: Union[Dict[str, float], str, int, float]
|
||||
Demos: image_classifier, main_note, titanic_survival
|
||||
'''
|
||||
"""
|
||||
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_top_classes: Optional[int] = None,
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
self,
|
||||
num_top_classes: Optional[int] = None,
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
num_top_classes (int): number of most confident classes to show.
|
||||
type (str): Type of value to be passed to component. "value" expects a single out label, "confidences" expects a dictionary mapping labels to confidence scores, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
self.num_top_classes = num_top_classes
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
@ -118,43 +117,50 @@ class Label(OutputComponent):
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Dict[str, float]): dictionary mapping label to confidence value
|
||||
y (Dict[str, float]): dictionary mapping label to confidence value
|
||||
Returns:
|
||||
(Dict[label: str, confidences: List[Dict[label: str, confidence: number]]]): Object with key 'label' representing primary label, and key 'confidences' representing a list of label-confidence pairs
|
||||
"""
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
if self.type == "label" or (
|
||||
self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))
|
||||
):
|
||||
return {"label": str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
key=operator.itemgetter(1),
|
||||
reverse=True
|
||||
)
|
||||
elif self.type == "confidences" or (
|
||||
self.type == "auto" and isinstance(y, dict)
|
||||
):
|
||||
sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True)
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
sorted_pred = sorted_pred[: self.num_top_classes]
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
"label": pred[0],
|
||||
"confidence": pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
{"label": pred[0], "confidence": pred[1]} for pred in sorted_pred
|
||||
],
|
||||
}
|
||||
else:
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
raise ValueError(
|
||||
"The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences."
|
||||
)
|
||||
|
||||
def deserialize(self, y):
|
||||
# 5 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion', '0.46'
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, int) or isinstance(y, float) or ('label' in y and not('confidences' in y.keys())))):
|
||||
if self.type == "label" or (
|
||||
self.type == "auto"
|
||||
and (
|
||||
isinstance(y, str)
|
||||
or isinstance(y, int)
|
||||
or isinstance(y, float)
|
||||
or ("label" in y and not ("confidences" in y.keys()))
|
||||
)
|
||||
):
|
||||
if isinstance(y, str) or isinstance(y, int) or isinstance(y, float):
|
||||
return y
|
||||
else:
|
||||
return y['label']
|
||||
return y["label"]
|
||||
elif self.type == "confidences" or self.type == "auto":
|
||||
if ('confidences' in y.keys()) and isinstance(y['confidences'], list):
|
||||
return {k['label']: k['confidence'] for k in y['confidences']}
|
||||
if ("confidences" in y.keys()) and isinstance(y["confidences"], list):
|
||||
return {k["label"]: k["confidence"] for k in y["confidences"]}
|
||||
else:
|
||||
return y
|
||||
raise ValueError("Unable to deserialize output: {}".format(y))
|
||||
@ -170,7 +176,12 @@ class Label(OutputComponent):
|
||||
Returns: (Union[str, Dict[str, number]]): Either a string representing the main category label, or a dictionary with category keys mapping to confidence levels.
|
||||
"""
|
||||
if "confidences" in data:
|
||||
return json.dumps({example["label"]: example["confidence"] for example in data["confidences"]})
|
||||
return json.dumps(
|
||||
{
|
||||
example["label"]: example["confidence"]
|
||||
for example in data["confidences"]
|
||||
}
|
||||
)
|
||||
else:
|
||||
return data["label"]
|
||||
|
||||
@ -183,26 +194,26 @@ class Label(OutputComponent):
|
||||
|
||||
|
||||
class Image(OutputComponent):
|
||||
'''
|
||||
Component displays an output image.
|
||||
"""
|
||||
Component displays an output image.
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
|
||||
Demos: image_mod, webcam
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str = "auto",
|
||||
plot: bool = False,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
self, type: str = "auto", plot: bool = False, label: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
if plot:
|
||||
warnings.warn(
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.type = "plot"
|
||||
else:
|
||||
self.type = type
|
||||
@ -210,16 +221,12 @@ class Image(OutputComponent):
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"image": {},
|
||||
"plot": {"type": "plot"},
|
||||
"pil": {"type": "pil"}
|
||||
}
|
||||
return {"image": {}, "plot": {"type": "plot"}, "pil": {"type": "pil"}}
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
|
||||
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
|
||||
Returns:
|
||||
(str): base64 url data
|
||||
"""
|
||||
@ -234,7 +241,8 @@ class Image(OutputComponent):
|
||||
dtype = "plot"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype in ["numpy", "pil"]:
|
||||
@ -246,8 +254,11 @@ class Image(OutputComponent):
|
||||
elif dtype == "plot":
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + dtype +
|
||||
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ dtype
|
||||
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
return out_y
|
||||
|
||||
def deserialize(self, x):
|
||||
@ -262,83 +273,71 @@ class Image(OutputComponent):
|
||||
|
||||
|
||||
class Video(OutputComponent):
|
||||
'''
|
||||
Used for video output.
|
||||
"""
|
||||
Used for video output.
|
||||
Output type: filepath
|
||||
Demos: video_flip
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[str] = None,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, type: Optional[str] = None, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of video format to be passed to component, such as 'avi' or 'mp4'. Use 'mp4' to ensure browser playability. If set to None, video will keep returned format.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"video": {},
|
||||
"playable_video": {"type": "mp4"}
|
||||
}
|
||||
return {"video": {}, "playable_video": {"type": "mp4"}}
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (str): path to video
|
||||
y (str): path to video
|
||||
Returns:
|
||||
(str): base64 url data
|
||||
"""
|
||||
returned_format = y.split(".")[-1].lower()
|
||||
if self.type is not None and returned_format != self.type:
|
||||
output_file_name = y[0: y.rindex(
|
||||
".") + 1] + self.type
|
||||
ff = FFmpeg(
|
||||
inputs={y: None},
|
||||
outputs={output_file_name: None}
|
||||
)
|
||||
output_file_name = y[0 : y.rindex(".") + 1] + self.type
|
||||
ff = FFmpeg(inputs={y: None}, outputs={output_file_name: None})
|
||||
ff.run()
|
||||
y = output_file_name
|
||||
return {
|
||||
"name": os.path.basename(y),
|
||||
"data": processing_utils.encode_file_to_base64(y)
|
||||
"data": processing_utils.encode_file_to_base64(y),
|
||||
}
|
||||
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
return self.save_flagged_file(dir, label, data['data'], encryption_key)
|
||||
return self.save_flagged_file(dir, label, data["data"], encryption_key)
|
||||
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return self.restore_flagged_file(dir, data, encryption_key)
|
||||
|
||||
|
||||
class KeyValues(OutputComponent):
|
||||
'''
|
||||
Component displays a table representing values for multiple fields.
|
||||
"""
|
||||
Component displays a table representing values for multiple fields.
|
||||
Output type: Union[Dict, List[Tuple[str, Union[str, int, float]]]]
|
||||
Demos: text_analysis
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
|
||||
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
|
||||
Returns:
|
||||
(List[Tuple[str, Union[str, number]]]): list of key value pairs
|
||||
"""
|
||||
@ -347,8 +346,10 @@ class KeyValues(OutputComponent):
|
||||
elif isinstance(y, list):
|
||||
return y
|
||||
else:
|
||||
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
|
||||
"labels and values are corresponding values.")
|
||||
raise ValueError(
|
||||
"The `KeyValues` output interface expects an output that is a dictionary whose keys are "
|
||||
"labels and values are corresponding values."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -364,24 +365,25 @@ class KeyValues(OutputComponent):
|
||||
|
||||
|
||||
class HighlightedText(OutputComponent):
|
||||
'''
|
||||
"""
|
||||
Component creates text that contains spans that are highlighted by category or numerical value.
|
||||
Output is represent as a list of Tuple pairs, where the first element represents the span of text represented by the tuple, and the second element represents the category or value of the text.
|
||||
Output type: List[Tuple[str, Union[float, str]]]
|
||||
Demos: diff_texts, text_analysis
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color_map: Dict[str, str] = None,
|
||||
self,
|
||||
color_map: Dict[str, str] = None,
|
||||
label: Optional[str] = None,
|
||||
show_legend: bool = False):
|
||||
'''
|
||||
show_legend: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
color_map (Dict[str, str]): Map between category and respective colors
|
||||
label (str): component name in interface.
|
||||
show_legend (bool): whether to show span categories in a separate legend or inline.
|
||||
'''
|
||||
"""
|
||||
self.color_map = color_map
|
||||
self.show_legend = show_legend
|
||||
super().__init__(label)
|
||||
@ -390,7 +392,7 @@ class HighlightedText(OutputComponent):
|
||||
return {
|
||||
"color_map": self.color_map,
|
||||
"show_legend": self.show_legend,
|
||||
**super().get_template_context()
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -402,7 +404,7 @@ class HighlightedText(OutputComponent):
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
|
||||
y (Union[Dict, List[Tuple[str, Union[str, int, float]]]]): dictionary or tuple list representing key value pairs
|
||||
Returns:
|
||||
(List[Tuple[str, Union[str, number]]]): list of key value pairs
|
||||
|
||||
@ -417,28 +419,23 @@ class HighlightedText(OutputComponent):
|
||||
|
||||
|
||||
class Audio(OutputComponent):
|
||||
'''
|
||||
"""
|
||||
Creates an audio player that plays the output audio.
|
||||
Output type: Union[Tuple[int, numpy.array], str]
|
||||
Demos: generate_tone, reverse_audio
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, type: str = "auto", label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" returns a 2-set tuple with an integer sample_rate and the data numpy.array of shape (samples, 2), "file" returns a temporary file path to the saved wav audio file, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
**super().get_template_context()
|
||||
}
|
||||
return {**super().get_template_context()}
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -457,13 +454,15 @@ class Audio(OutputComponent):
|
||||
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
|
||||
sample_rate, data = y
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
prefix="sample", suffix=".wav", delete=False)
|
||||
prefix="sample", suffix=".wav", delete=False
|
||||
)
|
||||
processing_utils.audio_to_file(sample_rate, data, file.name)
|
||||
y = file.name
|
||||
return processing_utils.encode_url_or_file_to_base64(y)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type +
|
||||
". Please choose from: 'numpy', 'file'.")
|
||||
raise ValueError(
|
||||
"Unknown type: " + self.type + ". Please choose from: 'numpy', 'file'."
|
||||
)
|
||||
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
@ -476,19 +475,17 @@ class Audio(OutputComponent):
|
||||
|
||||
|
||||
class JSON(OutputComponent):
|
||||
'''
|
||||
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
|
||||
"""
|
||||
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.
|
||||
Output type: Union[str, Any]
|
||||
Demos: zip_to_json
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, y):
|
||||
@ -517,19 +514,17 @@ class JSON(OutputComponent):
|
||||
|
||||
|
||||
class HTML(OutputComponent):
|
||||
'''
|
||||
Used for HTML output. Expects an HTML valid string.
|
||||
"""
|
||||
Used for HTML output. Expects an HTML valid string.
|
||||
Output type: str
|
||||
Demos: text_analysis
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, x):
|
||||
@ -549,19 +544,17 @@ class HTML(OutputComponent):
|
||||
|
||||
|
||||
class File(OutputComponent):
|
||||
'''
|
||||
Used for file output.
|
||||
"""
|
||||
Used for file output.
|
||||
Output type: Union[file-like, str]
|
||||
Demos: zip_two_files
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
def __init__(self, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
super().__init__(label)
|
||||
|
||||
@classmethod
|
||||
@ -575,12 +568,12 @@ class File(OutputComponent):
|
||||
Parameters:
|
||||
y (str): file path
|
||||
Returns:
|
||||
(Dict[name: str, size: number, data: str]): JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
|
||||
(Dict[name: str, size: number, data: str]): JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes
|
||||
"""
|
||||
return {
|
||||
"name": os.path.basename(y),
|
||||
"size": os.path.getsize(y),
|
||||
"data": processing_utils.encode_file_to_base64(y)
|
||||
"data": processing_utils.encode_file_to_base64(y),
|
||||
}
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
@ -598,22 +591,23 @@ class Dataframe(OutputComponent):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Optional[List[str]] = None,
|
||||
max_rows: Optional[int] = 20,
|
||||
max_cols: Optional[int] = None,
|
||||
overflow_row_behaviour: str = "paginate",
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
self,
|
||||
headers: Optional[List[str]] = None,
|
||||
max_rows: Optional[int] = 20,
|
||||
max_cols: Optional[int] = None,
|
||||
overflow_row_behaviour: str = "paginate",
|
||||
type: str = "auto",
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
headers (List[str]): Header names to dataframe. Only applicable if type is "numpy" or "array".
|
||||
max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
|
||||
max_rows (int): Maximum number of rows to display at once. Set to None for infinite.
|
||||
max_cols (int): Maximum number of columns to display at once. Set to None for infinite.
|
||||
overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
|
||||
overflow_row_behaviour (str): If set to "paginate", will create pages for overflow rows. If set to "show_ends", will show initial and final rows and truncate middle rows.
|
||||
type (str): Type of value to be passed to component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for Python array, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
self.headers = headers
|
||||
self.max_rows = max_rows
|
||||
self.max_cols = max_cols
|
||||
@ -627,7 +621,7 @@ class Dataframe(OutputComponent):
|
||||
"max_rows": self.max_rows,
|
||||
"max_cols": self.max_cols,
|
||||
"overflow_row_behaviour": self.overflow_row_behaviour,
|
||||
**super().get_template_context()
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -664,8 +658,11 @@ class Dataframe(OutputComponent):
|
||||
y = [y]
|
||||
return {"data": y}
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type +
|
||||
". Please choose from: 'pandas', 'numpy', 'array'.")
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ self.type
|
||||
+ ". Please choose from: 'pandas', 'numpy', 'array'."
|
||||
)
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
return json.dumps(data["data"])
|
||||
@ -682,24 +679,26 @@ class Carousel(OutputComponent):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: OutputComponent | List[OutputComponent],
|
||||
label: Optional[str] = None):
|
||||
'''
|
||||
self,
|
||||
components: OutputComponent | List[OutputComponent],
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
components (Union[List[OutputComponent], OutputComponent]): Classes of component(s) that will be scrolled through.
|
||||
label (str): component name in interface.
|
||||
'''
|
||||
"""
|
||||
if not isinstance(components, list):
|
||||
components = [components]
|
||||
self.components = [get_output_instance(
|
||||
component) for component in components]
|
||||
self.components = [get_output_instance(component) for component in components]
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"components": [component.get_template_context() for component in self.components],
|
||||
**super().get_template_context()
|
||||
"components": [
|
||||
component.get_template_context() for component in self.components
|
||||
],
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
@ -720,23 +719,29 @@ class Carousel(OutputComponent):
|
||||
output.append(output_row)
|
||||
return output
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type. Please provide a list for the Carousel.")
|
||||
raise ValueError("Unknown type. Please provide a list for the Carousel.")
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
return json.dumps([
|
||||
return json.dumps(
|
||||
[
|
||||
component.save_flagged(
|
||||
dir, f"{label}_{j}", data[i][j], encryption_key)
|
||||
for j, component in enumerate(self.components)
|
||||
] for i, _ in enumerate(data)])
|
||||
[
|
||||
component.save_flagged(
|
||||
dir, f"{label}_{j}", data[i][j], encryption_key
|
||||
)
|
||||
for j, component in enumerate(self.components)
|
||||
]
|
||||
for i, _ in enumerate(data)
|
||||
]
|
||||
)
|
||||
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return [
|
||||
[
|
||||
component.restore_flagged(dir, sample, encryption_key)
|
||||
for component, sample in zip(self.components, sample_set)
|
||||
] for sample_set in json.loads(data)]
|
||||
]
|
||||
for sample_set in json.loads(data)
|
||||
]
|
||||
|
||||
|
||||
class Timeseries(OutputComponent):
|
||||
@ -747,10 +752,8 @@ class Timeseries(OutputComponent):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x: str = None,
|
||||
y: str | List[str] = None,
|
||||
label: Optional[str] = None):
|
||||
self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
x (str): Column name of x (time) series. None if csv has no headers, in which case first column is x series.
|
||||
@ -764,11 +767,7 @@ class Timeseries(OutputComponent):
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"x": self.x,
|
||||
"y": self.y,
|
||||
**super().get_template_context()
|
||||
}
|
||||
return {"x": self.x, "y": self.y, **super().get_template_context()}
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -783,11 +782,7 @@ class Timeseries(OutputComponent):
|
||||
Returns:
|
||||
(Dict[headers: List[str], data: List[List[Union[str, number]]]]): JSON object with key 'headers' for list of header names, 'data' for 2D array of string or numeric data
|
||||
"""
|
||||
return {
|
||||
"headers": y.columns.values.tolist(),
|
||||
"data": y.values.tolist()
|
||||
|
||||
}
|
||||
return {"headers": y.columns.values.tolist(), "data": y.values.tolist()}
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
@ -803,13 +798,12 @@ class State(OutputComponent):
|
||||
"""
|
||||
Special hidden component that stores state across runs of the interface.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
label: Optional[str] = None):
|
||||
|
||||
def __init__(self, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
label (str): component name in interface (not used).
|
||||
"""
|
||||
"""
|
||||
super().__init__(label)
|
||||
|
||||
@classmethod
|
||||
@ -825,7 +819,7 @@ def get_output_instance(iface: Interface):
|
||||
return shortcut[0](**shortcut[1])
|
||||
# a dict with `name` as the output component type and other keys as parameters
|
||||
elif isinstance(iface, dict):
|
||||
name = iface.pop('name')
|
||||
name = iface.pop("name")
|
||||
for component in OutputComponent.__subclasses__():
|
||||
if component.__name__.lower() == name:
|
||||
break
|
||||
|
@ -1,22 +1,33 @@
|
||||
import os, shutil
|
||||
from gradio.flagging import CSVLogger
|
||||
from typing import Any, List
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, List
|
||||
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
CACHED_FOLDER = "gradio_cached_examples"
|
||||
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
|
||||
|
||||
|
||||
def process_example(interface, example_id: int):
|
||||
example_set = interface.examples[example_id]
|
||||
raw_input = [interface.input_components[i].preprocess_example(example) for i, example in enumerate(example_set)]
|
||||
raw_input = [
|
||||
interface.input_components[i].preprocess_example(example)
|
||||
for i, example in enumerate(example_set)
|
||||
]
|
||||
prediction, durations = interface.process(raw_input)
|
||||
return prediction, durations
|
||||
|
||||
|
||||
def cache_interface_examples(interface) -> None:
|
||||
if os.path.exists(CACHE_FILE):
|
||||
print(f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache.")
|
||||
print(
|
||||
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
)
|
||||
else:
|
||||
print(f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory.")
|
||||
print(
|
||||
f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory."
|
||||
)
|
||||
cache_logger = CSVLogger()
|
||||
cache_logger.setup(CACHED_FOLDER)
|
||||
for example_id, _ in enumerate(interface.examples):
|
||||
@ -27,12 +38,18 @@ def cache_interface_examples(interface) -> None:
|
||||
shutil.rmtree(CACHED_FOLDER)
|
||||
raise e
|
||||
|
||||
|
||||
def load_from_cache(interface, example_id: int) -> List[Any]:
|
||||
with open(CACHE_FILE) as cache:
|
||||
examples = list(csv.reader(cache))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
for component, cell in zip(interface.output_components, example):
|
||||
output.append(component.restore_flagged(
|
||||
CACHED_FOLDER, cell, interface.encryption_key if interface.encrypt else None))
|
||||
output.append(
|
||||
component.restore_flagged(
|
||||
CACHED_FOLDER,
|
||||
cell,
|
||||
interface.encryption_key if interface.encrypt else None,
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
@ -1,24 +1,27 @@
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import requests
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
import numpy as np
|
||||
from gradio import encryptor
|
||||
import warnings
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from gradio import encryptor
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
|
||||
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
|
||||
from pydub import AudioSegment
|
||||
|
||||
#########################
|
||||
# IMAGE PRE-PROCESSING
|
||||
#########################
|
||||
def decode_base64_to_image(encoding):
|
||||
content = encoding.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
content = encoding.split(";")[1]
|
||||
image_encoded = content.split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
@ -26,7 +29,7 @@ def get_url_or_file_as_bytes(path):
|
||||
try:
|
||||
return requests.get(path).content
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
with open(path, "rb") as f:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@ -37,14 +40,14 @@ def encode_url_or_file_to_base64(path):
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
return encode_file_to_base64(path)
|
||||
|
||||
|
||||
def get_mimetype(filename):
|
||||
mimetype = mimetypes.guess_type(filename)[0]
|
||||
if mimetype is not None:
|
||||
mimetype = mimetype.replace(
|
||||
"x-wav", "wav").replace(
|
||||
"x-flac", "flac")
|
||||
mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac")
|
||||
return mimetype
|
||||
|
||||
|
||||
def get_extension(encoding):
|
||||
encoding = encoding.replace("audio/wav", "audio/x-wav")
|
||||
type = mimetypes.guess_type(encoding)[0]
|
||||
@ -55,40 +58,49 @@ def get_extension(encoding):
|
||||
extension = extension[1:]
|
||||
return extension
|
||||
|
||||
|
||||
def encode_file_to_base64(f, encryption_key=None):
|
||||
with open(f, "rb") as file:
|
||||
encoded_string = base64.b64encode(file.read())
|
||||
if encryption_key:
|
||||
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
|
||||
base64_str = str(encoded_string, 'utf-8')
|
||||
base64_str = str(encoded_string, "utf-8")
|
||||
mimetype = get_mimetype(f)
|
||||
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
|
||||
return (
|
||||
"data:"
|
||||
+ (mimetype if mimetype is not None else "")
|
||||
+ ";base64,"
|
||||
+ base64_str
|
||||
)
|
||||
|
||||
|
||||
def encode_url_to_base64(url):
|
||||
encoded_string = base64.b64encode(requests.get(url).content)
|
||||
base64_str = str(encoded_string, 'utf-8')
|
||||
base64_str = str(encoded_string, "utf-8")
|
||||
mimetype = get_mimetype(url)
|
||||
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
|
||||
return (
|
||||
"data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
|
||||
)
|
||||
|
||||
|
||||
def encode_plot_to_base64(plt):
|
||||
with BytesIO() as output_bytes:
|
||||
plt.savefig(output_bytes, format="png")
|
||||
bytes_data = output_bytes.getvalue()
|
||||
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
|
||||
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
||||
return "data:image/png;base64," + base64_str
|
||||
|
||||
|
||||
def encode_array_to_base64(image_array):
|
||||
with BytesIO() as output_bytes:
|
||||
PIL_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
||||
PIL_image.save(output_bytes, 'PNG')
|
||||
PIL_image.save(output_bytes, "PNG")
|
||||
bytes_data = output_bytes.getvalue()
|
||||
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
|
||||
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
||||
return "data:image/png;base64," + base64_str
|
||||
|
||||
|
||||
def resize_and_crop(img, size, crop_type='center'):
|
||||
def resize_and_crop(img, size, crop_type="center"):
|
||||
"""
|
||||
Resize and crop an image to fit the specified size.
|
||||
args:
|
||||
@ -105,32 +117,36 @@ def resize_and_crop(img, size, crop_type='center'):
|
||||
center = (0.5, 0.5)
|
||||
else:
|
||||
raise ValueError
|
||||
return ImageOps.fit(img, size, centering=center)
|
||||
return ImageOps.fit(img, size, centering=center)
|
||||
|
||||
|
||||
##################
|
||||
# Audio
|
||||
##################
|
||||
|
||||
|
||||
def audio_from_file(filename, crop_min=0, crop_max=100):
|
||||
audio = AudioSegment.from_file(filename)
|
||||
if crop_min != 0 or crop_max != 100:
|
||||
audio_start = len(audio) * crop_min / 100
|
||||
audio_end = len(audio) * crop_max / 100
|
||||
audio = audio[audio_start : audio_end]
|
||||
audio = audio[audio_start:audio_end]
|
||||
data = np.array(audio.get_array_of_samples())
|
||||
if (audio.channels > 1):
|
||||
if audio.channels > 1:
|
||||
data = data.reshape(-1, audio.channels)
|
||||
return audio.frame_rate, data
|
||||
|
||||
|
||||
def audio_to_file(sample_rate, data, filename):
|
||||
audio = AudioSegment(
|
||||
data.tobytes(),
|
||||
data.tobytes(),
|
||||
frame_rate=sample_rate,
|
||||
sample_width=data.dtype.itemsize,
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1])
|
||||
sample_width=data.dtype.itemsize,
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
||||
)
|
||||
audio.export(filename, format="wav")
|
||||
|
||||
|
||||
##################
|
||||
# OUTPUT
|
||||
##################
|
||||
@ -139,7 +155,8 @@ def audio_to_file(sample_rate, data, filename):
|
||||
def decode_base64_to_binary(encoding):
|
||||
extension = get_extension(encoding)
|
||||
data = encoding.split(",")[1]
|
||||
return base64.b64decode(data), extension
|
||||
return base64.b64decode(data), extension
|
||||
|
||||
|
||||
def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
|
||||
data, extension = decode_base64_to_binary(encoding)
|
||||
@ -148,35 +165,41 @@ def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
|
||||
filename = os.path.basename(file_path)
|
||||
prefix = filename
|
||||
if "." in filename:
|
||||
prefix = filename[0: filename.index(".")]
|
||||
extension = filename[filename.index(".") + 1:]
|
||||
prefix = filename[0 : filename.index(".")]
|
||||
extension = filename[filename.index(".") + 1 :]
|
||||
if extension is None:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
|
||||
else:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, suffix="."+extension)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, prefix=prefix, suffix="." + extension
|
||||
)
|
||||
if encryption_key is not None:
|
||||
data = encryptor.encrypt(encryption_key, data)
|
||||
file_obj.write(data)
|
||||
file_obj.flush()
|
||||
return file_obj
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(file_path):
|
||||
file_name = os.path.basename(file_path)
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
prefix = file_name[0: file_name.index(".")]
|
||||
extension = file_name[file_name.index(".") + 1:]
|
||||
prefix = file_name[0 : file_name.index(".")]
|
||||
extension = file_name[file_name.index(".") + 1 :]
|
||||
if extension is None:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
|
||||
else:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, suffix="."+extension)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, prefix=prefix, suffix="." + extension
|
||||
)
|
||||
shutil.copy2(file_path, file_obj.name)
|
||||
return file_obj
|
||||
|
||||
|
||||
def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
"""
|
||||
Adapted from: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/dtype.py#L510-L531
|
||||
|
||||
|
||||
Convert an image to the requested data-type.
|
||||
Warnings are issued in case of precision loss, or when negative values
|
||||
are clipped during conversion to unsigned integer types (sign loss).
|
||||
@ -216,14 +239,16 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
.. [4] Dirty Pixels. J. Blinn. In "Jim Blinn's corner: Dirty Pixels",
|
||||
pp 47-57. Morgan Kaufmann, 1998.
|
||||
"""
|
||||
dtype_range = {bool: (False, True),
|
||||
np.bool_: (False, True),
|
||||
np.bool8: (False, True),
|
||||
float: (-1, 1),
|
||||
np.float_: (-1, 1),
|
||||
np.float16: (-1, 1),
|
||||
np.float32: (-1, 1),
|
||||
np.float64: (-1, 1)}
|
||||
dtype_range = {
|
||||
bool: (False, True),
|
||||
np.bool_: (False, True),
|
||||
np.bool8: (False, True),
|
||||
float: (-1, 1),
|
||||
np.float_: (-1, 1),
|
||||
np.float16: (-1, 1),
|
||||
np.float32: (-1, 1),
|
||||
np.float64: (-1, 1),
|
||||
}
|
||||
|
||||
def _dtype_itemsize(itemsize, *dtypes):
|
||||
"""Return first of `dtypes` with itemsize greater than `itemsize`
|
||||
@ -258,12 +283,14 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
Data type of `kind` that can store a `bits` wide unsigned int
|
||||
"""
|
||||
|
||||
s = next(i for i in (itemsize, ) + (2, 4, 8) if
|
||||
bits < (i * 8) or (bits == (i * 8) and kind == 'u'))
|
||||
s = next(
|
||||
i
|
||||
for i in (itemsize,) + (2, 4, 8)
|
||||
if bits < (i * 8) or (bits == (i * 8) and kind == "u")
|
||||
)
|
||||
|
||||
return np.dtype(kind + str(s))
|
||||
|
||||
|
||||
def _scale(a, n, m, copy=True):
|
||||
"""Scale an array of unsigned/positive integers from `n` to `m` bits.
|
||||
Numbers can be represented exactly only if `m` is a multiple of `n`.
|
||||
@ -298,21 +325,20 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
# downscale with precision loss
|
||||
if copy:
|
||||
b = np.empty(a.shape, _dtype_bits(kind, m))
|
||||
np.floor_divide(a, 2**(n - m), out=b, dtype=a.dtype,
|
||||
casting='unsafe')
|
||||
np.floor_divide(a, 2 ** (n - m), out=b, dtype=a.dtype, casting="unsafe")
|
||||
return b
|
||||
else:
|
||||
a //= 2**(n - m)
|
||||
a //= 2 ** (n - m)
|
||||
return a
|
||||
elif m % n == 0:
|
||||
# exact upscale to a multiple of `n` bits
|
||||
if copy:
|
||||
b = np.empty(a.shape, _dtype_bits(kind, m))
|
||||
np.multiply(a, (2**m - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
||||
np.multiply(a, (2 ** m - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
|
||||
return b
|
||||
else:
|
||||
a = a.astype(_dtype_bits(kind, m, a.dtype.itemsize), copy=False)
|
||||
a *= (2**m - 1) // (2**n - 1)
|
||||
a *= (2 ** m - 1) // (2 ** n - 1)
|
||||
return a
|
||||
else:
|
||||
# upscale to a multiple of `n` bits,
|
||||
@ -320,19 +346,19 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
o = (m // n + 1) * n
|
||||
if copy:
|
||||
b = np.empty(a.shape, _dtype_bits(kind, o))
|
||||
np.multiply(a, (2**o - 1) // (2**n - 1), out=b, dtype=b.dtype)
|
||||
b //= 2**(o - m)
|
||||
np.multiply(a, (2 ** o - 1) // (2 ** n - 1), out=b, dtype=b.dtype)
|
||||
b //= 2 ** (o - m)
|
||||
return b
|
||||
else:
|
||||
a = a.astype(_dtype_bits(kind, o, a.dtype.itemsize), copy=False)
|
||||
a *= (2**o - 1) // (2**n - 1)
|
||||
a //= 2**(o - m)
|
||||
a *= (2 ** o - 1) // (2 ** n - 1)
|
||||
a //= 2 ** (o - m)
|
||||
return a
|
||||
|
||||
image = np.asarray(image)
|
||||
dtypeobj_in = image.dtype
|
||||
if dtype is np.floating:
|
||||
dtypeobj_out = np.dtype('float64')
|
||||
dtypeobj_out = np.dtype("float64")
|
||||
else:
|
||||
dtypeobj_out = np.dtype(dtype)
|
||||
dtype_in = dtypeobj_in.type
|
||||
@ -356,28 +382,27 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
image = image.copy()
|
||||
return image
|
||||
|
||||
if kind_in in 'ui':
|
||||
if kind_in in "ui":
|
||||
imin_in = np.iinfo(dtype_in).min
|
||||
imax_in = np.iinfo(dtype_in).max
|
||||
if kind_out in 'ui':
|
||||
if kind_out in "ui":
|
||||
imin_out = np.iinfo(dtype_out).min
|
||||
imax_out = np.iinfo(dtype_out).max
|
||||
|
||||
# any -> binary
|
||||
if kind_out == 'b':
|
||||
if kind_out == "b":
|
||||
return image > dtype_in(dtype_range[dtype_in][1] / 2)
|
||||
|
||||
# binary -> any
|
||||
if kind_in == 'b':
|
||||
if kind_in == "b":
|
||||
result = image.astype(dtype_out)
|
||||
if kind_out != 'f':
|
||||
if kind_out != "f":
|
||||
result *= dtype_out(dtype_range[dtype_out][1])
|
||||
return result
|
||||
|
||||
|
||||
# float -> any
|
||||
if kind_in == 'f':
|
||||
if kind_out == 'f':
|
||||
if kind_in == "f":
|
||||
if kind_out == "f":
|
||||
# float -> float
|
||||
return image.astype(dtype_out)
|
||||
|
||||
@ -385,41 +410,42 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
raise ValueError("Images of type float must be between -1 and 1.")
|
||||
# floating point -> integer
|
||||
# use float type that can represent output integer type
|
||||
computation_type = _dtype_itemsize(itemsize_out, dtype_in,
|
||||
np.float32, np.float64)
|
||||
computation_type = _dtype_itemsize(
|
||||
itemsize_out, dtype_in, np.float32, np.float64
|
||||
)
|
||||
|
||||
if not uniform:
|
||||
if kind_out == 'u':
|
||||
image_out = np.multiply(image, imax_out,
|
||||
dtype=computation_type)
|
||||
if kind_out == "u":
|
||||
image_out = np.multiply(image, imax_out, dtype=computation_type)
|
||||
else:
|
||||
image_out = np.multiply(image, (imax_out - imin_out) / 2,
|
||||
dtype=computation_type)
|
||||
image_out -= 1.0 / 2.
|
||||
image_out = np.multiply(
|
||||
image, (imax_out - imin_out) / 2, dtype=computation_type
|
||||
)
|
||||
image_out -= 1.0 / 2.0
|
||||
np.rint(image_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out)
|
||||
elif kind_out == 'u':
|
||||
image_out = np.multiply(image, imax_out + 1,
|
||||
dtype=computation_type)
|
||||
elif kind_out == "u":
|
||||
image_out = np.multiply(image, imax_out + 1, dtype=computation_type)
|
||||
np.clip(image_out, 0, imax_out, out=image_out)
|
||||
else:
|
||||
image_out = np.multiply(image, (imax_out - imin_out + 1.0) / 2.0,
|
||||
dtype=computation_type)
|
||||
image_out = np.multiply(
|
||||
image, (imax_out - imin_out + 1.0) / 2.0, dtype=computation_type
|
||||
)
|
||||
np.floor(image_out, out=image_out)
|
||||
np.clip(image_out, imin_out, imax_out, out=image_out)
|
||||
return image_out.astype(dtype_out)
|
||||
|
||||
# signed/unsigned int -> float
|
||||
if kind_out == 'f':
|
||||
if kind_out == "f":
|
||||
# use float type that can exactly represent input integers
|
||||
computation_type = _dtype_itemsize(itemsize_in, dtype_out,
|
||||
np.float32, np.float64)
|
||||
computation_type = _dtype_itemsize(
|
||||
itemsize_in, dtype_out, np.float32, np.float64
|
||||
)
|
||||
|
||||
if kind_in == 'u':
|
||||
if kind_in == "u":
|
||||
# using np.divide or np.multiply doesn't copy the data
|
||||
# until the computation time
|
||||
image = np.multiply(image, 1. / imax_in,
|
||||
dtype=computation_type)
|
||||
image = np.multiply(image, 1.0 / imax_in, dtype=computation_type)
|
||||
# DirectX uses this conversion also for signed ints
|
||||
# if imin_in:
|
||||
# np.maximum(image, -1.0, out=image)
|
||||
@ -430,8 +456,8 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
return np.asarray(image, dtype_out)
|
||||
|
||||
# unsigned int -> signed/unsigned int
|
||||
if kind_in == 'u':
|
||||
if kind_out == 'i':
|
||||
if kind_in == "u":
|
||||
if kind_out == "i":
|
||||
# unsigned int -> signed int
|
||||
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out - 1)
|
||||
return image.view(dtype_out)
|
||||
@ -440,17 +466,17 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
return _scale(image, 8 * itemsize_in, 8 * itemsize_out)
|
||||
|
||||
# signed int -> unsigned int
|
||||
if kind_out == 'u':
|
||||
if kind_out == "u":
|
||||
image = _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out)
|
||||
result = np.empty(image.shape, dtype_out)
|
||||
np.maximum(image, 0, out=result, dtype=image.dtype, casting='unsafe')
|
||||
np.maximum(image, 0, out=result, dtype=image.dtype, casting="unsafe")
|
||||
return result
|
||||
|
||||
# signed int -> signed int
|
||||
if itemsize_in > itemsize_out:
|
||||
return _scale(image, 8 * itemsize_in - 1, 8 * itemsize_out - 1)
|
||||
|
||||
image = image.astype(_dtype_bits('i', itemsize_out * 8))
|
||||
image = image.astype(_dtype_bits("i", itemsize_out * 8))
|
||||
image -= imin_in
|
||||
image = _scale(image, 8 * itemsize_in, 8 * itemsize_out, copy=False)
|
||||
image += imin_out
|
||||
|
@ -1,66 +1,82 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
def generate_hash():
|
||||
generate = True
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
while generate:
|
||||
hash = uuid.uuid4().hex
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT hash FROM queue
|
||||
WHERE hash = ?;
|
||||
""", (hash,))
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
generate = c.fetchone() is not None
|
||||
conn.commit()
|
||||
return hash
|
||||
|
||||
|
||||
def init():
|
||||
if os.path.exists(DB_FILE):
|
||||
os.remove(DB_FILE)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("""CREATE TABLE queue (
|
||||
c.execute(
|
||||
"""CREATE TABLE queue (
|
||||
queue_index integer PRIMARY KEY,
|
||||
hash text,
|
||||
input_data text,
|
||||
action text,
|
||||
popped integer DEFAULT 0
|
||||
);""")
|
||||
c.execute("""
|
||||
);"""
|
||||
)
|
||||
c.execute(
|
||||
"""
|
||||
CREATE TABLE jobs (
|
||||
hash text PRIMARY KEY,
|
||||
status text,
|
||||
output_data text,
|
||||
error_message text
|
||||
);
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def close():
|
||||
if os.path.exists(DB_FILE):
|
||||
os.remove(DB_FILE)
|
||||
|
||||
|
||||
def pop():
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT queue_index, hash, input_data, action FROM queue
|
||||
WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result is None:
|
||||
conn.commit()
|
||||
return None
|
||||
queue_index = result[0]
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?;
|
||||
""", (queue_index,))
|
||||
""",
|
||||
(queue_index,),
|
||||
)
|
||||
conn.commit()
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
@ -70,42 +86,57 @@ def push(input_data, action):
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO queue (hash, input_data, action)
|
||||
VALUES (?, ?, ?);
|
||||
""", (hash, input_data, action))
|
||||
""",
|
||||
(hash, input_data, action),
|
||||
)
|
||||
queue_index = c.lastrowid
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM queue WHERE queue_index < ? and popped = 0;
|
||||
""", (queue_index,))
|
||||
""",
|
||||
(queue_index,),
|
||||
)
|
||||
queue_position = c.fetchone()[0]
|
||||
if queue_position is None:
|
||||
conn.commit()
|
||||
raise ValueError("Hash not found.")
|
||||
elif queue_position == 0:
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
""")
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result[0] == 0:
|
||||
queue_position -= 1
|
||||
conn.commit()
|
||||
return hash, queue_position
|
||||
|
||||
|
||||
def get_status(hash):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT queue_index, popped FROM queue WHERE hash = ?;
|
||||
""", (hash,))
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result is None:
|
||||
conn.commit()
|
||||
raise ValueError("Hash not found.")
|
||||
if result[1] == 1: # in jobs
|
||||
c.execute("""
|
||||
if result[1] == 1: # in jobs
|
||||
c.execute(
|
||||
"""
|
||||
SELECT status, output_data, error_message FROM jobs WHERE hash = ?;
|
||||
""", (hash,))
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result is None:
|
||||
conn.commit()
|
||||
@ -119,55 +150,83 @@ def get_status(hash):
|
||||
conn.commit()
|
||||
return "FAILED", error_message
|
||||
elif status == "COMPLETE":
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE jobs SET output_data = '' WHERE hash = ?;
|
||||
""", (hash,))
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
conn.commit()
|
||||
output_data = json.loads(output_data)
|
||||
return "COMPLETE", output_data
|
||||
else: # in queue
|
||||
else: # in queue
|
||||
queue_index = result[0]
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM queue WHERE queue_index < ? and popped = 0;
|
||||
""", (queue_index,))
|
||||
""",
|
||||
(queue_index,),
|
||||
)
|
||||
result = c.fetchone()
|
||||
queue_position = result[0]
|
||||
if queue_position == 0:
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
""")
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result[0] == 0:
|
||||
queue_position -= 1
|
||||
conn.commit()
|
||||
return "QUEUED", queue_position
|
||||
|
||||
|
||||
def start_job(hash):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE queue SET popped = 1 WHERE hash = ?;
|
||||
""", (hash,))
|
||||
c.execute("""
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
c.execute(
|
||||
"""
|
||||
INSERT INTO jobs (hash, status) VALUES (?, 'PENDING');
|
||||
""", (hash,))
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fail_job(hash, error_message):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE jobs SET status = 'FAILED', error_message = ? WHERE hash = ?;
|
||||
""", (error_message, hash,))
|
||||
""",
|
||||
(
|
||||
error_message,
|
||||
hash,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def pass_job(hash, output_data):
|
||||
output_data = json.dumps(output_data)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
c.execute(
|
||||
"""
|
||||
UPDATE jobs SET status = 'COMPLETE', output_data = ? WHERE hash = ?;
|
||||
""", (output_data, hash,))
|
||||
""",
|
||||
(
|
||||
output_data,
|
||||
hash,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
|
||||
|
||||
en = {
|
||||
en = {
|
||||
"RUNNING_LOCALLY": "Running on local URL: {}",
|
||||
"SHARE_LINK_DISPLAY": "Running on public URL: {}",
|
||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||
@ -12,11 +12,11 @@ en = {
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
|
||||
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
|
||||
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
|
||||
" avoid the 'Tensor is not an element of this graph.' error.",
|
||||
" avoid the 'Tensor is not an element of this graph.' error.",
|
||||
"BETA_INVITE": "\nWe want to invite you to become a beta user.\nYou'll get early access to new and premium "
|
||||
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
|
||||
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
|
||||
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
|
||||
"To turn off, set debug=False in launch().",
|
||||
"To turn off, set debug=False in launch().",
|
||||
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
|
||||
"SHARE_LINK_MESSAGE": "\nThis share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)",
|
||||
"PRIVATE_LINK_MESSAGE": "Since this is a private endpoint, this share link will never expire.",
|
||||
@ -27,12 +27,16 @@ en = {
|
||||
"Let users specify why they flagged input with the flagging_options= kwarg; for example: gr.Interface(..., flagging_options=['too slow', 'incorrect output', 'other'])",
|
||||
"You can show or hide the buttons for flagging, screenshots, and interpretation with the allow_*= kwargs; for example: gr.Interface(..., allow_screenshot=True, allow_flagging=False)",
|
||||
"The inputs and outputs flagged by the users are stored in the flagging directory, specified by the flagging_dir= kwarg. You can view this data through the interface by setting the examples= kwarg to the flagging directory; for example gr.Interface(..., examples='flagged')",
|
||||
"You can add a title and description to your interface using the title= and description= kwargs. The article= kwarg can be used to add markdown or HTML under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum')"
|
||||
]
|
||||
"You can add a title and description to your interface using the title= and description= kwargs. The article= kwarg can be used to add markdown or HTML under the interface; for example gr.Interface(..., title='My app', description='Lorem ipsum')",
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
updated_messaging = requests.get(MESSAGING_API_ENDPOINT, timeout=3).json()
|
||||
en.update(updated_messaging)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout, json.decoder.JSONDecodeError): # Use default messaging
|
||||
except (
|
||||
requests.ConnectionError,
|
||||
requests.exceptions.ReadTimeout,
|
||||
json.decoder.JSONDecodeError,
|
||||
): # Use default messaging
|
||||
pass
|
||||
|
8637
gradio/test_data.py
8637
gradio/test_data.py
File diff suppressed because one or more lines are too long
@ -7,8 +7,9 @@ import select
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from io import StringIO
|
||||
import warnings
|
||||
from io import StringIO
|
||||
|
||||
import paramiko
|
||||
|
||||
|
||||
@ -21,9 +22,9 @@ def handler(chan, host, port):
|
||||
return
|
||||
|
||||
verbose(
|
||||
"Connected! Tunnel open {} -> {} -> {}".format(chan.origin_addr,
|
||||
chan.getpeername(),
|
||||
(host, port))
|
||||
"Connected! Tunnel open {} -> {} -> {}".format(
|
||||
chan.origin_addr, chan.getpeername(), (host, port)
|
||||
)
|
||||
)
|
||||
while True:
|
||||
r, w, x = select.select([sock, chan], [], [])
|
||||
@ -39,7 +40,11 @@ def handler(chan, host, port):
|
||||
sock.send(data)
|
||||
chan.close()
|
||||
sock.close()
|
||||
verbose("Tunnel closed from {}".format(chan.origin_addr,))
|
||||
verbose(
|
||||
"Tunnel closed from {}".format(
|
||||
chan.origin_addr,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
@ -64,8 +69,7 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
verbose(
|
||||
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload[
|
||||
"port"]))
|
||||
"Connecting to ssh host {}:{} ...".format(payload["host"], int(payload["port"]))
|
||||
)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
@ -78,16 +82,16 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to {}:{}: {}}".format(payload["host"],
|
||||
int(payload["port"]), e)
|
||||
"*** Failed to connect to {}:{}: {}}".format(
|
||||
payload["host"], int(payload["port"]), e
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
verbose(
|
||||
"Now forwarding remote port {} to {}:{} ...".format(int(payload[
|
||||
"remote_port"]),
|
||||
local_server,
|
||||
local_server_port)
|
||||
"Now forwarding remote port {} to {}:{} ...".format(
|
||||
int(payload["remote_port"]), local_server, local_server_port
|
||||
)
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
|
143
gradio/utils.py
143
gradio/utils.py
@ -1,19 +1,21 @@
|
||||
""" Handy utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import analytics
|
||||
|
||||
import csv
|
||||
from distutils.version import StrictVersion
|
||||
import inspect
|
||||
import json
|
||||
import json.decoder
|
||||
import os
|
||||
import pkg_resources
|
||||
import random
|
||||
import requests
|
||||
from typing import Callable, Any, Dict, TYPE_CHECKING
|
||||
import warnings
|
||||
from distutils.version import StrictVersion
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
|
||||
import aiohttp
|
||||
import analytics
|
||||
import pkg_resources
|
||||
import requests
|
||||
|
||||
import gradio
|
||||
|
||||
@ -21,7 +23,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio import Interface
|
||||
|
||||
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
analytics_url = "https://api.gradio.app/"
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
@ -32,13 +34,18 @@ def version_check():
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
print(
|
||||
"IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version
|
||||
)
|
||||
)
|
||||
print("--------")
|
||||
except pkg_resources.DistributionNotFound:
|
||||
warnings.warn("gradio is not setup or installed properly. Unable to get version info.")
|
||||
warnings.warn(
|
||||
"gradio is not setup or installed properly. Unable to get version info."
|
||||
)
|
||||
except json.decoder.JSONDecodeError:
|
||||
warnings.warn("unable to parse version details from package URL.")
|
||||
except KeyError:
|
||||
@ -49,34 +56,36 @@ def version_check():
|
||||
|
||||
def get_local_ip_address() -> str:
|
||||
try:
|
||||
ip_address = requests.get('https://api.ipify.org', timeout=3).text
|
||||
ip_address = requests.get("https://api.ipify.org", timeout=3).text
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
ip_address = "No internet connection"
|
||||
return ip_address
|
||||
|
||||
|
||||
def initiated_analytics(data: Dict[str: Any]) -> None:
|
||||
def initiated_analytics(data: Dict[str:Any]) -> None:
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-initiated-analytics/',
|
||||
data=data, timeout=3)
|
||||
requests.post(
|
||||
analytics_url + "gradio-initiated-analytics/", data=data, timeout=3
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def launch_analytics(data: Dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-launched-analytics/',
|
||||
data=data, timeout=3)
|
||||
requests.post(
|
||||
analytics_url + "gradio-launched-analytics/", data=data, timeout=3
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def integration_analytics(data: Dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-integration-analytics/',
|
||||
data=data, timeout=3)
|
||||
except (
|
||||
requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
requests.post(
|
||||
analytics_url + "gradio-integration-analytics/", data=data, timeout=3
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
@ -85,23 +94,23 @@ def error_analytics(ip_address: str, message: str) -> None:
|
||||
Send error analytics if there is network
|
||||
:param type: RuntimeError or NameError
|
||||
"""
|
||||
data = {'ip_address': ip_address, 'error': message}
|
||||
data = {"ip_address": ip_address, "error": message}
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-error-analytics/',
|
||||
data=data, timeout=3)
|
||||
requests.post(analytics_url + "gradio-error-analytics/", data=data, timeout=3)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
async def log_feature_analytics(ip_address: str, feature: str) -> None:
|
||||
data={'ip_address': ip_address, 'feature': feature}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(
|
||||
analytics_url + 'gradio-feature-analytics/', data=data):
|
||||
pass
|
||||
except (aiohttp.ClientError):
|
||||
pass # do not push analytics if no network
|
||||
data = {"ip_address": ip_address, "feature": feature}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(
|
||||
analytics_url + "gradio-feature-analytics/", data=data
|
||||
):
|
||||
pass
|
||||
except (aiohttp.ClientError):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def colab_check() -> bool:
|
||||
@ -112,6 +121,7 @@ def colab_check() -> bool:
|
||||
is_colab = False
|
||||
try: # Check if running interactively using ipython.
|
||||
from IPython import get_ipython
|
||||
|
||||
from_ipynb = get_ipython()
|
||||
if "google.colab" in str(from_ipynb):
|
||||
is_colab = True
|
||||
@ -128,6 +138,7 @@ def ipython_check() -> bool:
|
||||
is_ipython = False
|
||||
try: # Check if running interactively using ipython.
|
||||
from IPython import get_ipython
|
||||
|
||||
if get_ipython() is not None:
|
||||
is_ipython = True
|
||||
except (ImportError, NameError):
|
||||
@ -138,7 +149,7 @@ def ipython_check() -> bool:
|
||||
def readme_to_html(article: str) -> str:
|
||||
try:
|
||||
response = requests.get(article, timeout=3)
|
||||
if response.status_code == requests.codes.ok: #pylint: disable=no-member
|
||||
if response.status_code == requests.codes.ok: # pylint: disable=no-member
|
||||
article = response.text
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
@ -172,11 +183,11 @@ def launch_counter() -> None:
|
||||
def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
config = {
|
||||
"input_components": [
|
||||
iface.get_template_context()
|
||||
for iface in interface.input_components],
|
||||
iface.get_template_context() for iface in interface.input_components
|
||||
],
|
||||
"output_components": [
|
||||
iface.get_template_context()
|
||||
for iface in interface.output_components],
|
||||
iface.get_template_context() for iface in interface.output_components
|
||||
],
|
||||
"function_count": len(interface.predict),
|
||||
"live": interface.live,
|
||||
"examples_per_page": interface.examples_per_page,
|
||||
@ -194,9 +205,11 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
"flagging_options": interface.flagging_options,
|
||||
"allow_interpretation": interface.interpretation is not None,
|
||||
"queue": interface.enable_queue,
|
||||
"cached_examples": interface.cache_examples if hasattr(interface, "cache_examples") else False,
|
||||
"cached_examples": interface.cache_examples
|
||||
if hasattr(interface, "cache_examples")
|
||||
else False,
|
||||
"version": pkg_resources.require("gradio")[0].version,
|
||||
"favicon_path": interface.favicon_path
|
||||
"favicon_path": interface.favicon_path,
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(interface.predict[0])[0]
|
||||
@ -205,16 +218,23 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
iface["label"] = param.replace("_", " ")
|
||||
for i, iface in enumerate(config["output_components"]):
|
||||
outputs_per_function = int(
|
||||
len(interface.output_components) / len(interface.predict))
|
||||
len(interface.output_components) / len(interface.predict)
|
||||
)
|
||||
function_index = i // outputs_per_function
|
||||
component_index = i - function_index * outputs_per_function
|
||||
ret_name = "Output " + \
|
||||
str(component_index + 1) if outputs_per_function > 1 else "Output"
|
||||
ret_name = (
|
||||
"Output " + str(component_index + 1)
|
||||
if outputs_per_function > 1
|
||||
else "Output"
|
||||
)
|
||||
if iface["label"] is None:
|
||||
iface["label"] = ret_name
|
||||
if len(interface.predict) > 1:
|
||||
iface["label"] = interface.function_names[function_index].replace(
|
||||
"_", " ") + ": " + iface["label"]
|
||||
iface["label"] = (
|
||||
interface.function_names[function_index].replace("_", " ")
|
||||
+ ": "
|
||||
+ iface["label"]
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
@ -222,23 +242,36 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
if isinstance(interface.examples, str):
|
||||
if not os.path.exists(interface.examples):
|
||||
raise FileNotFoundError(
|
||||
"Could not find examples directory: " + interface.examples)
|
||||
"Could not find examples directory: " + interface.examples
|
||||
)
|
||||
log_file = os.path.join(interface.examples, "log.csv")
|
||||
if not os.path.exists(log_file):
|
||||
if len(interface.input_components) == 1:
|
||||
examples = [[os.path.join(interface.examples, item)]
|
||||
for item in os.listdir(interface.examples)]
|
||||
examples = [
|
||||
[os.path.join(interface.examples, item)]
|
||||
for item in os.listdir(interface.examples)
|
||||
]
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Could not find log file (required for multiple inputs): " + log_file)
|
||||
"Could not find log file (required for multiple inputs): "
|
||||
+ log_file
|
||||
)
|
||||
else:
|
||||
with open(log_file) as logs:
|
||||
examples = list(csv.reader(logs))
|
||||
examples = examples[1:] # remove header
|
||||
for i, example in enumerate(examples):
|
||||
for j, (component, cell) in enumerate(zip(interface.input_components + interface.output_components, example)):
|
||||
for j, (component, cell) in enumerate(
|
||||
zip(
|
||||
interface.input_components + interface.output_components,
|
||||
example,
|
||||
)
|
||||
):
|
||||
examples[i][j] = component.restore_flagged(
|
||||
interface.flagging_dir, cell, interface.encryption_key if interface.encrypt else None)
|
||||
interface.flagging_dir,
|
||||
cell,
|
||||
interface.encryption_key if interface.encrypt else None,
|
||||
)
|
||||
config["examples"] = examples
|
||||
config["examples_dir"] = interface.examples
|
||||
else:
|
||||
@ -249,8 +282,6 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
|
||||
def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
signature = inspect.signature(func)
|
||||
return [
|
||||
v.default
|
||||
if v.default is not inspect.Parameter.empty
|
||||
else None
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
for v in signature.parameters.values()
|
||||
]
|
||||
]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import re
|
||||
from os.path import join, exists, getmtime
|
||||
from jinja2 import Environment, BaseLoader, TemplateNotFound
|
||||
from os.path import exists, getmtime, join
|
||||
|
||||
from jinja2 import BaseLoader, Environment, TemplateNotFound
|
||||
|
||||
README_TEMPLATE = "readme_template.md"
|
||||
GETTING_STARTED_TEMPLATE = "getting_started.md"
|
||||
@ -15,11 +16,16 @@ code, demos = {}, {}
|
||||
for code_src in code_tags:
|
||||
with open(join("demo", code_src, "run.py")) as code_file:
|
||||
python_code = code_file.read()
|
||||
python_code = python_code.replace('if __name__ == "__main__":\n iface.launch()', "iface.launch()")
|
||||
python_code = python_code.replace(
|
||||
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
|
||||
)
|
||||
code[code_src] = "```python\n" + python_code + "\n```"
|
||||
|
||||
for demo_src in demo_tags:
|
||||
demos[demo_src] = ""
|
||||
demos[demo_src] = (
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
class GuidesLoader(BaseLoader):
|
||||
def __init__(self, path):
|
||||
@ -34,8 +40,11 @@ class GuidesLoader(BaseLoader):
|
||||
source = f.read()
|
||||
return source, path, lambda: mtime == getmtime(path)
|
||||
|
||||
readme_template = Environment(loader=GuidesLoader("guides")).get_template(README_TEMPLATE)
|
||||
|
||||
readme_template = Environment(loader=GuidesLoader("guides")).get_template(
|
||||
README_TEMPLATE
|
||||
)
|
||||
output_readme = readme_template.render(code=code, demos=demos)
|
||||
|
||||
with open("README.md", "w") as readme_md:
|
||||
readme_md.write(output_readme)
|
||||
readme_md.write(output_readme)
|
||||
|
48
setup.py
48
setup.py
@ -4,31 +4,31 @@ except ImportError:
|
||||
from distutils.core import setup
|
||||
|
||||
setup(
|
||||
name='gradio',
|
||||
version='2.7.1',
|
||||
name="gradio",
|
||||
version="2.7.1",
|
||||
include_package_data=True,
|
||||
description='Python library for easily interacting with trained machine learning models',
|
||||
author='Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq',
|
||||
author_email='team@gradio.app',
|
||||
url='https://github.com/gradio-app/gradio-UI',
|
||||
packages=['gradio'],
|
||||
license='Apache License 2.0',
|
||||
keywords=['machine learning', 'visualization', 'reproducibility'],
|
||||
description="Python library for easily interacting with trained machine learning models",
|
||||
author="Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq",
|
||||
author_email="team@gradio.app",
|
||||
url="https://github.com/gradio-app/gradio-UI",
|
||||
packages=["gradio"],
|
||||
license="Apache License 2.0",
|
||||
keywords=["machine learning", "visualization", "reproducibility"],
|
||||
install_requires=[
|
||||
'analytics-python',
|
||||
'aiohttp',
|
||||
'fastapi',
|
||||
'ffmpy',
|
||||
'markdown2',
|
||||
'matplotlib',
|
||||
'numpy',
|
||||
'pandas',
|
||||
'paramiko',
|
||||
'pillow',
|
||||
'pycryptodome',
|
||||
'python-multipart',
|
||||
'pydub',
|
||||
'requests',
|
||||
'uvicorn',
|
||||
"analytics-python",
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"ffmpy",
|
||||
"markdown2",
|
||||
"matplotlib",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"paramiko",
|
||||
"pillow",
|
||||
"pycryptodome",
|
||||
"python-multipart",
|
||||
"pydub",
|
||||
"requests",
|
||||
"uvicorn",
|
||||
],
|
||||
)
|
||||
|
@ -1,14 +1,15 @@
|
||||
import unittest
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from matplotlib.testing.compare import compare_images
|
||||
import random
|
||||
import os
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
|
||||
current_dir = os.getcwd()
|
||||
|
||||
@ -36,12 +37,14 @@ def wait_for_url(url):
|
||||
|
||||
def diff_texts_thread(return_dict):
|
||||
from demo.diff_texts.run import iface
|
||||
|
||||
iface.save_to = return_dict
|
||||
iface.launch()
|
||||
|
||||
|
||||
def image_mod_thread(return_dict):
|
||||
from demo.image_mod.run import iface
|
||||
|
||||
iface.examples = None
|
||||
iface.save_to = return_dict
|
||||
iface.launch()
|
||||
@ -49,12 +52,14 @@ def image_mod_thread(return_dict):
|
||||
|
||||
def longest_word_thread(return_dict):
|
||||
from demo.longest_word.run import iface
|
||||
|
||||
iface.save_to = return_dict
|
||||
iface.launch()
|
||||
|
||||
|
||||
def sentence_builder_thread(return_dict):
|
||||
from demo.sentence_builder.run import iface
|
||||
|
||||
iface.save_to = return_dict
|
||||
iface.launch()
|
||||
|
||||
@ -63,8 +68,7 @@ class TestDemo(unittest.TestCase):
|
||||
def start_test(self, target):
|
||||
manager = multiprocessing.Manager()
|
||||
return_dict = manager.dict()
|
||||
self.i_thread = multiprocessing.Process(target=target,
|
||||
args=(return_dict,))
|
||||
self.i_thread = multiprocessing.Process(target=target, args=(return_dict,))
|
||||
self.i_thread.start()
|
||||
total_sleep = 0
|
||||
while not return_dict and total_sleep < TIMEOUT:
|
||||
@ -81,25 +85,36 @@ class TestDemo(unittest.TestCase):
|
||||
def test_diff_texts(self):
|
||||
driver = self.start_test(target=diff_texts_thread)
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(1) .input_text textarea"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(1) .input_text textarea",
|
||||
)
|
||||
)
|
||||
)
|
||||
elem.clear()
|
||||
elem.send_keys("Want to see a magic trick?")
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(2) .input_text textarea"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(2) .input_text textarea",
|
||||
)
|
||||
)
|
||||
)
|
||||
elem.clear()
|
||||
elem.send_keys("Let's go see a magic trick!")
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".submit"))
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
|
||||
)
|
||||
elem.click()
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".panel:nth-child(2) .component:nth-child(2) .textfield"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(2) .component:nth-child(2) .textfield",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
total_sleep = 0
|
||||
@ -108,10 +123,12 @@ class TestDemo(unittest.TestCase):
|
||||
total_sleep += 0.2
|
||||
|
||||
self.assertEqual(elem.text, "L+e+W-a-n-t'+s+ t-g+o see a magic trick?-!+")
|
||||
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
|
||||
"diff_texts", "magic_trick"))
|
||||
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
|
||||
random.getrandbits(32)))
|
||||
golden_img = os.path.join(
|
||||
current_dir, GOLDEN_PATH.format("diff_texts", "magic_trick")
|
||||
)
|
||||
tmp = os.path.join(
|
||||
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
|
||||
)
|
||||
time.sleep(GAP_TO_SCREENSHOT)
|
||||
driver.save_screenshot(tmp)
|
||||
driver.close()
|
||||
@ -122,23 +139,32 @@ class TestDemo(unittest.TestCase):
|
||||
driver = self.start_test(target=image_mod_thread)
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located(
|
||||
(By.CSS_SELECTOR, ".panel:nth-child(1) .component:nth-child(1) .hidden_upload"))
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(1) .hidden_upload",
|
||||
)
|
||||
)
|
||||
)
|
||||
cwd = os.getcwd()
|
||||
rel = "test/test_files/cheetah1.jpg"
|
||||
elem.send_keys(os.path.join(cwd, rel))
|
||||
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
|
||||
"image_mod", "cheetah1"))
|
||||
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
|
||||
random.getrandbits(32)))
|
||||
golden_img = os.path.join(
|
||||
current_dir, GOLDEN_PATH.format("image_mod", "cheetah1")
|
||||
)
|
||||
tmp = os.path.join(
|
||||
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
|
||||
)
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".submit"))
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
|
||||
)
|
||||
elem.click()
|
||||
WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.visibility_of_element_located(
|
||||
(By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_image"))
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(2) .component:nth-child(2) .output_image",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
time.sleep(GAP_TO_SCREENSHOT)
|
||||
@ -150,18 +176,25 @@ class TestDemo(unittest.TestCase):
|
||||
def test_longest_word(self):
|
||||
driver = self.start_test(target=longest_word_thread)
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(1) .input_text textarea"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(1) .component:nth-child(1) .input_text textarea",
|
||||
)
|
||||
)
|
||||
)
|
||||
elem.send_keys("This is the most wonderful machine learning "
|
||||
"library.")
|
||||
elem.send_keys("This is the most wonderful machine learning " "library.")
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".submit"))
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
|
||||
)
|
||||
elem.click()
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
total_sleep = 0
|
||||
@ -169,10 +202,12 @@ class TestDemo(unittest.TestCase):
|
||||
time.sleep(0.2)
|
||||
total_sleep += 0.2
|
||||
|
||||
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
|
||||
"longest_word", "wonderful"))
|
||||
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
|
||||
random.getrandbits(32)))
|
||||
golden_img = os.path.join(
|
||||
current_dir, GOLDEN_PATH.format("longest_word", "wonderful")
|
||||
)
|
||||
tmp = os.path.join(
|
||||
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
|
||||
)
|
||||
time.sleep(GAP_TO_SCREENSHOT)
|
||||
driver.save_screenshot(tmp)
|
||||
driver.close()
|
||||
@ -182,12 +217,16 @@ class TestDemo(unittest.TestCase):
|
||||
def test_sentence_builder(self):
|
||||
driver = self.start_test(target=sentence_builder_thread)
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR,
|
||||
".submit"))
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".submit"))
|
||||
)
|
||||
elem.click()
|
||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_text"))
|
||||
EC.presence_of_element_located(
|
||||
(
|
||||
By.CSS_SELECTOR,
|
||||
".panel:nth-child(2) .component:nth-child(2) .output_text",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
total_sleep = 0
|
||||
@ -196,11 +235,14 @@ class TestDemo(unittest.TestCase):
|
||||
total_sleep += 0.2
|
||||
|
||||
self.assertEqual(
|
||||
elem.text, "The 2 cats went to the park where they until the night")
|
||||
golden_img = os.path.join(current_dir, GOLDEN_PATH.format(
|
||||
"sentence_builder", "two_cats"))
|
||||
tmp = os.path.join(current_dir, "test/tmp/{}.png".format(
|
||||
random.getrandbits(32)))
|
||||
elem.text, "The 2 cats went to the park where they until the night"
|
||||
)
|
||||
golden_img = os.path.join(
|
||||
current_dir, GOLDEN_PATH.format("sentence_builder", "two_cats")
|
||||
)
|
||||
tmp = os.path.join(
|
||||
current_dir, "test/tmp/{}.png".format(random.getrandbits(32))
|
||||
)
|
||||
time.sleep(GAP_TO_SCREENSHOT)
|
||||
driver.save_screenshot(tmp)
|
||||
self.assertIsNone(compare_images(tmp, golden_img, TOLERANCE))
|
||||
@ -212,5 +254,5 @@ class TestDemo(unittest.TestCase):
|
||||
self.i_thread.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
import pathlib
|
||||
import transformers
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
import gradio as gr
|
||||
|
||||
"""
|
||||
@ -18,7 +19,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_question_answering(self):
|
||||
model_type = "question-answering"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
|
||||
"deepset/roberta-base-squad2", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||
@ -28,7 +30,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"gpt2", api_key=None, alias=model_type)
|
||||
"gpt2", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
@ -36,7 +39,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_summarization(self):
|
||||
model_type = "summarization"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
@ -44,7 +48,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_translation(self):
|
||||
model_type = "translation"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
@ -52,7 +57,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_text2text_generation(self):
|
||||
model_type = "text2text-generation"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
|
||||
"sshleifer/tiny-mbart", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
@ -61,7 +67,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
model_type = "text-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||
api_key=None, alias=model_type)
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
@ -69,8 +77,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_fill_mask(self):
|
||||
model_type = "fill-mask"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"bert-base-uncased",
|
||||
api_key=None, alias=model_type)
|
||||
"bert-base-uncased", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
@ -78,8 +86,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_zero_shot_classification(self):
|
||||
model_type = "zero-shot-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/bart-large-mnli",
|
||||
api_key=None, alias=model_type)
|
||||
"facebook/bart-large-mnli", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||
@ -89,8 +97,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_automatic_speech_recognition(self):
|
||||
model_type = "automatic-speech-recognition"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/wav2vec2-base-960h",
|
||||
api_key=None, alias=model_type)
|
||||
"facebook/wav2vec2-base-960h", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
@ -98,8 +106,8 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_image_classification(self):
|
||||
model_type = "image-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"google/vit-base-patch16-224",
|
||||
api_key=None, alias=model_type)
|
||||
"google/vit-base-patch16-224", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
@ -108,7 +116,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
model_type = "feature-extraction"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"sentence-transformers/distilbert-base-nli-mean-tokens",
|
||||
api_key=None, alias=model_type)
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
|
||||
@ -117,7 +127,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
model_type = "text-to-speech"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
|
||||
api_key=None, alias=model_type)
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
|
||||
@ -126,7 +138,9 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
model_type = "text-to-speech"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
|
||||
api_key=None, alias=model_type)
|
||||
api_key=None,
|
||||
alias=model_type,
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
|
||||
@ -134,60 +148,74 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_text_to_image(self):
|
||||
model_type = "text-to-image"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"osanseviero/BigGAN-deep-128",
|
||||
api_key=None, alias=model_type)
|
||||
"osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
|
||||
|
||||
def test_english_to_spanish(self):
|
||||
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
|
||||
interface_info = gr.external.get_spaces_interface(
|
||||
"abidlabs/english_to_spanish", api_key=None, alias=None
|
||||
)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
|
||||
|
||||
class TestLoadInterface(unittest.TestCase):
|
||||
def test_english_to_spanish(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/english_to_spanish")
|
||||
interface_info = gr.external.load_interface(
|
||||
"spaces/abidlabs/english_to_spanish"
|
||||
)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
|
||||
def test_sentiment_model(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
interface_info = gr.external.load_interface(
|
||||
"models/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
alias="sentiment_classifier",
|
||||
)
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['POSITIVE'], 0.5)
|
||||
self.assertGreater(output["POSITIVE"], 0.5)
|
||||
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
interface_info = gr.external.load_interface(
|
||||
"models/google/vit-base-patch16-224"
|
||||
)
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("test/test_data/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
self.assertGreater(output["lion"], 0.5)
|
||||
|
||||
def test_translation_model(self):
|
||||
interface_info = gr.external.load_interface("models/t5-base")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
|
||||
self.assertEquals(output, "Mein Name ist Sarah und ich lebe in London")
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("male", 77, 10)
|
||||
self.assertLess(output['Survives'], 0.5)
|
||||
self.assertLess(output["Survives"], 0.5)
|
||||
|
||||
def test_speech_recognition_model(self):
|
||||
interface_info = gr.external.load_interface("models/jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
||||
interface_info = gr.external.load_interface(
|
||||
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english"
|
||||
)
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("test/test_data/test_audio.wav")
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
def test_text_to_image_model(self):
|
||||
interface_info = gr.external.load_interface("models/osanseviero/BigGAN-deep-128")
|
||||
interface_info = gr.external.load_interface(
|
||||
"models/osanseviero/BigGAN-deep-128"
|
||||
)
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
filename = io("chest")
|
||||
@ -207,11 +235,14 @@ class TestLoadInterface(unittest.TestCase):
|
||||
|
||||
class TestLoadFromPipeline(unittest.TestCase):
|
||||
def test_question_answering(self):
|
||||
p = transformers.pipeline("question-answering")
|
||||
p = transformers.pipeline("question-answering")
|
||||
io = gr.Interface.from_pipeline(p)
|
||||
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn", "Where do I work?")
|
||||
output = io(
|
||||
"My name is Sylvain and I work at Hugging Face in Brooklyn",
|
||||
"Where do I work?",
|
||||
)
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,9 +1,10 @@
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
def test_default_flagging_handler(self):
|
||||
@ -18,7 +19,13 @@ class TestFlagging(unittest.TestCase):
|
||||
|
||||
def test_simple_csv_flagging_handler(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname, flagging_callback=flagging.SimpleCSVLogger())
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
flagging_dir=tmpdirname,
|
||||
flagging_callback=flagging.SimpleCSVLogger(),
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 0) # no header
|
||||
@ -27,5 +34,5 @@ class TestFlagging(unittest.TestCase):
|
||||
io.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,14 +1,15 @@
|
||||
from re import sub
|
||||
import unittest
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import numpy as np
|
||||
import pandas
|
||||
from pydub import AudioSegment
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
import unittest
|
||||
from re import sub
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
import PIL
|
||||
from pydub import AudioSegment
|
||||
|
||||
import gradio as gr
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -32,7 +33,9 @@ class TestTextbox(unittest.TestCase):
|
||||
self.assertEqual(text_input.preprocess_example("Hello World!"), "Hello World!")
|
||||
self.assertEqual(text_input.serialize("Hello World!", True), "Hello World!")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = text_input.save_flagged(tmpdirname, "text_input", "Hello World!", None)
|
||||
to_save = text_input.save_flagged(
|
||||
tmpdirname, "text_input", "Hello World!", None
|
||||
)
|
||||
self.assertEqual(to_save, "Hello World!")
|
||||
restored = text_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, "Hello World!")
|
||||
@ -44,29 +47,79 @@ class TestTextbox(unittest.TestCase):
|
||||
wrong_type = gr.inputs.Textbox(type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
|
||||
['Hello', 'World!', 'Gradio', 'speaking.'],
|
||||
['World! Gradio speaking.', 'Hello Gradio speaking.', 'Hello World! speaking.', 'Hello World! Gradio'],
|
||||
None))
|
||||
self.assertEqual(
|
||||
text_input.tokenize("Hello World! Gradio speaking."),
|
||||
(
|
||||
["Hello", "World!", "Gradio", "speaking."],
|
||||
[
|
||||
"World! Gradio speaking.",
|
||||
"Hello Gradio speaking.",
|
||||
"Hello World! speaking.",
|
||||
"Hello World! Gradio",
|
||||
],
|
||||
None,
|
||||
),
|
||||
)
|
||||
text_input.interpretation_replacement = "unknown"
|
||||
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
|
||||
['Hello', 'World!', 'Gradio', 'speaking.'],
|
||||
['unknown World! Gradio speaking.', 'Hello unknown Gradio speaking.', 'Hello World! unknown speaking.',
|
||||
'Hello World! Gradio unknown'], None))
|
||||
self.assertEqual(
|
||||
text_input.tokenize("Hello World! Gradio speaking."),
|
||||
(
|
||||
["Hello", "World!", "Gradio", "speaking."],
|
||||
[
|
||||
"unknown World! Gradio speaking.",
|
||||
"Hello unknown Gradio speaking.",
|
||||
"Hello World! unknown speaking.",
|
||||
"Hello World! Gradio unknown",
|
||||
],
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIsInstance(text_input.generate_sample(), str)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
iface = gr.Interface(lambda sentence: max([len(word) for word in sentence.split()]), gr.inputs.Textbox(),
|
||||
gr.outputs.Textbox(), interpretation="default")
|
||||
scores, alternative_outputs = iface.interpret(["Return the length of the longest word in this sentence"])
|
||||
self.assertEqual(scores, [[('Return', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('length', 0.0), (' ', 0),
|
||||
('of', 0.0), (' ', 0), ('the', 0.0), (' ', 0), ('longest', 0.0), (' ', 0),
|
||||
('word', 0.0), (' ', 0), ('in', 0.0), (' ', 0), ('this', 0.0), (' ', 0),
|
||||
('sentence', 1.0), (' ', 0)]])
|
||||
self.assertEqual(alternative_outputs, [[['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['8'], ['7']]])
|
||||
iface = gr.Interface(
|
||||
lambda sentence: max([len(word) for word in sentence.split()]),
|
||||
gr.inputs.Textbox(),
|
||||
gr.outputs.Textbox(),
|
||||
interpretation="default",
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret(
|
||||
["Return the length of the longest word in this sentence"]
|
||||
)
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
("Return", 0.0),
|
||||
(" ", 0),
|
||||
("the", 0.0),
|
||||
(" ", 0),
|
||||
("length", 0.0),
|
||||
(" ", 0),
|
||||
("of", 0.0),
|
||||
(" ", 0),
|
||||
("the", 0.0),
|
||||
(" ", 0),
|
||||
("longest", 0.0),
|
||||
(" ", 0),
|
||||
("word", 0.0),
|
||||
(" ", 0),
|
||||
("in", 0.0),
|
||||
(" ", 0),
|
||||
("this", 0.0),
|
||||
(" ", 0),
|
||||
("sentence", 1.0),
|
||||
(" ", 0),
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[[["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["8"], ["7"]]],
|
||||
)
|
||||
|
||||
|
||||
class TestNumber(unittest.TestCase):
|
||||
@ -82,20 +135,50 @@ class TestNumber(unittest.TestCase):
|
||||
self.assertEqual(restored, 3)
|
||||
self.assertIsInstance(numeric_input.generate_sample(), float)
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
|
||||
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}))
|
||||
self.assertEqual(
|
||||
numeric_input.get_interpretation_neighbors(1),
|
||||
([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}),
|
||||
)
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent")
|
||||
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
|
||||
self.assertEqual(
|
||||
numeric_input.get_interpretation_neighbors(1),
|
||||
([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}),
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ['4.0'])
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox", interpretation="default")
|
||||
iface = gr.Interface(lambda x: x ** 2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4.0"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x ** 2, "number", "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(scores, [[(1.94, -0.23640000000000017), (1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012), [2, None], (2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997), (2.06, 0.24359999999999982)]])
|
||||
self.assertEqual(alternative_outputs, [[['3.7636'], ['3.8415999999999997'], ['3.9204'], ['4.0804'], ['4.1616'],
|
||||
['4.2436']]])
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
["3.7636"],
|
||||
["3.8415999999999997"],
|
||||
["3.9204"],
|
||||
["4.0804"],
|
||||
["4.1616"],
|
||||
["4.2436"],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestSlider(unittest.TestCase):
|
||||
@ -111,26 +194,58 @@ class TestSlider(unittest.TestCase):
|
||||
self.assertEqual(restored, 3)
|
||||
|
||||
self.assertIsInstance(slider_input.generate_sample(), int)
|
||||
slider_input = gr.inputs.Slider(minimum=10, maximum=20, step=1, default=15, label="Slide Your Input")
|
||||
self.assertEqual(slider_input.get_template_context(), {
|
||||
'minimum': 10,
|
||||
'maximum': 20,
|
||||
'step': 1,
|
||||
'default': 15,
|
||||
'name': 'slider',
|
||||
'label': 'Slide Your Input'
|
||||
})
|
||||
slider_input = gr.inputs.Slider(
|
||||
minimum=10, maximum=20, step=1, default=15, label="Slide Your Input"
|
||||
)
|
||||
self.assertEqual(
|
||||
slider_input.get_template_context(),
|
||||
{
|
||||
"minimum": 10,
|
||||
"maximum": 20,
|
||||
"step": 1,
|
||||
"default": 15,
|
||||
"name": "slider",
|
||||
"label": "Slide Your Input",
|
||||
},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ['4'])
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox", interpretation="default")
|
||||
iface = gr.Interface(lambda x: x ** 2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x ** 2, "slider", "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(scores, [[-4.0, 200.08163265306123, 812.3265306122449, 1832.7346938775513, 3261.3061224489797,
|
||||
5098.040816326531, 7342.938775510205, 9996.0]])
|
||||
self.assertEqual(alternative_outputs, [[['0.0'], ['204.08163265306123'], ['816.3265306122449'],
|
||||
['1836.7346938775513'], ['3265.3061224489797'], ['5102.040816326531'],
|
||||
['7346.938775510205'], ['10000.0']]])
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
-4.0,
|
||||
200.08163265306123,
|
||||
812.3265306122449,
|
||||
1832.7346938775513,
|
||||
3261.3061224489797,
|
||||
5098.040816326531,
|
||||
7342.938775510205,
|
||||
9996.0,
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
["0.0"],
|
||||
["204.08163265306123"],
|
||||
["816.3265306122449"],
|
||||
["1836.7346938775513"],
|
||||
["3265.3061224489797"],
|
||||
["5102.040816326531"],
|
||||
["7346.938775510205"],
|
||||
["10000.0"],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestCheckbox(unittest.TestCase):
|
||||
@ -146,22 +261,23 @@ class TestCheckbox(unittest.TestCase):
|
||||
self.assertEqual(restored, True)
|
||||
self.assertIsInstance(bool_input.generate_sample(), bool)
|
||||
bool_input = gr.inputs.Checkbox(default=True, label="Check Your Input")
|
||||
self.assertEqual(bool_input.get_template_context(), {
|
||||
'default': True,
|
||||
'name': 'checkbox',
|
||||
'label': 'Check Your Input'
|
||||
})
|
||||
self.assertEqual(
|
||||
bool_input.get_template_context(),
|
||||
{"default": True, "name": "checkbox", "label": "Check Your Input"},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox")
|
||||
self.assertEqual(iface.process([True])[0], ['1'])
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox", interpretation="default")
|
||||
self.assertEqual(iface.process([True])[0], ["1"])
|
||||
iface = gr.Interface(
|
||||
lambda x: 1 if x else 0, "checkbox", "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([False])
|
||||
self.assertEqual(scores, [(None, 1.0)])
|
||||
self.assertEqual(alternative_outputs, [[['1']]])
|
||||
self.assertEqual(alternative_outputs, [[["1"]]])
|
||||
scores, alternative_outputs = iface.interpret([True])
|
||||
self.assertEqual(scores, [(-1.0, None)])
|
||||
self.assertEqual(alternative_outputs, [[['0']]])
|
||||
self.assertEqual(alternative_outputs, [[["0"]]])
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
@ -171,19 +287,25 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
self.assertEqual(checkboxes_input.preprocess_example(["a", "c"]), ["a", "c"])
|
||||
self.assertEqual(checkboxes_input.serialize(["a", "c"], True), ["a", "c"])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = checkboxes_input.save_flagged(tmpdirname, "checkboxes_input", ["a", "c"], None)
|
||||
to_save = checkboxes_input.save_flagged(
|
||||
tmpdirname, "checkboxes_input", ["a", "c"], None
|
||||
)
|
||||
self.assertEqual(to_save, '["a", "c"]')
|
||||
restored = checkboxes_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, ["a", "c"])
|
||||
self.assertIsInstance(checkboxes_input.generate_sample(), list)
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(choices=["a", "b", "c"], default=["a", "c"],
|
||||
label="Check Your Inputs")
|
||||
self.assertEqual(checkboxes_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': ['a', 'c'],
|
||||
'name': 'checkboxgroup',
|
||||
'label': 'Check Your Inputs'
|
||||
})
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(
|
||||
choices=["a", "b", "c"], default=["a", "c"], label="Check Your Inputs"
|
||||
)
|
||||
self.assertEqual(
|
||||
checkboxes_input.get_template_context(),
|
||||
{
|
||||
"choices": ["a", "b", "c"],
|
||||
"default": ["a", "c"],
|
||||
"name": "checkboxgroup",
|
||||
"label": "Check Your Inputs",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.CheckboxGroup(["a"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
@ -194,11 +316,16 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
|
||||
self.assertEqual(iface.process([[]])[0], [""])
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes_input, "textbox", interpretation="default")
|
||||
iface = gr.Interface(
|
||||
lambda x: "|".join(map(str, x)),
|
||||
checkboxes_input,
|
||||
"textbox",
|
||||
interpretation="default",
|
||||
)
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["0|2"])
|
||||
scores, alternative_outputs = iface.interpret([["a", "c"]])
|
||||
self.assertEqual(scores, [[[-1, None], [None, -1], [-1, None]]])
|
||||
self.assertEqual(alternative_outputs, [[['2'], ['0|2|1'], ['0']]])
|
||||
self.assertEqual(alternative_outputs, [[["2"], ["0|2|1"], ["0"]]])
|
||||
|
||||
|
||||
class TestRadio(unittest.TestCase):
|
||||
@ -209,20 +336,24 @@ class TestRadio(unittest.TestCase):
|
||||
self.assertEqual(radio_input.serialize("a", True), "a")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = radio_input.save_flagged(tmpdirname, "radio_input", "a", None)
|
||||
self.assertEqual(to_save, 'a')
|
||||
self.assertEqual(to_save, "a")
|
||||
restored = radio_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, "a")
|
||||
self.assertIsInstance(radio_input.generate_sample(), str)
|
||||
radio_input = gr.inputs.Radio(choices=["a", "b", "c"], default="a",
|
||||
label="Pick Your One Input")
|
||||
self.assertEqual(radio_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': 'a',
|
||||
'name': 'radio',
|
||||
'label': 'Pick Your One Input'
|
||||
})
|
||||
radio_input = gr.inputs.Radio(
|
||||
choices=["a", "b", "c"], default="a", label="Pick Your One Input"
|
||||
)
|
||||
self.assertEqual(
|
||||
radio_input.get_template_context(),
|
||||
{
|
||||
"choices": ["a", "b", "c"],
|
||||
"default": "a",
|
||||
"name": "radio",
|
||||
"label": "Pick Your One Input",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Radio(["a","b"], type="unknown")
|
||||
wrong_type = gr.inputs.Radio(["a", "b"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
@ -230,7 +361,9 @@ class TestRadio(unittest.TestCase):
|
||||
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
radio_input = gr.inputs.Radio(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, radio_input, "number", interpretation="default")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, radio_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
scores, alternative_outputs = iface.interpret(["b"])
|
||||
self.assertEqual(scores, [[-2.0, None, 2.0]])
|
||||
@ -244,19 +377,25 @@ class TestDropdown(unittest.TestCase):
|
||||
self.assertEqual(dropdown_input.preprocess_example("a"), "a")
|
||||
self.assertEqual(dropdown_input.serialize("a", True), "a")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = dropdown_input.save_flagged(tmpdirname, "dropdown_input", "a", None)
|
||||
self.assertEqual(to_save, 'a')
|
||||
to_save = dropdown_input.save_flagged(
|
||||
tmpdirname, "dropdown_input", "a", None
|
||||
)
|
||||
self.assertEqual(to_save, "a")
|
||||
restored = dropdown_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, "a")
|
||||
self.assertIsInstance(dropdown_input.generate_sample(), str)
|
||||
dropdown_input = gr.inputs.Dropdown(choices=["a", "b", "c"], default="a",
|
||||
label="Drop Your Input")
|
||||
self.assertEqual(dropdown_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': 'a',
|
||||
'name': 'dropdown',
|
||||
'label': 'Drop Your Input'
|
||||
})
|
||||
dropdown_input = gr.inputs.Dropdown(
|
||||
choices=["a", "b", "c"], default="a", label="Drop Your Input"
|
||||
)
|
||||
self.assertEqual(
|
||||
dropdown_input.get_template_context(),
|
||||
{
|
||||
"choices": ["a", "b", "c"],
|
||||
"default": "a",
|
||||
"name": "dropdown",
|
||||
"label": "Drop Your Input",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Dropdown(["a"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
@ -266,7 +405,9 @@ class TestDropdown(unittest.TestCase):
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown, "number", interpretation="default")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, dropdown, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
scores, alternative_outputs = iface.interpret(["b"])
|
||||
self.assertEqual(scores, [[-2.0, None, 2.0]])
|
||||
@ -293,16 +434,21 @@ class TestImage(unittest.TestCase):
|
||||
self.assertEqual(restored, "image_input/1.png")
|
||||
|
||||
self.assertIsInstance(image_input.generate_sample(), str)
|
||||
image_input = gr.inputs.Image(source="upload", tool="editor", type="pil", label="Upload Your Image")
|
||||
self.assertEqual(image_input.get_template_context(), {
|
||||
'image_mode': 'RGB',
|
||||
'shape': None,
|
||||
'source': 'upload',
|
||||
'tool': 'editor',
|
||||
'optional': False,
|
||||
'name': 'image',
|
||||
'label': 'Upload Your Image'
|
||||
})
|
||||
image_input = gr.inputs.Image(
|
||||
source="upload", tool="editor", type="pil", label="Upload Your Image"
|
||||
)
|
||||
self.assertEqual(
|
||||
image_input.get_template_context(),
|
||||
{
|
||||
"image_mode": "RGB",
|
||||
"shape": None,
|
||||
"source": "upload",
|
||||
"tool": "editor",
|
||||
"optional": False,
|
||||
"name": "image",
|
||||
"label": "Upload Your Image",
|
||||
},
|
||||
)
|
||||
self.assertIsNone(image_input.preprocess(None))
|
||||
image_input = gr.inputs.Image(invert_colors=True)
|
||||
self.assertIsNotNone(image_input.preprocess(img))
|
||||
@ -318,7 +464,7 @@ class TestImage(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Image(type="unknown")
|
||||
wrong_type.serialize("test/test_files/bus.png", False)
|
||||
img_pil = PIL.Image.open('test/test_files/bus.png')
|
||||
img_pil = PIL.Image.open("test/test_files/bus.png")
|
||||
image_input = gr.inputs.Image(type="numpy")
|
||||
self.assertIsInstance(image_input.serialize(img_pil, False), str)
|
||||
image_input = gr.inputs.Image(type="pil")
|
||||
@ -332,21 +478,40 @@ class TestImage(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
img = gr.test_data.BASE64_IMAGE
|
||||
image_input = gr.inputs.Image()
|
||||
iface = gr.Interface(lambda x: PIL.Image.open(x).rotate(90, expand=True),
|
||||
gr.inputs.Image(shape=(30, 10), type="file"), "image")
|
||||
iface = gr.Interface(
|
||||
lambda x: PIL.Image.open(x).rotate(90, expand=True),
|
||||
gr.inputs.Image(shape=(30, 10), type="file"),
|
||||
"image",
|
||||
)
|
||||
output = iface.process([img])[0][0]
|
||||
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
|
||||
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
|
||||
self.assertEqual(
|
||||
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "textbox", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"])
|
||||
self.assertEqual(alternative_outputs, gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"])
|
||||
iface = gr.Interface(lambda x: np.sum(x), image_input, "label", interpretation="shap")
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"],
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "label", interpretation="shap"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(len(scores[0]), len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]))
|
||||
self.assertEqual(len(alternative_outputs[0]),
|
||||
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]))
|
||||
self.assertEqual(
|
||||
len(scores[0]),
|
||||
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]),
|
||||
)
|
||||
self.assertEqual(
|
||||
len(alternative_outputs[0]),
|
||||
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]),
|
||||
)
|
||||
image_input = gr.inputs.Image(shape=(30, 10))
|
||||
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "textbox", interpretation="default"
|
||||
)
|
||||
self.assertIsNotNone(iface.interpret([img]))
|
||||
|
||||
|
||||
@ -357,8 +522,11 @@ class TestAudio(unittest.TestCase):
|
||||
output = audio_input.preprocess(x_wav)
|
||||
self.assertEqual(output[0], 8000)
|
||||
self.assertEqual(output[1].shape, (8046,))
|
||||
self.assertEqual(audio_input.serialize("test/test_files/audio_sample.wav", True)["data"], x_wav["data"])
|
||||
|
||||
self.assertEqual(
|
||||
audio_input.serialize("test/test_files/audio_sample.wav", True)["data"],
|
||||
x_wav["data"],
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = audio_input.save_flagged(tmpdirname, "audio_input", x_wav, None)
|
||||
self.assertEqual("audio_input/0.wav", to_save)
|
||||
@ -369,12 +537,15 @@ class TestAudio(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(audio_input.generate_sample(), dict)
|
||||
audio_input = gr.inputs.Audio(label="Upload Your Audio")
|
||||
self.assertEqual(audio_input.get_template_context(), {
|
||||
'source': 'upload',
|
||||
'optional': False,
|
||||
'name': 'audio',
|
||||
'label': 'Upload Your Audio'
|
||||
})
|
||||
self.assertEqual(
|
||||
audio_input.get_template_context(),
|
||||
{
|
||||
"source": "upload",
|
||||
"optional": False,
|
||||
"name": "audio",
|
||||
"label": "Upload Your Audio",
|
||||
},
|
||||
)
|
||||
self.assertIsNone(audio_input.preprocess(None))
|
||||
x_wav["is_example"] = True
|
||||
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
|
||||
@ -394,7 +565,6 @@ class TestAudio(unittest.TestCase):
|
||||
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
|
||||
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
|
||||
|
||||
|
||||
# def test_in_interface(self):
|
||||
# x_wav = gr.test_data.BASE64_AUDIO
|
||||
# def max_amplitude_from_wav_file(wav_file):
|
||||
@ -405,7 +575,7 @@ class TestAudio(unittest.TestCase):
|
||||
# max_amplitude_from_wav_file,
|
||||
# gr.inputs.Audio(type="file"),
|
||||
# "number", interpretation="default")
|
||||
# # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576)
|
||||
# # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576)
|
||||
# self.assertEqual(iface.process([x_wav])[0], [576])
|
||||
# scores, alternative_outputs = iface.interpret([x_wav])
|
||||
# self.assertEqual(scores, ... )
|
||||
@ -418,8 +588,11 @@ class TestFile(unittest.TestCase):
|
||||
file_input = gr.inputs.File()
|
||||
output = file_input.preprocess(x_file)
|
||||
self.assertIsInstance(output, tempfile._TemporaryFileWrapper)
|
||||
self.assertEqual(file_input.serialize("test/test_files/sample_file.pdf", True), 'test/test_files/sample_file.pdf')
|
||||
|
||||
self.assertEqual(
|
||||
file_input.serialize("test/test_files/sample_file.pdf", True),
|
||||
"test/test_files/sample_file.pdf",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = file_input.save_flagged(tmpdirname, "file_input", [x_file], None)
|
||||
self.assertEqual("file_input/0", to_save)
|
||||
@ -430,12 +603,15 @@ class TestFile(unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(file_input.generate_sample(), dict)
|
||||
file_input = gr.inputs.File(label="Upload Your File")
|
||||
self.assertEqual(file_input.get_template_context(), {
|
||||
'file_count': 'single',
|
||||
'optional': False,
|
||||
'name': 'file',
|
||||
'label': 'Upload Your File'
|
||||
})
|
||||
self.assertEqual(
|
||||
file_input.get_template_context(),
|
||||
{
|
||||
"file_count": "single",
|
||||
"optional": False,
|
||||
"name": "file",
|
||||
"label": "Upload Your File",
|
||||
},
|
||||
)
|
||||
self.assertIsNone(file_input.preprocess(None))
|
||||
x_file["is_example"] = True
|
||||
self.assertIsNotNone(file_input.preprocess(x_file))
|
||||
@ -445,39 +621,46 @@ class TestFile(unittest.TestCase):
|
||||
|
||||
def get_size_of_file(file_obj):
|
||||
return os.path.getsize(file_obj.name)
|
||||
iface = gr.Interface(
|
||||
get_size_of_file, "file", "number")
|
||||
|
||||
iface = gr.Interface(get_size_of_file, "file", "number")
|
||||
self.assertEqual(iface.process([[x_file]])[0], [10558])
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
x_data = [["Tim", 12, False], ["Jan", 24, True]]
|
||||
dataframe_input = gr.inputs.Dataframe(headers=["Name","Age","Member"])
|
||||
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"])
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
self.assertEqual(output["Age"][1], 24)
|
||||
self.assertEqual(output["Member"][0], False)
|
||||
self.assertEqual(dataframe_input.preprocess_example(x_data), x_data)
|
||||
self.assertEqual(dataframe_input.serialize(x_data, True), x_data)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = dataframe_input.save_flagged(tmpdirname, "dataframe_input", x_data, None)
|
||||
to_save = dataframe_input.save_flagged(
|
||||
tmpdirname, "dataframe_input", x_data, None
|
||||
)
|
||||
self.assertEqual(json.dumps(x_data), to_save)
|
||||
restored = dataframe_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(x_data, restored)
|
||||
|
||||
self.assertIsInstance(dataframe_input.generate_sample(), list)
|
||||
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"], label="Dataframe Input")
|
||||
self.assertEqual(dataframe_input.get_template_context(), {
|
||||
'headers': ['Name', 'Age', 'Member'],
|
||||
'datatype': 'str',
|
||||
'row_count': 3,
|
||||
'col_count': 3,
|
||||
'col_width': None,
|
||||
'default': [[None, None, None], [None, None, None], [None, None, None]],
|
||||
'name': 'dataframe',
|
||||
'label': 'Dataframe Input'
|
||||
})
|
||||
dataframe_input = gr.inputs.Dataframe(
|
||||
headers=["Name", "Age", "Member"], label="Dataframe Input"
|
||||
)
|
||||
self.assertEqual(
|
||||
dataframe_input.get_template_context(),
|
||||
{
|
||||
"headers": ["Name", "Age", "Member"],
|
||||
"datatype": "str",
|
||||
"row_count": 3,
|
||||
"col_count": 3,
|
||||
"col_width": None,
|
||||
"default": [[None, None, None], [None, None, None], [None, None, None]],
|
||||
"name": "dataframe",
|
||||
"label": "Dataframe Input",
|
||||
},
|
||||
)
|
||||
dataframe_input = gr.inputs.Dataframe()
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
self.assertEqual(output[1][1], 24)
|
||||
@ -493,6 +676,7 @@ class TestDataframe(unittest.TestCase):
|
||||
|
||||
def get_last(l):
|
||||
return l[-1]
|
||||
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
|
||||
@ -503,7 +687,7 @@ class TestVideo(unittest.TestCase):
|
||||
video_input = gr.inputs.Video()
|
||||
output = video_input.preprocess(x_video)
|
||||
self.assertIsInstance(output, str)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = video_input.save_flagged(tmpdirname, "video_input", x_video, None)
|
||||
self.assertEqual("video_input/0.mp4", to_save)
|
||||
@ -511,15 +695,18 @@ class TestVideo(unittest.TestCase):
|
||||
self.assertEqual("video_input/1.mp4", to_save)
|
||||
restored = video_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, "video_input/1.mp4")
|
||||
|
||||
|
||||
self.assertIsInstance(video_input.generate_sample(), dict)
|
||||
video_input = gr.inputs.Video(label="Upload Your Video")
|
||||
self.assertEqual(video_input.get_template_context(), {
|
||||
'source': 'upload',
|
||||
'optional': False,
|
||||
'name': 'video',
|
||||
'label': 'Upload Your Video'
|
||||
})
|
||||
self.assertEqual(
|
||||
video_input.get_template_context(),
|
||||
{
|
||||
"source": "upload",
|
||||
"optional": False,
|
||||
"name": "video",
|
||||
"label": "Upload Your Video",
|
||||
},
|
||||
)
|
||||
self.assertIsNone(video_input.preprocess(None))
|
||||
x_video["is_example"] = True
|
||||
self.assertIsNotNone(video_input.preprocess(x_video))
|
||||
@ -528,72 +715,73 @@ class TestVideo(unittest.TestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
video_input.serialize(x_video, True)
|
||||
|
||||
|
||||
def test_in_interface(self):
|
||||
x_video = gr.test_data.BASE64_VIDEO
|
||||
iface = gr.Interface(
|
||||
lambda x:x,
|
||||
"video",
|
||||
"playable_video")
|
||||
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
||||
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
timeseries_input = gr.inputs.Timeseries(
|
||||
x="time",
|
||||
y=["retail", "food", "other"]
|
||||
)
|
||||
x_timeseries = {"data": [[1] + [2] * len(timeseries_input.y)] * 4, "headers": [timeseries_input.x] +
|
||||
timeseries_input.y}
|
||||
timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"])
|
||||
x_timeseries = {
|
||||
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
|
||||
"headers": [timeseries_input.x] + timeseries_input.y,
|
||||
}
|
||||
output = timeseries_input.preprocess(x_timeseries)
|
||||
self.assertIsInstance(output, pandas.core.frame.DataFrame)
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = timeseries_input.save_flagged(tmpdirname, "video_input", x_timeseries, None)
|
||||
to_save = timeseries_input.save_flagged(
|
||||
tmpdirname, "video_input", x_timeseries, None
|
||||
)
|
||||
self.assertEqual(json.dumps(x_timeseries), to_save)
|
||||
restored = timeseries_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(x_timeseries, restored)
|
||||
|
||||
self.assertIsInstance(timeseries_input.generate_sample(), dict)
|
||||
timeseries_input = gr.inputs.Timeseries(
|
||||
x="time",
|
||||
y="retail", label="Upload Your Timeseries"
|
||||
x="time", y="retail", label="Upload Your Timeseries"
|
||||
)
|
||||
self.assertEqual(
|
||||
timeseries_input.get_template_context(),
|
||||
{
|
||||
"x": "time",
|
||||
"y": ["retail"],
|
||||
"optional": False,
|
||||
"name": "timeseries",
|
||||
"label": "Upload Your Timeseries",
|
||||
},
|
||||
)
|
||||
self.assertEqual(timeseries_input.get_template_context(), {
|
||||
'x': 'time',
|
||||
'y': ['retail'],
|
||||
'optional': False,
|
||||
'name': 'timeseries',
|
||||
'label': 'Upload Your Timeseries'
|
||||
})
|
||||
self.assertIsNone(timeseries_input.preprocess(None))
|
||||
x_timeseries["range"] = (0, 1)
|
||||
self.assertIsNotNone(timeseries_input.preprocess(x_timeseries))
|
||||
|
||||
def test_in_interface(self):
|
||||
timeseries_input = gr.inputs.Timeseries(
|
||||
x="time",
|
||||
y=["retail", "food", "other"]
|
||||
timeseries_input = gr.inputs.Timeseries(x="time", y=["retail", "food", "other"])
|
||||
x_timeseries = {
|
||||
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
|
||||
"headers": [timeseries_input.x] + timeseries_input.y,
|
||||
}
|
||||
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
|
||||
self.assertEqual(
|
||||
iface.process([x_timeseries])[0],
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
"data": [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]],
|
||||
}
|
||||
],
|
||||
)
|
||||
x_timeseries = {"data": [[1] + [2] * len(timeseries_input.y)] * 4, "headers": [timeseries_input.x] +
|
||||
timeseries_input.y}
|
||||
iface = gr.Interface(
|
||||
lambda x: x,
|
||||
timeseries_input,
|
||||
"dataframe")
|
||||
self.assertEqual(iface.process([x_timeseries])[0], [{'headers': ['time', 'retail', 'food', 'other'],
|
||||
'data': [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2],
|
||||
[1, 2, 2, 2]]}])
|
||||
|
||||
|
||||
class TestNames(unittest.TestCase):
|
||||
# this ensures that `inputs.get_input_instance()` works correctly when instantiating from components
|
||||
def test_no_duplicate_uncased_names(self):
|
||||
def test_no_duplicate_uncased_names(self):
|
||||
subclasses = gr.inputs.InputComponent.__subclasses__()
|
||||
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
|
||||
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,19 +1,22 @@
|
||||
import io
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from gradio.interface import *
|
||||
import threading
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import requests
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import threading
|
||||
from comet_ml import Experiment
|
||||
|
||||
import mlflow
|
||||
import requests
|
||||
import wandb
|
||||
import socket
|
||||
from comet_ml import Experiment
|
||||
|
||||
from gradio.interface import *
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output():
|
||||
new_out, new_err = io.StringIO(), io.StringIO()
|
||||
@ -24,6 +27,7 @@ def captured_output():
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
def test_close(self):
|
||||
io = Interface(lambda input: None, "textbox", "label")
|
||||
@ -33,32 +37,34 @@ class TestInterface(unittest.TestCase):
|
||||
io.close()
|
||||
with self.assertRaises(Exception):
|
||||
response = requests.get(local_url)
|
||||
|
||||
|
||||
def test_close_all(self):
|
||||
interface = Interface(lambda input: None, "textbox", "label")
|
||||
interface.close = mock.MagicMock()
|
||||
close_all()
|
||||
interface.close.assert_called()
|
||||
|
||||
|
||||
def test_examples_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Interface(lambda x: x, examples=1234)
|
||||
|
||||
|
||||
def test_examples_valid_path(self):
|
||||
path = os.path.join(os.path.dirname(__file__), 'test_data/flagged_with_log')
|
||||
path = os.path.join(os.path.dirname(__file__), "test_data/flagged_with_log")
|
||||
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
|
||||
self.assertEqual(len(interface.get_config_file()['examples']), 2)
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), 'test_data/flagged_no_log')
|
||||
self.assertEqual(len(interface.get_config_file()["examples"]), 2)
|
||||
|
||||
path = os.path.join(os.path.dirname(__file__), "test_data/flagged_no_log")
|
||||
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
|
||||
self.assertEqual(len(interface.get_config_file()['examples']), 3)
|
||||
|
||||
self.assertEqual(len(interface.get_config_file()["examples"]), 3)
|
||||
|
||||
def test_examples_not_valid_path(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
interface = Interface(lambda x: x, "textbox", "label", examples='invalid-path')
|
||||
interface = Interface(
|
||||
lambda x: x, "textbox", "label", examples="invalid-path"
|
||||
)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
interface.close()
|
||||
|
||||
|
||||
def test_test_launch(self):
|
||||
with captured_output() as (out, err):
|
||||
prediction_fn = lambda x: x
|
||||
@ -66,8 +72,8 @@ class TestInterface(unittest.TestCase):
|
||||
interface = Interface(prediction_fn, "textbox", "label")
|
||||
interface.test_launch()
|
||||
output = out.getvalue().strip()
|
||||
self.assertEqual(output, 'Test launch: prediction_fn()... PASSED')
|
||||
|
||||
self.assertEqual(output, "Test launch: prediction_fn()... PASSED")
|
||||
|
||||
@mock.patch("time.sleep")
|
||||
def test_block_thread(self, mock_sleep):
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
@ -76,19 +82,20 @@ class TestInterface(unittest.TestCase):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(prevent_thread_lock=False)
|
||||
output = out.getvalue().strip()
|
||||
self.assertEqual(output, 'Keyboard interruption in main thread... closing server.')
|
||||
self.assertEqual(
|
||||
output, "Keyboard interruption in main thread... closing server."
|
||||
)
|
||||
|
||||
@mock.patch('gradio.utils.colab_check')
|
||||
@mock.patch("gradio.utils.colab_check")
|
||||
def test_launch_colab_share(self, mock_colab_check):
|
||||
mock_colab_check.return_value = True
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
_, _, share_url = interface.launch(prevent_thread_lock=True)
|
||||
self.assertIsNotNone(share_url)
|
||||
interface.close()
|
||||
|
||||
|
||||
@mock.patch('gradio.utils.colab_check')
|
||||
@mock.patch('gradio.networking.setup_tunnel')
|
||||
|
||||
@mock.patch("gradio.utils.colab_check")
|
||||
@mock.patch("gradio.networking.setup_tunnel")
|
||||
def test_launch_colab_share_error(self, mock_setup_tunnel, mock_colab_check):
|
||||
mock_setup_tunnel.side_effect = RuntimeError()
|
||||
mock_colab_check.return_value = True
|
||||
@ -96,40 +103,43 @@ class TestInterface(unittest.TestCase):
|
||||
_, _, share_url = interface.launch(prevent_thread_lock=True)
|
||||
self.assertIsNone(share_url)
|
||||
interface.close()
|
||||
|
||||
|
||||
def test_interface_representation(self):
|
||||
prediction_fn = lambda x: x
|
||||
prediction_fn.__name__ = "prediction_fn"
|
||||
repr = str(Interface(prediction_fn, "textbox", "label")).split('\n')
|
||||
repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
|
||||
self.assertTrue(prediction_fn.__name__ in repr[0])
|
||||
self.assertEqual(len(repr[0]), len(repr[1]))
|
||||
|
||||
|
||||
def test_interface_load(self):
|
||||
io = Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
io = Interface.load(
|
||||
"models/distilbert-base-uncased-finetuned-sst-2-english",
|
||||
alias="sentiment_classifier",
|
||||
)
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['POSITIVE'], 0.5)
|
||||
|
||||
self.assertGreater(output["POSITIVE"], 0.5)
|
||||
|
||||
def test_interface_none_interp(self):
|
||||
interface = Interface(lambda x: x, "textbox", "label", interpretation=[None])
|
||||
scores, alternative_outputs = interface.interpret(["quickest brown fox"])
|
||||
self.assertIsNone(scores[0])
|
||||
|
||||
@mock.patch('webbrowser.open')
|
||||
|
||||
@mock.patch("webbrowser.open")
|
||||
def test_interface_browser(self, mock_browser):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(inbrowser=True, prevent_thread_lock=True)
|
||||
mock_browser.assert_called_once()
|
||||
interface.close()
|
||||
|
||||
|
||||
def test_examples_list(self):
|
||||
examples = ['test1', 'test2']
|
||||
examples = ["test1", "test2"]
|
||||
interface = Interface(lambda x: x, "textbox", "label", examples=examples)
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
self.assertEqual(len(interface.examples), 2)
|
||||
self.assertEqual(len(interface.examples[0]), 1)
|
||||
interface.close()
|
||||
|
||||
@mock.patch('IPython.display.display')
|
||||
@mock.patch("IPython.display.display")
|
||||
def test_inline_display(self, mock_display):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(inline=True, prevent_thread_lock=True)
|
||||
@ -137,8 +147,8 @@ class TestInterface(unittest.TestCase):
|
||||
interface.launch(inline=True, share=True, prevent_thread_lock=True)
|
||||
self.assertEqual(mock_display.call_count, 2)
|
||||
interface.close()
|
||||
|
||||
@mock.patch('comet_ml.Experiment')
|
||||
|
||||
@mock.patch("comet_ml.Experiment")
|
||||
def test_integration_comet(self, mock_experiment):
|
||||
experiment = mock_experiment()
|
||||
experiment.log_text = mock.MagicMock()
|
||||
@ -146,44 +156,52 @@ class TestInterface(unittest.TestCase):
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
interface.integrate(comet_ml=experiment)
|
||||
experiment.log_text.assert_called_with('gradio: ' + interface.local_url)
|
||||
interface.share_url = 'tmp' # used to avoid creating real share links.
|
||||
experiment.log_text.assert_called_with("gradio: " + interface.local_url)
|
||||
interface.share_url = "tmp" # used to avoid creating real share links.
|
||||
interface.integrate(comet_ml=experiment)
|
||||
experiment.log_text.assert_called_with('gradio: ' + interface.share_url)
|
||||
experiment.log_text.assert_called_with("gradio: " + interface.share_url)
|
||||
self.assertEqual(experiment.log_other.call_count, 2)
|
||||
interface.share_url = None
|
||||
interface.close()
|
||||
|
||||
|
||||
def test_integration_mlflow(self):
|
||||
mlflow.log_param = mock.MagicMock()
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.launch(prevent_thread_lock=True)
|
||||
interface.integrate(mlflow=mlflow)
|
||||
mlflow.log_param.assert_called_with("Gradio Interface Local Link", interface.local_url)
|
||||
interface.share_url = 'tmp' # used to avoid creating real share links.
|
||||
mlflow.log_param.assert_called_with(
|
||||
"Gradio Interface Local Link", interface.local_url
|
||||
)
|
||||
interface.share_url = "tmp" # used to avoid creating real share links.
|
||||
interface.integrate(mlflow=mlflow)
|
||||
mlflow.log_param.assert_called_with("Gradio Interface Share Link", interface.share_url)
|
||||
mlflow.log_param.assert_called_with(
|
||||
"Gradio Interface Share Link", interface.share_url
|
||||
)
|
||||
interface.share_url = None
|
||||
interface.close()
|
||||
|
||||
|
||||
def test_integration_wandb(self):
|
||||
with captured_output() as (out, err):
|
||||
wandb.log = mock.MagicMock()
|
||||
wandb.Html = mock.MagicMock()
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.integrate(wandb=wandb)
|
||||
self.assertEqual(out.getvalue().strip(), "The WandB integration requires you to `launch(share=True)` first.")
|
||||
interface.share_url = 'tmp'
|
||||
self.assertEqual(
|
||||
out.getvalue().strip(),
|
||||
"The WandB integration requires you to `launch(share=True)` first.",
|
||||
)
|
||||
interface.share_url = "tmp"
|
||||
interface.integrate(wandb=wandb)
|
||||
wandb.log.assert_called_once()
|
||||
|
||||
@mock.patch('requests.post')
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_integration_analytics(self, mock_post):
|
||||
mlflow.log_param = mock.MagicMock()
|
||||
interface = Interface(lambda x: x, "textbox", "label")
|
||||
interface.analytics_enabled = True
|
||||
interface.integrate(mlflow=mlflow)
|
||||
mock_post.assert_called_once()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,45 +1,71 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gradio.interpretation
|
||||
import gradio.test_data
|
||||
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
|
||||
from gradio import Interface
|
||||
import numpy as np
|
||||
import os
|
||||
from gradio.processing_utils import (decode_base64_to_image,
|
||||
encode_array_to_base64)
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestDefault(unittest.TestCase):
|
||||
def test_default_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation="default"
|
||||
)
|
||||
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
||||
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
||||
self.assertGreater(
|
||||
interpretation[0][1], 0
|
||||
) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(
|
||||
interpretation[-1][1], 0
|
||||
) # Checks to see if the last word has 0 score.
|
||||
|
||||
|
||||
class TestShapley(unittest.TestCase):
|
||||
def test_shapley_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley")
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation="shapley"
|
||||
)
|
||||
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
||||
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
||||
self.assertGreater(
|
||||
interpretation[0][1], 0
|
||||
) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(
|
||||
interpretation[-1][1], 0
|
||||
) # Checks to see if the last word has 0 score.
|
||||
|
||||
|
||||
class TestCustom(unittest.TestCase):
|
||||
def test_custom_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
custom = lambda text: [(char, 1) for char in text]
|
||||
text_interface = Interface(max_word_len, "textbox", "label", interpretation=custom)
|
||||
text_interface = Interface(
|
||||
max_word_len, "textbox", "label", interpretation=custom
|
||||
)
|
||||
result = text_interface.interpret(["quickest brown fox"])[0][0]
|
||||
self.assertEqual(result[0][1], 1) # Checks to see if the first letter has score of 1.
|
||||
self.assertEqual(
|
||||
result[0][1], 1
|
||||
) # Checks to see if the first letter has score of 1.
|
||||
|
||||
def test_custom_img(self):
|
||||
max_pixel_value = lambda img: img.max()
|
||||
custom = lambda img: img.tolist()
|
||||
img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom)
|
||||
img_interface = Interface(
|
||||
max_pixel_value, "image", "label", interpretation=custom
|
||||
)
|
||||
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][0]
|
||||
expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist()
|
||||
expected_result = np.asarray(
|
||||
decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert("RGB")
|
||||
).tolist()
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
|
||||
class TestHelperMethods(unittest.TestCase):
|
||||
def test_diff(self):
|
||||
@ -52,9 +78,13 @@ class TestHelperMethods(unittest.TestCase):
|
||||
|
||||
def test_quantify_difference_with_textbox(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["textbox"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, ["test"], ["test"]
|
||||
)
|
||||
self.assertEquals(diff, 0)
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, ["test"], ["test_diff"]
|
||||
)
|
||||
self.assertEquals(diff, 1)
|
||||
|
||||
def test_quantify_difference_with_label(self):
|
||||
@ -66,48 +96,43 @@ class TestHelperMethods(unittest.TestCase):
|
||||
|
||||
def test_quantify_difference_with_confidences(self):
|
||||
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
||||
output_1 = {
|
||||
"cat": 0.9,
|
||||
"dog": 0.1
|
||||
}
|
||||
output_2 = {
|
||||
"cat": 0.6,
|
||||
"dog": 0.4
|
||||
}
|
||||
output_3 = {
|
||||
"cat": 0.1,
|
||||
"dog": 0.6
|
||||
}
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2])
|
||||
output_1 = {"cat": 0.9, "dog": 0.1}
|
||||
output_2 = {"cat": 0.6, "dog": 0.4}
|
||||
output_3 = {"cat": 0.1, "dog": 0.6}
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, [output_1], [output_2]
|
||||
)
|
||||
self.assertAlmostEquals(diff, 0.3)
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, [output_1], [output_3]
|
||||
)
|
||||
self.assertAlmostEquals(diff, 0.8)
|
||||
|
||||
def test_get_regression_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
output_1 = {
|
||||
"cat": 0.9,
|
||||
"dog": 0.1
|
||||
}
|
||||
output_2 = {
|
||||
"cat": float("nan"),
|
||||
"dog": 0.4
|
||||
}
|
||||
output_3 = {
|
||||
"cat": 0.1,
|
||||
"dog": 0.6
|
||||
}
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2])
|
||||
output_1 = {"cat": 0.9, "dog": 0.1}
|
||||
output_2 = {"cat": float("nan"), "dog": 0.4}
|
||||
output_3 = {"cat": 0.1, "dog": 0.6}
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, [output_1], [output_2]
|
||||
)
|
||||
self.assertEquals(diff, 0)
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, [output_1], [output_3]
|
||||
)
|
||||
self.assertAlmostEquals(diff, 0.1)
|
||||
|
||||
def test_get_classification_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, ["cat"], ["test"]
|
||||
)
|
||||
self.assertEquals(diff, 1)
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, ["test"], ["test"]
|
||||
)
|
||||
self.assertEquals(diff, 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import gradio as gr
|
||||
from gradio import mix
|
||||
import os
|
||||
|
||||
|
||||
"""
|
||||
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
|
||||
@ -14,8 +14,7 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
class TestSeries(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
io1 = gr.Interface(lambda x: x + " World", "textbox",
|
||||
gr.outputs.Textbox())
|
||||
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
|
||||
series = mix.Series(io1, io2)
|
||||
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
|
||||
@ -25,18 +24,18 @@ class TestSeries(unittest.TestCase):
|
||||
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
|
||||
series = mix.Series(io1, io2)
|
||||
output = series("test/test_data/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
self.assertGreater(output["lion"], 0.5)
|
||||
|
||||
|
||||
class TestParallel(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
|
||||
gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox",
|
||||
gr.outputs.Textbox())
|
||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
|
||||
"Hello World 2!"])
|
||||
self.assertEqual(
|
||||
parallel.process(["Hello"])[0], ["Hello World 1!", "Hello World 2!"]
|
||||
)
|
||||
|
||||
def test_with_external(self):
|
||||
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
||||
io2 = gr.Interface.load("spaces/abidlabs/english2german")
|
||||
@ -46,5 +45,5 @@ class TestParallel(unittest.TestCase):
|
||||
self.assertIn("hallo", hello_de.lower())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,15 +1,15 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import aiohttp
|
||||
from fastapi.testclient import TestClient
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import urllib.request
|
||||
import warnings
|
||||
|
||||
from gradio import flagging, Interface, networking, reset_all, utils
|
||||
import aiohttp
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from gradio import Interface, flagging, networking, reset_all, utils
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -37,46 +37,46 @@ class TestPort(unittest.TestCase):
|
||||
|
||||
class TestRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get('/')
|
||||
response = self.client.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_api_route(self):
|
||||
response = self.client.get('/api/')
|
||||
response = self.client.get("/api/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_static_files_served_safely(self):
|
||||
# Make sure things outside the static folder are not accessible
|
||||
response = self.client.get(r'/static/..%2findex.html')
|
||||
response = self.client.get(r"/static/..%2findex.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(r'/static/..%2f..%2fapi_docs.html')
|
||||
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get('/config/')
|
||||
response = self.client.get("/config/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post('/api/predict/', json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
# def test_queue_push_route(self):
|
||||
# networking.queue.push = mock.MagicMock(return_value=(None, None))
|
||||
# response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
|
||||
# def test_queue_push_route(self):
|
||||
# networking.queue.get_status = mock.MagicMock(return_value=(None, None))
|
||||
# response = self.client.post('/api/queue/status/', json={"hash": "test"})
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
@ -85,15 +85,21 @@ class TestRoutes(unittest.TestCase):
|
||||
|
||||
class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="correct_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
|
||||
self.assertEqual(response.status_code, 400)
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="incorrect_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
@ -102,10 +108,10 @@ class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
io = Interface(lambda x: 1/x, "number", "number")
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post('/api/predict/', json={"data": [0]})
|
||||
response = client.post("/api/predict/", json={"data": [0]})
|
||||
self.assertEqual(response.status_code, 500)
|
||||
self.assertTrue("error" in response.json())
|
||||
io.close()
|
||||
@ -119,13 +125,18 @@ class TestFlagging(unittest.TestCase):
|
||||
aiohttp.ClientSession.post.__aenter__ = None
|
||||
aiohttp.ClientSession.post.__aexit__ = None
|
||||
io = Interface(
|
||||
lambda x: x, "text", "text",
|
||||
analytics_enabled=True, flagging_callback=callback)
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
analytics_enabled=True,
|
||||
flagging_callback=callback,
|
||||
)
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
'/api/flag/',
|
||||
json={"data": {"input_data": ["test"], "output_data": ["test"]}})
|
||||
"/api/flag/",
|
||||
json={"data": {"input_data": ["test"], "output_data": ["test"]}},
|
||||
)
|
||||
aiohttp.ClientSession.post.assert_called()
|
||||
callback.flag.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
@ -135,16 +146,19 @@ class TestFlagging(unittest.TestCase):
|
||||
class TestInterpretation(unittest.TestCase):
|
||||
def test_interpretation(self):
|
||||
io = Interface(
|
||||
lambda x: len(x), "text", "label",
|
||||
interpretation="default", analytics_enabled=True)
|
||||
lambda x: len(x),
|
||||
"text",
|
||||
"label",
|
||||
interpretation="default",
|
||||
analytics_enabled=True,
|
||||
)
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
aiohttp.ClientSession.post = mock.MagicMock()
|
||||
aiohttp.ClientSession.post.__aenter__ = None
|
||||
aiohttp.ClientSession.post.__aexit__ = None
|
||||
aiohttp.ClientSession.post.__aexit__ = None
|
||||
io.interpret = mock.MagicMock(return_value=(None, None))
|
||||
response = client.post(
|
||||
'/api/interpret/', json={"data": ["test test"]})
|
||||
response = client.post("/api/interpret/", json={"data": ["test test"]})
|
||||
aiohttp.ClientSession.post.assert_called()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
io.close()
|
||||
@ -187,5 +201,5 @@ class TestURLs(unittest.TestCase):
|
||||
# networking.queue.fail_job.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,11 +1,12 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -26,13 +27,15 @@ class TestTextbox(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox(type="number"))
|
||||
iface = gr.Interface(
|
||||
lambda x: x / 2, "number", gr.outputs.Textbox(type="number")
|
||||
)
|
||||
self.assertEqual(iface.process([10])[0], [5])
|
||||
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y = 'happy'
|
||||
y = "happy"
|
||||
label_output = gr.outputs.Label()
|
||||
label = label_output.postprocess(y)
|
||||
self.assertDictEqual(label, {"label": "happy"})
|
||||
@ -41,45 +44,55 @@ class TestLabel(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
|
||||
self.assertEqual(to_save, y)
|
||||
y = {
|
||||
3: 0.7,
|
||||
1: 0.2,
|
||||
0: 0.1
|
||||
}
|
||||
y = {3: 0.7, 1: 0.2, 0: 0.1}
|
||||
label_output = gr.outputs.Label()
|
||||
label = label_output.postprocess(y)
|
||||
self.assertDictEqual(label, {
|
||||
"label": 3,
|
||||
"confidences": [
|
||||
{"label": 3, "confidence": 0.7},
|
||||
{"label": 1, "confidence": 0.2},
|
||||
{"label": 0, "confidence": 0.1},
|
||||
]
|
||||
})
|
||||
self.assertDictEqual(
|
||||
label,
|
||||
{
|
||||
"label": 3,
|
||||
"confidences": [
|
||||
{"label": 3, "confidence": 0.7},
|
||||
{"label": 1, "confidence": 0.2},
|
||||
{"label": 0, "confidence": 0.1},
|
||||
],
|
||||
},
|
||||
)
|
||||
label_output = gr.outputs.Label(num_top_classes=2)
|
||||
label = label_output.postprocess(y)
|
||||
self.assertDictEqual(label, {
|
||||
"label": 3,
|
||||
"confidences": [
|
||||
{"label": 3, "confidence": 0.7},
|
||||
{"label": 1, "confidence": 0.2},
|
||||
]
|
||||
})
|
||||
self.assertDictEqual(
|
||||
label,
|
||||
{
|
||||
"label": 3,
|
||||
"confidences": [
|
||||
{"label": 3, "confidence": 0.7},
|
||||
{"label": 1, "confidence": 0.2},
|
||||
],
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
label_output.postprocess([1, 2, 3])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
|
||||
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
|
||||
self.assertEqual(label_output.restore_flagged(tmpdir, to_save, None),
|
||||
{'label': '3', 'confidences': [{"label": "3", "confidence": 0.7}, {"label": "1", "confidence": 0.2}]})
|
||||
self.assertEqual(
|
||||
label_output.restore_flagged(tmpdir, to_save, None),
|
||||
{
|
||||
"label": "3",
|
||||
"confidences": [
|
||||
{"label": "3", "confidence": 0.7},
|
||||
{"label": "1", "confidence": 0.2},
|
||||
],
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
label_output = gr.outputs.Label(type="unknown")
|
||||
label_output.deserialize([1, 2, 3])
|
||||
|
||||
def test_in_interface(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
|
||||
|
||||
def rgb_distribution(img):
|
||||
rgb_dist = np.mean(img, axis=(0, 1))
|
||||
rgb_dist /= np.sum(rgb_dist)
|
||||
@ -92,22 +105,33 @@ class TestLabel(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(rgb_distribution, "image", "label")
|
||||
output = iface.process([x_img])[0][0]
|
||||
self.assertDictEqual(output, {
|
||||
'label': 'red',
|
||||
'confidences': [
|
||||
{'label': 'red', 'confidence': 0.44},
|
||||
{'label': 'green', 'confidence': 0.28},
|
||||
{'label': 'blue', 'confidence': 0.28}
|
||||
]
|
||||
})
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
"label": "red",
|
||||
"confidences": [
|
||||
{"label": "red", "confidence": 0.44},
|
||||
{"label": "green", "confidence": 0.28},
|
||||
{"label": "blue", "confidence": 0.28},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
|
||||
image_output = gr.outputs.Image()
|
||||
self.assertTrue(image_output.postprocess(y_img).startswith(""))
|
||||
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith(""))
|
||||
self.assertTrue(
|
||||
image_output.postprocess(y_img).startswith(
|
||||
""
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
image_output.postprocess(np.array(y_img)).startswith(
|
||||
""
|
||||
)
|
||||
)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
plot_output = gr.outputs.Image(plot=True)
|
||||
|
||||
@ -115,15 +139,23 @@ class TestImage(unittest.TestCase):
|
||||
ypoints = np.array([0, 250])
|
||||
fig = plt.figure()
|
||||
p = plt.plot(xpoints, ypoints)
|
||||
self.assertTrue(plot_output.postprocess(fig).startswith("data:image/png;base64,"))
|
||||
self.assertTrue(
|
||||
plot_output.postprocess(fig).startswith("data:image/png;base64,")
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
image_output.postprocess([1, 2, 3])
|
||||
image_output = gr.outputs.Image(type="numpy")
|
||||
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,"))
|
||||
self.assertTrue(
|
||||
image_output.postprocess(y_img).startswith("data:image/png;base64,")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = image_output.save_flagged(tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None)
|
||||
to_save = image_output.save_flagged(
|
||||
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
|
||||
)
|
||||
self.assertEqual("image_output/0.png", to_save)
|
||||
to_save = image_output.save_flagged(tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None)
|
||||
to_save = image_output.save_flagged(
|
||||
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
|
||||
)
|
||||
self.assertEqual("image_output/1.png", to_save)
|
||||
|
||||
def test_in_interface(self):
|
||||
@ -131,19 +163,29 @@ class TestImage(unittest.TestCase):
|
||||
return np.random.randint(0, 256, (width, height, 3))
|
||||
|
||||
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
||||
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
|
||||
self.assertTrue(
|
||||
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
|
||||
)
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_vid = "test/test_files/video_sample.mp4"
|
||||
video_output = gr.outputs.Video()
|
||||
self.assertTrue(video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,"))
|
||||
self.assertTrue(video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4"))
|
||||
self.assertTrue(
|
||||
video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,")
|
||||
)
|
||||
self.assertTrue(
|
||||
video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = video_output.save_flagged(tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None)
|
||||
to_save = video_output.save_flagged(
|
||||
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
|
||||
)
|
||||
self.assertEqual("video_output/0.mp4", to_save)
|
||||
to_save = video_output.save_flagged(tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None)
|
||||
to_save = video_output.save_flagged(
|
||||
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
|
||||
)
|
||||
self.assertEqual("video_output/1.mp4", to_save)
|
||||
|
||||
|
||||
@ -159,7 +201,10 @@ class TestKeyValues(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = kv_output.save_flagged(tmpdirname, "kv_output", kv_list, None)
|
||||
self.assertEqual(to_save, '[["a", 1], ["b", 2]]')
|
||||
self.assertEqual(kv_output.restore_flagged(tmpdirname, to_save, None), [["a", 1], ["b", 2]])
|
||||
self.assertEqual(
|
||||
kv_output.restore_flagged(tmpdirname, to_save, None),
|
||||
[["a", 1], ["b", 2]],
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
def letter_distribution(word):
|
||||
@ -169,27 +214,31 @@ class TestKeyValues(unittest.TestCase):
|
||||
return dist
|
||||
|
||||
iface = gr.Interface(letter_distribution, "text", "key_values")
|
||||
self.assertListEqual(iface.process(["alpaca"])[0][0], [
|
||||
("a", 3), ("l", 1), ("p", 1), ("c", 1)])
|
||||
self.assertListEqual(
|
||||
iface.process(["alpaca"])[0][0], [("a", 3), ("l", 1), ("p", 1), ("c", 1)]
|
||||
)
|
||||
|
||||
|
||||
class TestHighlightedText(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"})
|
||||
self.assertEqual(ht_output.get_template_context(), {
|
||||
'color_map': {'pos': 'green', 'neg': 'red'},
|
||||
'name': 'highlightedtext',
|
||||
'label': None,
|
||||
'show_legend': False
|
||||
})
|
||||
ht = {
|
||||
"pos": "Hello ",
|
||||
"neg": "World"
|
||||
}
|
||||
self.assertEqual(
|
||||
ht_output.get_template_context(),
|
||||
{
|
||||
"color_map": {"pos": "green", "neg": "red"},
|
||||
"name": "highlightedtext",
|
||||
"label": None,
|
||||
"show_legend": False,
|
||||
},
|
||||
)
|
||||
ht = {"pos": "Hello ", "neg": "World"}
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = ht_output.save_flagged(tmpdirname, "ht_output", ht, None)
|
||||
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
|
||||
self.assertEqual(ht_output.restore_flagged(tmpdirname, to_save, None), {"pos": "Hello ", "neg": "World"})
|
||||
self.assertEqual(
|
||||
ht_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{"pos": "Hello ", "neg": "World"},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
def highlight_vowels(sentence):
|
||||
@ -206,29 +255,42 @@ class TestHighlightedText(unittest.TestCase):
|
||||
cur_phrase += letter
|
||||
phrases.append((cur_phrase, mode))
|
||||
return phrases
|
||||
|
||||
|
||||
iface = gr.Interface(highlight_vowels, "text", "highlight")
|
||||
self.assertListEqual(iface.process(["Helloooo"])[0][0], [
|
||||
("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")])
|
||||
self.assertListEqual(
|
||||
iface.process(["Helloooo"])[0][0],
|
||||
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
|
||||
)
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_audio = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_AUDIO["data"])
|
||||
y_audio = gr.processing_utils.decode_base64_to_file(
|
||||
gr.test_data.BASE64_AUDIO["data"]
|
||||
)
|
||||
audio_output = gr.outputs.Audio(type="file")
|
||||
self.assertTrue(audio_output.postprocess(y_audio.name).startswith("data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"))
|
||||
self.assertEqual(audio_output.get_template_context(), {
|
||||
'name': 'audio',
|
||||
'label': None
|
||||
})
|
||||
self.assertTrue(
|
||||
audio_output.postprocess(y_audio.name).startswith(
|
||||
"data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
audio_output.get_template_context(), {"name": "audio", "label": None}
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Audio(type="unknown")
|
||||
wrong_type.postprocess(y_audio.name)
|
||||
self.assertTrue(audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav"))
|
||||
self.assertTrue(
|
||||
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = audio_output.save_flagged(tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
|
||||
)
|
||||
self.assertEqual("audio_output/0.wav", to_save)
|
||||
to_save = audio_output.save_flagged(tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
|
||||
)
|
||||
self.assertEqual("audio_output/1.wav", to_save)
|
||||
|
||||
def test_in_interface(self):
|
||||
@ -242,16 +304,18 @@ class TestAudio(unittest.TestCase):
|
||||
class TestJSON(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
js_output = gr.outputs.JSON()
|
||||
self.assertTrue(js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"')
|
||||
js = {
|
||||
"pos": "Hello ",
|
||||
"neg": "World"
|
||||
}
|
||||
self.assertTrue(
|
||||
js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
|
||||
)
|
||||
js = {"pos": "Hello ", "neg": "World"}
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = js_output.save_flagged(tmpdirname, "js_output", js, None)
|
||||
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
|
||||
self.assertEqual(js_output.restore_flagged(tmpdirname, to_save, None), {"pos": "Hello ", "neg": "World"})
|
||||
|
||||
self.assertEqual(
|
||||
js_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{"pos": "Hello ", "neg": "World"},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
def get_avg_age_per_gender(data):
|
||||
return {
|
||||
@ -263,7 +327,8 @@ class TestJSON(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
get_avg_age_per_gender,
|
||||
gr.inputs.Dataframe(headers=["gender", "age"]),
|
||||
"json")
|
||||
"json",
|
||||
)
|
||||
y_data = [
|
||||
["M", 30],
|
||||
["F", 20],
|
||||
@ -271,9 +336,7 @@ class TestJSON(unittest.TestCase):
|
||||
["O", 20],
|
||||
["F", 30],
|
||||
]
|
||||
self.assertDictEqual(iface.process([y_data])[0][0], {
|
||||
"M": 35, "F": 25, "O": 20
|
||||
})
|
||||
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
|
||||
|
||||
|
||||
class TestHTML(unittest.TestCase):
|
||||
@ -293,80 +356,116 @@ class TestFile(unittest.TestCase):
|
||||
return "test.txt"
|
||||
|
||||
iface = gr.Interface(write_file, "text", "file")
|
||||
self.assertDictEqual(iface.process(["hello world"])[0][0], {
|
||||
'name': 'test.txt', 'size': 11, 'data': 'data:text/plain;base64,aGVsbG8gd29ybGQ='
|
||||
})
|
||||
self.assertDictEqual(
|
||||
iface.process(["hello world"])[0][0],
|
||||
{
|
||||
"name": "test.txt",
|
||||
"size": 11,
|
||||
"data": "data:text/plain;base64,aGVsbG8gd29ybGQ=",
|
||||
},
|
||||
)
|
||||
file_output = gr.outputs.File()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = file_output.save_flagged(tmpdirname, "file_output", gr.test_data.BASE64_FILE, None)
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
|
||||
)
|
||||
self.assertEqual("file_output/0", to_save)
|
||||
to_save = file_output.save_flagged(tmpdirname, "file_output", gr.test_data.BASE64_FILE, None)
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
|
||||
)
|
||||
self.assertEqual("file_output/1", to_save)
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
dataframe_output = gr.outputs.Dataframe()
|
||||
output = dataframe_output.postprocess(np.zeros((2,2)))
|
||||
self.assertDictEqual(output, {"data": [[0,0],[0,0]]})
|
||||
output = dataframe_output.postprocess([[1,3,5]])
|
||||
output = dataframe_output.postprocess(np.zeros((2, 2)))
|
||||
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
|
||||
output = dataframe_output.postprocess([[1, 3, 5]])
|
||||
self.assertDictEqual(output, {"data": [[1, 3, 5]]})
|
||||
output = dataframe_output.postprocess(pd.DataFrame(
|
||||
[[2, True], [3, True], [4, False]], columns=["num", "prime"]))
|
||||
self.assertDictEqual(output,
|
||||
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]})
|
||||
self.assertEqual(dataframe_output.get_template_context(), {
|
||||
'headers': None,
|
||||
'max_rows': 20,
|
||||
'max_cols': None,
|
||||
'overflow_row_behaviour': 'paginate',
|
||||
'name': 'dataframe',
|
||||
'label': None
|
||||
})
|
||||
output = dataframe_output.postprocess(
|
||||
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
|
||||
)
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]},
|
||||
)
|
||||
self.assertEqual(
|
||||
dataframe_output.get_template_context(),
|
||||
{
|
||||
"headers": None,
|
||||
"max_rows": 20,
|
||||
"max_cols": None,
|
||||
"overflow_row_behaviour": "paginate",
|
||||
"name": "dataframe",
|
||||
"label": None,
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Dataframe(type="unknown")
|
||||
wrong_type.postprocess(0)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = dataframe_output.save_flagged(tmpdirname, "dataframe_output", output, None)
|
||||
self.assertEqual(to_save, '[[2, true], [3, true], [4, false]]')
|
||||
self.assertEqual(dataframe_output.restore_flagged(tmpdirname, to_save, None), {"data": [[2, True], [3, True], [4, False]]})
|
||||
to_save = dataframe_output.save_flagged(
|
||||
tmpdirname, "dataframe_output", output, None
|
||||
)
|
||||
self.assertEqual(to_save, "[[2, true], [3, true], [4, false]]")
|
||||
self.assertEqual(
|
||||
dataframe_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{"data": [[2, True], [3, True], [4, False]]},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
def check_odd(array):
|
||||
return array % 2 == 0
|
||||
|
||||
iface = gr.Interface(check_odd, "numpy", "numpy")
|
||||
self.assertEqual(
|
||||
iface.process([[2, 3, 4]])[0][0],
|
||||
{"data": [[True, False, True]]})
|
||||
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
|
||||
)
|
||||
|
||||
|
||||
class TestCarousel(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
|
||||
|
||||
output = carousel_output.postprocess([["Hello World", "test/test_files/bus.png"],
|
||||
["Bye World", "test/test_files/bus.png"]])
|
||||
self.assertEqual(output, [['Hello World', gr.test_data.BASE64_IMAGE],
|
||||
['Bye World', gr.test_data.BASE64_IMAGE]])
|
||||
output = carousel_output.postprocess(
|
||||
[
|
||||
["Hello World", "test/test_files/bus.png"],
|
||||
["Bye World", "test/test_files/bus.png"],
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
["Hello World", gr.test_data.BASE64_IMAGE],
|
||||
["Bye World", gr.test_data.BASE64_IMAGE],
|
||||
],
|
||||
)
|
||||
|
||||
carousel_output = gr.outputs.Carousel("text", label="Disease")
|
||||
output = carousel_output.postprocess([["Hello World"], ["Bye World"]])
|
||||
self.assertEqual(output, [['Hello World'], ['Bye World']])
|
||||
self.assertEqual(carousel_output.get_template_context(), {
|
||||
'components': [{'name': 'textbox', 'label': None}],
|
||||
'name': 'carousel',
|
||||
'label': 'Disease'
|
||||
})
|
||||
self.assertEqual(output, [["Hello World"], ["Bye World"]])
|
||||
self.assertEqual(
|
||||
carousel_output.get_template_context(),
|
||||
{
|
||||
"components": [{"name": "textbox", "label": None}],
|
||||
"name": "carousel",
|
||||
"label": "Disease",
|
||||
},
|
||||
)
|
||||
output = carousel_output.postprocess(["Hello World", "Bye World"])
|
||||
self.assertEqual(output, [['Hello World'], ['Bye World']])
|
||||
self.assertEqual(output, [["Hello World"], ["Bye World"]])
|
||||
with self.assertRaises(ValueError):
|
||||
carousel_output.postprocess('Hello World!')
|
||||
carousel_output.postprocess("Hello World!")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = carousel_output.save_flagged(tmpdirname, "carousel_output", output, None)
|
||||
to_save = carousel_output.save_flagged(
|
||||
tmpdirname, "carousel_output", output, None
|
||||
)
|
||||
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
|
||||
|
||||
def test_in_interface(self):
|
||||
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
|
||||
|
||||
def report(img):
|
||||
results = []
|
||||
for i, mode in enumerate(["Red", "Green", "Blue"]):
|
||||
@ -374,48 +473,82 @@ class TestCarousel(unittest.TestCase):
|
||||
color_filter[i] = 1
|
||||
results.append([mode, img * color_filter])
|
||||
return results
|
||||
|
||||
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
|
||||
self.assertEqual(
|
||||
iface.process([gr.test_data.BASE64_IMAGE])[0], [[['Red',
|
||||
''],
|
||||
['Green',
|
||||
''],
|
||||
['Blue',
|
||||
'']]])
|
||||
iface.process([gr.test_data.BASE64_IMAGE])[0],
|
||||
[
|
||||
[
|
||||
[
|
||||
"Red",
|
||||
"",
|
||||
],
|
||||
[
|
||||
"Green",
|
||||
"",
|
||||
],
|
||||
[
|
||||
"Blue",
|
||||
"",
|
||||
],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
timeseries_output = gr.outputs.Timeseries(label="Disease")
|
||||
self.assertEqual(timeseries_output.get_template_context(), {
|
||||
'x': None, 'y': None, 'name': 'timeseries', 'label': 'Disease'
|
||||
})
|
||||
data = {'Name': ['Tom', 'nick', 'krish', 'jack'], 'Age': [20, 21, 19, 18]}
|
||||
self.assertEqual(
|
||||
timeseries_output.get_template_context(),
|
||||
{"x": None, "y": None, "name": "timeseries", "label": "Disease"},
|
||||
)
|
||||
data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]}
|
||||
df = pd.DataFrame(data)
|
||||
self.assertEqual(timeseries_output.postprocess(df),{'headers': ['Name', 'Age'],
|
||||
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
|
||||
['jack', 18]]})
|
||||
self.assertEqual(
|
||||
timeseries_output.postprocess(df),
|
||||
{
|
||||
"headers": ["Name", "Age"],
|
||||
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
|
||||
},
|
||||
)
|
||||
|
||||
timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease")
|
||||
output = timeseries_output.postprocess(df)
|
||||
self.assertEqual(output, {'headers': ['Name', 'Age'],
|
||||
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
|
||||
['jack', 18]]})
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"headers": ["Name", "Age"],
|
||||
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
|
||||
},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = timeseries_output.save_flagged(tmpdirname, "timeseries_output", output, None)
|
||||
self.assertEqual(to_save, '{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
|
||||
'["jack", 18]]}')
|
||||
self.assertEqual(timeseries_output.restore_flagged(tmpdirname, to_save, None), {"headers": ["Name", "Age"],
|
||||
"data": [["Tom", 20], ["nick", 21],
|
||||
["krish", 19], ["jack", 18]]})
|
||||
to_save = timeseries_output.save_flagged(
|
||||
tmpdirname, "timeseries_output", output, None
|
||||
)
|
||||
self.assertEqual(
|
||||
to_save,
|
||||
'{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
|
||||
'["jack", 18]]}',
|
||||
)
|
||||
self.assertEqual(
|
||||
timeseries_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{
|
||||
"headers": ["Name", "Age"],
|
||||
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestNames(unittest.TestCase):
|
||||
def test_no_duplicate_uncased_names(self): # this ensures that get_input_instance() works correctly when instantiating from components
|
||||
def test_no_duplicate_uncased_names(
|
||||
self,
|
||||
): # this ensures that get_input_instance() works correctly when instantiating from components
|
||||
subclasses = gr.outputs.OutputComponent.__subclasses__()
|
||||
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
|
||||
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,12 +1,13 @@
|
||||
import unittest
|
||||
import os
|
||||
import pathlib
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
import gradio as gr
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -14,23 +15,27 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
class ImagePreprocessing(unittest.TestCase):
|
||||
def test_decode_base64_to_image(self):
|
||||
output_image = gr.processing_utils.decode_base64_to_image(
|
||||
gr.test_data.BASE64_IMAGE)
|
||||
gr.test_data.BASE64_IMAGE
|
||||
)
|
||||
self.assertIsInstance(output_image, Image.Image)
|
||||
|
||||
def test_encode_url_or_file_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
|
||||
"test/test_data/test_image.png")
|
||||
"test/test_data/test_image.png"
|
||||
)
|
||||
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
|
||||
|
||||
def test_encode_file_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_file_to_base64(
|
||||
"test/test_data/test_image.png")
|
||||
"test/test_data/test_image.png"
|
||||
)
|
||||
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
|
||||
|
||||
def test_encode_url_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_url_to_base64(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/master/test"
|
||||
"/test_data/test_image.png")
|
||||
"/test_data/test_image.png"
|
||||
)
|
||||
self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
|
||||
|
||||
# def test_encode_plot_to_base64(self): # Commented out because this is throwing errors on Windows. Possibly due to different matplotlib behavior on Windows?
|
||||
@ -49,19 +54,21 @@ class ImagePreprocessing(unittest.TestCase):
|
||||
img = Image.open("test/test_data/test_image.png")
|
||||
new_img = gr.processing_utils.resize_and_crop(img, (20, 20))
|
||||
self.assertEqual(new_img.size, (20, 20))
|
||||
self.assertRaises(ValueError, gr.processing_utils.resize_and_crop,
|
||||
**{'img': img, 'size': (20, 20), 'crop_type': 'test'})
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
gr.processing_utils.resize_and_crop,
|
||||
**{"img": img, "size": (20, 20), "crop_type": "test"}
|
||||
)
|
||||
|
||||
|
||||
class AudioPreprocessing(unittest.TestCase):
|
||||
def test_audio_from_file(self):
|
||||
audio = gr.processing_utils.audio_from_file(
|
||||
"test/test_data/test_audio.wav")
|
||||
audio = gr.processing_utils.audio_from_file("test/test_data/test_audio.wav")
|
||||
self.assertEqual(audio[0], 22050)
|
||||
self.assertIsInstance(audio[1], np.ndarray)
|
||||
|
||||
def test_audio_to_file(self):
|
||||
audio = gr.processing_utils.audio_from_file(
|
||||
"test/test_data/test_audio.wav")
|
||||
audio = gr.processing_utils.audio_from_file("test/test_data/test_audio.wav")
|
||||
gr.processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
|
||||
self.assertTrue(os.path.exists("test_audio_to_file"))
|
||||
os.remove("test_audio_to_file")
|
||||
@ -69,31 +76,39 @@ class AudioPreprocessing(unittest.TestCase):
|
||||
|
||||
class OutputPreprocessing(unittest.TestCase):
|
||||
def test_decode_base64_to_binary(self):
|
||||
binary = gr.processing_utils.decode_base64_to_binary(
|
||||
gr.test_data.BASE64_IMAGE)
|
||||
binary = gr.processing_utils.decode_base64_to_binary(gr.test_data.BASE64_IMAGE)
|
||||
self.assertEqual(gr.test_data.BINARY_IMAGE, binary)
|
||||
|
||||
def test_decode_base64_to_file(self):
|
||||
temp_file = gr.processing_utils.decode_base64_to_file(
|
||||
gr.test_data.BASE64_IMAGE)
|
||||
temp_file = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_IMAGE)
|
||||
self.assertIsInstance(temp_file, tempfile._TemporaryFileWrapper)
|
||||
|
||||
def test_create_tmp_copy_of_file(self):
|
||||
temp_file = gr.processing_utils.create_tmp_copy_of_file(
|
||||
"test.txt")
|
||||
temp_file = gr.processing_utils.create_tmp_copy_of_file("test.txt")
|
||||
self.assertIsInstance(temp_file, tempfile._TemporaryFileWrapper)
|
||||
|
||||
float_dtype_list = [float, float, np.double, np.single, np.float32,
|
||||
np.float64, 'float32', 'float64']
|
||||
float_dtype_list = [
|
||||
float,
|
||||
float,
|
||||
np.double,
|
||||
np.single,
|
||||
np.float32,
|
||||
np.float64,
|
||||
"float32",
|
||||
"float64",
|
||||
]
|
||||
|
||||
def test_float_conversion_dtype(self):
|
||||
"""Test any convertion from a float dtype to an other."""
|
||||
|
||||
x = np.array([-1, 1])
|
||||
# Test all combinations of dtypes conversions
|
||||
dtype_combin = np.array(np.meshgrid(
|
||||
OutputPreprocessing.float_dtype_list,
|
||||
OutputPreprocessing.float_dtype_list)).T.reshape(-1, 2)
|
||||
dtype_combin = np.array(
|
||||
np.meshgrid(
|
||||
OutputPreprocessing.float_dtype_list,
|
||||
OutputPreprocessing.float_dtype_list,
|
||||
)
|
||||
).T.reshape(-1, 2)
|
||||
|
||||
for dtype_in, dtype_out in dtype_combin:
|
||||
x = x.astype(dtype_in)
|
||||
@ -108,5 +123,6 @@ class OutputPreprocessing(unittest.TestCase):
|
||||
y = gr.processing_utils._convert(x, np.floating)
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,13 +1,14 @@
|
||||
import io
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from gradio import tunneling, networking, Interface
|
||||
import threading
|
||||
import paramiko
|
||||
import os
|
||||
|
||||
import paramiko
|
||||
|
||||
from gradio import Interface, networking, tunneling
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -27,13 +28,13 @@ class TestTunneling(unittest.TestCase):
|
||||
io.close()
|
||||
|
||||
|
||||
class TestVerbose(unittest.TestCase):
|
||||
"""Not absolutely needed but just including them for the sake of completion."""
|
||||
|
||||
class TestVerbose(unittest.TestCase):
|
||||
"""Not absolutely needed but just including them for the sake of completion."""
|
||||
|
||||
def setUp(self):
|
||||
self.message = "print test"
|
||||
self.capturedOutput = io.StringIO() # Create StringIO object
|
||||
sys.stdout = self.capturedOutput # and redirect stdout.
|
||||
self.capturedOutput = io.StringIO() # Create StringIO object
|
||||
sys.stdout = self.capturedOutput # and redirect stdout.
|
||||
|
||||
def test_verbose_debug_true(self):
|
||||
tunneling.verbose(self.message, debug_mode=True)
|
||||
@ -41,10 +42,11 @@ class TestVerbose(unittest.TestCase):
|
||||
|
||||
def test_verbose_debug_false(self):
|
||||
tunneling.verbose(self.message, debug_mode=False)
|
||||
self.assertEqual(self.capturedOutput.getvalue().strip(), '')
|
||||
self.assertEqual(self.capturedOutput.getvalue().strip(), "")
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,14 +1,15 @@
|
||||
import ipaddress
|
||||
import os
|
||||
import pkg_resources
|
||||
import requests
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
|
||||
import pkg_resources
|
||||
import requests
|
||||
|
||||
import gradio
|
||||
from gradio.utils import *
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
@ -21,7 +22,10 @@ class TestUtils(unittest.TestCase):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
self.assertEqual(str(w[-1].message), "gradio is not setup or installed properly. Unable to get version info.")
|
||||
self.assertEqual(
|
||||
str(w[-1].message),
|
||||
"gradio is not setup or installed properly. Unable to get version info.",
|
||||
)
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_should_warn_with_unable_to_parse(self, mock_get):
|
||||
@ -31,7 +35,9 @@ class TestUtils(unittest.TestCase):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
self.assertEqual(str(w[-1].message), "unable to parse version details from package URL.")
|
||||
self.assertEqual(
|
||||
str(w[-1].message), "unable to parse version details from package URL."
|
||||
)
|
||||
|
||||
@mock.patch("requests.Response.json")
|
||||
def test_should_warn_url_not_having_version(self, mock_json):
|
||||
@ -41,17 +47,18 @@ class TestUtils(unittest.TestCase):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
self.assertEqual(str(w[-1].message), "package URL does not contain version info.")
|
||||
|
||||
|
||||
self.assertEqual(
|
||||
str(w[-1].message), "package URL does not contain version info."
|
||||
)
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
error_analytics("placeholder", "placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_successful(self, mock_post):
|
||||
error_analytics("placeholder", "placeholder")
|
||||
mock_post.assert_called()
|
||||
@ -61,29 +68,29 @@ class TestUtils(unittest.TestCase):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
launch_analytics(data={})
|
||||
mock_post.assert_called()
|
||||
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_colab_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert colab_check() is False
|
||||
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_ipython_check_import_fail(self, mock_get_ipython):
|
||||
mock_get_ipython.side_effect = ImportError()
|
||||
assert ipython_check() is False
|
||||
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_ipython_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
assert ipython_check() is False
|
||||
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_readme_to_html_doesnt_crash_on_connection_error(self, mock_get):
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
readme_to_html("placeholder")
|
||||
|
||||
|
||||
def test_readme_to_html_correct_parse(self):
|
||||
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
|
||||
readme_to_html("https://github.com/gradio-app/gradio/blob/master/README.md")
|
||||
|
||||
|
||||
class TestIPAddress(unittest.TestCase):
|
||||
@ -101,6 +108,5 @@ class TestIPAddress(unittest.TestCase):
|
||||
self.assertEqual(ip, "No internet connection")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,8 +1,10 @@
|
||||
import os, sys
|
||||
import subprocess
|
||||
from jinja2 import Template
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
GRADIO_DIR = os.path.join(os.getcwd(), os.pardir, os.pardir)
|
||||
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
|
||||
@ -22,17 +24,19 @@ for demo_name in demos_to_run:
|
||||
demo_folder = os.path.join(GRADIO_DEMO_DIR, demo_name)
|
||||
requirements_file = os.path.join(demo_folder, "requirements.txt")
|
||||
if os.path.exists(requirements_file):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", requirements_file])
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "-r", requirements_file]
|
||||
)
|
||||
setup_file = os.path.join(demo_folder, "setup.sh")
|
||||
if os.path.exists(setup_file):
|
||||
subprocess.check_call(["sh", setup_file])
|
||||
subprocess.check_call(["sh", setup_file])
|
||||
demo_port_sets.append((demo_name, port))
|
||||
port += 1
|
||||
|
||||
with open("nginx_template.conf") as nginx_template_conf:
|
||||
with open("nginx_template.conf") as nginx_template_conf:
|
||||
template = Template(nginx_template_conf.read())
|
||||
output_nginx_conf = template.render(demo_port_sets=demo_port_sets)
|
||||
with open("nginx.conf", "w") as nginx_conf:
|
||||
nginx_conf.write(output_nginx_conf)
|
||||
with open("demos.json", "w") as demos_file:
|
||||
json.dump(demo_port_sets, demos_file)
|
||||
json.dump(demo_port_sets, demos_file)
|
||||
|
@ -1,11 +1,12 @@
|
||||
import time
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
LAUNCH_PERIOD = 60
|
||||
GRADIO_DEMO_DIR = "../../demo"
|
||||
@ -14,18 +15,18 @@ sys.path.insert(0, GRADIO_DEMO_DIR)
|
||||
with open("demos.json") as demos_file:
|
||||
demo_port_sets = json.load(demos_file)
|
||||
|
||||
|
||||
def launch_demo(demo_folder):
|
||||
subprocess.call(f"cd {demo_folder} && python run.py", shell=True)
|
||||
|
||||
|
||||
for demo_name, port in demo_port_sets:
|
||||
demo_folder = os.path.join(GRADIO_DEMO_DIR, demo_name)
|
||||
demo_file = os.path.join(demo_folder, "run.py")
|
||||
with open(demo_file, 'r') as file:
|
||||
with open(demo_file, "r") as file:
|
||||
filedata = file.read()
|
||||
filedata = filedata.replace(
|
||||
f'iface.launch()',
|
||||
f'iface.launch(server_port={port})')
|
||||
with open(demo_file, 'w') as file:
|
||||
filedata = filedata.replace(f"iface.launch()", f"iface.launch(server_port={port})")
|
||||
with open(demo_file, "w") as file:
|
||||
file.write(filedata)
|
||||
demo_thread = threading.Thread(target=launch_demo, args=(demo_folder,))
|
||||
demo_thread.start()
|
||||
|
@ -1,81 +1,108 @@
|
||||
import os
|
||||
import json
|
||||
from jinja2 import Template
|
||||
import requests
|
||||
import markdown2
|
||||
import re
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio.interface import Interface
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import markdown2
|
||||
import requests
|
||||
from jinja2 import Template
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.interface import Interface
|
||||
from gradio.outputs import OutputComponent
|
||||
|
||||
GRADIO_DIR = "../../"
|
||||
GRADIO_GUIDES_DIR = os.path.join(GRADIO_DIR, "guides")
|
||||
GRADIO_DEMO_DIR = os.path.join(GRADIO_DIR, "demo")
|
||||
|
||||
guide_names = [] # used for dropdown in navbar
|
||||
guide_names = [] # used for dropdown in navbar
|
||||
for guide in sorted(os.listdir(GRADIO_GUIDES_DIR)):
|
||||
if "template" in guide or "getting_started" in guide:
|
||||
continue
|
||||
guide_name = guide[:-3]
|
||||
pretty_guide_name = " ".join([word.capitalize().replace("Ml", "ML")
|
||||
for word in guide_name.split("_")])
|
||||
pretty_guide_name = " ".join(
|
||||
[word.capitalize().replace("Ml", "ML") for word in guide_name.split("_")]
|
||||
)
|
||||
guide_names.append((guide_name, pretty_guide_name))
|
||||
|
||||
|
||||
def render_index():
|
||||
os.makedirs("generated", exist_ok=True)
|
||||
with open("src/tweets.json", encoding='utf-8') as tweets_file:
|
||||
with open("src/tweets.json", encoding="utf-8") as tweets_file:
|
||||
tweets = json.load(tweets_file)
|
||||
star_count = "{:,}".format(requests.get("https://api.github.com/repos/gradio-app/gradio"
|
||||
).json()["stargazers_count"])
|
||||
with open("src/index_template.html", encoding='utf-8') as template_file:
|
||||
star_count = "{:,}".format(
|
||||
requests.get("https://api.github.com/repos/gradio-app/gradio").json()[
|
||||
"stargazers_count"
|
||||
]
|
||||
)
|
||||
with open("src/index_template.html", encoding="utf-8") as template_file:
|
||||
template = Template(template_file.read())
|
||||
output_html = template.render(tweets=tweets, star_count=star_count, guide_names=guide_names)
|
||||
with open(os.path.join("generated", "index.html"), "w", encoding='utf-8') as generated_template:
|
||||
output_html = template.render(
|
||||
tweets=tweets, star_count=star_count, guide_names=guide_names
|
||||
)
|
||||
with open(
|
||||
os.path.join("generated", "index.html"), "w", encoding="utf-8"
|
||||
) as generated_template:
|
||||
generated_template.write(output_html)
|
||||
|
||||
|
||||
def render_guides():
|
||||
guides = []
|
||||
for guide in os.listdir(GRADIO_GUIDES_DIR):
|
||||
if "template" in guide:
|
||||
continue
|
||||
with open(os.path.join(GRADIO_GUIDES_DIR, guide), encoding='utf-8') as guide_file:
|
||||
with open(
|
||||
os.path.join(GRADIO_GUIDES_DIR, guide), encoding="utf-8"
|
||||
) as guide_file:
|
||||
guide_text = guide_file.read()
|
||||
code_tags = re.findall(r'\{\{ code\["([^\s]*)"\] \}\}', guide_text)
|
||||
demo_names = re.findall(r'\{\{ demos\["([^\s]*)"\] \}\}', guide_text)
|
||||
code, demos = {}, {}
|
||||
guide_text = guide_text.replace(
|
||||
"website/src/assets", "/assets").replace(
|
||||
"```python\n", "<pre><code class='lang-python'>").replace(
|
||||
"```bash\n", "<pre><code class='lang-bash'>").replace(
|
||||
"```directory\n", "<pre><code class='lang-bash'>").replace(
|
||||
"```csv\n", "<pre><code class='lang-bash'>").replace(
|
||||
"```", "</code></pre>")
|
||||
guide_text = (
|
||||
guide_text.replace("website/src/assets", "/assets")
|
||||
.replace("```python\n", "<pre><code class='lang-python'>")
|
||||
.replace("```bash\n", "<pre><code class='lang-bash'>")
|
||||
.replace("```directory\n", "<pre><code class='lang-bash'>")
|
||||
.replace("```csv\n", "<pre><code class='lang-bash'>")
|
||||
.replace("```", "</code></pre>")
|
||||
)
|
||||
for code_src in code_tags:
|
||||
with open(os.path.join(GRADIO_DEMO_DIR, code_src, "run.py")) as code_file:
|
||||
python_code = code_file.read().replace(
|
||||
'if __name__ == "__main__":\n iface.launch()', "iface.launch()")
|
||||
code[code_src] = "<pre><code class='lang-python'>" + \
|
||||
python_code + "</code></pre>"
|
||||
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
|
||||
)
|
||||
code[code_src] = (
|
||||
"<pre><code class='lang-python'>" + python_code + "</code></pre>"
|
||||
)
|
||||
for demo_name in demo_names:
|
||||
demos[demo_name] = "<div id='interface_" + demo_name + "'></div>"
|
||||
guide_template = Template(guide_text)
|
||||
guide_output = guide_template.render(code=code, demos=demos)
|
||||
output_html = markdown2.markdown(guide_output)
|
||||
output_html = output_html.replace("<a ", "<a target='blank' ")
|
||||
for match in re.findall(r'<h3>([A-Za-z0-9 ]*)<\/h3>', output_html):
|
||||
for match in re.findall(r"<h3>([A-Za-z0-9 ]*)<\/h3>", output_html):
|
||||
output_html = output_html.replace(
|
||||
f"<h3>{match}</h3>", f"<h3 id={match.lower().replace(' ', '_')}>{match}</h3>")
|
||||
f"<h3>{match}</h3>",
|
||||
f"<h3 id={match.lower().replace(' ', '_')}>{match}</h3>",
|
||||
)
|
||||
os.makedirs("generated", exist_ok=True)
|
||||
guide = guide[:-3]
|
||||
os.makedirs(os.path.join(
|
||||
"generated", guide), exist_ok=True)
|
||||
with open("src/guides_template.html", encoding='utf-8') as general_template_file:
|
||||
os.makedirs(os.path.join("generated", guide), exist_ok=True)
|
||||
with open(
|
||||
"src/guides_template.html", encoding="utf-8"
|
||||
) as general_template_file:
|
||||
general_template = Template(general_template_file.read())
|
||||
with open(os.path.join("generated", guide, "index.html"), "w", encoding='utf-8') as generated_template:
|
||||
output_html = general_template.render(template_html=output_html, demo_names=demo_names, guide_names=guide_names)
|
||||
with open(
|
||||
os.path.join("generated", guide, "index.html"), "w", encoding="utf-8"
|
||||
) as generated_template:
|
||||
output_html = general_template.render(
|
||||
template_html=output_html,
|
||||
demo_names=demo_names,
|
||||
guide_names=guide_names,
|
||||
)
|
||||
generated_template.write(output_html)
|
||||
|
||||
|
||||
def render_docs():
|
||||
if os.path.exists("generated/colab_links.json"):
|
||||
with open("generated/colab_links.json") as demo_links_file:
|
||||
@ -110,10 +137,15 @@ def render_docs():
|
||||
name = line[:space_index]
|
||||
documented_params.add(name)
|
||||
params_doc.append(
|
||||
(name, line[space_index+2:colon_index-1], line[colon_index+2:]))
|
||||
(
|
||||
name,
|
||||
line[space_index + 2 : colon_index - 1],
|
||||
line[colon_index + 2 :],
|
||||
)
|
||||
)
|
||||
elif mode == "out":
|
||||
colon_index = line.index(":")
|
||||
return_doc.append((line[1:colon_index-1], line[colon_index+2:]))
|
||||
return_doc.append((line[1 : colon_index - 1], line[colon_index + 2 :]))
|
||||
params = inspect.getfullargspec(func)
|
||||
param_set = []
|
||||
for i in range(len(params.args)):
|
||||
@ -139,13 +171,20 @@ def render_docs():
|
||||
inp["doc"] = "\n".join(doc_lines[:-2])
|
||||
inp["type"] = doc_lines[-2].split("type: ")[-1]
|
||||
inp["demos"] = doc_lines[-1][7:].split(", ")
|
||||
_, inp["params"], inp["params_doc"], _ = get_function_documentation(cls.__init__)
|
||||
_, inp["params"], inp["params_doc"], _ = get_function_documentation(
|
||||
cls.__init__
|
||||
)
|
||||
inp["shortcuts"] = list(cls.get_shortcut_implementations().items())
|
||||
if "interpret" in cls.__dict__:
|
||||
inp["interpret"], inp["interpret_params"], inp["interpret_params_doc"], _ = get_function_documentation(
|
||||
cls.interpret)
|
||||
(
|
||||
inp["interpret"],
|
||||
inp["interpret_params"],
|
||||
inp["interpret_params_doc"],
|
||||
_,
|
||||
) = get_function_documentation(cls.interpret)
|
||||
_, _, _, inp["interpret_returns_doc"] = get_function_documentation(
|
||||
cls.get_interpretation_scores)
|
||||
cls.get_interpretation_scores
|
||||
)
|
||||
|
||||
return inp
|
||||
|
||||
@ -178,24 +217,32 @@ def render_docs():
|
||||
os.makedirs("generated", exist_ok=True)
|
||||
with open("src/docs_template.html") as template_file:
|
||||
template = Template(template_file.read())
|
||||
output_html = template.render(docs=docs, demo_links=demo_links, guide_names=guide_names)
|
||||
output_html = template.render(
|
||||
docs=docs, demo_links=demo_links, guide_names=guide_names
|
||||
)
|
||||
os.makedirs(os.path.join("generated", "docs"), exist_ok=True)
|
||||
with open(os.path.join("generated", "docs", "index.html"), "w") as generated_template:
|
||||
with open(
|
||||
os.path.join("generated", "docs", "index.html"), "w"
|
||||
) as generated_template:
|
||||
generated_template.write(output_html)
|
||||
|
||||
|
||||
def render_other():
|
||||
os.makedirs("generated", exist_ok=True)
|
||||
for template_filename in os.listdir("src/other_templates"):
|
||||
with open(os.path.join("src/other_templates", template_filename)) as template_file:
|
||||
with open(
|
||||
os.path.join("src/other_templates", template_filename)
|
||||
) as template_file:
|
||||
template = Template(template_file.read())
|
||||
output_html = template.render(guide_names=guide_names)
|
||||
folder_name = template_filename[:-14]
|
||||
os.makedirs(os.path.join("generated", folder_name), exist_ok=True)
|
||||
with open(os.path.join("generated", folder_name, "index.html"), "w", encoding='utf-8') as generated_template:
|
||||
with open(
|
||||
os.path.join("generated", folder_name, "index.html"), "w", encoding="utf-8"
|
||||
) as generated_template:
|
||||
generated_template.write(output_html)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_index()
|
||||
render_guides()
|
||||
|
@ -8,11 +8,15 @@ for folder in ("generated", "dist"):
|
||||
for root, _, files in os.walk(folder):
|
||||
for file in files:
|
||||
if file.endswith(".html"):
|
||||
with open(os.path.join(root, file), encoding='utf-8') as old_file:
|
||||
with open(os.path.join(root, file), encoding="utf-8") as old_file:
|
||||
content = old_file.read()
|
||||
for old_name, new_name in style_map.items():
|
||||
content = content.replace(old_name, new_name)
|
||||
with open(os.path.join(root, file), "w", encoding='utf-8') as new_file:
|
||||
with open(os.path.join(root, file), "w", encoding="utf-8") as new_file:
|
||||
new_file.write(content)
|
||||
elif file.startswith("style.") and file.endswith(".css") and file not in list(style_map.values()):
|
||||
elif (
|
||||
file.startswith("style.")
|
||||
and file.endswith(".css")
|
||||
and file not in list(style_map.values())
|
||||
):
|
||||
os.remove(os.path.join(root, file))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
from pydrive.auth import GoogleAuth
|
||||
from pydrive.drive import GoogleDrive
|
||||
|
||||
@ -23,17 +24,24 @@ GRADIO_DEMO_DIR = "../../demo/"
|
||||
NOTEBOOK_TYPE = "application/vnd.google.colaboratory"
|
||||
GOOGLE_FOLDER_TYPE = "application/vnd.google-apps.folder"
|
||||
|
||||
|
||||
def run():
|
||||
drive = GoogleDrive(gauth)
|
||||
|
||||
demo_links = {}
|
||||
with open("colab_template.ipynb") as notebook_template_file:
|
||||
notebook_template = notebook_template_file.read()
|
||||
file_list = drive.ListFile({'q': f"title='Demos' and mimeType='{GOOGLE_FOLDER_TYPE}' and 'root' in parents and trashed=false"}).GetList()
|
||||
file_list = drive.ListFile(
|
||||
{
|
||||
"q": f"title='Demos' and mimeType='{GOOGLE_FOLDER_TYPE}' and 'root' in parents and trashed=false"
|
||||
}
|
||||
).GetList()
|
||||
if len(file_list) > 0:
|
||||
demo_folder = file_list[0].metadata["id"]
|
||||
else:
|
||||
demo_folder_file = drive.CreateFile({'title' : "Demos", 'mimeType' : GOOGLE_FOLDER_TYPE})
|
||||
demo_folder_file = drive.CreateFile(
|
||||
{"title": "Demos", "mimeType": GOOGLE_FOLDER_TYPE}
|
||||
)
|
||||
demo_folder_file.Upload()
|
||||
demo_folder = demo_folder_file.metadata["id"]
|
||||
for demo_name in os.listdir(GRADIO_DEMO_DIR):
|
||||
@ -41,29 +49,40 @@ def run():
|
||||
print("--- " + demo_name + " ---")
|
||||
with open(os.path.join(GRADIO_DEMO_DIR, demo_name, "run.py")) as demo_file:
|
||||
demo_content = demo_file.read()
|
||||
demo_content = demo_content.replace('if __name__ == "__main__":\n iface.launch()', "iface.launch()")
|
||||
demo_content = demo_content.replace(
|
||||
'if __name__ == "__main__":\n iface.launch()', "iface.launch()"
|
||||
)
|
||||
lines = demo_content.split("/n")
|
||||
demo_content = [line + "\n" if i != len(lines) - 1 else line for i, line in enumerate(lines)]
|
||||
demo_content = [
|
||||
line + "\n" if i != len(lines) - 1 else line
|
||||
for i, line in enumerate(lines)
|
||||
]
|
||||
notebook = json.loads(notebook_template)
|
||||
notebook["cells"][1]["source"] = demo_content
|
||||
file_list = drive.ListFile({'q': f"title='{notebook_title}' and mimeType='{NOTEBOOK_TYPE}' and 'root' in parents and trashed=false"}).GetList()
|
||||
file_list = drive.ListFile(
|
||||
{
|
||||
"q": f"title='{notebook_title}' and mimeType='{NOTEBOOK_TYPE}' and 'root' in parents and trashed=false"
|
||||
}
|
||||
).GetList()
|
||||
if len(file_list) > 0:
|
||||
drive_file = file_list[0]
|
||||
else:
|
||||
drive_file = drive.CreateFile({
|
||||
'title': notebook_title,
|
||||
'mimeType': NOTEBOOK_TYPE,
|
||||
'parents': [{"id": demo_folder}]
|
||||
})
|
||||
drive_file = drive.CreateFile(
|
||||
{
|
||||
"title": notebook_title,
|
||||
"mimeType": NOTEBOOK_TYPE,
|
||||
"parents": [{"id": demo_folder}],
|
||||
}
|
||||
)
|
||||
drive_file.SetContentString(json.dumps(notebook))
|
||||
drive_file.Upload()
|
||||
drive_file.InsertPermission({
|
||||
'type': 'anyone',
|
||||
'value': 'anyone',
|
||||
'role': 'reader'})
|
||||
demo_links[demo_name] = drive_file['alternateLink']
|
||||
drive_file.InsertPermission(
|
||||
{"type": "anyone", "value": "anyone", "role": "reader"}
|
||||
)
|
||||
demo_links[demo_name] = drive_file["alternateLink"]
|
||||
with open("../.env", "w") as env_file:
|
||||
env_file.write(f"COLAB_NOTEBOOK_LINKS='{json.dumps(demo_links)}'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
run()
|
||||
|
Loading…
x
Reference in New Issue
Block a user