move interpretation logic

This commit is contained in:
Ali Abid 2020-09-24 16:10:18 -07:00
parent 9579280f5e
commit 12f8d8fbf6
4 changed files with 54 additions and 38 deletions

View File

@ -18,6 +18,7 @@ import sys
import weakref import weakref
import analytics import analytics
import os import os
import copy
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/' analytics_url = 'https://api.gradio.app/'
@ -256,6 +257,26 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations return processed_output, durations
def interpret(self, raw_input):
if self.interpretation == "default":
interpreter = gradio.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(self.input_interfaces[i])
interface_type = type(input_interface)
if interface_type in gradio.interpretation.expected_types:
input_interface.type = gradio.interpretation.expected_types[interface_type]
processed_input.append(input_interface.preprocess(x))
interpretation = interpreter(self, processed_input)
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
interpreter = self.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation
def close(self): def close(self):
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port)) print("Closing Gradio server on port {}...".format(self.server_port))

View File

@ -18,7 +18,6 @@ from shutil import copyfile
import requests import requests
import sys import sys
import csv import csv
import copy
import logging import logging
INITIAL_PORT_VALUE = int(os.getenv( INITIAL_PORT_VALUE = int(os.getenv(
@ -87,7 +86,11 @@ def main():
@app.route("/static/<path:path>") @app.route("/static/<path:path>")
def static(path): def static(path):
return send_file(os.path.join(STATIC_PATH_LIB, path)) path = os.path.join(STATIC_PATH_LIB, path)
if os.path.exists(path):
return send_file()
else:
abort(404)
@app.route("/config/", methods=["GET"]) @app.route("/config/", methods=["GET"])
@ -150,23 +153,7 @@ def flag():
@app.route("/api/interpret/", methods=["POST"]) @app.route("/api/interpret/", methods=["POST"])
def interpret(): def interpret():
raw_input = request.json["data"] raw_input = request.json["data"]
if app.interface.interpretation == "default": interpretation = app.interface.interpret(raw_input)
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
interface_type = type(input_interface)
if interface_type in gr.interpretation.expected_types:
input_interface.type = gr.interpretation.expected_types[interface_type]
processed_input.append(input_interface.preprocess(x))
interpretation = interpreter(app.interface, processed_input)
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return jsonify(interpretation) return jsonify(interpretation)

View File

@ -18,6 +18,7 @@ import sys
import weakref import weakref
import analytics import analytics
import os import os
import copy
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/' analytics_url = 'https://api.gradio.app/'
@ -256,6 +257,26 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations return processed_output, durations
def interpret(self, raw_input):
if self.interpretation == "default":
interpreter = gradio.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(self.input_interfaces[i])
interface_type = type(input_interface)
if interface_type in gradio.interpretation.expected_types:
input_interface.type = gradio.interpretation.expected_types[interface_type]
processed_input.append(input_interface.preprocess(x))
interpretation = interpreter(self, processed_input)
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
interpreter = self.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return interpretation
def close(self): def close(self):
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port)) print("Closing Gradio server on port {}...".format(self.server_port))

View File

@ -18,7 +18,6 @@ from shutil import copyfile
import requests import requests
import sys import sys
import csv import csv
import copy
import logging import logging
INITIAL_PORT_VALUE = int(os.getenv( INITIAL_PORT_VALUE = int(os.getenv(
@ -87,7 +86,11 @@ def main():
@app.route("/static/<path:path>") @app.route("/static/<path:path>")
def static(path): def static(path):
return send_file(os.path.join(STATIC_PATH_LIB, path)) path = os.path.join(STATIC_PATH_LIB, path)
if os.path.exists(path):
return send_file()
else:
abort(404)
@app.route("/config/", methods=["GET"]) @app.route("/config/", methods=["GET"])
@ -150,23 +153,7 @@ def flag():
@app.route("/api/interpret/", methods=["POST"]) @app.route("/api/interpret/", methods=["POST"])
def interpret(): def interpret():
raw_input = request.json["data"] raw_input = request.json["data"]
if app.interface.interpretation == "default": interpretation = app.interface.interpret(raw_input)
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
interface_type = type(input_interface)
if interface_type in gr.interpretation.expected_types:
input_interface.type = gr.interpretation.expected_types[interface_type]
processed_input.append(input_interface.preprocess(x))
interpretation = interpreter(app.interface, processed_input)
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(*processed_input)
if len(raw_input) == 1:
interpretation = [interpretation]
return jsonify(interpretation) return jsonify(interpretation)