mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
text and image interpretation first pass
This commit is contained in:
parent
f4a251989b
commit
d73b4ec1c8
59
gradio/explain.py
Normal file
59
gradio/explain.py
Normal file
@ -0,0 +1,59 @@
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.inputs import Image
|
||||
|
||||
from skimage.color import rgb2gray
|
||||
from skimage.filters import sobel
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from skimage import io
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
leave_one_out_tokens.append(tokens)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(image):
|
||||
img = img_as_float(image[::2, ::2])
|
||||
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = np.copy(img)
|
||||
mask[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append(mask)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score(outputs):
|
||||
print(outputs)
|
||||
|
||||
def simple_explanation(interface, input_interfaces,
|
||||
output_interfaces, input):
|
||||
if isinstance(input_interfaces[0], Textbox):
|
||||
leave_one_out_tokens = tokenize_text(input[0])
|
||||
outputs = []
|
||||
for input_text in leave_one_out_tokens:
|
||||
input_text = " ".join(input_text)
|
||||
print("Input Text: ", input_text)
|
||||
output = interface.process(input_text)
|
||||
outputs.extend(output)
|
||||
print("Output: ", output)
|
||||
score(outputs)
|
||||
|
||||
elif isinstance(input_interfaces[0], Image):
|
||||
leave_one_out_tokens = tokenize_image(input[0])
|
||||
outputs = []
|
||||
for input_text in leave_one_out_tokens:
|
||||
input_text = " ".join(input_text)
|
||||
print("Input Text: ", input_text)
|
||||
output = interface.process(input_text)
|
||||
outputs.extend(output)
|
||||
print("Output: ", output)
|
||||
score(outputs)
|
||||
else:
|
||||
print("Not valid input type")
|
@ -7,9 +7,13 @@ import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
from gradio import networking, strings, utils, processing_utils
|
||||
from distutils.version import StrictVersion
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
@ -20,6 +24,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
@ -46,7 +51,8 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
capture_session=False, explain_by="default", title=None,
|
||||
description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
@ -110,6 +116,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.explain_by = explain_by
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -190,7 +197,6 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
@ -414,6 +420,84 @@ class Interface:
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
def tokenize_text(self, text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
leave_one_out_tokens.append(tokens)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(self, image):
|
||||
image = self.input_interfaces[0].preprocess(image)
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = np.copy(image)
|
||||
mask[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append(mask)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(self, leave_one_out_tokens, text):
|
||||
original_label = ""
|
||||
original_confidence = 0
|
||||
tokens = text.split()
|
||||
outputs = {}
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
output = self.process(input_text)
|
||||
if idx == 0:
|
||||
original_label = output[0][0]['confidences'][0][
|
||||
'label']
|
||||
original_confidence = output[0][0]['confidences'][0][
|
||||
'confidence']
|
||||
else:
|
||||
if output[0][0]['confidences'][0]['label'] == original_label:
|
||||
outputs[tokens[idx-1]] = original_confidence - output[0][0]
|
||||
['confidences'][0]['confidence']
|
||||
return outputs
|
||||
|
||||
def score_image(self, leave_one_out_tokens, image):
|
||||
original_output = self.process(image)
|
||||
original_label = original_output[0][0]['confidences'][0][
|
||||
'label']
|
||||
original_confidence = original_output[0][0]['confidences'][0][
|
||||
'confidence']
|
||||
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
|
||||
image[0])), 255)
|
||||
for input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
input_image_arr = []
|
||||
input_image_arr.append(input_image_base64)
|
||||
output = self.process(input_image_arr)
|
||||
np.set_printoptions(threshold=sys.maxsize)
|
||||
if output[0][0]['confidences'][0]['label'] == original_label:
|
||||
input_image[input_image == 255] = (original_confidence -
|
||||
output[0][0][
|
||||
'confidences'][0][
|
||||
'confidence']) * 100
|
||||
mask = (output_scores == 255)
|
||||
output_scores[mask] = input_image[mask]
|
||||
return output_scores
|
||||
|
||||
def simple_explanation(self, input):
|
||||
if isinstance(self.input_interfaces[0], Textbox):
|
||||
leave_one_out_tokens = self.tokenize_text(input[0])
|
||||
return self.score_text(leave_one_out_tokens, input[0])
|
||||
elif isinstance(self.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = self.tokenize_image(input[0])
|
||||
return self.score_image(leave_one_out_tokens, input)
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
def explain(self, input):
|
||||
if self.explain_by == "default":
|
||||
self.simple_explanation(input)
|
||||
else:
|
||||
self.explain_by(input)
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
|
Loading…
Reference in New Issue
Block a user