cache examples functionality; other fixes

This commit is contained in:
Ali Abid 2021-12-30 08:20:28 +00:00
parent 63a4540b92
commit 5020f8aa1f
33 changed files with 532 additions and 354 deletions

1
.gitignore vendored
View File

@ -21,6 +21,7 @@ gradio/templates/frontend/static
*.sqlite3
gradio/launches.json
flagged/
gradio_cached_examples/
# Tests
.coverage

View File

@ -0,0 +1,8 @@
time,value,price
1,1,4
2,3,8
3,6,12
4,10,16
5,15,20
6,21,24
7,28,28
1 time value price
2 1 1 4
3 2 3 8
4 3 6 12
5 4 10 16
6 5 15 20
7 6 21 24
8 7 28 28

Binary file not shown.

View File

@ -23,7 +23,7 @@ def fn(text1, text2, num, slider1, slider2, single_checkbox,
{"name": "Jane", "age": 34, "children": checkboxes}, # JSON
"<button style='background-color: red'>Click Me: " + radio + "</button>", # HTML
"files/titanic.csv",
np.ones((4, 3)), # Dataframe
df1, # Dataframe
[im for im in [im1, im2, im3, im4, "files/cheetah1.jpg"] if im is not None], # Carousel
df2 # Timeseries
)
@ -33,30 +33,24 @@ iface = gr.Interface(
fn,
inputs=[
gr.inputs.Textbox(default="Lorem ipsum", label="Textbox"),
gr.inputs.Textbox(lines=3, placeholder="Type here..",
label="Textbox 2"),
gr.inputs.Textbox(lines=3, placeholder="Type here..", label="Textbox 2"),
gr.inputs.Number(label="Number", default=42),
gr.inputs.Slider(minimum=10, maximum=20, default=15,
label="Slider: 10 - 20"),
gr.inputs.Slider(maximum=20, step=0.04,
label="Slider: step @ 0.04"),
gr.inputs.Slider(minimum=10, maximum=20, default=15, label="Slider: 10 - 20"),
gr.inputs.Slider(maximum=20, step=0.04, label="Slider: step @ 0.04"),
gr.inputs.Checkbox(label="Checkbox"),
gr.inputs.CheckboxGroup(label="CheckboxGroup",
choices=CHOICES, default=CHOICES[0:2]),
gr.inputs.CheckboxGroup(label="CheckboxGroup", choices=CHOICES, default=CHOICES[0:2]),
gr.inputs.Radio(label="Radio", choices=CHOICES, default=CHOICES[2]),
gr.inputs.Dropdown(label="Dropdown", choices=CHOICES),
gr.inputs.Image(label="Image", optional=True),
gr.inputs.Image(label="Image w/ Cropper",
tool="select", optional=True),
gr.inputs.Image(label="Image w/ Cropper", tool="select", optional=True),
gr.inputs.Image(label="Sketchpad", source="canvas", optional=True),
gr.inputs.Image(label="Webcam", source="webcam", optional=True),
gr.inputs.Video(label="Video", optional=True),
gr.inputs.Audio(label="Audio", optional=True),
gr.inputs.Audio(label="Microphone",
source="microphone", optional=True),
gr.inputs.Audio(label="Microphone", source="microphone", optional=True),
gr.inputs.File(label="File", optional=True),
gr.inputs.Dataframe(),
gr.inputs.Timeseries(x="time", y="value", optional=True),
gr.inputs.Timeseries(x="time", y=["price", "value"], optional=True),
],
outputs=[
gr.outputs.Textbox(label="Textbox"),
@ -71,8 +65,11 @@ iface = gr.Interface(
gr.outputs.File(label="File"),
gr.outputs.Dataframe(label="Dataframe"),
gr.outputs.Carousel("image", label="Carousel"),
gr.outputs.Timeseries(x="time", y="value", label="Timeseries")
gr.outputs.Timeseries(x="time", y=["price", "value"], label="Timeseries")
],
examples=[
["the quick brown fox", "jumps over the lazy dog", 10, 12, 4, True, ["foo", "baz"], "baz", "bar", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/cheetah1.jpg", "files/world.mp4", "files/cantina.wav", "files/cantina.wav","files/titanic.csv", [[1,2,3],[3,4,5]], "files/time.csv"]
] * 3,
theme="huggingface",
title="Kitchen Sink",
description="Try out all the components!",

View File

@ -39,7 +39,11 @@ module.exports = {
},
style: {
postcss: {
plugins: [require("postcss-prefixwrap")(".gradio_app"), require("tailwindcss"), require("autoprefixer")]
plugins: [
require("postcss-prefixwrap")(".gradio_app"),
require("tailwindcss"),
require("autoprefixer")
]
}
}
};

View File

@ -25,13 +25,11 @@
gtag('js', new Date());
gtag('config', 'UA-156449732-1');
window.gradio_mode = "app";
try {
</script>
<script>
window.gradio_config = {{ config|tojson }};
} catch (e) {
}
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<title>Gradio</title>
</head>

View File

@ -5,13 +5,13 @@ export default class ComponentExample extends React.Component {
render() {
return <div>{this.props.value}</div>;
}
static async preprocess(x) {
static async preprocess(x, examples_dir, component_config) {
return x;
}
}
export class FileComponentExample extends ComponentExample {
static async preprocess(x, examples_dir) {
static async preprocess(x, examples_dir, component_config) {
return {
name: x,
data: examples_dir + "/" + x,
@ -21,7 +21,7 @@ export class FileComponentExample extends ComponentExample {
}
export class DataURLComponentExample extends ComponentExample {
static async preprocess(x, examples_dir) {
static async preprocess(x, examples_dir, component_config) {
let file_url = examples_dir + "/" + x;
let response = await fetch(file_url);
if (!response.ok) {

View File

@ -1,27 +1,22 @@
import React from "react";
import BaseComponent from "../base_component";
import ComponentExample from "../component_example";
import FileComponentExample from "../component_example";
import { CSVToArray } from "../../utils";
import { Scatter } from 'react-chartjs-2';
import { Scatter } from "react-chartjs-2";
import { getNextColor } from "../../utils";
class TimeseriesInput extends BaseComponent {
constructor(props) {
super(props);
this.handleChange = this.handleChange.bind(this);
this.uploader = React.createRef();
this.openFileUpload = this.openFileUpload.bind(this);
this.load_preview_from_files = this.load_preview_from_files.bind(this);
this.load_preview_from_upload = this.load_preview_from_upload.bind(this);
this.load_preview_from_drop = this.load_preview_from_drop.bind(this);
}
handleChange(data) {
handleChange = (data) => {
this.props.handleChange(data);
}
openFileUpload() {
};
openFileUpload = () => {
this.uploader.current.click();
}
render() {
};
render = () => {
let no_action = (evt) => {
evt.preventDefault();
evt.stopPropagation();
@ -29,23 +24,25 @@ class TimeseriesInput extends BaseComponent {
if (this.props.value !== null) {
return (
<div className="input_timeseries">
<Scatter data={{
"datasets": this.state.y_indices.map((y_index, i) => {
return {
label: this.props.y[i],
borderColor: getNextColor(i),
showLine: true,
fill: true,
backgroundColor: getNextColor(i, 0.25),
data: this.props.value["data"].map((row) => {
return {
x: row[this.state.x_index],
y: row[y_index]
}
})
}
})
}} />
<Scatter
data={{
datasets: this.props.y.map((header, i) => {
return {
label: header,
borderColor: getNextColor(i),
showLine: true,
fill: true,
backgroundColor: getNextColor(i, 0.25),
data: this.props.value["data"].map((row) => {
return {
x: row[0],
y: row[i + 1]
};
})
};
})
}}
/>
</div>
);
} else {
@ -88,46 +85,18 @@ class TimeseriesInput extends BaseComponent {
</div>
);
}
}
load_preview_from_drop(evt) {
};
load_preview_from_drop = (evt) => {
this.load_preview_from_files(evt.dataTransfer.files);
}
load_preview_from_upload(evt) {
};
load_preview_from_upload = (evt) => {
this.load_preview_from_files(evt.target.files);
}
load_file(reader) {
};
load_file = (reader) => {
let lines = reader.result;
let headers = null;
let data = null;
if (lines && lines.length > 0) {
let line_array = CSVToArray(lines);
if (line_array.length === 0) {
return;
}
if (this.props.x === null) {
this.setState({ x_index: 0, y_indices: [1] });
data = line_array;
} else {
let x_index = line_array[0].indexOf(this.props.x);
let y_indices = this.props.y.map((y_col) =>
line_array[0].indexOf(y_col)
);
if (x_index === -1) {
alert("Missing x column: " + this.props.x);
return;
}
if (y_indices.includes(-1)) {
alert("Missing y column: " + this.props.y[y_indices.indexOf(-1)]);
return;
}
this.setState({ x_index: x_index, y_indices: y_indices });
headers = line_array[0];
data = line_array.slice(1);
}
this.handleChange({ headers: headers, data: data, range: null });
}
}
load_preview_from_files(files) {
this.handleChange(load_data(lines, this.props.x, this.props.y));
};
load_preview_from_files = (files) => {
if (!files.length || !window.FileReader) {
return;
}
@ -137,7 +106,7 @@ class TimeseriesInput extends BaseComponent {
ReaderObj.readAsBinaryString(file);
ReaderObj.onloadend = this.load_file.bind(this, ReaderObj);
}
}
};
static memo = (a, b) => {
if (a.value instanceof Object && b.value instanceof Object) {
return (
@ -150,10 +119,45 @@ class TimeseriesInput extends BaseComponent {
};
}
class TimeseriesInputExample extends ComponentExample {
class TimeseriesInputExample extends FileComponentExample {
static async preprocess(x, examples_dir, component_config) {
let file_url = examples_dir + "/" + x;
let response = await fetch(file_url);
response = await response.text();
return load_data(response, component_config.x, component_config.y);
}
render() {
return <div className="input_file_example">{this.props.value}</div>;
}
}
var load_data = (lines, x, y) => {
let headers = null;
let data = null;
let line_array = CSVToArray(lines);
if (line_array.length === 0) {
return;
}
if (x === null) {
data = line_array;
} else {
let x_index = line_array[0].indexOf(x);
let y_indices = y.map((y_col) => line_array[0].indexOf(y_col));
if (x_index === -1) {
alert("Missing x column: " + x);
return;
}
if (y_indices.includes(-1)) {
alert("Missing y column: " + y[y_indices.indexOf(-1)]);
return;
}
line_array = line_array.map((line) =>
[line[x_index]].concat(y_indices.map((y_index) => line[y_index]))
);
headers = line_array[0];
data = line_array.slice(1);
}
return { headers: headers, data: data, range: null };
};
export { TimeseriesInput, TimeseriesInputExample };

View File

@ -16,7 +16,7 @@ class VideoInput extends BaseComponent {
this.load_preview_from_drop = this.load_preview_from_drop.bind(this);
this.camera_stream = null;
this.state = {
recording: false,
recording: false
};
}
handleChange(evt) {
@ -30,20 +30,23 @@ class VideoInput extends BaseComponent {
this.media_recorder.stop();
let video_blob = new Blob(this.blobs_recorded, { type: this.mimeType });
var ReaderObj = new FileReader();
ReaderObj.onload = (function(e) {
ReaderObj.onload = function (e) {
let file_name = "sample." + this.mimeType.substring(6);
this.props.handleChange({
name: file_name,
data: e.target.result,
is_example: false
});
}).bind(this);
}.bind(this);
ReaderObj.readAsDataURL(video_blob);
this.setState({ recording: false });
} else {
this.blobs_recorded = [];
this.camera_stream = await navigator.mediaDevices.getUserMedia({ video: true, audio: true });
this.camera_stream = await navigator.mediaDevices.getUserMedia({
video: true,
audio: true
});
this.videoRecorder.current.srcObject = this.camera_stream;
this.videoRecorder.current.volume = 0;
let selectedMimeType = null;
@ -52,22 +55,27 @@ class VideoInput extends BaseComponent {
if (MediaRecorder.isTypeSupported(mimeType)) {
selectedMimeType = mimeType;
break;
}
}
}
if (selectedMimeType === null) {
console.error("No supported MediaRecorder mimeType");
return;
}
this.media_recorder = new MediaRecorder(this.camera_stream, {mimeType: selectedMimeType});
this.media_recorder = new MediaRecorder(this.camera_stream, {
mimeType: selectedMimeType
});
this.mimeType = selectedMimeType;
this.media_recorder.addEventListener('dataavailable', (function (e) {
this.blobs_recorded.push(e.data);
}).bind(this));
this.media_recorder.addEventListener(
"dataavailable",
function (e) {
this.blobs_recorded.push(e.data);
}.bind(this)
);
this.media_recorder.start(200);
this.videoRecorder.current.play();
this.setState({ recording: true });
}
}
};
render() {
let no_action = (evt) => {
evt.preventDefault();
@ -133,10 +141,20 @@ class VideoInput extends BaseComponent {
} else if (this.props.source == "webcam") {
return (
<div className="input_video">
<video ref={this.videoRecorder} class="video_recorder" autoPlay playsInline muted></video>
<video
ref={this.videoRecorder}
class="video_recorder"
autoPlay
playsInline
muted
></video>
<div class="record_holder">
<div class="record_message">
{this.state.recording ? <>Stop Recording</> : <>Click to Record</>}
{this.state.recording ? (
<>Stop Recording</>
) : (
<>Click to Record</>
)}
</div>
<button class="record" onClick={this.record}></button>
</div>

View File

@ -9,12 +9,12 @@ class FileOutput extends BaseComponent {
<div className="output_file">
<a
className="file_display"
href={"data:;base64," + this.props.value.data}
href={this.props.value.data}
download={this.props.value.name}
>
<div className="file_name">{this.props.value.name}</div>
<div className="file_size">
{this.props.value.size === null
{isNaN(this.props.value.size)
? ""
: prettyBytes(this.props.value.size)}
</div>

View File

@ -16,7 +16,7 @@ class HighlightedTextOutput extends BaseComponent {
}
let color = category_map[category];
if (!color) {
color = getNextColor(this.new_category_index)
color = getNextColor(this.new_category_index);
this.new_category_index++;
}
this.color_map[category] = color;

View File

@ -38,11 +38,13 @@ class LabelOutput extends BaseComponent {
<div className="confidences">{confidences}</div>
</div>
</div>
);
);
}
return (
<div className="output_label">
<div className="output_class_without_confidences">{this.props.value["label"]}</div>
<div className="output_class_without_confidences">
{this.props.value["label"]}
</div>
</div>
);
}

View File

@ -1,7 +1,7 @@
import React from "react";
import BaseComponent from "../base_component";
import ComponentExample from "../component_example";
import { Scatter } from 'react-chartjs-2';
import { Scatter } from "react-chartjs-2";
import { getNextColor } from "../../utils";
class TimeseriesOutput extends BaseComponent {
@ -16,23 +16,25 @@ class TimeseriesOutput extends BaseComponent {
let x_index = this.props.value.headers.indexOf(this.props.x);
return (
<div className="output_timeseries">
<Scatter data={{
"datasets": y_indices.map((y_index, i) => {
return {
label: this.props.value.headers[y_index],
borderColor: getNextColor(i),
showLine: true,
fill: true,
backgroundColor: getNextColor(i, 0.25),
data: this.props.value["data"].map((row) => {
return {
x: row[x_index],
y: row[y_index]
}
})
}
})
}} />
<Scatter
data={{
datasets: y_indices.map((y_index, i) => {
return {
label: this.props.value.headers[y_index],
borderColor: getNextColor(i),
showLine: true,
fill: true,
backgroundColor: getNextColor(i, 0.25),
data: this.props.value["data"].map((row) => {
return {
x: row[x_index],
y: row[y_index]
};
})
};
})
}}
/>
</div>
);
} else {

View File

@ -9,7 +9,12 @@ class VideoOutput extends BaseComponent {
if (isPlayable("video", this.props.value["name"])) {
return (
<div className="output_video">
<video controls playsInline preload src={this.props.value["data"]}></video>
<video
controls
playsInline
preload
src={this.props.value["data"]}
></video>
</div>
);
} else {

View File

@ -19,11 +19,16 @@ export class GradioPage extends React.Component {
let space_name = this.props.space;
if (is_embedded) {
let slash_index = space_name.indexOf("/");
space_name = space_name[slash_index + 1].toUpperCase() + space_name.substring(
slash_index + 2);
space_name =
space_name[slash_index + 1].toUpperCase() +
space_name.substring(slash_index + 2);
}
return (
<div class={"gradio_bg"} theme={this.props.theme} is_embedded={is_embedded.toString()}>
<div
class={"gradio_bg"}
theme={this.props.theme}
is_embedded={is_embedded.toString()}
>
<div class="gradio_page">
<div class="content">
{this.props.title ? (
@ -47,27 +52,27 @@ export class GradioPage extends React.Component {
)}
</div>
<div className="footer">
{is_embedded ?
{is_embedded ? (
<>
<a href={"https://huggingface.co/spaces/" + this.props.space}>
{space_name}
</a> built with&nbsp;
<a href="https://gradio.app">
Gradio
</a>, hosted on&nbsp;
</a>{" "}
built with&nbsp;
<a href="https://gradio.app">Gradio</a>, hosted on&nbsp;
<a href="https://huggingface.co/spaces">Hugging Face Spaces</a>.
</>
:
) : (
<>
<a href="api" target="_blank" rel="noreferrer">
view the api <img className="api-logo" src={api_logo} alt="api"/>
view the api{" "}
<img className="api-logo" src={api_logo} alt="api" />
</a>
&bull;
<a href="https://gradio.app" target="_blank" rel="noreferrer">
built with <img className="logo" src={logo} alt="logo" />
</a>
</>
}
)}
</div>
</div>
</div>
@ -88,8 +93,8 @@ export class GradioInterface extends React.Component {
this.props.root +
(this.props.examples_dir === null
? "file" +
this.props.examples_dir +
(this.props.examples_dir.endswith("/") ? "" : "/")
this.props.examples_dir +
(this.props.examples_dir.endswith("/") ? "" : "/")
: "file");
}
get_default_state = () => {
@ -124,18 +129,25 @@ export class GradioInterface extends React.Component {
return;
}
this.pending_response = true;
let input_state = [];
for (let [i, input_component] of this.props.input_components.entries()) {
if (
this.state[i] === null &&
this.props.input_components[i].optional !== true
) {
return;
if (this.state.example_id === null) {
let input_state = [];
for (let [i, input_component] of this.props.input_components.entries()) {
if (
this.state[i] === null &&
this.props.input_components[i].optional !== true
) {
return;
}
let InputComponentClass = input_component_set.find(
(c) => c.name === input_component.name
).component;
input_state[i] = InputComponentClass.postprocess(this.state[i]);
}
let InputComponentClass = input_component_set.find(
(c) => c.name === input_component.name
).component;
input_state[i] = InputComponentClass.postprocess(this.state[i]);
var data = { data: input_state };
var queue = this.props.queue;
} else {
var data = { example_id: this.state.example_id };
var queue = this.props.queue && !this.props.cached_examples;
}
this.setState({
submitting: true,
@ -144,7 +156,7 @@ export class GradioInterface extends React.Component {
flag_index: null
});
this.props
.fn(input_state, "predict", this.queueCallback)
.fn(data, "predict", queue, this.queueCallback)
.then((output) => {
if (output["error"] != null) {
console.error("Error:", output["error"]);
@ -155,8 +167,14 @@ export class GradioInterface extends React.Component {
this.pending_response = false;
let index_start = this.props.input_components.length;
let new_state = {};
new_state["last_duration"] = output["durations"][0];
new_state["avg_duration"] = output["avg_durations"][0];
if (output["durations"]) {
new_state["last_duration"] = output["durations"][0];
} else {
new_state["last_duration"] = null;
}
if (output["avg_duration"]) {
new_state["avg_duration"] = output["avg_durations"][0];
}
for (let [i, value] of output["data"].entries()) {
new_state[index_start + i] = value;
}
@ -206,7 +224,7 @@ export class GradioInterface extends React.Component {
this.setState({ just_flagged: false });
}, 1000);
component_state["flag_option"] = flag_option;
this.props.fn(component_state, "flag");
this.props.fn({ data: component_state }, "flag");
};
interpret = () => {
if (this.pending_response) {
@ -228,7 +246,12 @@ export class GradioInterface extends React.Component {
}
this.setState({ submitting: true, has_changed: false, error: false });
this.props
.fn(input_state, "interpret", this.queueCallback)
.fn(
{ data: input_state },
"interpret",
this.props.queue,
this.queueCallback
)
.then((output) => {
if (!this.pending_response) {
return;
@ -264,8 +287,12 @@ export class GradioInterface extends React.Component {
saveAs(canvas.toDataURL(), "screenshot.png");
});
};
handleChange = (_id, value) => {
let state_change = { [_id]: value, has_changed: true };
handleChange = (_id, value, example_id) => {
let state_change = {
[_id]: value,
has_changed: true,
example_id: example_id === undefined ? null : example_id
};
if (this.props.live && !this.state.submitting) {
this.setState(state_change, this.submit);
} else {
@ -281,18 +308,23 @@ export class GradioInterface extends React.Component {
let component_data = input_component_set.find(
(c) => c.name === component_name
);
var component_config = this.props.input_components[i];
ExampleComponent = component_data.example_component;
} else {
let component_name =
this.props.output_components[i - this.props.input_components.length]
.name;
let component_index = i - this.props.input_components.length;
let component_name = this.props.output_components[component_index].name;
let component_data = output_component_set.find(
(c) => c.name === component_name
);
var component_config = this.props.input_components[component_index];
ExampleComponent = component_data.example_component;
}
ExampleComponent.preprocess(item, this.examples_dir).then((data) => {
this.handleChange(i, data);
ExampleComponent.preprocess(
item,
this.examples_dir,
component_config
).then((data) => {
this.handleChange(i, data, example_id);
});
}
};
@ -378,7 +410,9 @@ export class GradioInterface extends React.Component {
>
<div
className={classNames("component_set", "relative", {
"opacity-50": (this.pending_response && !this.props.live) || this.state.error
"opacity-50":
(this.pending_response && !this.props.live) ||
this.state.error
})}
>
{status}
@ -526,14 +560,14 @@ class GradioInterfaceExamples extends React.Component {
return <th key={i}>{component.label}</th>;
})}
{this.props.examples[0].length >
this.props.input_components.length
this.props.input_components.length
? this.props.output_components.map((component, i) => {
return (
<th key={i + this.props.input_components.length}>
{component.label}
</th>
);
})
return (
<th key={i + this.props.input_components.length}>
{component.label}
</th>
);
})
: false}
</tr>
</thead>

View File

@ -19,12 +19,10 @@ let postData = async (url, body) => {
return output;
};
let fn = async (api_endpoint, queue, data, action, queue_callback) => {
let fn = async (api_endpoint, data, action, queue, queue_callback) => {
if (queue && ["predict", "interpret"].includes(action)) {
const output = await postData(api_endpoint + "queue/push/", {
data: data,
action: action
});
data["action"] = action;
const output = await postData(api_endpoint + "queue/push/", data);
const output_json = await output.json();
let [hash, queue_position] = [
output_json["hash"],
@ -53,7 +51,7 @@ let fn = async (api_endpoint, queue, data, action, queue_callback) => {
return status_obj["data"];
}
} else {
const output = await postData(api_endpoint + action + "/", { data: data });
const output = await postData(api_endpoint + action + "/", data);
return await output.json();
}
};
@ -62,9 +60,7 @@ window.launchGradio = (config, element_query, space) => {
let target = document.querySelector(element_query);
target.classList.add("gradio_app");
if (config.auth_required) {
ReactDOM.render(
<Login {...config} />, target
);
ReactDOM.render(<Login {...config} />, target);
} else {
if (config.css !== null) {
var head = document.head || document.getElementsByTagName("head")[0],
@ -86,21 +82,21 @@ window.launchGradio = (config, element_query, space) => {
<GradioPage
{...config}
space={space}
fn={fn.bind(null, config.root + "api/", config.queue)}
fn={fn.bind(null, config.root + "api/")}
/>,
target
);
}
}
};
window.launchGradioFromSpaces = async (space, target) => {
const space_url = `https://huggingface.co/gradioiframe/${space}/+/`
const space_url = `https://huggingface.co/gradioiframe/${space}/+/`;
let config = await fetch(space_url + "config");
config = await config.json();
delete config.css;
config.root = space_url;
launchGradio(config, target, space);
}
};
async function get_config() {
if (process.env.REACT_APP_BACKEND_URL) {
@ -113,7 +109,7 @@ async function get_config() {
}
}
if (window.gradio_mode == "app") {
get_config().then(config => {
get_config().then((config) => {
launchGradio(config, "#root");
});
}

View File

@ -1,7 +1,10 @@
if (window.gradio_mode === "app") {
__webpack_public_path__ = "";
__webpack_public_path__ = "";
} else if (window.gradio_mode === "website") {
__webpack_public_path__ = "/gradio_static/"
__webpack_public_path__ = "/gradio_static/";
} else {
__webpack_public_path__ = "https://gradio.s3-us-west-2.amazonaws.com/" + process.env.REACT_APP_VERSION + "/";
__webpack_public_path__ =
"https://gradio.s3-us-west-2.amazonaws.com/" +
process.env.REACT_APP_VERSION +
"/";
}

View File

@ -171,7 +171,7 @@
@apply bg-gray-100 hover:bg-gray-200 p-2;
}
.examples_table:not(.gallery) {
@apply table-auto p-2 bg-gray-100 mt-4 rounded;
@apply table-auto p-2 bg-gray-100 mt-4 rounded max-w-full;
tbody tr {
@apply cursor-pointer;
}

View File

@ -11,7 +11,20 @@ html {
@apply absolute dark:text-gray-50 right-2 flex items-center gap-2 text-xs;
}
.load_status img {
@apply h-5;
@apply h-5 ml-2 inline-block;
}
.load_status .loading {
@keyframes ld-breath{
0%{
animation-timing-function:cubic-bezier(0.9647,0.2413,-0.0705,0.7911);
transform:scale(0.9)
}
51%{animation-timing-function:cubic-bezier(0.9226,0.2631,-0.0308,0.7628);
transform:scale(1.2)
}
100%{transform:scale(0.9)}
}
animation:ld-breath 0.75s infinite linear;
}
.panels {
@apply flex flex-wrap justify-center gap-4;
@ -71,6 +84,7 @@ html {
@apply hidden flex hidden flex-grow;
}
.examples {
@apply max-w-full overflow-x-auto;
h4 {
@apply text-lg font-semibold my-2;
}

View File

@ -7,7 +7,20 @@
@apply text-gray-700 dark:text-gray-50 absolute right-2 flex items-center gap-2 text-sm;
}
.load_status img {
@apply h-5;
@apply h-5 ml-2 inline-block;
}
.load_status .loading {
@keyframes ld-breath{
0%{
animation-timing-function:cubic-bezier(0.9647,0.2413,-0.0705,0.7911);
transform:scale(0.9)
}
51%{animation-timing-function:cubic-bezier(0.9226,0.2631,-0.0308,0.7628);
transform:scale(1.2)
}
100%{transform:scale(0.9)}
}
animation:ld-breath 0.75s infinite linear;
}
.panels {
@apply flex flex-wrap justify-center gap-4;
@ -58,6 +71,7 @@
@apply hidden flex hidden flex-grow;
}
.examples {
@apply max-w-full overflow-x-auto;
h4 {
@apply text-lg font-semibold my-2;
}
@ -194,7 +208,7 @@
}
}
.input_image_example {
@apply h-24;
@apply h-24 max-w-none;
}
.input_radio {
@apply flex flex-wrap gap-2;
@ -416,7 +430,7 @@
@apply h-36 object-contain flex justify-center;
}
.video_preview {
@apply w-full;
@apply max-w-none;
}
}
.input_file {

View File

@ -10,6 +10,19 @@
.load_status img {
@apply h-5 ml-2 inline-block;
}
.load_status .loading {
@keyframes ld-breath{
0%{
animation-timing-function:cubic-bezier(0.9647,0.2413,-0.0705,0.7911);
transform:scale(0.9)
}
51%{animation-timing-function:cubic-bezier(0.9226,0.2631,-0.0308,0.7628);
transform:scale(1.2)
}
100%{transform:scale(0.9)}
}
animation:ld-breath 0.75s infinite linear;
}
.panels {
@apply flex flex-wrap justify-center gap-4;
}
@ -71,6 +84,7 @@
@apply rounded-tl-none rounded-bl-none hover:bg-gray-200;
}
.examples {
@apply max-w-full overflow-x-auto;
h4 {
@apply text-lg font-semibold my-2;
}

View File

@ -131,12 +131,12 @@ export function CSVToArray(strData, strDelimiter) {
strDelimiter = strDelimiter || ",";
let objPattern = new RegExp(
"(\\" +
strDelimiter +
"|\\r?\\n|\\r|^)" +
'(?:"([^"]*(?:""[^"]*)*)"|' +
'([^"\\' +
strDelimiter +
"\\r\\n]*))",
strDelimiter +
"|\\r?\\n|\\r|^)" +
'(?:"([^"]*(?:""[^"]*)*)"|' +
'([^"\\' +
strDelimiter +
"\\r\\n]*))",
"gi"
);
let arrData = [[]];
@ -189,9 +189,10 @@ export function getNextColor(index, alpha) {
if (index < default_colors.length) {
var color_set = default_colors[index];
} else {
var color_set = [randInt(128, 240), randInt(128, 240), randInt(128, 240)]
var color_set = [randInt(128, 240), randInt(128, 240), randInt(128, 240)];
}
return "rgba(" +
return (
"rgba(" +
color_set[0] +
", " +
color_set[1] +
@ -199,5 +200,6 @@ export function getNextColor(index, alpha) {
color_set[2] +
", " +
alpha +
")";
}
")"
);
}

View File

@ -39,7 +39,7 @@ class Component():
"""
return data
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
"""
Restores flagged data from logs
"""
@ -65,6 +65,10 @@ class Component():
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
return label + "/" + new_file_name
def restore_flagged_file(self, dir, file, encryption_key):
data = processing_utils.encode_file_to_base64(os.path.join(dir, file), encryption_key=encryption_key)
return {"name": file, "data": data}
@classmethod
def get_all_shortcut_implementations(cls):
shortcuts = {}

View File

@ -85,31 +85,37 @@ class CSVLogger(FlaggingCallback):
log_fp = "{}/log.csv".format(flagging_dir)
encryption_key = interface.encryption_key if interface.encrypt else None
is_new = not os.path.exists(log_fp)
output_only_mode = input_data is None
if flag_index is None:
csv_data = []
for i, input in enumerate(interface.input_components):
csv_data.append(input.save_flagged(
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], encryption_key))
if not output_only_mode:
for i, input in enumerate(interface.input_components):
csv_data.append(input.save_flagged(
flagging_dir, interface.config["input_components"][i]["label"], input_data[i], encryption_key))
for i, output in enumerate(interface.output_components):
csv_data.append(output.save_flagged(
flagging_dir, interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
output_data[i] is not None else "")
if flag_option is not None:
csv_data.append(flag_option)
if username is not None:
csv_data.append(username)
csv_data.append(str(datetime.datetime.now()))
if not output_only_mode:
if flag_option is not None:
csv_data.append(flag_option)
if username is not None:
csv_data.append(username)
csv_data.append(str(datetime.datetime.now()))
if is_new:
headers = [interface["label"]
for interface in interface.config["input_components"]]
headers = []
if not output_only_mode:
headers += [interface["label"]
for interface in interface.config["input_components"]]
headers += [interface["label"]
for interface in interface.config["output_components"]]
if interface.flagging_options is not None:
headers.append("flag")
if username is not None:
headers.append("username")
headers.append("timestamp")
if not output_only_mode:
if interface.flagging_options is not None:
headers.append("flag")
if username is not None:
headers.append("username")
headers.append("timestamp")
def replace_flag_at_index(file_content):
file_content = io.StringIO(file_content)

View File

@ -14,8 +14,7 @@ import pandas as pd
from ffmpy import FFmpeg
import math
import tempfile
from pathlib import Path
import csv
class InputComponent(Component):
"""
@ -526,7 +525,7 @@ class CheckboxGroup(InputComponent):
"""
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
def generate_sample(self):
@ -902,6 +901,9 @@ class Video(InputComponent):
"optional": self.optional,
**super().get_template_context()
}
def preprocess_example(self, x):
return {"name": x, "data": None, "is_example": True}
def preprocess(self, x):
"""
@ -935,9 +937,6 @@ class Video(InputComponent):
def serialize(self, x, called_directly):
raise NotImplementedError()
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="video")
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to video file
@ -986,6 +985,9 @@ class Audio(InputComponent):
"mic": {"source": "microphone"}
}
def preprocess_example(self, x):
return {"name": x, "data": None, "is_example": True}
def preprocess(self, x):
"""
Parameters:
@ -1017,9 +1019,6 @@ class Audio(InputComponent):
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'filepath'.")
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio")
def serialize(self, x, called_directly):
if x is None:
return None
@ -1037,7 +1036,7 @@ class Audio(InputComponent):
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'filepath'.")
file_data = processing_utils.encode_url_or_file_to_base64(name, type="audio")
file_data = processing_utils.encode_url_or_file_to_base64(name)
return {"name": name, "data": file_data, "is_example": False}
def set_interpret_parameters(self, segments=8):
@ -1067,8 +1066,7 @@ class Audio(InputComponent):
leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False)
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
out_data = processing_utils.encode_file_to_base64(
file.name, type="audio", ext="wav")
out_data = processing_utils.encode_file_to_base64(file.name)
leave_one_out_sets.append(out_data)
# Handle the tokens
token = np.copy(data)
@ -1076,8 +1074,7 @@ class Audio(InputComponent):
token[stop:] = 0
file = tempfile.NamedTemporaryFile(delete=False)
processing_utils.audio_to_file(sample_rate, token, file.name)
token_data = processing_utils.encode_file_to_base64(
file.name, type="audio", ext="wav")
token_data = processing_utils.encode_file_to_base64(file.name)
tokens.append(token_data)
return tokens, leave_one_out_sets, masks
@ -1101,8 +1098,7 @@ class Audio(InputComponent):
masked_input = masked_input + t*int(b)
file = tempfile.NamedTemporaryFile(delete=False)
processing_utils.audio_to_file(sample_rate, masked_input, file_obj.name)
masked_data = processing_utils.encode_file_to_base64(
file.name, type="audio", ext="wav")
masked_data = processing_utils.encode_file_to_base64(file.name)
masked_inputs.append(masked_data)
return masked_inputs
@ -1160,6 +1156,9 @@ class File(InputComponent):
"files": {"file_count": "multiple"},
}
def preprocess_example(self, x):
return {"name": x, "data": None, "is_example": True}
def preprocess(self, x):
"""
Parameters:
@ -1287,7 +1286,7 @@ class Dataframe(InputComponent):
"""
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
def generate_sample(self):
@ -1330,6 +1329,9 @@ class Timeseries(InputComponent):
"timeseries": {},
}
def preprocess_example(self, x):
return {"name": x, "is_example": True}
def preprocess(self, x):
"""
Parameters:
@ -1339,7 +1341,10 @@ class Timeseries(InputComponent):
"""
if x is None:
return x
dataframe = pd.DataFrame(data=x["data"], columns=x["headers"])
elif x.get("is_example"):
dataframe = pd.read_csv(x["name"])
else:
dataframe = pd.DataFrame(data=x["data"], columns=x["headers"])
if x.get("range") is not None:
dataframe = dataframe.loc[dataframe[self.x or 0] >= x["range"][0]]
dataframe = dataframe.loc[dataframe[self.x or 0] <= x["range"][1]]
@ -1351,7 +1356,7 @@ class Timeseries(InputComponent):
"""
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
def generate_sample(self):

View File

@ -22,6 +22,7 @@ from gradio.external import load_interface, load_from_pipeline # type: ignore
from gradio.flagging import FlaggingCallback, CSVLogger # type: ignore
from gradio.inputs import get_input_instance, InputComponent # type: ignore
from gradio.outputs import get_output_instance, OutputComponent # type: ignore
from gradio.process_examples import cache_interface_examples
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
@ -374,6 +375,15 @@ class Interface:
processed_input, return_duration=True)
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
avg_durations = []
for i, duration in enumerate(durations):
self.predict_durations[i][0] += duration
self.predict_durations[i][1] += 1
avg_durations.append(self.predict_durations[i][0]
/ self.predict_durations[i][1])
self.config["avg_durations"] = avg_durations
return processed_output, durations
def interpret(
@ -431,7 +441,8 @@ class Interface:
enable_queue: bool = False,
height: int = 500,
width: int = 900,
encrypt: bool = False
encrypt: bool = False,
cache_examples: bool = False
) -> Tuple[flask.Flask, str, str]:
"""
Launches the webserver that serves the UI for the interface.
@ -452,14 +463,14 @@ class Interface:
width (int): The width in pixels of the <iframe> element containing the interface (used if inline=True)
height (int): The height in pixels of the <iframe> element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
Returns:
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
# Set up local flask server
config = self.get_config_file()
self.config = config
self.cache_examples = cache_examples
if auth and not callable(auth) and not isinstance(auth[0], tuple) and not isinstance(auth[0], list):
auth = [auth]
self.auth = auth
@ -470,20 +481,20 @@ class Interface:
self.width = self.width or width # if width is not set in constructor, use the one provided here
if self.encrypt is None:
self.encrypt = encrypt # if encrypt is not set in constructor, use the one provided here
# Request key for encryption
if self.encrypt:
self.encryption_key = encryptor.get_key(
getpass.getpass("Enter key for encryption: "))
# Store parameters
if self.enable_queue is None:
self.enable_queue = enable_queue
# Setup flagging
if self.allow_flagging:
self.flagging_callback.setup(self.flagging_dir)
config = self.get_config_file()
self.config = config
if self.cache_examples:
cache_interface_examples(self)
# Launch local flask server
server_port, path_to_local_server, app, thread, server = networking.start_server(
self, server_name, server_port, self.auth)

View File

@ -2,8 +2,6 @@
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
"""
import csv
import datetime
from flask import Flask, request, session, jsonify, abort, send_file, render_template, redirect
from flask_cachebuster import CacheBuster
from flask_login import LoginManager, login_user, current_user, login_required
@ -28,6 +26,7 @@ from werkzeug.serving import make_server
from gradio import encryptor, queue
from gradio.tunneling import create_tunnel
from gradio.process_examples import load_from_cache, process_example
# By default, the http server will try to open on port 7860. If not available, 7861, 7862, etc.
INITIAL_PORT_VALUE = int(os.getenv('GRADIO_SERVER_PORT', "7860"))
@ -189,35 +188,37 @@ def shutdown():
@app.route("/api/predict/", methods=["POST"])
@login_check
def predict():
raw_input = request.json["data"]
# Capture any errors made and pipe to front end
if app.interface.show_error:
try:
prediction, durations = app.interface.process(raw_input)
except BaseException as error:
traceback.print_exc()
return jsonify({"error": str(error)}), 500
request_data = request.get_json()
flag_index = None
if request_data.get("example_id") != None:
example_id = request_data["example_id"]
if app.interface.cache_examples:
prediction = load_from_cache(app.interface, example_id)
durations = None
else:
prediction, durations = process_example(app.interface, example_id)
else:
prediction, durations = app.interface.process(raw_input)
avg_durations = []
for i, duration in enumerate(durations):
app.interface.predict_durations[i][0] += duration
app.interface.predict_durations[i][1] += 1
avg_durations.append(app.interface.predict_durations[i][0]
/ app.interface.predict_durations[i][1])
app.interface.config["avg_durations"] = avg_durations
output = {"data": prediction, "durations": durations, "avg_durations": avg_durations}
if app.interface.allow_flagging == "auto":
try:
flag_index = app.interface.flagging_handler.flag(raw_input, prediction,
raw_input = request_data["data"]
if app.interface.show_error:
try:
prediction, durations = app.interface.process(raw_input)
except BaseException as error:
traceback.print_exc()
return jsonify({"error": str(error)}), 500
else:
prediction, durations = app.interface.process(raw_input)
if app.interface.allow_flagging == "auto":
flag_index = app.interface.flagging_callback.flag(app.interface, raw_input, prediction,
flag_option=(None if app.interface.flagging_options is None else ""),
username=current_user.id if current_user.is_authenticated else None,
flag_path=os.path.join(app.cwd, app.interface.flagging_dir))
output["flag_index"] = flag_index
except Exception as e:
print(str(e))
pass
return jsonify(output)
output = {
"data": prediction,
"durations": durations,
"avg_durations": app.interface.config.get("avg_durations"),
"flag_index": flag_index
}
return output
def get_types(cls_set, component):
@ -321,9 +322,9 @@ def file(path):
@app.route("/api/queue/push/", methods=["POST"])
@login_check
def queue_push():
data = request.json["data"]
action = request.json["action"]
job_hash, queue_position = queue.push({"data": data}, action)
data = request.get_json()
action = data["action"]
job_hash, queue_position = queue.push(data, action)
return {"hash": job_hash, "queue_position": queue_position}

View File

@ -20,6 +20,7 @@ from types import ModuleType
from ffmpy import FFmpeg
import requests
class OutputComponent(Component):
"""
Output Component. All output components subclass this.
@ -142,11 +143,10 @@ class Label(OutputComponent):
return y['label']
elif self.type == "confidences" or self.type == "auto":
if ('confidences' in y.keys()) and isinstance(y['confidences'], list):
return {k['label']:k['confidence'] for k in y['confidences']}
return {k['label']: k['confidence'] for k in y['confidences']}
else:
return y
raise ValueError("Unable to deserialize output: {}".format(y))
@classmethod
def get_shortcut_implementations(cls):
@ -163,11 +163,11 @@ class Label(OutputComponent):
else:
return data["label"]
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
try:
data = json.loads(data)
return data
except:
return self.postprocess(data)
except ValueError:
return data
@ -240,11 +240,11 @@ class Image(OutputComponent):
return y
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)["data"]
class Video(OutputComponent):
'''
@ -279,7 +279,7 @@ class Video(OutputComponent):
returned_format = y.split(".")[-1].lower()
if self.type is not None and returned_format != self.type:
output_file_name = y[0: y.rindex(
".") + 1] + self.type
".") + 1] + self.type
ff = FFmpeg(
inputs={y: None},
outputs={output_file_name: None}
@ -288,18 +288,18 @@ class Video(OutputComponent):
y = output_file_name
return {
"name": os.path.basename(y),
"data": processing_utils.encode_file_to_base64(y, type="video")
"data": processing_utils.encode_file_to_base64(y)
}
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data['data'], encryption_key)
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)
class KeyValues(OutputComponent):
'''
@ -339,7 +339,7 @@ class KeyValues(OutputComponent):
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
@ -385,7 +385,7 @@ class HighlightedText(OutputComponent):
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
@ -426,10 +426,11 @@ class Audio(OutputComponent):
if self.type in ["numpy", "file", "auto"]:
if self.type == "numpy" or (self.type == "auto" and isinstance(y, tuple)):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(prefix="sample", suffix=".wav", delete=False)
file = tempfile.NamedTemporaryFile(
prefix="sample", suffix=".wav", delete=False)
processing_utils.audio_to_file(sample_rate, data, file.name)
y = file.name
return processing_utils.encode_url_or_file_to_base64(y, type="audio", ext="wav")
return processing_utils.encode_url_or_file_to_base64(y)
else:
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'numpy', 'file'.")
@ -438,11 +439,11 @@ class Audio(OutputComponent):
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to audio file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)["data"]
class JSON(OutputComponent):
'''
@ -479,7 +480,7 @@ class JSON(OutputComponent):
def save_flagged(self, dir, label, data, encryption_key):
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)
@ -543,15 +544,15 @@ class File(OutputComponent):
return {
"name": os.path.basename(y),
"size": os.path.getsize(y),
"data": processing_utils.encode_file_to_base64(y, header=False)
"data": processing_utils.encode_file_to_base64(y)
}
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data["data"], encryption_key)
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)
class Dataframe(OutputComponent):
"""
@ -624,13 +625,10 @@ class Dataframe(OutputComponent):
". Please choose from: 'pandas', 'numpy', 'array'.")
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (List[List[Union[str, float]]]) 2D array
"""
return json.dumps(data["data"])
def restore_flagged(self, data):
return json.loads(data)
def restore_flagged(self, dir, data, encryption_key):
return {"data": json.loads(data)}
class Carousel(OutputComponent):
@ -685,14 +683,22 @@ class Carousel(OutputComponent):
component.save_flagged(
dir, f"{label}_{j}", data[i][j], encryption_key)
for j, component in enumerate(self.components)
] for i, sample in enumerate(data)])
] for i, _ in enumerate(data)])
def restore_flagged(self, dir, data, encryption_key):
return [
[
component.restore_flagged(dir, sample, encryption_key)
for component, sample in zip(self.components, sample_set)
] for sample_set in json.loads(data)]
def get_output_instance(iface):
if isinstance(iface, str):
shortcut = OutputComponent.get_all_shortcut_implementations()[iface]
return shortcut[0](**shortcut[1])
elif isinstance(iface, dict): # a dict with `name` as the output component type and other keys as parameters
# a dict with `name` as the output component type and other keys as parameters
elif isinstance(iface, dict):
name = iface.pop('name')
for component in OutputComponent.__subclasses__():
if component.__name__.lower() == name:
@ -761,5 +767,5 @@ class Timeseries(OutputComponent):
"""
return json.dumps(data)
def restore_flagged(self, data):
def restore_flagged(self, dir, data, encryption_key):
return json.loads(data)

View File

@ -0,0 +1,38 @@
import os, shutil
from gradio.flagging import CSVLogger
from typing import Any, List
import csv
CACHED_FOLDER = "gradio_cached_examples"
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
def process_example(interface, example_id: int):
example_set = interface.examples[example_id]
raw_input = [interface.input_components[i].preprocess_example(example) for i, example in enumerate(example_set)]
prediction, durations = interface.process(raw_input)
return prediction, durations
def cache_interface_examples(interface) -> None:
if os.path.exists(CACHE_FILE):
print(f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache.")
else:
print(f"Cache at {os.path.abspath(CACHE_FILE)} not found. Caching now in '{CACHED_FOLDER}/' directory.")
cache_logger = CSVLogger()
cache_logger.setup(CACHED_FOLDER)
for example_id, _ in enumerate(interface.examples):
try:
prediction = process_example(interface, example_id)[0]
cache_logger.flag(interface, None, prediction)
except Exception as e:
shutil.rmtree(CACHED_FOLDER)
raise e
def load_from_cache(interface, example_id: int) -> List[Any]:
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(interface.output_components, example):
output.append(component.restore_flagged(
CACHED_FOLDER, cell, interface.encryption_key if interface.encrypt else None))
return output

View File

@ -8,6 +8,7 @@ import os
import numpy as np
from gradio import encryptor
import warnings
import mimetypes
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
from pydub import AudioSegment
@ -29,33 +30,40 @@ def get_url_or_file_as_bytes(path):
return f.read()
def encode_url_or_file_to_base64(path, type="image", ext=None, header=True):
def encode_url_or_file_to_base64(path):
try:
requests.get(path)
return encode_url_to_base64(path, type, ext, header)
return encode_url_to_base64(path)
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return encode_file_to_base64(path, type, ext, header)
return encode_file_to_base64(path)
def get_mimetype(filename):
mimetype = mimetypes.guess_type(filename)[0]
if mimetype is not None:
mimetype = mimetype.replace("x-wav", "wav")
return mimetype
def encode_file_to_base64(f, type="image", ext=None, header=True):
def get_extension(encoding):
encoding = encoding.replace("audio/wav", "audio/x-wav")
extension = mimetypes.guess_extension(mimetypes.guess_type(
encoding)[0])
return extension
def encode_file_to_base64(f, encryption_key=None):
with open(f, "rb") as file:
encoded_string = base64.b64encode(file.read())
if encryption_key:
encoded_string = encryptor.decrypt(encryption_key, encoded_string)
base64_str = str(encoded_string, 'utf-8')
if not header:
return base64_str
if ext is None:
ext = f.split(".")[-1]
return "data:" + type + "/" + ext + ";base64," + base64_str
mimetype = get_mimetype(f)
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
def encode_url_to_base64(url, type="image", ext=None, header=True):
def encode_url_to_base64(url):
encoded_string = base64.b64encode(requests.get(url).content)
base64_str = str(encoded_string, 'utf-8')
if not header:
return base64_str
if ext is None:
ext = url.split(".")[-1]
return "data:" + type + "/" + ext + ";base64," + base64_str
mimetype = get_mimetype(url)
return "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str
def encode_plot_to_base64(plt):
@ -121,30 +129,21 @@ def audio_to_file(sample_rate, data, filename):
# OUTPUT
##################
def decode_base64_to_binary(encoding):
extension = None
if "," in encoding:
header, data = encoding.split(",")
header = header[5:]
if ";base64" in header:
header = header[0:header.index(";base64")]
if "/" in header:
extension = header[header.index("/") + 1:]
else:
data = encoding
return base64.b64decode(data), extension
extension = get_extension(encoding)
data = encoding.split(",")[1]
return base64.b64decode(data), extension
def decode_base64_to_file(encoding, encryption_key=None, file_path=None):
data, mime_extension = decode_base64_to_binary(encoding)
prefix, extension = None, None
data, extension = decode_base64_to_binary(encoding)
prefix = None
if file_path is not None:
filename = os.path.basename(file_path)
prefix = filename
if "." in filename:
prefix = filename[0: filename.index(".")]
extension = filename[filename.index(".") + 1:]
if extension is None:
extension = mime_extension
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix)
else:

View File

@ -1,6 +1,6 @@
{
"files": {
"main.css": "/static/css/main.bd8d8c8b.css",
"main.css": "/static/css/main.8fe32992.css",
"main.js": "/static/bundle.js",
"index.html": "/index.html",
"static/media/api-logo.b3893a52.svg": "/static/media/api-logo.b3893a52.svg",
@ -13,7 +13,7 @@
},
"entrypoints": [
"static/bundle.css",
"static/css/main.bd8d8c8b.css",
"static/css/main.8fe32992.css",
"static/bundle.js"
]
}

View File

@ -1,11 +1 @@
<!doctype html><html lang="en" style="height:100%;margin:0;padding:0"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,shrink-to-fit=no"><title>{{ config['title'] or 'Gradio' }}</title><meta property="og:url" content="https://gradio.app/"/><meta property="og:type" content="website"/><meta property="og:image" content="{{ config['thumbnail'] or '' }}"/><meta property="og:title" content="{{ config['title'] or '' }}"/><meta property="og:description" content="{{ config['description'] or '' }}"/><meta name="twitter:card" content="summary_large_image"><meta name="twitter:creator" content="@teamGradio"><meta name="twitter:title" content="{{ config['title'] or '' }}"><meta name="twitter:description" content="{{ config['description'] or '' }}"><meta name="twitter:image" content="{{ config['thumbnail'] or '' }}"><script async src="https://www.googletagmanager.com/gtag/js?id=UA-156449732-1"></script><script>window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'UA-156449732-1');
window.gradio_mode = "app";
try {
window.gradio_config = {{ config|tojson }};
} catch (e) {
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.bd8d8c8b.css" rel="stylesheet"></head><body style="height:100%;margin:0;padding:0"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
<!doctype html><html lang="en" style="height:100%;margin:0;padding:0"><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1,shrink-to-fit=no"><title>{{ config['title'] or 'Gradio' }}</title><meta property="og:url" content="https://gradio.app/"/><meta property="og:type" content="website"/><meta property="og:image" content="{{ config['thumbnail'] or '' }}"/><meta property="og:title" content="{{ config['title'] or '' }}"/><meta property="og:description" content="{{ config['description'] or '' }}"/><meta name="twitter:card" content="summary_large_image"><meta name="twitter:creator" content="@teamGradio"><meta name="twitter:title" content="{{ config['title'] or '' }}"><meta name="twitter:description" content="{{ config['description'] or '' }}"><meta name="twitter:image" content="{{ config['thumbnail'] or '' }}"><script async src="https://www.googletagmanager.com/gtag/js?id=UA-156449732-1"></script><script>function gtag(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],gtag("js",new Date),gtag("config","UA-156449732-1"),window.gradio_mode="app"</script><script>window.gradio_config = {{ config|tojson }};</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.8fe32992.css" rel="stylesheet"></head><body style="height:100%;margin:0;padding:0"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>

View File

@ -167,6 +167,7 @@ def get_config_file(interface):
"flagging_options": interface.flagging_options,
"allow_interpretation": interface.interpretation is not None,
"queue": interface.enable_queue,
"cached_examples": interface.cache_examples,
"version": pkg_resources.require("gradio")[0].version
}
try:
@ -208,7 +209,8 @@ def get_config_file(interface):
examples = examples[1:] # remove header
for i, example in enumerate(examples):
for j, (component, cell) in enumerate(zip(interface.input_components + interface.output_components, example)):
examples[i][j] = component.restore_flagged(cell)
examples[i][j] = component.restore_flagged(
interface, interface.flagging_dir, cell, interface.encryption_key if interface.encrypt else None)
config["examples"] = examples
config["examples_dir"] = interface.examples
else: