gradio/demo/image_classifier.py

36 lines
889 B
Python
Raw Normal View History

2020-10-22 20:07:43 +08:00
# Demo: (Image) -> (Label)
2020-09-17 07:43:37 +08:00
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import requests
from urllib.request import urlretrieve
2020-09-21 22:54:34 +08:00
# # Download human-readable labels for ImageNet.
2020-09-23 02:16:46 +08:00
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
2020-09-17 07:43:37 +08:00
mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
2020-09-21 22:54:34 +08:00
image = gr.inputs.Image(shape=(224, 224))
2020-09-17 07:43:37 +08:00
label = gr.outputs.Label(num_top_classes=3)
2020-09-23 02:16:46 +08:00
io = gr.Interface(image_classifier, image, label,
2020-09-17 07:43:37 +08:00
capture_session=True,
2020-09-21 22:54:34 +08:00
interpretation="default",
2020-09-17 07:43:37 +08:00
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
2020-09-23 02:16:46 +08:00
])
io.launch()