Merge branch 'master' into aliabd/components-tests

This commit is contained in:
Ali Abdalla 2021-11-03 18:15:39 -07:00 committed by GitHub
commit 7ce37829e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 565 additions and 26673 deletions

View File

@ -35,7 +35,7 @@ module.exports = {
};
paths.appBuild = webpackConfig.output.path;
return webpackConfig;
},
}
},
style: {
postcss: {

26719
frontend/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -33,9 +33,9 @@
"webpack": "^4.44.2"
},
"scripts": {
"start": "REACT_APP_BACKEND_URL='http://localhost:7860/' craco start",
"start": "cross-env REACT_APP_BACKEND_URL='http://localhost:7860/' craco start",
"format": "prettier-eslint --write '**/*.js*'",
"build": "REACT_APP_BACKEND_URL='' GENERATE_SOURCEMAP=false craco build",
"build": "cross-env REACT_APP_BACKEND_URL='' GENERATE_SOURCEMAP=false craco build",
"eject": "react-scripts eject"
},
"eslintConfig": {
@ -57,6 +57,7 @@
},
"devDependencies": {
"autoprefixer": "^9.8.6",
"cross-env": "^7.0.3",
"eslint": "^7.32.0",
"mini-css-extract-plugin": "^0.11.3",
"postcss": "^7.0.36",

View File

@ -77,16 +77,16 @@ class AudioInput extends BaseComponent {
data: this.props.value["data"],
is_example: this.props.value["is_example"],
crop_min: crop_min,
crop_max: crop_max,
crop_max: crop_max
});
this.setState({ editorMode: !this.state.editorMode })
}
this.setState({ editorMode: !this.state.editorMode });
};
crop = (min, max, lastChange) => {
if (this.state.duration) {
if (lastChange === "min") {
this.audioRef.current.currentTime = (min / 100.) * this.state.duration;
this.audioRef.current.currentTime = (min / 100) * this.state.duration;
} else {
this.audioRef.current.currentTime = (max / 100.) * this.state.duration;
this.audioRef.current.currentTime = (max / 100) * this.state.duration;
}
}
this.props.handleChange({
@ -94,20 +94,24 @@ class AudioInput extends BaseComponent {
data: this.props.value["data"],
is_example: this.props.value["is_example"],
crop_min: min,
crop_max: max,
})
}
crop_max: max
});
};
reset_playback_within_crop = () => {
let position_ratio = this.audioRef.current.currentTime / this.state.duration;
let position_ratio =
this.audioRef.current.currentTime / this.state.duration;
let min_ratio = this.props.value.crop_min / 100;
let max_ratio = this.props.value.crop_max / 100;
if ((position_ratio > max_ratio - 0.00001) || (position_ratio < min_ratio - 0.00001)) {
if (
position_ratio > max_ratio - 0.00001 ||
position_ratio < min_ratio - 0.00001
) {
this.audioRef.current.currentTime = this.state.duration * min_ratio;
return true;
} else {
return false;
}
}
};
render() {
if (this.props.value !== null) {
if (
@ -124,17 +128,27 @@ class AudioInput extends BaseComponent {
<div className="input_audio">
<div className="edit_buttons">
<button
className={classNames("edit_button", { "active": this.state.editorMode })}
className={classNames("edit_button", {
active: this.state.editorMode
})}
onClick={this.toggleEditor}
>
<img src={edit_icon} />
</button>
<button className="clear_button" onClick={this.props.handleChange.bind(this, null)}>
<button
className="clear_button"
onClick={this.props.handleChange.bind(this, null)}
>
<img src={clear_icon} />
</button>
</div>
<audio controls key={this.key} ref={this.audioRef}
onLoadedMetadata={e => this.setState({ duration: e.nativeEvent.target.duration })}
<audio
controls
key={this.key}
ref={this.audioRef}
onLoadedMetadata={(e) =>
this.setState({ duration: e.nativeEvent.target.duration })
}
onPlay={() => {
this.reset_playback_within_crop();
this.audioRef.current.play();
@ -162,9 +176,17 @@ class AudioInput extends BaseComponent {
))}
</div>
)}
{this.state.editorMode ?
<MultiRangeSlider min={0} max={100} onChange={({ min, max, lastChange }) => this.crop(min, max, lastChange)} />
: false}
{this.state.editorMode ? (
<MultiRangeSlider
min={0}
max={100}
onChange={({ min, max, lastChange }) =>
this.crop(min, max, lastChange)
}
/>
) : (
false
)}
</div>
);
} else {

View File

@ -51,7 +51,10 @@ class FileInput extends BaseComponent {
return (
<div className="input_file">
<div className="file_preview_holder">
<button className="clear_button" onClick={this.handleChange.bind(this, null)}>
<button
className="clear_button"
onClick={this.handleChange.bind(this, null)}
>
<img src={clear_icon} />
</button>
<div className="file_name">{file_name}</div>

View File

@ -154,8 +154,11 @@ class ImageInput extends BaseComponent {
<button className="edit_button" onClick={this.openEditor}>
<img src={edit_icon} />
</button>
<button className="clear_button" onClick={this.handleChange.bind(this, null)}>
<img src={clear_icon} />
<button
className="clear_button"
onClick={this.handleChange.bind(this, null)}
>
<img src={clear_icon} />
</button>
</div>
)

View File

@ -27,20 +27,29 @@ class VideoInput extends BaseComponent {
evt.stopPropagation();
};
if (this.props.value != null) {
return <div className="input_video">
<div className="edit_buttons">
<button className="clear_button" onClick={this.props.handleChange.bind(this, null)}>
<img src={clear_icon} />
</button>
return (
<div className="input_video">
<div className="edit_buttons">
<button
className="clear_button"
onClick={this.props.handleChange.bind(this, null)}
>
<img src={clear_icon} />
</button>
</div>
{isPlayable("video", this.props.value["name"]) ? (
<div className="video_preview_holder">
<video
className="video_preview"
controls
src={this.props.value["data"]}
></video>
</div>
) : (
<div className="video_file_holder">{this.props.value["name"]}</div>
)}
</div>
{isPlayable("video", this.props.value["name"]) ? <div className="video_preview_holder">
<video
className="video_preview"
controls
src={this.props.value["data"]}
></video>
</div> : <div className="video_file_holder">{this.props.value["name"]}</div>}
</div>
);
} else {
return (
<div

View File

@ -28,16 +28,21 @@ class LabelOutput extends BaseComponent {
</div>
);
}
return (
<div className="output_label">
<div className="output_class">{this.props.value["label"]}</div>
<div className="confidence_intervals">
<div className="labels" style={{ maxWidth: "120px" }}>
{labels}
</div>
<div className="confidences">{confidences}</div>
</div>
</div>
);
}
return (
<div className="output_label">
<div className="output_class">{this.props.value["label"]}</div>
<div className="confidence_intervals">
<div className="labels" style={{ maxWidth: "120px" }}>
{labels}
</div>
<div className="confidences">{confidences}</div>
</div>
<div className="output_class_without_confidences">{this.props.value["label"]}</div>
</div>
);
}

View File

@ -39,19 +39,22 @@ export class GradioPage extends React.Component {
false
)}
</div>
<a href="/api/" target="_blank" class="footer" rel="noreferrer">
<span>view the api </span><img class="logo" src="https://i.ibb.co/6DVLqmf/noun-tools-2220412.png" alt="logo"/>
<span> |</span>
<a
<div className="footer">
<a
href="/api/"
target="_blank"
className="footer"
rel="noreferrer">
view the api <img className="logo" src="https://i.ibb.co/6DVLqmf/noun-tools-2220412.png" alt="api"/>
</a> | <a
href="https://gradio.app"
target="_blank"
className="footer"
rel="noreferrer"
>
<span> built with</span>
> built with
<img className="logo" src={logo} alt="logo" />
</a>
</a>
</div>
</div>
</div>
);
@ -73,7 +76,7 @@ export class GradioInterface extends React.Component {
? "file" +
this.props.examples_dir +
(this.props.examples_dir.endswith("/") ? "" : "/")
: "file");
: "file");
}
get_default_state = () => {
let state = {};
@ -170,8 +173,13 @@ export class GradioInterface extends React.Component {
if (this.state.flag_index !== undefined) {
component_state["flag_index"] = this.state.flag_index;
} else {
for (let i = 0; i < this.props.input_components.length; i++) {
component_state["input_data"].push(this.state[i]);
for (let [i, input_component] of this.props.input_components.entries()) {
let InputComponentClass = input_component_set.find(
(c) => c.name === input_component.name
).component;
component_state["input_data"].push(
InputComponentClass.postprocess(this.state[i])
);
}
for (let i = 0; i < this.props.output_components.length; i++) {
component_state["output_data"].push(
@ -192,11 +200,17 @@ export class GradioInterface extends React.Component {
}
this.pending_response = true;
let input_state = [];
for (let i = 0; i < this.props.input_components.length; i++) {
if (this.state[i] === null) {
for (let [i, input_component] of this.props.input_components.entries()) {
if (
this.state[i] === null &&
this.props.input_components[i].optional !== true
) {
return;
}
input_state[i] = this.state[i];
let InputComponentClass = input_component_set.find(
(c) => c.name === input_component.name
).component;
input_state[i] = InputComponentClass.postprocess(this.state[i]);
}
this.setState({ submitting: true, has_changed: false, error: false });
this.props
@ -339,7 +353,6 @@ export class GradioInterface extends React.Component {
<button className="panel_button submit" onClick={this.submit}>
Submit
</button>
)}
</div>
</div>

View File

@ -86,8 +86,7 @@ function load_config(config) {
if (config.theme !== null && config.theme.startsWith("dark")) {
target.classList.add("dark");
config.theme = config.theme.substring(4);
}
else if (url.searchParams.get("__dark-theme") === "true") {
} else if (url.searchParams.get("__dark-theme") === "true") {
target.classList.add("dark");
}
ReactDOM.render(

View File

@ -17,10 +17,10 @@
@apply flex-grow flex-shrink-0;
}
.footer {
@apply flex-shrink-0 flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-center py-2;
@apply flex-shrink-0 inline-flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-center py-2;
}
.api {
@apply flex-shrink-0 flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-end py-2;
@apply flex-shrink-0 inline-flex gap-1 items-center text-gray-400 dark:text-gray-50 justify-end py-2;
}
.logo {
@apply h-6;
@ -512,7 +512,8 @@
}
.output_label {
@apply dark:text-gray-50;
.output_class {
.output_class,
.output_class_without_confidences {
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
}
.confidence_intervals {

View File

@ -459,8 +459,10 @@ html {
word-break: break-word;
@apply w-full bg-white dark:text-blue-900 dark:font-semibold dark:bg-gray-200 border-gray-400 box-border p-1 whitespace-pre-wrap;
}
.output_label {
.output_class {
.output_class,
.output_class_without_confidences {
@apply font-bold text-xl py-6 px-4 flex-grow flex items-center justify-center;
}
.confidence_intervals {

View File

@ -456,6 +456,9 @@
.output_class {
@apply hidden;
}
.output_class_without_confidences{
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center dark:text-gray-50;
}
.confidence_intervals {
@apply flex text-xl;
}

View File

@ -464,7 +464,8 @@
}
.output_label {
@apply dark:text-gray-50;
.output_class {
.output_class,
.output_class_without_confidences {
@apply font-bold text-2xl py-8 px-4 flex-grow flex items-center justify-center;
}
.confidence_intervals {

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0
Name: gradio
Version: 2.4.0
Version: 2.4.1
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid

View File

@ -46,7 +46,10 @@ class Component():
return data
def save_flagged_file(self, dir, label, data, encryption_key):
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key)
label = "".join([char for char in label if char.isalnum() or char in "._- "])
old_file_name = file.name
output_dir = os.path.join(dir, label)
if os.path.exists(output_dir):

View File

@ -69,7 +69,7 @@ def get_huggingface_interface(model_name, api_key, alias):
},
'fill-mask': {
'inputs': inputs.Textbox(label="Input"),
'outputs': "label",
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
},

View File

@ -939,7 +939,7 @@ class Video(InputComponent):
"""
Returns: (str) path to video file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
return self.save_flagged_file(dir, label, None if data is None else data["data"], encryption_key)
def generate_sample(self):
return test_data.BASE64_VIDEO
@ -1113,7 +1113,7 @@ class Audio(InputComponent):
"""
Returns: (str) path to audio file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
return self.save_flagged_file(dir, label, None if data is None else data["data"], encryption_key)
def generate_sample(self):
return test_data.BASE64_AUDIO
@ -1192,7 +1192,7 @@ class File(InputComponent):
"""
Returns: (str) path to file
"""
return self.save_flagged_file(dir, label, data["data"], encryption_key)
return self.save_flagged_file(dir, label, None if data is None else data[0]["data"], encryption_key)
def generate_sample(self):
return test_data.BASE64_FILE

View File

@ -490,12 +490,6 @@ class Interface:
interpretation = [interpretation]
return interpretation, []
def close(self):
# checks to see if server is running
if self.simple_server and not (self.simple_server.fileno() == -1):
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def run_until_interrupted(self, thread, path_to_local_server):
try:
while True:
@ -590,15 +584,15 @@ class Interface:
self.share = share
if share:
if private_endpoint:
print(strings.en["PRIVATE_LINK_MESSAGE"])
else:
print(strings.en["SHARE_LINK_MESSAGE"])
try:
share_url = networking.setup_tunnel(
server_port, private_endpoint)
self.share_url = share_url
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
if private_endpoint:
print(strings.en["PRIVATE_LINK_MESSAGE"])
else:
print(strings.en["SHARE_LINK_MESSAGE"])
except RuntimeError:
send_error_analytics(self.analytics_enabled)
share_url = None
@ -647,6 +641,19 @@ class Interface:
return app, path_to_local_server, share_url
def close(self):
try:
if self.share_url:
requests.get("{}/shutdown".format(self.share_url))
print("Closing Gradio server on port {}...".format(self.server_port))
elif self.local_url:
requests.get("{}shutdown".format(self.local_url))
print("Closing Gradio server on port {}...".format(self.server_port))
else:
pass # server not running
except (requests.ConnectionError, ConnectionResetError):
pass # server is already closed
def integrate(self, comet_ml=None, wandb=None, mlflow=None):
analytics_integration = ""
if comet_ml is not None:
@ -738,6 +745,8 @@ def send_launch_analytics(analytics_enabled, inbrowser, is_colab, share, share_u
pass # do not push analytics if no network
def reset_all():
def close_all():
for io in Interface.get_instances():
io.close()
reset_all = close_all # for backwards compatibility

View File

@ -174,6 +174,15 @@ def enable_sharing(path):
return jsonify(success=True)
@app.route("/shutdown", methods=['GET'])
def shutdown():
shutdown_func = request.environ.get('werkzeug.server.shutdown')
if shutdown_func is None:
raise RuntimeError('Not running werkzeug')
shutdown_func()
return "Shutting down..."
@app.route("/api/predict/", methods=["POST"])
@login_check
def predict():
@ -383,6 +392,7 @@ def queue_push():
job_hash, queue_position = queue.push({"data": data}, action)
return {"hash": job_hash, "queue_position": queue_position}
@app.route("/api/queue/status/", methods=["POST"])
@login_check
def queue_status():
@ -390,15 +400,19 @@ def queue_status():
status, data = queue.get_status(hash)
return {"status": status, "data": data}
def queue_thread(path_to_local_server):
def queue_thread(path_to_local_server, test_mode=False):
while True:
try:
next_job = queue.pop()
print(next_job)
if next_job is not None:
_, hash, input_data, task_type = next_job
print(hash)
queue.start_job(hash)
response = requests.post(
path_to_local_server + "/api/" + task_type + "/", json=input_data)
print('response', response)
if response.status_code == 200:
queue.pass_job(hash, response.json())
else:
@ -408,6 +422,9 @@ def queue_thread(path_to_local_server):
except Exception as e:
time.sleep(1)
pass
if test_mode:
break
def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
if server_port is None:

View File

@ -243,7 +243,7 @@ class Image(OutputComponent):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data[0], encryption_key)
return self.save_flagged_file(dir, label, data, encryption_key)
class Video(OutputComponent):
@ -298,7 +298,7 @@ class Video(OutputComponent):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
return self.save_flagged_file(dir, label, data['data'], encryption_key)
class KeyValues(OutputComponent):

View File

@ -5,24 +5,21 @@ import json
MESSAGING_API_ENDPOINT = "https://api.gradio.app/gradio-messaging/en"
en = {
"BETA_MESSAGE": "NOTE: Gradio is in beta stage, please report all bugs to: gradio.app@gmail.com",
"RUNNING_LOCALLY": "Running locally at: {}",
"NGROK_NO_INTERNET": "Unable to create public link for interface, please check internet connection or try "
"restarting python interpreter.",
"RUNNING_LOCALLY": "Running on local URL: {}",
"SHARE_LINK_DISPLAY": "Running on public URL: {}",
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in `launch()`.",
"PUBLIC_SHARE_TRUE": "\nTo create a public link, set `share=True` in `launch()`.",
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} (may take up to a minute for link to be usable)",
"GENERATING_PUBLIC_LINK": "Generating public link (may take a few seconds...):",
"TF1_ERROR": "It looks like you might be using tensorflow < 2.0. Please pass capture_session=True in Interface() to"
" avoid the 'Tensor is not an element of this graph.' error.",
"BETA_INVITE": "\nWe want to invite you to become a beta user.\nYou'll get early access to new and premium "
"features (persistent links, hosting, and more).\nIf you're interested please email beta@gradio.app\n",
"features (persistent links, hosting, and more).\nIf you're interested please email: beta@gradio.app\n",
"COLAB_DEBUG_TRUE": "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
"To turn off, set debug=False in launch().",
"COLAB_DEBUG_FALSE": "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()",
"SHARE_LINK_MESSAGE": "This share link will expire in 72 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)",
"SHARE_LINK_MESSAGE": "\nThis share link will expire in 72 hours. To get longer links, send an email to: support@gradio.app",
"PRIVATE_LINK_MESSAGE": "Since this is a private endpoint, this share link will never expire.",
"SHARE_LINK_DISPLAY": "Running on External URL: {}",
"INLINE_DISPLAY_BELOW": "Interface loading below...",
"MEDIA_PERMISSIONS_IN_COLAB": "Your interface requires microphone or webcam permissions - this may cause issues in Colab. Use the External URL in case of issues.",
"TIPS": [

View File

@ -284,8 +284,8 @@
<p>&emsp;&emsp;{</p>
<p>&emsp;&emsp;&emsp;&emsp;"data": [{%for i in range(0, len_outputs)%} <span>{{ output_types[i]
}}</span>{% if i != len_outputs - 1 %} ,{% endif %}{%endfor%} ],</p>
<p>&emsp;&emsp;&emsp;&emsp;"durations": [ float ],</p>
<p>&emsp;&emsp;&emsp;&emsp;"avg_durations": [ float ]</p>
<p>&emsp;&emsp;&emsp;&emsp;"durations": [ float ], // the time taken for the prediction to complete</p>
<p>&emsp;&emsp;&emsp;&emsp;"avg_durations": [ float ] // the average time taken for all predictions so far (used to estimate the runtime)</p>
<p>&emsp;&emsp;}</p>
</div>
<h4 id="try-it">Try it (live demo): </h4>

View File

@ -1,18 +1,18 @@
{
"files": {
"main.css": "/static/css/main.e23a1a2e.css",
"main.css": "/static/css/main.ccb63765.css",
"main.js": "/static/bundle.js",
"index.html": "/index.html",
"static/media/arrow-left.e497f657.svg": "/static/media/arrow-left.e497f657.svg",
"static/media/arrow-right.ea6059fd.svg": "/static/media/arrow-right.ea6059fd.svg",
"static/media/clear.33f9b5f3.svg": "/static/media/clear.33f9b5f3.svg",
"static/media/edit.44bd4fe1.svg": "/static/media/edit.44bd4fe1.svg",
"static/media/logo.411acfd1.svg": "/static/media/logo.411acfd1.svg",
"static/media/arrow-left.794a4706.svg": "/static/media/arrow-left.794a4706.svg",
"static/media/arrow-right.5a7d4ada.svg": "/static/media/arrow-right.5a7d4ada.svg",
"static/media/clear.85cf6de8.svg": "/static/media/clear.85cf6de8.svg",
"static/media/edit.c6b7d6f7.svg": "/static/media/edit.c6b7d6f7.svg",
"static/media/logo.36a8f455.svg": "/static/media/logo.36a8f455.svg",
"static/media/logo_loading.e93acd82.jpg": "/static/media/logo_loading.e93acd82.jpg"
},
"entrypoints": [
"static/bundle.css",
"static/css/main.e23a1a2e.css",
"static/css/main.ccb63765.css",
"static/bundle.js"
]
}

View File

@ -8,4 +8,4 @@
window.config = {{ config|tojson }};
} catch (e) {
window.config = {};
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.e23a1a2e.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.ccb63765.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>

File diff suppressed because one or more lines are too long

View File

@ -1 +1 @@
2.4.0
2.4.1

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name='gradio',
version='2.4.0',
version='2.4.1',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',

BIN
test/images/test_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -158,7 +158,7 @@ class TestDemo(unittest.TestCase):
)
elem.click()
elem = WebDriverWait(driver, TIMEOUT).until(
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class"))
EC.presence_of_element_located((By.CSS_SELECTOR, ".panel:nth-child(2) .component:nth-child(2) .output_class_without_confidences"))
)
total_sleep = 0

View File

@ -7,22 +7,105 @@ WARNING: These tests have an external dependency: namely that Hugging Face's Hub
"""
class TestHuggingFaceModelAPI(unittest.TestCase):
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
self.assertIsInstance(interface_info["outputs"][1], gr.outputs.Label)
def test_text_generation(self):
model_type = "text_generation"
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
self.assertEqual(interface_info["fn"].__name__, "gpt2")
interface_info = gr.external.get_huggingface_interface("gpt2",
api_key=None,
alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_sentiment_classifier(self):
model_type = "sentiment_classifier"
def test_summarization(self):
model_type = "summarization"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english", api_key=None,
alias=model_type)
"facebook/bart-large-cnn", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_translation(self):
model_type = "translation"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_text2text_generation(self):
model_type = "text2text-generation"
interface_info = gr.external.get_huggingface_interface(
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_text_classification(self):
model_type = "text-classification"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_fill_mask(self):
model_type = "fill-mask"
interface_info = gr.external.get_huggingface_interface(
"bert-base-uncased",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-mnli",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][2], gr.inputs.Checkbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface_info = gr.external.get_huggingface_interface(
"facebook/wav2vec2-base-960h",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_image_classification(self):
model_type = "image-classification"
interface_info = gr.external.get_huggingface_interface(
"google/vit-base-patch16-224",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_feature_extraction(self):
model_type = "feature-extraction"
interface_info = gr.external.get_huggingface_interface(
"sentence-transformers/distilbert-base-nli-mean-tokens",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
def test_sentence_similarity(self):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
@ -50,8 +133,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
class TestHuggingFaceSpaceAPI(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
@ -61,21 +142,8 @@ class TestLoadInterface(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.load_interface("spaces/abidlabs/english_to_spanish")
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
def test_distilbert_classification(self):
interface_info = gr.external.load_interface("distilbert-base-uncased-finetuned-sst-2-english", src="huggingface", alias="sentiment_classifier")
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_models_src(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
class TestCallingLoadInterface(unittest.TestCase):
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = gr.Interface(**interface_info)

View File

@ -64,12 +64,12 @@ class TestFlaskRoutes(unittest.TestCase):
response = self.client.get('/')
self.assertEqual(response.status_code, 200)
def test_get_config_route(self):
response = self.client.get('/config/')
def test_get_api_route(self):
response = self.client.get('/api/')
self.assertEqual(response.status_code, 200)
def test_get_static_route(self):
response = self.client.get('/static/bundle.css')
def test_get_config_route(self):
response = self.client.get('/config/')
self.assertEqual(response.status_code, 200)
def test_enable_sharing_route(self):
@ -212,5 +212,29 @@ class TestURLs(unittest.TestCase):
self.assertTrue(res)
class TestQueuing(unittest.TestCase):
def test_queueing(self):
io = gr.Interface(lambda x: x, "text", "text")
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
# mock queue methods and post method
networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
networking.queue.start_job = mock.MagicMock(return_value=None)
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
# execute queue action successfully
networking.queue_thread('test_path', test_mode=True)
networking.queue.pass_job.assert_called_once()
# execute queue action unsuccessfully
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
networking.queue_thread('test_path', test_mode=True)
networking.queue.fail_job.assert_called_once()
# no more things on the queue so methods shouldn't be called any more times
networking.queue.pop = mock.MagicMock(return_value=None)
networking.queue.pass_job.assert_called_once()
networking.queue.fail_job.assert_called_once()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,70 @@
import unittest
import pathlib
import gradio as gr
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
class ImagePreprocessing(unittest.TestCase):
def test_decode_base64_to_image(self):
output_image = gr.processing_utils.decode_base64_to_image(
gr.test_data.BASE64_IMAGE)
self.assertIsInstance(output_image, Image.Image)
def test_encode_url_or_file_to_base64(self):
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
"test/images/test_image.png")
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
def test_encode_file_to_base64(self):
output_base64 = gr.processing_utils.encode_file_to_base64(
"test/images/test_image.png")
self.assertEquals(output_base64, gr.test_data.BASE64_IMAGE)
def test_encode_url_to_base64(self):
output_base64 = gr.processing_utils.encode_url_to_base64(
"https://raw.githubusercontent.com/gradio-app/gradio/master/test"
"/images/test_image.png")
self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
def test_encode_plot_to_base64(self):
plt.plot([1, 2, 3, 4])
output_base64 = gr.processing_utils.encode_plot_to_base64(plt)
self.assertEqual(output_base64, gr.test_data.BASE64_PLT_IMG)
def test_encode_array_to_base64(self):
img = Image.open("test/images/test_image.png")
img = img.convert("RGB")
numpy_data = np.asarray(img, dtype=np.uint8)
output_base64 = gr.processing_utils.encode_array_to_base64(numpy_data)
# self.assertEqual(output_base64, gr.test_data.BASE64_IMAGE)
class OutputPreprocessing(unittest.TestCase):
float_dtype_list = [float, float, np.double, np.single, np.float32,
np.float64, 'float32', 'float64']
def test_float_conversion_dtype(self):
"""Test any convertion from a float dtype to an other."""
x = np.array([-1, 1])
# Test all combinations of dtypes conversions
dtype_combin = np.array(np.meshgrid(
OutputPreprocessing.float_dtype_list,
OutputPreprocessing.float_dtype_list)).T.reshape(-1, 2)
for dtype_in, dtype_out in dtype_combin:
x = x.astype(dtype_in)
y = gr.processing_utils._convert(x, dtype_out)
assert y.dtype == np.dtype(dtype_out)
def test_subclass_conversion(self):
"""Check subclass conversion behavior"""
x = np.array([-1, 1])
for dtype in OutputPreprocessing.float_dtype_list:
x = x.astype(dtype)
y = gr.processing_utils._convert(x, np.floating)
assert y.dtype == x.dtype
if __name__ == '__main__':
unittest.main()