mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
new method
This commit is contained in:
parent
ef0caef720
commit
6f9815608f
@ -232,71 +232,6 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
#TODO(abidlabs): clean this up
|
||||
elif self.path == "/api/auto/rotation":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': f'rotation by {deg} degrees'}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
elif self.path == "/api/auto/lighting":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': f'brighting adjustment by a factor of {i}'}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
else:
|
||||
self.send_error(404, 'Path not found: %s' % self.path)
|
||||
|
||||
|
@ -20,14 +20,6 @@ class AbstractOutput(ABC):
|
||||
When this is subclassed, it is automatically added to the registry
|
||||
"""
|
||||
|
||||
def __init__(self, postprocessing_fn=None):
|
||||
"""
|
||||
:param postprocessing_fn: an optional postprocessing function that overrides the default
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self.postprocess = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
def get_js_context(self):
|
||||
"""
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
@ -40,6 +32,12 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
"""
|
||||
return prediction
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
"""
|
||||
@ -47,13 +45,6 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
All interfaces should define a default postprocessing method
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
@ -61,63 +52,36 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
import operator
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = 'label'
|
||||
CONFIDENCES_KEY = 'confidences'
|
||||
CONFIDENCE_KEY = 'confidence'
|
||||
|
||||
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
|
||||
max_label_length=None, max_label_words=None, word_delimiter=" "):
|
||||
def __init__(self, num_top_classes=3, show_confidences=True):
|
||||
self.num_top_classes = num_top_classes
|
||||
self.show_confidences = show_confidences
|
||||
self.label_names = label_names
|
||||
self.max_label_length = max_label_length
|
||||
self.max_label_words = max_label_words
|
||||
self.word_delimiter = word_delimiter
|
||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||
super().__init__()
|
||||
|
||||
def get_name(self):
|
||||
return 'label'
|
||||
|
||||
def get_label_name(self, label):
|
||||
if self.label_names is None:
|
||||
name = label
|
||||
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
|
||||
name = imagenet_class_labels.NAMES1000[label]
|
||||
else: # if list or dictionary
|
||||
name = self.label_names[label]
|
||||
if self.max_label_words is not None:
|
||||
name = name.split(self.word_delimiter)[:self.max_label_words]
|
||||
name = self.word_delimiter.join(name)
|
||||
if self.max_label_length is not None:
|
||||
name = name[:self.max_label_length]
|
||||
return name
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
response = dict()
|
||||
# TODO(abidlabs): check if list, if so convert to numpy array
|
||||
if isinstance(prediction, np.ndarray):
|
||||
prediction = prediction.squeeze()
|
||||
if prediction.size == 1: # if it's single value
|
||||
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
|
||||
elif len(prediction.shape) == 1: # if a 1D
|
||||
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
|
||||
if self.show_confidences:
|
||||
response[Label.CONFIDENCES_KEY] = []
|
||||
for i in range(self.num_top_classes):
|
||||
response[Label.CONFIDENCES_KEY].append({
|
||||
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
|
||||
Label.CONFIDENCE_KEY: float(prediction.max()),
|
||||
})
|
||||
prediction[prediction.argmax()] = 0
|
||||
elif isinstance(prediction, str):
|
||||
response[Label.LABEL_KEY] = prediction
|
||||
if isinstance(prediction, str):
|
||||
return {"label": str}
|
||||
elif isinstance(prediction, dict):
|
||||
sorted_pred = sorted(
|
||||
prediction.items(),
|
||||
key=operator.itemgetter(1),
|
||||
reverse=True
|
||||
)
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
"label": pred[0],
|
||||
"confidence" : pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unable to post-process model prediction.")
|
||||
return json.dumps(response)
|
||||
raise ValueError("Function output should be string or dict")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
|
@ -8,10 +8,9 @@ const label_output = {
|
||||
`,
|
||||
init: function() {},
|
||||
output: function(data) {
|
||||
data = JSON.parse(data)
|
||||
this.target.find(".output_class").html(data["label"])
|
||||
this.target.find(".confidence_intervals > div").empty()
|
||||
if (data.confidences) {
|
||||
if ("confidences" in data) {
|
||||
for (var i = 0; i < data.confidences.length; i++)
|
||||
{
|
||||
let c = data.confidences[i]
|
||||
|
34
demo/digit_classifier.py
Normal file
34
demo/digit_classifier.py
Normal file
@ -0,0 +1,34 @@
|
||||
import tensorflow as tf
|
||||
import gradio
|
||||
from tensorflow.keras.layers import *
|
||||
import gradio as gr
|
||||
|
||||
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||
x_train, x_test = x_train.reshape(-1,784) / 255.0, x_test.reshape(-1,784) / 255.0
|
||||
|
||||
def get_trained_model(n):
|
||||
model = tf.keras.models.Sequential()
|
||||
model.add(Reshape((28, 28, 1), input_shape=(784,)))
|
||||
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
|
||||
model.add(Conv2D(64, (3, 3), activation='relu'))
|
||||
model.add(MaxPooling2D(pool_size=(2, 2)))
|
||||
model.add(Dropout(0.25))
|
||||
model.add(Flatten())
|
||||
model.add(Dense(128, activation='relu'))
|
||||
model.add(Dropout(0.5))
|
||||
model.add(Dense(10, activation='softmax'))
|
||||
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
||||
model.fit(x_train[:n], y_train[:n], epochs=2)
|
||||
print(model.evaluate(x_test, y_test))
|
||||
return model
|
||||
|
||||
model = get_trained_model(n=50000)
|
||||
|
||||
def recognize_digit(image):
|
||||
return {
|
||||
"5": 0.6,
|
||||
"4": 0.12,
|
||||
"6": 0.1
|
||||
}
|
||||
|
||||
gr.Interface(recognize_digit, "sketchpad", "label").launch()
|
@ -232,71 +232,6 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
#TODO(abidlabs): clean this up
|
||||
elif self.path == "/api/auto/rotation":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': f'rotation by {deg} degrees'}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
elif self.path == "/api/auto/lighting":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
|
||||
flag_dir = os.path.join(directory_to_serve, FLAGGING_DIRECTORY)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
'message': f'brighting adjustment by a factor of {i}'}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps({}).encode())
|
||||
|
||||
else:
|
||||
self.send_error(404, 'Path not found: %s' % self.path)
|
||||
|
||||
|
@ -20,14 +20,6 @@ class AbstractOutput(ABC):
|
||||
When this is subclassed, it is automatically added to the registry
|
||||
"""
|
||||
|
||||
def __init__(self, postprocessing_fn=None):
|
||||
"""
|
||||
:param postprocessing_fn: an optional postprocessing function that overrides the default
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self.postprocess = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
def get_js_context(self):
|
||||
"""
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
@ -40,6 +32,12 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
"""
|
||||
return prediction
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
"""
|
||||
@ -47,13 +45,6 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
All interfaces should define a default postprocessing method
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
@ -61,63 +52,36 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
import operator
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = 'label'
|
||||
CONFIDENCES_KEY = 'confidences'
|
||||
CONFIDENCE_KEY = 'confidence'
|
||||
|
||||
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
|
||||
max_label_length=None, max_label_words=None, word_delimiter=" "):
|
||||
def __init__(self, num_top_classes=3, show_confidences=True):
|
||||
self.num_top_classes = num_top_classes
|
||||
self.show_confidences = show_confidences
|
||||
self.label_names = label_names
|
||||
self.max_label_length = max_label_length
|
||||
self.max_label_words = max_label_words
|
||||
self.word_delimiter = word_delimiter
|
||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||
super().__init__()
|
||||
|
||||
def get_name(self):
|
||||
return 'label'
|
||||
|
||||
def get_label_name(self, label):
|
||||
if self.label_names is None:
|
||||
name = label
|
||||
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
|
||||
name = imagenet_class_labels.NAMES1000[label]
|
||||
else: # if list or dictionary
|
||||
name = self.label_names[label]
|
||||
if self.max_label_words is not None:
|
||||
name = name.split(self.word_delimiter)[:self.max_label_words]
|
||||
name = self.word_delimiter.join(name)
|
||||
if self.max_label_length is not None:
|
||||
name = name[:self.max_label_length]
|
||||
return name
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
response = dict()
|
||||
# TODO(abidlabs): check if list, if so convert to numpy array
|
||||
if isinstance(prediction, np.ndarray):
|
||||
prediction = prediction.squeeze()
|
||||
if prediction.size == 1: # if it's single value
|
||||
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
|
||||
elif len(prediction.shape) == 1: # if a 1D
|
||||
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
|
||||
if self.show_confidences:
|
||||
response[Label.CONFIDENCES_KEY] = []
|
||||
for i in range(self.num_top_classes):
|
||||
response[Label.CONFIDENCES_KEY].append({
|
||||
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
|
||||
Label.CONFIDENCE_KEY: float(prediction.max()),
|
||||
})
|
||||
prediction[prediction.argmax()] = 0
|
||||
elif isinstance(prediction, str):
|
||||
response[Label.LABEL_KEY] = prediction
|
||||
if isinstance(prediction, str):
|
||||
return {"label": str}
|
||||
elif isinstance(prediction, dict):
|
||||
sorted_pred = sorted(
|
||||
prediction.items(),
|
||||
key=operator.itemgetter(1),
|
||||
reverse=True
|
||||
)
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
"label": pred[0],
|
||||
"confidence" : pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unable to post-process model prediction.")
|
||||
return json.dumps(response)
|
||||
raise ValueError("Function output should be string or dict")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
|
@ -8,10 +8,9 @@ const label_output = {
|
||||
`,
|
||||
init: function() {},
|
||||
output: function(data) {
|
||||
data = JSON.parse(data)
|
||||
this.target.find(".output_class").html(data["label"])
|
||||
this.target.find(".confidence_intervals > div").empty()
|
||||
if (data.confidences) {
|
||||
if ("confidences" in data) {
|
||||
for (var i = 0; i < data.confidences.length; i++)
|
||||
{
|
||||
let c = data.confidences[i]
|
||||
|
Loading…
Reference in New Issue
Block a user