mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
move interpretation logic
This commit is contained in:
parent
9579280f5e
commit
12f8d8fbf6
@ -18,6 +18,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import copy
|
||||
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
@ -255,6 +256,26 @@ class Interface:
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input):
|
||||
if self.interpretation == "default":
|
||||
interpreter = gradio.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(self.input_interfaces[i])
|
||||
interface_type = type(input_interface)
|
||||
if interface_type in gradio.interpretation.expected_types:
|
||||
input_interface.type = gradio.interpretation.expected_types[interface_type]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
interpretation = interpreter(self, processed_input)
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
interpreter = self.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
return interpretation
|
||||
|
||||
def close(self):
|
||||
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
||||
|
@ -18,7 +18,6 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
import copy
|
||||
import logging
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
@ -87,7 +86,11 @@ def main():
|
||||
|
||||
@app.route("/static/<path:path>")
|
||||
def static(path):
|
||||
return send_file(os.path.join(STATIC_PATH_LIB, path))
|
||||
path = os.path.join(STATIC_PATH_LIB, path)
|
||||
if os.path.exists(path):
|
||||
return send_file()
|
||||
else:
|
||||
abort(404)
|
||||
|
||||
|
||||
@app.route("/config/", methods=["GET"])
|
||||
@ -150,23 +153,7 @@ def flag():
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
interface_type = type(input_interface)
|
||||
if interface_type in gr.interpretation.expected_types:
|
||||
input_interface.type = gr.interpretation.expected_types[interface_type]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
interpretation = app.interface.interpret(raw_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import copy
|
||||
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
@ -255,6 +256,26 @@ class Interface:
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input):
|
||||
if self.interpretation == "default":
|
||||
interpreter = gradio.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(self.input_interfaces[i])
|
||||
interface_type = type(input_interface)
|
||||
if interface_type in gradio.interpretation.expected_types:
|
||||
input_interface.type = gradio.interpretation.expected_types[interface_type]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
interpretation = interpreter(self, processed_input)
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
interpreter = self.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
return interpretation
|
||||
|
||||
def close(self):
|
||||
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
||||
|
@ -18,7 +18,6 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
import copy
|
||||
import logging
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
@ -87,7 +86,11 @@ def main():
|
||||
|
||||
@app.route("/static/<path:path>")
|
||||
def static(path):
|
||||
return send_file(os.path.join(STATIC_PATH_LIB, path))
|
||||
path = os.path.join(STATIC_PATH_LIB, path)
|
||||
if os.path.exists(path):
|
||||
return send_file()
|
||||
else:
|
||||
abort(404)
|
||||
|
||||
|
||||
@app.route("/config/", methods=["GET"])
|
||||
@ -150,23 +153,7 @@ def flag():
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
interface_type = type(input_interface)
|
||||
if interface_type in gr.interpretation.expected_types:
|
||||
input_interface.type = gr.interpretation.expected_types[interface_type]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(*processed_input)
|
||||
if len(raw_input) == 1:
|
||||
interpretation = [interpretation]
|
||||
interpretation = app.interface.interpret(raw_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user