gradio/demo/image_classifier.py

47 lines
1.4 KiB
Python
Raw Normal View History

2020-09-17 07:43:37 +08:00
import gradio as gr
import tensorflow as tf
# from vis.utils import utils
# from vis.visualization import visualize_cam
import numpy as np
from PIL import Image
import requests
from urllib.request import urlretrieve
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
def image_explain(im):
model.layers[-1].activation = keras.activations.linear
model = utils.apply_modifications(model)
penultimate_layer_idx = 2
class_idx = class_idxs_sorted[0]
seed_input = img
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
penultimate_layer_idx = penultimate_layer_idx,#None,
backprop_modifier = None,
grad_modifier = None)
print(grad_top_1)
return grad_top1
imagein = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(image_classifier, imagein, label,
capture_session=True,
interpret_by="default",
2020-09-17 07:43:37 +08:00
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
]).launch();