mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
move interpretation code to separate file
This commit is contained in:
parent
7e264abaad
commit
c8b6dec050
@ -40,7 +40,7 @@ label = gr.outputs.Label(num_top_classes=3)
|
||||
|
||||
gr.Interface(image_classifier, imagein, label,
|
||||
capture_session=True,
|
||||
explain_by="default",
|
||||
interpret_by="default",
|
||||
examples=[
|
||||
["images/cheetah1.jpg"],
|
||||
["images/lion.jpg"]
|
||||
|
@ -9,7 +9,7 @@ def sentiment_analysis(text):
|
||||
del scores["compound"]
|
||||
return scores
|
||||
|
||||
io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default")
|
||||
io = gr.Interface(sentiment_analysis, "textbox", "label", interpret_by="default")
|
||||
|
||||
io.test_launch()
|
||||
io.launch()
|
@ -7,15 +7,9 @@ import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils, processing_utils
|
||||
from gradio import networking, strings, utils
|
||||
from distutils.version import StrictVersion
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from gradio import processing_utils
|
||||
import PIL
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
@ -26,7 +20,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
@ -53,7 +47,8 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, explain_by=None, title=None, description=None,
|
||||
capture_session=False, interpret_by=None, title=None,
|
||||
description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
@ -117,7 +112,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.explain_by = explain_by
|
||||
self.interpret_by = interpret_by
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -186,7 +181,7 @@ class Interface:
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.explain_by is not None
|
||||
"allow_interpretation": self.interpret_by is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -422,95 +417,6 @@ class Interface:
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
def tokenize_text(self, text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return tokens, leave_one_out_tokens
|
||||
|
||||
def tokenize_image(self, image):
|
||||
image = np.array(processing_utils.decode_base64_to_image(image))
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||
original_label = ""
|
||||
original_confidence = 0
|
||||
tokens = text.split()
|
||||
|
||||
input_text = " ".join(tokens)
|
||||
original_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
scores = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
raw_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
scores.append(original_confidence - output[original_label])
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores[idx]))
|
||||
return scores_by_char
|
||||
|
||||
def score_image(self, leave_one_out_tokens, image):
|
||||
original_output = self.process([image])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
image_interface = self.input_interfaces[0]
|
||||
shape = processing_utils.decode_base64_to_image(image).size
|
||||
output_scores = np.full((shape[1], shape[0]), 0.0)
|
||||
|
||||
for mask, input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
raw_output = self.process([input_image_base64])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
score = original_confidence - output[original_label]
|
||||
output_scores += score * mask
|
||||
max_val = np.max(np.abs(output_scores))
|
||||
if max_val > 0:
|
||||
output_scores = output_scores / max_val
|
||||
return output_scores.tolist()
|
||||
|
||||
def simple_explanation(self, x):
|
||||
if isinstance(self.input_interfaces[0], Textbox):
|
||||
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
|
||||
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
|
||||
elif isinstance(self.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = self.tokenize_image(x[0])
|
||||
return [self.score_image(leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
def explain(self, x):
|
||||
if self.explain_by == "default":
|
||||
return self.simple_explanation(x)
|
||||
else:
|
||||
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
|
||||
return self.explain_by(*preprocessed_x)
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
99
gradio/interpretation.py
Normal file
99
gradio/interpretation.py
Normal file
@ -0,0 +1,99 @@
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio import processing_utils
|
||||
from skimage.segmentation import slic
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return tokens, leave_one_out_tokens
|
||||
|
||||
|
||||
def tokenize_image(image):
|
||||
image = np.array(processing_utils.decode_base64_to_image(image))
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
|
||||
def score_text(interface, leave_one_out_tokens, text):
|
||||
tokens = text.split()
|
||||
|
||||
input_text = " ".join(tokens)
|
||||
original_output = interface.process([input_text])
|
||||
output = {result["label"]: result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
scores = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
raw_output = interface.process([input_text])
|
||||
output = {result["label"]: result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
scores.append(original_confidence - output[original_label])
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores[idx]))
|
||||
return scores_by_char
|
||||
|
||||
|
||||
def score_image(interface, leave_one_out_tokens, image):
|
||||
original_output = interface.process([image])
|
||||
output = {result["label"]: result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
shape = processing_utils.decode_base64_to_image(image).size
|
||||
output_scores = np.full((shape[1], shape[0]), 0.0)
|
||||
|
||||
for mask, input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
raw_output = interface.process([input_image_base64])
|
||||
output = {result["label"]: result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
score = original_confidence - output[original_label]
|
||||
output_scores += score * mask
|
||||
max_val = np.max(np.abs(output_scores))
|
||||
if max_val > 0:
|
||||
output_scores = output_scores / max_val
|
||||
return output_scores.tolist()
|
||||
|
||||
|
||||
def simple_interpretation(interface, x):
|
||||
if isinstance(interface.input_interfaces[0], Textbox):
|
||||
tokens, leave_one_out_tokens = tokenize_text(interface,
|
||||
x[0])
|
||||
return [score_text(interface, tokens, leave_one_out_tokens, x[0])]
|
||||
elif isinstance(interface.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = tokenize_image(x[0])
|
||||
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
|
||||
def interpret(interface, x):
|
||||
if interface.interpret_by == "default":
|
||||
return simple_interpretation(interface, x)
|
||||
else:
|
||||
preprocessed_x = [input_interface(x_i) for x_i, input_interface in
|
||||
zip(x, interface.input_interfaces)]
|
||||
return interface.interpret_by(*preprocessed_x)
|
@ -12,6 +12,7 @@ from gradio import inputs, outputs
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
from gradio.interpretation import interpret
|
||||
import urllib.request
|
||||
from shutil import copyfile
|
||||
import requests
|
||||
@ -199,7 +200,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
interpretation = interface.explain(msg["data"])
|
||||
interpretation = interpret(interface, msg["data"])
|
||||
self.wfile.write(json.dumps(interpretation).encode())
|
||||
else:
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
|
Loading…
Reference in New Issue
Block a user