mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
move interpretation logic
This commit is contained in:
parent
9579280f5e
commit
12f8d8fbf6
@ -18,6 +18,7 @@ import sys
|
|||||||
import weakref
|
import weakref
|
||||||
import analytics
|
import analytics
|
||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||||
analytics_url = 'https://api.gradio.app/'
|
analytics_url = 'https://api.gradio.app/'
|
||||||
@ -255,6 +256,26 @@ class Interface:
|
|||||||
processed_output = [output_interface.postprocess(
|
processed_output = [output_interface.postprocess(
|
||||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||||
return processed_output, durations
|
return processed_output, durations
|
||||||
|
|
||||||
|
def interpret(self, raw_input):
|
||||||
|
if self.interpretation == "default":
|
||||||
|
interpreter = gradio.interpretation.default()
|
||||||
|
processed_input = []
|
||||||
|
for i, x in enumerate(raw_input):
|
||||||
|
input_interface = copy.deepcopy(self.input_interfaces[i])
|
||||||
|
interface_type = type(input_interface)
|
||||||
|
if interface_type in gradio.interpretation.expected_types:
|
||||||
|
input_interface.type = gradio.interpretation.expected_types[interface_type]
|
||||||
|
processed_input.append(input_interface.preprocess(x))
|
||||||
|
interpretation = interpreter(self, processed_input)
|
||||||
|
else:
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)]
|
||||||
|
interpreter = self.interpretation
|
||||||
|
interpretation = interpreter(*processed_input)
|
||||||
|
if len(raw_input) == 1:
|
||||||
|
interpretation = [interpretation]
|
||||||
|
return interpretation
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
||||||
|
@ -18,7 +18,6 @@ from shutil import copyfile
|
|||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = int(os.getenv(
|
INITIAL_PORT_VALUE = int(os.getenv(
|
||||||
@ -87,7 +86,11 @@ def main():
|
|||||||
|
|
||||||
@app.route("/static/<path:path>")
|
@app.route("/static/<path:path>")
|
||||||
def static(path):
|
def static(path):
|
||||||
return send_file(os.path.join(STATIC_PATH_LIB, path))
|
path = os.path.join(STATIC_PATH_LIB, path)
|
||||||
|
if os.path.exists(path):
|
||||||
|
return send_file()
|
||||||
|
else:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/config/", methods=["GET"])
|
@app.route("/config/", methods=["GET"])
|
||||||
@ -150,23 +153,7 @@ def flag():
|
|||||||
@app.route("/api/interpret/", methods=["POST"])
|
@app.route("/api/interpret/", methods=["POST"])
|
||||||
def interpret():
|
def interpret():
|
||||||
raw_input = request.json["data"]
|
raw_input = request.json["data"]
|
||||||
if app.interface.interpretation == "default":
|
interpretation = app.interface.interpret(raw_input)
|
||||||
interpreter = gr.interpretation.default()
|
|
||||||
processed_input = []
|
|
||||||
for i, x in enumerate(raw_input):
|
|
||||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
|
||||||
interface_type = type(input_interface)
|
|
||||||
if interface_type in gr.interpretation.expected_types:
|
|
||||||
input_interface.type = gr.interpretation.expected_types[interface_type]
|
|
||||||
processed_input.append(input_interface.preprocess(x))
|
|
||||||
interpretation = interpreter(app.interface, processed_input)
|
|
||||||
else:
|
|
||||||
processed_input = [input_interface.preprocess(raw_input[i])
|
|
||||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
|
||||||
interpreter = app.interface.interpretation
|
|
||||||
interpretation = interpreter(*processed_input)
|
|
||||||
if len(raw_input) == 1:
|
|
||||||
interpretation = [interpretation]
|
|
||||||
return jsonify(interpretation)
|
return jsonify(interpretation)
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ import sys
|
|||||||
import weakref
|
import weakref
|
||||||
import analytics
|
import analytics
|
||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||||
analytics_url = 'https://api.gradio.app/'
|
analytics_url = 'https://api.gradio.app/'
|
||||||
@ -255,6 +256,26 @@ class Interface:
|
|||||||
processed_output = [output_interface.postprocess(
|
processed_output = [output_interface.postprocess(
|
||||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||||
return processed_output, durations
|
return processed_output, durations
|
||||||
|
|
||||||
|
def interpret(self, raw_input):
|
||||||
|
if self.interpretation == "default":
|
||||||
|
interpreter = gradio.interpretation.default()
|
||||||
|
processed_input = []
|
||||||
|
for i, x in enumerate(raw_input):
|
||||||
|
input_interface = copy.deepcopy(self.input_interfaces[i])
|
||||||
|
interface_type = type(input_interface)
|
||||||
|
if interface_type in gradio.interpretation.expected_types:
|
||||||
|
input_interface.type = gradio.interpretation.expected_types[interface_type]
|
||||||
|
processed_input.append(input_interface.preprocess(x))
|
||||||
|
interpretation = interpreter(self, processed_input)
|
||||||
|
else:
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)]
|
||||||
|
interpreter = self.interpretation
|
||||||
|
interpretation = interpreter(*processed_input)
|
||||||
|
if len(raw_input) == 1:
|
||||||
|
interpretation = [interpretation]
|
||||||
|
return interpretation
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
|
||||||
|
@ -18,7 +18,6 @@ from shutil import copyfile
|
|||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = int(os.getenv(
|
INITIAL_PORT_VALUE = int(os.getenv(
|
||||||
@ -87,7 +86,11 @@ def main():
|
|||||||
|
|
||||||
@app.route("/static/<path:path>")
|
@app.route("/static/<path:path>")
|
||||||
def static(path):
|
def static(path):
|
||||||
return send_file(os.path.join(STATIC_PATH_LIB, path))
|
path = os.path.join(STATIC_PATH_LIB, path)
|
||||||
|
if os.path.exists(path):
|
||||||
|
return send_file()
|
||||||
|
else:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/config/", methods=["GET"])
|
@app.route("/config/", methods=["GET"])
|
||||||
@ -150,23 +153,7 @@ def flag():
|
|||||||
@app.route("/api/interpret/", methods=["POST"])
|
@app.route("/api/interpret/", methods=["POST"])
|
||||||
def interpret():
|
def interpret():
|
||||||
raw_input = request.json["data"]
|
raw_input = request.json["data"]
|
||||||
if app.interface.interpretation == "default":
|
interpretation = app.interface.interpret(raw_input)
|
||||||
interpreter = gr.interpretation.default()
|
|
||||||
processed_input = []
|
|
||||||
for i, x in enumerate(raw_input):
|
|
||||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
|
||||||
interface_type = type(input_interface)
|
|
||||||
if interface_type in gr.interpretation.expected_types:
|
|
||||||
input_interface.type = gr.interpretation.expected_types[interface_type]
|
|
||||||
processed_input.append(input_interface.preprocess(x))
|
|
||||||
interpretation = interpreter(app.interface, processed_input)
|
|
||||||
else:
|
|
||||||
processed_input = [input_interface.preprocess(raw_input[i])
|
|
||||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
|
||||||
interpreter = app.interface.interpretation
|
|
||||||
interpretation = interpreter(*processed_input)
|
|
||||||
if len(raw_input) == 1:
|
|
||||||
interpretation = [interpretation]
|
|
||||||
return jsonify(interpretation)
|
return jsonify(interpretation)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user