text and image interpretation first pass

This commit is contained in:
dawoodkhan82 2020-09-14 13:14:52 -04:00
parent f4a251989b
commit d73b4ec1c8
2 changed files with 146 additions and 3 deletions

59
gradio/explain.py Normal file
View File

@ -0,0 +1,59 @@
from gradio.inputs import Textbox
from gradio.inputs import Image
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage import io
import numpy as np
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split()
leave_one_out_tokens.append(tokens)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
img = img_as_float(image[::2, ::2])
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(img)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
return leave_one_out_tokens
def score(outputs):
print(outputs)
def simple_explanation(interface, input_interfaces,
output_interfaces, input):
if isinstance(input_interfaces[0], Textbox):
leave_one_out_tokens = tokenize_text(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
elif isinstance(input_interfaces[0], Image):
leave_one_out_tokens = tokenize_image(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
else:
print("Not valid input type")

View File

@ -7,9 +7,13 @@ import tempfile
import webbrowser
from gradio.inputs import InputComponent
from gradio.inputs import Image
from gradio.inputs import Textbox
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils
from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
import pkg_resources
import requests
import random
@ -20,6 +24,7 @@ import sys
import weakref
import analytics
import os
import numpy as np
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
@ -46,7 +51,8 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
capture_session=False, explain_by="default", title=None,
description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True):
@ -110,6 +116,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.explain_by = explain_by
self.session = None
self.server_name = server_name
self.title = title
@ -190,7 +197,6 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
return config
def process(self, raw_input, predict_fn=None):
@ -414,6 +420,84 @@ class Interface:
return httpd, path_to_local_server, share_url
def tokenize_text(self, text):
leave_one_out_tokens = []
tokens = text.split()
leave_one_out_tokens.append(tokens)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image)
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
return leave_one_out_tokens
def score_text(self, leave_one_out_tokens, text):
original_label = ""
original_confidence = 0
tokens = text.split()
outputs = {}
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
output = self.process(input_text)
if idx == 0:
original_label = output[0][0]['confidences'][0][
'label']
original_confidence = output[0][0]['confidences'][0][
'confidence']
else:
if output[0][0]['confidences'][0]['label'] == original_label:
outputs[tokens[idx-1]] = original_confidence - output[0][0]
['confidences'][0]['confidence']
return outputs
def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image)
original_label = original_output[0][0]['confidences'][0][
'label']
original_confidence = original_output[0][0]['confidences'][0][
'confidence']
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255)
for input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
input_image_arr = []
input_image_arr.append(input_image_base64)
output = self.process(input_image_arr)
np.set_printoptions(threshold=sys.maxsize)
if output[0][0]['confidences'][0]['label'] == original_label:
input_image[input_image == 255] = (original_confidence -
output[0][0][
'confidences'][0][
'confidence']) * 100
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
def simple_explanation(self, input):
if isinstance(self.input_interfaces[0], Textbox):
leave_one_out_tokens = self.tokenize_text(input[0])
return self.score_text(leave_one_out_tokens, input[0])
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0])
return self.score_image(leave_one_out_tokens, input)
else:
print("Not valid input type")
def explain(self, input):
if self.explain_by == "default":
self.simple_explanation(input)
else:
self.explain_by(input)
def reset_all():
for io in Interface.get_instances():