use threading instead of process

This commit is contained in:
Ali Abid 2020-09-22 11:16:46 -07:00
parent ec4d96732f
commit 5b7a4bbfa7
3 changed files with 11 additions and 26 deletions

View File

@ -6,7 +6,7 @@ import os
import socket
import threading
from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process
import threading
import pkg_resources
from distutils import dir_util
import gradio as gr
@ -35,8 +35,7 @@ app = Flask(__name__,
template_folder=STATIC_TEMPLATE_LIB,
static_folder=STATIC_PATH_LIB)
app.app_globals = {}
# app.config["FLASK_SKIP_DOTENV"] = 1
# app.FLASK_SKIP_DOTENV = 1
def set_meta_tags(title, description, thumbnail):
app.app_globals.update({
@ -160,7 +159,6 @@ def interpret():
def file(path):
return send_file(os.path.join(app.cwd, path))
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
@ -169,7 +167,7 @@ def start_server(interface, server_port=None):
)
app.interface = interface
app.cwd = os.getcwd()
process = Process(target=app.run, kwargs={"port": port})
process = threading.Thread(target=app.run, kwargs={"port": port})
process.start()
return port, app, process

View File

@ -9,9 +9,8 @@ import requests
from urllib.request import urlretrieve
# # Download human-readable labels for ImageNet.
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")
labels = range(1000) # comment this later
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
mobile_net = tf.keras.applications.MobileNetV2()
@ -22,27 +21,16 @@ def image_classifier(im):
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
def image_explain(im):
model.layers[-1].activation = tf.keras.activations.linear
model = utils.apply_modifications(model)
penultimate_layer_idx = 2
class_idx = class_idxs_sorted[0]
seed_input = img
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
penultimate_layer_idx = penultimate_layer_idx,#None,
backprop_modifier = None,
grad_modifier = None)
print(grad_top_1)
return grad_top1
image = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(image_classifier, image, label,
io = gr.Interface(image_classifier, image, label,
capture_session=True,
interpretation="default",
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
]).launch();
])
io.launch()

View File

@ -6,7 +6,7 @@ import os
import socket
import threading
from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process
import threading
import pkg_resources
from distutils import dir_util
import gradio as gr
@ -159,7 +159,6 @@ def interpret():
def file(path):
return send_file(os.path.join(app.cwd, path))
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
@ -168,7 +167,7 @@ def start_server(interface, server_port=None):
)
app.interface = interface
app.cwd = os.getcwd()
process = Process(target=app.run, kwargs={"port": port})
process = threading.Thread(target=app.run, kwargs={"port": port})
process.start()
return port, app, process