mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
Merge branch 'master' into aliabd/components-tests
This commit is contained in:
commit
7ce37829e0
@ -35,7 +35,7 @@ module.exports = {
|
|||||||
};
|
};
|
||||||
paths.appBuild = webpackConfig.output.path;
|
paths.appBuild = webpackConfig.output.path;
|
||||||
return webpackConfig;
|
return webpackConfig;
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
style: {
|
style: {
|
||||||
postcss: {
|
postcss: {
|
||||||
|
26719
frontend/package-lock.json
generated
26719
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -33,9 +33,9 @@
|
|||||||
"webpack": "^4.44.2"
|
"webpack": "^4.44.2"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "REACT_APP_BACKEND_URL='http://localhost:7860/' craco start",
|
"start": "cross-env REACT_APP_BACKEND_URL='http://localhost:7860/' craco start",
|
||||||
"format": "prettier-eslint --write '**/*.js*'",
|
"format": "prettier-eslint --write '**/*.js*'",
|
||||||
"build": "REACT_APP_BACKEND_URL='' GENERATE_SOURCEMAP=false craco build",
|
"build": "cross-env REACT_APP_BACKEND_URL='' GENERATE_SOURCEMAP=false craco build",
|
||||||
"eject": "react-scripts eject"
|
"eject": "react-scripts eject"
|
||||||
},
|
},
|
||||||
"eslintConfig": {
|
"eslintConfig": {
|
||||||
@ -57,6 +57,7 @@
|
|||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"autoprefixer": "^9.8.6",
|
"autoprefixer": "^9.8.6",
|
||||||
|
"cross-env": "^7.0.3",
|
||||||
"eslint": "^7.32.0",
|
"eslint": "^7.32.0",
|
||||||
"mini-css-extract-plugin": "^0.11.3",
|
"mini-css-extract-plugin": "^0.11.3",
|
||||||
"postcss": "^7.0.36",
|
"postcss": "^7.0.36",
|
||||||
|
@ -77,16 +77,16 @@ class AudioInput extends BaseComponent {
|
|||||||
data: this.props.value["data"],
|
data: this.props.value["data"],
|
||||||
is_example: this.props.value["is_example"],
|
is_example: this.props.value["is_example"],
|
||||||
crop_min: crop_min,
|
crop_min: crop_min,
|
||||||
crop_max: crop_max,
|
crop_max: crop_max
|
||||||
});
|
});
|
||||||
this.setState({ editorMode: !this.state.editorMode })
|
this.setState({ editorMode: !this.state.editorMode });
|
||||||
}
|
};
|
||||||
crop = (min, max, lastChange) => {
|
crop = (min, max, lastChange) => {
|
||||||
if (this.state.duration) {
|
if (this.state.duration) {
|
||||||
if (lastChange === "min") {
|
if (lastChange === "min") {
|
||||||
this.audioRef.current.currentTime = (min / 100.) * this.state.duration;
|
this.audioRef.current.currentTime = (min / 100) * this.state.duration;
|
||||||
} else {
|
} else {
|
||||||
this.audioRef.current.currentTime = (max / 100.) * this.state.duration;
|
this.audioRef.current.currentTime = (max / 100) * this.state.duration;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.props.handleChange({
|
this.props.handleChange({
|
||||||
@ -94,20 +94,24 @@ class AudioInput extends BaseComponent {
|
|||||||
data: this.props.value["data"],
|
data: this.props.value["data"],
|
||||||
is_example: this.props.value["is_example"],
|
is_example: this.props.value["is_example"],
|
||||||
crop_min: min,
|
crop_min: min,
|
||||||
crop_max: max,
|
crop_max: max
|
||||||
})
|
});
|
||||||
}
|
};
|
||||||
reset_playback_within_crop = () => {
|
reset_playback_within_crop = () => {
|
||||||
let position_ratio = this.audioRef.current.currentTime / this.state.duration;
|
let position_ratio =
|
||||||
|
this.audioRef.current.currentTime / this.state.duration;
|
||||||
let min_ratio = this.props.value.crop_min / 100;
|
let min_ratio = this.props.value.crop_min / 100;
|
||||||
let max_ratio = this.props.value.crop_max / 100;
|
let max_ratio = this.props.value.crop_max / 100;
|
||||||
if ((position_ratio > max_ratio - 0.00001) || (position_ratio < min_ratio - 0.00001)) {
|
if (
|
||||||
|
position_ratio > max_ratio - 0.00001 ||
|
||||||
|
position_ratio < min_ratio - 0.00001
|
||||||
|
) {
|
||||||
this.audioRef.current.currentTime = this.state.duration * min_ratio;
|
this.audioRef.current.currentTime = this.state.duration * min_ratio;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
render() {
|
render() {
|
||||||
if (this.props.value !== null) {
|
if (this.props.value !== null) {
|
||||||
if (
|
if (
|
||||||
@ -124,17 +128,27 @@ class AudioInput extends BaseComponent {
|
|||||||
<div className="input_audio">
|
<div className="input_audio">
|
||||||
<div className="edit_buttons">
|
<div className="edit_buttons">
|
||||||
<button
|
<button
|
||||||
className={classNames("edit_button", { "active": this.state.editorMode })}
|
className={classNames("edit_button", {
|
||||||
|
active: this.state.editorMode
|
||||||
|
})}
|
||||||
onClick={this.toggleEditor}
|
onClick={this.toggleEditor}
|
||||||
>
|
>
|
||||||
<img src={edit_icon} />
|
<img src={edit_icon} />
|
||||||
</button>
|
</button>
|
||||||
<button className="clear_button" onClick={this.props.handleChange.bind(this, null)}>
|
<button
|
||||||
|
className="clear_button"
|
||||||
|
onClick={this.props.handleChange.bind(this, null)}
|
||||||
|
>
|
||||||
<img src={clear_icon} />
|
<img src={clear_icon} />
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<audio controls key={this.key} ref={this.audioRef}
|
<audio
|
||||||
onLoadedMetadata={e => this.setState({ duration: e.nativeEvent.target.duration })}
|
controls
|
||||||
|
key={this.key}
|
||||||
|
ref={this.audioRef}
|
||||||
|
onLoadedMetadata={(e) =>
|
||||||
|
this.setState({ duration: e.nativeEvent.target.duration })
|
||||||
|
}
|
||||||
onPlay={() => {
|
onPlay={() => {
|
||||||
this.reset_playback_within_crop();
|
this.reset_playback_within_crop();
|
||||||
this.audioRef.current.play();
|
this.audioRef.current.play();
|
||||||
@ -162,9 +176,17 @@ class AudioInput extends BaseComponent {
|
|||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{this.state.editorMode ?
|
{this.state.editorMode ? (
|
||||||
<MultiRangeSlider min={0} max={100} onChange={({ min, max, lastChange }) => this.crop(min, max, lastChange)} />
|
<MultiRangeSlider
|
||||||
: false}
|
min={0}
|
||||||
|
max={100}
|
||||||
|
onChange={({ min, max, lastChange }) =>
|
||||||
|
this.crop(min, max, lastChange)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
false
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
|
@ -51,7 +51,10 @@ class FileInput extends BaseComponent {
|
|||||||
return (
|
return (
|
||||||
<div className="input_file">
|
<div className="input_file">
|
||||||
<div className="file_preview_holder">
|
<div className="file_preview_holder">
|
||||||
<button className="clear_button" onClick={this.handleChange.bind(this, null)}>
|
<button
|
||||||
|
className="clear_button"
|
||||||
|
onClick={this.handleChange.bind(this, null)}
|
||||||
|
>
|
||||||
<img src={clear_icon} />
|
<img src={clear_icon} />
|
||||||
</button>
|
</button>
|
||||||
<div className="file_name">{file_name}</div>
|
<div className="file_name">{file_name}</div>
|
||||||
|
@ -154,8 +154,11 @@ class ImageInput extends BaseComponent {
|
|||||||
<button className="edit_button" onClick={this.openEditor}>
|
<button className="edit_button" onClick={this.openEditor}>
|
||||||
<img src={edit_icon} />
|
<img src={edit_icon} />
|
||||||
</button>
|
</button>
|
||||||
<button className="clear_button" onClick={this.handleChange.bind(this, null)}>
|
<button
|
||||||
<img src={clear_icon} />
|
className="clear_button"
|
||||||
|
onClick={this.handleChange.bind(this, null)}
|
||||||
|
>
|
||||||
|
<img src={clear_icon} />
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -27,20 +27,29 @@ class VideoInput extends BaseComponent {
|
|||||||
evt.stopPropagation();
|
evt.stopPropagation();
|
||||||
};
|
};
|
||||||
if (this.props.value != null) {
|
if (this.props.value != null) {
|
||||||
return <div className="input_video">
|
return (
|
||||||
<div className="edit_buttons">
|
<div className="input_video">
|
||||||
<button className="clear_button" onClick={this.props.handleChange.bind(this, null)}>
|
<div className="edit_buttons">
|
||||||
<img src={clear_icon} />
|
<button
|
||||||
</button>
|
className="clear_button"
|
||||||
|
onClick={this.props.handleChange.bind(this, null)}
|
||||||
|
>
|
||||||
|
<img src={clear_icon} />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
{isPlayable("video", this.props.value["name"]) ? (
|
||||||
|
<div className="video_preview_holder">
|
||||||
|
<video
|
||||||
|
className="video_preview"
|
||||||
|
controls
|
||||||
|
src={this.props.value["data"]}
|
||||||
|
></video>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="video_file_holder">{this.props.value["name"]}</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
{isPlayable("video", this.props.value["name"]) ? <div className="video_preview_holder">
|
);
|
||||||
<video
|
|
||||||
className="video_preview"
|
|
||||||
controls
|
|
||||||
src={this.props.value["data"]}
|
|
||||||
></video>
|
|
||||||
</div> : <div className="video_file_holder">{this.props.value["name"]}</div>}
|
|
||||||
</div>
|
|
||||||
} else {
|
} else {
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
@ -28,16 +28,21 @@ class LabelOutput extends BaseComponent {
|
|||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
return (
|
||||||
|
<div className="output_label">
|
||||||
|
<div className="output_class">{this.props.value["label"]}</div>
|
||||||
|
<div className="confidence_intervals">
|
||||||
|
<div className="labels" style={{ maxWidth: "120px" }}>
|
||||||
|
{labels}
|
||||||
|
</div>
|
||||||
|
<div className="confidences">{confidences}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<div className="output_label">
|
<div className="output_label">
|
||||||
<div className="output_class">{this.props.value["label"]}</div>
|
<div className="output_class_without_confidences">{this.props.value["label"]}</div>
|
||||||
<div className="confidence_intervals">
|
|
||||||
<div className="labels" style={{ maxWidth: "120px" }}>
|
|
||||||
{labels}
|
|
||||||
</div>
|
|
||||||
<div className="confidences">{confidences}</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -39,19 +39,22 @@ export class GradioPage extends React.Component {
|
|||||||
false
|
false
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<a href="/api/" target="_blank" class="footer" rel="noreferrer">
|
<div className="footer">
|
||||||
<span>view the api </span><img class="logo" src="https://i.ibb.co/6DVLqmf/noun-tools-2220412.png" alt="logo"/>
|
<a
|
||||||
<span> |</span>
|
href="/api/"
|
||||||
<a
|
target="_blank"
|
||||||
|
className="footer"
|
||||||
|
rel="noreferrer">
|
||||||
|
view the api <img className="logo" src="https://i.ibb.co/6DVLqmf/noun-tools-2220412.png" alt="api"/>
|
||||||
|
</a> | <a
|
||||||
href="https://gradio.app"
|
href="https://gradio.app"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
className="footer"
|
className="footer"
|
||||||
rel="noreferrer"
|
rel="noreferrer"
|
||||||
>
|
> built with
|
||||||
<span> built with</span>
|
|
||||||
<img className="logo" src={logo} alt="logo" />
|
<img className="logo" src={logo} alt="logo" />
|
||||||
</a>
|
</a>
|
||||||
</a>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@ -170,8 +173,13 @@ export class GradioInterface extends React.Component {
|
|||||||
if (this.state.flag_index !== undefined) {
|
if (this.state.flag_index !== undefined) {
|
||||||
component_state["flag_index"] = this.state.flag_index;
|
component_state["flag_index"] = this.state.flag_index;
|
||||||
} else {
|
} else {
|
||||||
for (let i = 0; i < this.props.input_components.length; i++) {
|
for (let [i, input_component] of this.props.input_components.entries()) {
|
||||||
component_state["input_data"].push(this.state[i]);
|
let InputComponentClass = input_component_set.find(
|
||||||
|
(c) => c.name === input_component.name
|
||||||
|
).component;
|
||||||
|
component_state["input_data"].push(
|
||||||
|
InputComponentClass.postprocess(this.state[i])
|
||||||
|
);
|
||||||
}
|
}
|
||||||
for (let i = 0; i < this.props.output_components.length; i++) {
|
for (let i = 0; i < this.props.output_components.length; i++) {
|
||||||
component_state["output_data"].push(
|
component_state["output_data"].push(
|
||||||
@ -192,11 +200,17 @@ export class GradioInterface extends React.Component {
|
|||||||
}
|
}
|
||||||
this.pending_response = true;
|
this.pending_response = true;
|
||||||
let input_state = [];
|
let input_state = [];
|
||||||
for (let i = 0; i < this.props.input_components.length; i++) {
|
for (let [i, input_component] of this.props.input_components.entries()) {
|
||||||
if (this.state[i] === null) {
|
if (
|
||||||
|
this.state[i] === null &&
|
||||||
|
this.props.input_components[i].optional !== true
|
||||||
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
input_state[i] = this.state[i];
|
let InputComponentClass = input_component_set.find(
|
||||||
|
(c) => c.name === input_component.name
|
||||||
|
).component;
|
||||||
|
input_state[i] = InputComponentClass.postprocess(this.state[i]);
|
||||||
}
|
}
|
||||||
this.setState({ submitting: true, has_changed: false, error: false });
|
this.setState({ submitting: true, has_changed: false, error: false });
|
||||||
this.props
|
this.props
|
||||||
@ -339,7 +353,6 @@ export class GradioInterface extends React.Component {
|
|||||||
<button className="panel_button submit" onClick={this.submit}>
|
<button className="panel_button submit" onClick={this.submit}>
|
||||||
Submit
|
Submit
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -86,8 +86,7 @@ function load_config(config) {
|
|||||||
if (config.theme !== null && config.theme.startsWith("dark")) {
|
if (config.theme !== null && config.theme.startsWith("dark")) {
|
||||||
target.classList.add("dark");
|
target.classList.add("dark");
|
||||||
config.theme = config.theme.substring(4);
|
config.theme = config.theme.substring(4);
|
||||||
}
|
} else if (url.searchParams.get("__dark-theme") === "true") {
|
||||||
else if (url.searchParams.get("__dark-theme") === "true") {
|
|
||||||
target.classList.add("dark");
|
target.classList.add("dark");
|
||||||
}
|
}
|
||||||
ReactDOM.render(
|
ReactDOM.render(
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
@apply flex-grow flex-shrink-0;
|
@apply flex-grow flex-shrink-0;
|
||||||
}
|
}
|
||||||
.footer {
|
.footer {
|
||||||
@apply flex-shrink-0 flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-center py-2;
|
@apply flex-shrink-0 inline-flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-center py-2;
|
||||||
}
|
}
|
||||||
.api {
|
.api {
|
||||||
@apply flex-shrink-0 flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-end py-2;
|
@apply flex-shrink-0 inline-flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-end py-2;
|
||||||
}
|
}
|
||||||
.logo {
|
.logo {
|
||||||
@apply h-6;
|
@apply h-6;
|
||||||
@ -512,7 +512,8 @@
|
|||||||
}
|
}
|
||||||
.output_label {
|
.output_label {
|
||||||
@apply dark:text-gray-50;
|
@apply dark:text-gray-50;
|
||||||
.output_class {
|
.output_class,
|
||||||
|
.output_class_without_confidences {
|
||||||
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
|
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
|
||||||
}
|
}
|
||||||
.confidence_intervals {
|
.confidence_intervals {
|
||||||
|
@ -459,8 +459,10 @@ html {
|
|||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
@apply w-full bg-white dark:text-blue-900 dark:font-semibold dark:bg-gray-200 border-gray-400 box-border p-1 whitespace-pre-wrap;
|
@apply w-full bg-white dark:text-blue-900 dark:font-semibold dark:bg-gray-200 border-gray-400 box-border p-1 whitespace-pre-wrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
.output_label {
|
.output_label {
|
||||||
.output_class {
|
.output_class,
|
||||||
|
.output_class_without_confidences {
|
||||||
@apply font-bold text-xl py-6 px-4 flex-grow flex items-center justify-center;
|
@apply font-bold text-xl py-6 px-4 flex-grow flex items-center justify-center;
|
||||||
}
|
}
|
||||||
.confidence_intervals {
|
.confidence_intervals {
|
||||||
|
@ -456,6 +456,9 @@
|
|||||||
.output_class {
|
.output_class {
|
||||||
@apply hidden;
|
@apply hidden;
|
||||||
}
|
}
|
||||||
|
.output_class_without_confidences{
|
||||||
|
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center dark:text-gray-50;
|
||||||
|
}
|
||||||
.confidence_intervals {
|
.confidence_intervals {
|
||||||
@apply flex text-xl;
|
@apply flex text-xl;
|
||||||
}
|
}
|
||||||
|
@ -464,7 +464,8 @@
|
|||||||
}
|
}
|
||||||
.output_label {
|
.output_label {
|
||||||
@apply dark:text-gray-50;
|
@apply dark:text-gray-50;
|
||||||
.output_class {
|
.output_class,
|
||||||
|
.output_class_without_confidences {
|
||||||
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
|
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
|
||||||
}
|
}
|
||||||
.confidence_intervals {
|
.confidence_intervals {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
Metadata-Version: 1.0
|
Metadata-Version: 1.0
|
||||||
Name: gradio
|
Name: gradio
|
||||||
Version: 2.4.0
|
Version: 2.4.1
|
||||||
Summary: Python library for easily interacting with trained machine learning models
|
Summary: Python library for easily interacting with trained machine learning models
|
||||||
Home-page: https://github.com/gradio-app/gradio-UI
|
Home-page: https://github.com/gradio-app/gradio-UI
|
||||||
Author: Abubakar Abid
|
Author: Abubakar Abid
|
||||||
|
@ -46,7 +46,10 @@ class Component():
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def save_flagged_file(self, dir, label, data, encryption_key):
|
def save_flagged_file(self, dir, label, data, encryption_key):
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
file = processing_utils.decode_base64_to_file(data, encryption_key)
|
file = processing_utils.decode_base64_to_file(data, encryption_key)
|
||||||
|
label = "".join([char for char in label if char.isalnum() or char in "._- "])
|
||||||
old_file_name = file.name
|
old_file_name = file.name
|
||||||
output_dir = os.path.join(dir, label)
|
output_dir = os.path.join(dir, label)
|
||||||
if os.path.exists(output_dir):
|
if os.path.exists(output_dir):
|
||||||
|
@ -69,7 +69,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
|||||||
},
|
},
|
||||||
'fill-mask': {
|
'fill-mask': {
|
||||||
'inputs': inputs.Textbox(label="Input"),
|
'inputs': inputs.Textbox(label="Input"),
|
||||||
'outputs': "label",
|
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||||
'preprocess': lambda x: {"inputs": x},
|
'preprocess': lambda x: {"inputs": x},
|
||||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
|
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
|
||||||
},
|
},
|
||||||
|
@ -939,7 +939,7 @@ class Video(InputComponent):
|
|||||||
"""
|
"""
|
||||||
Returns: (str) path to video file
|
Returns: (str) path to video file
|
||||||
"""
|
"""
|
||||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
return self.save_flagged_file(dir, label, None if data is None else data["data"], encryption_key)
|
||||||
|
|
||||||
def generate_sample(self):
|
def generate_sample(self):
|
||||||
return test_data.BASE64_VIDEO
|
return test_data.BASE64_VIDEO
|
||||||
@ -1113,7 +1113,7 @@ class Audio(InputComponent):
|
|||||||
"""
|
"""
|
||||||
Returns: (str) path to audio file
|
Returns: (str) path to audio file
|
||||||
"""
|
"""
|
||||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
return self.save_flagged_file(dir, label, None if data is None else data["data"], encryption_key)
|
||||||
|
|
||||||
def generate_sample(self):
|
def generate_sample(self):
|
||||||
return test_data.BASE64_AUDIO
|
return test_data.BASE64_AUDIO
|
||||||
@ -1192,7 +1192,7 @@ class File(InputComponent):
|
|||||||
"""
|
"""
|
||||||
Returns: (str) path to file
|
Returns: (str) path to file
|
||||||
"""
|
"""
|
||||||
return self.save_flagged_file(dir, label, data["data"], encryption_key)
|
return self.save_flagged_file(dir, label, None if data is None else data[0]["data"], encryption_key)
|
||||||
|
|
||||||
def generate_sample(self):
|
def generate_sample(self):
|
||||||
return test_data.BASE64_FILE
|
return test_data.BASE64_FILE
|
||||||
|
@ -490,12 +490,6 @@ class Interface:
|
|||||||
interpretation = [interpretation]
|
interpretation = [interpretation]
|
||||||
return interpretation, []
|
return interpretation, []
|
||||||
|
|
||||||
def close(self):
|
|
||||||
# checks to see if server is running
|
|
||||||
if self.simple_server and not (self.simple_server.fileno() == -1):
|
|
||||||
print("Closing Gradio server on port {}...".format(self.server_port))
|
|
||||||
networking.close_server(self.simple_server)
|
|
||||||
|
|
||||||
def run_until_interrupted(self, thread, path_to_local_server):
|
def run_until_interrupted(self, thread, path_to_local_server):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
@ -590,15 +584,15 @@ class Interface:
|
|||||||
self.share = share
|
self.share = share
|
||||||
|
|
||||||
if share:
|
if share:
|
||||||
if private_endpoint:
|
|
||||||
print(strings.en["PRIVATE_LINK_MESSAGE"])
|
|
||||||
else:
|
|
||||||
print(strings.en["SHARE_LINK_MESSAGE"])
|
|
||||||
try:
|
try:
|
||||||
share_url = networking.setup_tunnel(
|
share_url = networking.setup_tunnel(
|
||||||
server_port, private_endpoint)
|
server_port, private_endpoint)
|
||||||
self.share_url = share_url
|
self.share_url = share_url
|
||||||
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
|
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
|
||||||
|
if private_endpoint:
|
||||||
|
print(strings.en["PRIVATE_LINK_MESSAGE"])
|
||||||
|
else:
|
||||||
|
print(strings.en["SHARE_LINK_MESSAGE"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
send_error_analytics(self.analytics_enabled)
|
send_error_analytics(self.analytics_enabled)
|
||||||
share_url = None
|
share_url = None
|
||||||
@ -647,6 +641,19 @@ class Interface:
|
|||||||
|
|
||||||
return app, path_to_local_server, share_url
|
return app, path_to_local_server, share_url
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
try:
|
||||||
|
if self.share_url:
|
||||||
|
requests.get("{}/shutdown".format(self.share_url))
|
||||||
|
print("Closing Gradio server on port {}...".format(self.server_port))
|
||||||
|
elif self.local_url:
|
||||||
|
requests.get("{}shutdown".format(self.local_url))
|
||||||
|
print("Closing Gradio server on port {}...".format(self.server_port))
|
||||||
|
else:
|
||||||
|
pass # server not running
|
||||||
|
except (requests.ConnectionError, ConnectionResetError):
|
||||||
|
pass # server is already closed
|
||||||
|
|
||||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
|
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
|
||||||
analytics_integration = ""
|
analytics_integration = ""
|
||||||
if comet_ml is not None:
|
if comet_ml is not None:
|
||||||
@ -738,6 +745,8 @@ def send_launch_analytics(analytics_enabled, inbrowser, is_colab, share, share_u
|
|||||||
pass # do not push analytics if no network
|
pass # do not push analytics if no network
|
||||||
|
|
||||||
|
|
||||||
def reset_all():
|
def close_all():
|
||||||
for io in Interface.get_instances():
|
for io in Interface.get_instances():
|
||||||
io.close()
|
io.close()
|
||||||
|
|
||||||
|
reset_all = close_all # for backwards compatibility
|
||||||
|
@ -174,6 +174,15 @@ def enable_sharing(path):
|
|||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/shutdown", methods=['GET'])
|
||||||
|
def shutdown():
|
||||||
|
shutdown_func = request.environ.get('werkzeug.server.shutdown')
|
||||||
|
if shutdown_func is None:
|
||||||
|
raise RuntimeError('Not running werkzeug')
|
||||||
|
shutdown_func()
|
||||||
|
return "Shutting down..."
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/predict/", methods=["POST"])
|
@app.route("/api/predict/", methods=["POST"])
|
||||||
@login_check
|
@login_check
|
||||||
def predict():
|
def predict():
|
||||||
@ -383,6 +392,7 @@ def queue_push():
|
|||||||
job_hash, queue_position = queue.push({"data": data}, action)
|
job_hash, queue_position = queue.push({"data": data}, action)
|
||||||
return {"hash": job_hash, "queue_position": queue_position}
|
return {"hash": job_hash, "queue_position": queue_position}
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/queue/status/", methods=["POST"])
|
@app.route("/api/queue/status/", methods=["POST"])
|
||||||
@login_check
|
@login_check
|
||||||
def queue_status():
|
def queue_status():
|
||||||
@ -390,15 +400,19 @@ def queue_status():
|
|||||||
status, data = queue.get_status(hash)
|
status, data = queue.get_status(hash)
|
||||||
return {"status": status, "data": data}
|
return {"status": status, "data": data}
|
||||||
|
|
||||||
def queue_thread(path_to_local_server):
|
|
||||||
|
def queue_thread(path_to_local_server, test_mode=False):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
next_job = queue.pop()
|
next_job = queue.pop()
|
||||||
|
print(next_job)
|
||||||
if next_job is not None:
|
if next_job is not None:
|
||||||
_, hash, input_data, task_type = next_job
|
_, hash, input_data, task_type = next_job
|
||||||
|
print(hash)
|
||||||
queue.start_job(hash)
|
queue.start_job(hash)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
path_to_local_server + "/api/" + task_type + "/", json=input_data)
|
path_to_local_server + "/api/" + task_type + "/", json=input_data)
|
||||||
|
print('response', response)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
queue.pass_job(hash, response.json())
|
queue.pass_job(hash, response.json())
|
||||||
else:
|
else:
|
||||||
@ -408,6 +422,9 @@ def queue_thread(path_to_local_server):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
pass
|
pass
|
||||||
|
if test_mode:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
|
def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
|
||||||
if server_port is None:
|
if server_port is None:
|
||||||
|
@ -243,7 +243,7 @@ class Image(OutputComponent):
|
|||||||
"""
|
"""
|
||||||
Returns: (str) path to image file
|
Returns: (str) path to image file
|
||||||
"""
|
"""
|
||||||
return self.save_flagged_file(dir, label, data[0], encryption_key)
|
return self.save_flagged_file(dir, label, data, encryption_key)
|
||||||
|
|
||||||
|
|
||||||
class Video(OutputComponent):
|
class Video(OutputComponent):
|
||||||
@ -298,7 +298,7 @@ class Video(OutputComponent):
|
|||||||
"""
|
"""
|
||||||
Returns: (str) path to image file
|
Returns: (str) path to image file
|
||||||
"""
|
"""
|
||||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
return self.save_flagged_file(dir, label, data['data'], encryption_key)
|
||||||
|
|
||||||
|
|
||||||
class KeyValues(OutputComponent):
|
class KeyValues(OutputComponent):
|
||||||
|
@ -5,24 +5,21 @@ import json
|
|||||||
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
|
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
|
||||||
|
|
||||||
en = {
|
en = {
|
||||||
"BETA_MESSAGE": "NOTE: Gradio is in beta stage, please report all bugs to: gradio.app@gmail.com",
|
"RUNNING_LOCALLY": "Running on local URL: {}",
|
||||||
"RUNNING_LOCALLY": "Running locally at: {}",
|
"SHARE_LINK_DISPLAY": "Running on public URL: {}",
|
||||||
"NGROK_NO_INTERNET": "Unable to create public link for interface, please check internet connection or try "
|
|
||||||
"restarting python interpreter.",
|
|
||||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||||
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in `launch()`.",
|
"PUBLIC_SHARE_TRUE": "\nTo create a public link, set `share=True` in `launch()`.",
|
||||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
|
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
|
||||||
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
|
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
|
||||||
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
|
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
|
||||||
" avoid the 'Tensor is not an element of this graph.' error.",
|
" avoid the 'Tensor is not an element of this graph.' error.",
|
||||||
"BETA_INVITE": "\nWe want to invite you to become a beta user.\nYou'll get early access to new and premium "
|
"BETA_INVITE": "\nWe want to invite you to become a beta user.\nYou'll get early access to new and premium "
|
||||||
"features (persistent links, hosting, and more).\nIf you're interested please email beta@gradio.app\n",
|
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
|
||||||
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
|
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
|
||||||
"To turn off, set debug=False in launch().",
|
"To turn off, set debug=False in launch().",
|
||||||
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
|
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
|
||||||
"SHARE_LINK_MESSAGE": "This share link will expire in 72 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)",
|
"SHARE_LINK_MESSAGE": "\nThis share link will expire in 72 hours. To get longer links, send an email to: support@gradio.app",
|
||||||
"PRIVATE_LINK_MESSAGE": "Since this is a private endpoint, this share link will never expire.",
|
"PRIVATE_LINK_MESSAGE": "Since this is a private endpoint, this share link will never expire.",
|
||||||
"SHARE_LINK_DISPLAY": "Running on External URL: {}",
|
|
||||||
"INLINE_DISPLAY_BELOW": "Interface loading below...",
|
"INLINE_DISPLAY_BELOW": "Interface loading below...",
|
||||||
"MEDIA_PERMISSIONS_IN_COLAB": "Your interface requires microphone or webcam permissions - this may cause issues in Colab. Use the External URL in case of issues.",
|
"MEDIA_PERMISSIONS_IN_COLAB": "Your interface requires microphone or webcam permissions - this may cause issues in Colab. Use the External URL in case of issues.",
|
||||||
"TIPS": [
|
"TIPS": [
|
||||||
|
@ -284,8 +284,8 @@
|
|||||||
<p>  {</p>
|
<p>  {</p>
|
||||||
<p>    "data": [{%for i in range(0, len_outputs)%} <span>{{ output_types[i]
|
<p>    "data": [{%for i in range(0, len_outputs)%} <span>{{ output_types[i]
|
||||||
}}</span>{% if i != len_outputs - 1 %} ,{% endif %}{%endfor%} ],</p>
|
}}</span>{% if i != len_outputs - 1 %} ,{% endif %}{%endfor%} ],</p>
|
||||||
<p>    "durations": [ float ],</p>
|
<p>    "durations": [ float ], // the time taken for the prediction to complete</p>
|
||||||
<p>    "avg_durations": [ float ]</p>
|
<p>    "avg_durations": [ float ] // the average time taken for all predictions so far (used to estimate the runtime)</p>
|
||||||
<p>  }</p>
|
<p>  }</p>
|
||||||
</div>
|
</div>
|
||||||
<h4 id="try-it">Try it (live demo): </h4>
|
<h4 id="try-it">Try it (live demo): </h4>
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
{
|
{
|
||||||
"files": {
|
"files": {
|
||||||
"main.css": "/static/css/main.e23a1a2e.css",
|
"main.css": "/static/css/main.ccb63765.css",
|
||||||
"main.js": "/static/bundle.js",
|
"main.js": "/static/bundle.js",
|
||||||
"index.html": "/index.html",
|
"index.html": "/index.html",
|
||||||
"static/media/arrow-left.e497f657.svg": "/static/media/arrow-left.e497f657.svg",
|
"static/media/arrow-left.794a4706.svg": "/static/media/arrow-left.794a4706.svg",
|
||||||
"static/media/arrow-right.ea6059fd.svg": "/static/media/arrow-right.ea6059fd.svg",
|
"static/media/arrow-right.5a7d4ada.svg": "/static/media/arrow-right.5a7d4ada.svg",
|
||||||
"static/media/clear.33f9b5f3.svg": "/static/media/clear.33f9b5f3.svg",
|
"static/media/clear.85cf6de8.svg": "/static/media/clear.85cf6de8.svg",
|
||||||
"static/media/edit.44bd4fe1.svg": "/static/media/edit.44bd4fe1.svg",
|
"static/media/edit.c6b7d6f7.svg": "/static/media/edit.c6b7d6f7.svg",
|
||||||
"static/media/logo.411acfd1.svg": "/static/media/logo.411acfd1.svg",
|
"static/media/logo.36a8f455.svg": "/static/media/logo.36a8f455.svg",
|
||||||
"static/media/logo_loading.e93acd82.jpg": "/static/media/logo_loading.e93acd82.jpg"
|
"static/media/logo_loading.e93acd82.jpg": "/static/media/logo_loading.e93acd82.jpg"
|
||||||
},
|
},
|
||||||
"entrypoints": [
|
"entrypoints": [
|
||||||
"static/bundle.css",
|
"static/bundle.css",
|
||||||
"static/css/main.e23a1a2e.css",
|
"static/css/main.ccb63765.css",
|
||||||
"static/bundle.js"
|
"static/bundle.js"
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -8,4 +8,4 @@
|
|||||||
window.config = {{ config|tojson }};
|
window.config = {{ config|tojson }};
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
window.config = {};
|
window.config = {};
|
||||||
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.e23a1a2e.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.ccb63765.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
|||||||
2.4.0
|
2.4.1
|
||||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ except ImportError:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='gradio',
|
name='gradio',
|
||||||
version='2.4.0',
|
version='2.4.1',
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
description='Python library for easily interacting with trained machine learning models',
|
description='Python library for easily interacting with trained machine learning models',
|
||||||
author='Abubakar Abid',
|
author='Abubakar Abid',
|
||||||
|
BIN
test/images/test_image.png
Normal file
BIN
test/images/test_image.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
@ -158,7 +158,7 @@ class TestDemo(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
elem.click()
|
elem.click()
|
||||||
elem = WebDriverWait(driver, TIMEOUT).until(
|
elem = WebDriverWait(driver, TIMEOUT).until(
|
||||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class"))
|
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences"))
|
||||||
)
|
)
|
||||||
|
|
||||||
total_sleep = 0
|
total_sleep = 0
|
||||||
|
@ -7,22 +7,105 @@ WARNING: These tests have an external dependency: namely that Hugging Face's Hub
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class TestHuggingFaceModelAPI(unittest.TestCase):
|
class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||||
|
def test_question_answering(self):
|
||||||
|
model_type = "question-answering"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"][1], gr.outputs.Label)
|
||||||
|
|
||||||
def test_text_generation(self):
|
def test_text_generation(self):
|
||||||
model_type = "text_generation"
|
model_type = "text_generation"
|
||||||
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
|
interface_info = gr.external.get_huggingface_interface("gpt2",
|
||||||
self.assertEqual(interface_info["fn"].__name__, "gpt2")
|
api_key=None,
|
||||||
|
alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||||
|
|
||||||
def test_sentiment_classifier(self):
|
def test_summarization(self):
|
||||||
model_type = "sentiment_classifier"
|
model_type = "summarization"
|
||||||
interface_info = gr.external.get_huggingface_interface(
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
"distilbert-base-uncased-finetuned-sst-2-english", api_key=None,
|
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||||
alias=model_type)
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||||
|
|
||||||
|
def test_translation(self):
|
||||||
|
model_type = "translation"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||||
|
|
||||||
|
def test_text2text_generation(self):
|
||||||
|
model_type = "text2text-generation"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||||
|
|
||||||
|
def test_text_classification(self):
|
||||||
|
model_type = "text-classification"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||||
|
|
||||||
|
def test_fill_mask(self):
|
||||||
|
model_type = "fill-mask"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"bert-base-uncased",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||||
|
|
||||||
|
def test_zero_shot_classification(self):
|
||||||
|
model_type = "zero-shot-classification"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"facebook/bart-large-mnli",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["inputs"][2], gr.inputs.Checkbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||||
|
|
||||||
|
def test_automatic_speech_recognition(self):
|
||||||
|
model_type = "automatic-speech-recognition"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"facebook/wav2vec2-base-960h",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||||
|
|
||||||
|
def test_image_classification(self):
|
||||||
|
model_type = "image-classification"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"google/vit-base-patch16-224",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||||
|
|
||||||
|
def test_feature_extraction(self):
|
||||||
|
model_type = "feature-extraction"
|
||||||
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
|
"sentence-transformers/distilbert-base-nli-mean-tokens",
|
||||||
|
api_key=None, alias=model_type)
|
||||||
|
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||||
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
|
||||||
|
|
||||||
def test_sentence_similarity(self):
|
def test_sentence_similarity(self):
|
||||||
model_type = "text-to-speech"
|
model_type = "text-to-speech"
|
||||||
interface_info = gr.external.get_huggingface_interface(
|
interface_info = gr.external.get_huggingface_interface(
|
||||||
@ -50,8 +133,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
|||||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
|
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
|
||||||
|
|
||||||
|
|
||||||
class TestHuggingFaceSpaceAPI(unittest.TestCase):
|
|
||||||
def test_english_to_spanish(self):
|
def test_english_to_spanish(self):
|
||||||
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
|
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
|
||||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||||
@ -63,19 +144,6 @@ class TestLoadInterface(unittest.TestCase):
|
|||||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||||
|
|
||||||
def test_distilbert_classification(self):
|
|
||||||
interface_info = gr.external.load_interface("distilbert-base-uncased-finetuned-sst-2-english", src="huggingface", alias="sentiment_classifier")
|
|
||||||
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
|
|
||||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
|
||||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
|
||||||
|
|
||||||
def test_models_src(self):
|
|
||||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
|
||||||
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
|
|
||||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
|
||||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
|
||||||
|
|
||||||
class TestCallingLoadInterface(unittest.TestCase):
|
|
||||||
def test_sentiment_model(self):
|
def test_sentiment_model(self):
|
||||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||||
io = gr.Interface(**interface_info)
|
io = gr.Interface(**interface_info)
|
||||||
|
@ -64,12 +64,12 @@ class TestFlaskRoutes(unittest.TestCase):
|
|||||||
response = self.client.get('/')
|
response = self.client.get('/')
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
def test_get_config_route(self):
|
def test_get_api_route(self):
|
||||||
response = self.client.get('/config/')
|
response = self.client.get('/api/')
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
def test_get_static_route(self):
|
def test_get_config_route(self):
|
||||||
response = self.client.get('/static/bundle.css')
|
response = self.client.get('/config/')
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
def test_enable_sharing_route(self):
|
def test_enable_sharing_route(self):
|
||||||
@ -212,5 +212,29 @@ class TestURLs(unittest.TestCase):
|
|||||||
self.assertTrue(res)
|
self.assertTrue(res)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueuing(unittest.TestCase):
|
||||||
|
def test_queueing(self):
|
||||||
|
io = gr.Interface(lambda x: x, "text", "text")
|
||||||
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||||
|
client = app.test_client()
|
||||||
|
# mock queue methods and post method
|
||||||
|
networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
|
||||||
|
networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
|
||||||
|
networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
|
||||||
|
networking.queue.start_job = mock.MagicMock(return_value=None)
|
||||||
|
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
|
||||||
|
# execute queue action successfully
|
||||||
|
networking.queue_thread('test_path', test_mode=True)
|
||||||
|
networking.queue.pass_job.assert_called_once()
|
||||||
|
# execute queue action unsuccessfully
|
||||||
|
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
|
||||||
|
networking.queue_thread('test_path', test_mode=True)
|
||||||
|
networking.queue.fail_job.assert_called_once()
|
||||||
|
# no more things on the queue so methods shouldn't be called any more times
|
||||||
|
networking.queue.pop = mock.MagicMock(return_value=None)
|
||||||
|
networking.queue.pass_job.assert_called_once()
|
||||||
|
networking.queue.fail_job.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
70
test/test_processing_utils.py
Normal file
70
test/test_processing_utils.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import unittest
|
||||||
|
import pathlib
|
||||||
|
import gradio as gr
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class ImagePreprocessing(unittest.TestCase):
|
||||||
|
def test_decode_base64_to_image(self):
|
||||||
|
output_image = gr.processing_utils.decode_base64_to_image(
|
||||||
|
gr.test_data.BASE64_IMAGE)
|
||||||
|
self.assertIsInstance(output_image, Image.Image)
|
||||||
|
|
||||||
|
def test_encode_url_or_file_to_base64(self):
|
||||||
|
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
|
||||||
|
"test/images/test_image.png")
|
||||||
|
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
|
||||||
|
|
||||||
|
def test_encode_file_to_base64(self):
|
||||||
|
output_base64 = gr.processing_utils.encode_file_to_base64(
|
||||||
|
"test/images/test_image.png")
|
||||||
|
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
|
||||||
|
|
||||||
|
def test_encode_url_to_base64(self):
|
||||||
|
output_base64 = gr.processing_utils.encode_url_to_base64(
|
||||||
|
"https://raw.githubusercontent.com/gradio-app/gradio/master/test"
|
||||||
|
"/images/test_image.png")
|
||||||
|
self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
|
||||||
|
|
||||||
|
def test_encode_plot_to_base64(self):
|
||||||
|
plt.plot([1, 2, 3, 4])
|
||||||
|
output_base64 = gr.processing_utils.encode_plot_to_base64(plt)
|
||||||
|
self.assertEqual(output_base64, gr.test_data.BASE64_PLT_IMG)
|
||||||
|
|
||||||
|
def test_encode_array_to_base64(self):
|
||||||
|
img = Image.open("test/images/test_image.png")
|
||||||
|
img = img.convert("RGB")
|
||||||
|
numpy_data = np.asarray(img, dtype=np.uint8)
|
||||||
|
output_base64 = gr.processing_utils.encode_array_to_base64(numpy_data)
|
||||||
|
# self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
|
||||||
|
|
||||||
|
class OutputPreprocessing(unittest.TestCase):
|
||||||
|
|
||||||
|
float_dtype_list = [float, float, np.double, np.single, np.float32,
|
||||||
|
np.float64, 'float32', 'float64']
|
||||||
|
def test_float_conversion_dtype(self):
|
||||||
|
"""Test any convertion from a float dtype to an other."""
|
||||||
|
|
||||||
|
x = np.array([-1, 1])
|
||||||
|
# Test all combinations of dtypes conversions
|
||||||
|
dtype_combin = np.array(np.meshgrid(
|
||||||
|
OutputPreprocessing.float_dtype_list,
|
||||||
|
OutputPreprocessing.float_dtype_list)).T.reshape(-1, 2)
|
||||||
|
|
||||||
|
for dtype_in, dtype_out in dtype_combin:
|
||||||
|
x = x.astype(dtype_in)
|
||||||
|
y = gr.processing_utils._convert(x, dtype_out)
|
||||||
|
assert y.dtype == np.dtype(dtype_out)
|
||||||
|
|
||||||
|
def test_subclass_conversion(self):
|
||||||
|
"""Check subclass conversion behavior"""
|
||||||
|
x = np.array([-1, 1])
|
||||||
|
|
||||||
|
for dtype in OutputPreprocessing.float_dtype_list:
|
||||||
|
x = x.astype(dtype)
|
||||||
|
y = gr.processing_utils._convert(x, np.floating)
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user