gradio/demo/image_classifier.py

34 lines
951 B
Python
Raw Normal View History

2020-10-22 20:07:43 +08:00
# Demo: (Image) -> (Label)
2020-09-17 07:43:37 +08:00
import gradio as gr
2020-11-19 22:29:32 +08:00
import tensorflow as tf
2020-09-17 07:43:37 +08:00
import numpy as np
2020-11-03 10:21:31 +08:00
import json
2021-02-27 02:51:51 +08:00
from os.path import dirname, realpath, join
2020-09-17 07:43:37 +08:00
2020-11-03 10:21:31 +08:00
# Load human-readable labels for ImageNet.
2021-02-27 02:51:51 +08:00
current_dir = dirname(realpath(__file__))
with open(join(current_dir, "files/imagenet_labels.json")) as labels_file:
2020-11-03 10:21:31 +08:00
labels = json.load(labels_file)
2020-09-17 07:43:37 +08:00
2020-11-19 22:29:32 +08:00
mobile_net = tf.keras.applications.MobileNetV2()
2020-09-17 07:43:37 +08:00
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
2020-11-19 22:29:32 +08:00
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
2020-09-17 07:43:37 +08:00
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
2021-02-27 02:51:51 +08:00
iface = gr.Interface(
image_classifier,
gr.inputs.Image(shape=(224, 224)),
gr.outputs.Label(num_top_classes=3),
2020-09-17 07:43:37 +08:00
capture_session=True,
2020-09-21 22:54:34 +08:00
interpretation="default",
2020-09-17 07:43:37 +08:00
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
2020-09-23 02:16:46 +08:00
])
2020-11-11 22:15:53 +08:00
if __name__ == "__main__":
2021-01-30 02:23:17 +08:00
iface.launch()