got the class probs to be working!

This commit is contained in:
Abubakar Abid 2019-03-05 23:23:04 -08:00
parent e25730f542
commit 8403c86949
6 changed files with 80 additions and 77 deletions

File diff suppressed because one or more lines are too long

View File

@ -117,7 +117,10 @@ class ImageUpload(AbstractInput):
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
if self.num_channels is None:
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
else:
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array

View File

@ -102,7 +102,6 @@ class Interface:
while True:
try:
msg = await websocket.recv()
print('>>>>>>>>>msg', msg)
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
processed_output = self.output_interface.postprocess(prediction)

View File

@ -6,7 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
from abc import ABC, abstractmethod
import numpy as np
import json
class AbstractOutput(ABC):
"""
@ -60,20 +60,21 @@ class Label(AbstractOutput):
if prediction.size == 1: # if it's single value
response[Label.LABEL_KEY] = np.asscalar(prediction)
elif len(prediction.shape) == 1: # if a 1D
response[Label.LABEL_KEY] = prediction.argmax()
response[Label.LABEL_KEY] = int(prediction.argmax())
if self.show_confidences:
response[Label.CONFIDENCES_KEY] = []
for i in range(self.num_top_classes):
response[Label.CONFIDENCES_KEY].append({
Label.LABEL_KEY: prediction.argmax(),
Label.CONFIDENCE_KEY: prediction.max(),
Label.LABEL_KEY: int(prediction.argmax()),
Label.CONFIDENCE_KEY: float(prediction.max()),
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
return response
print(response)
return json.dumps(response)
class Textbox(AbstractOutput):

View File

@ -26,7 +26,7 @@ try {
ws.onmessage = function (event) {
sleep(300).then(() => {
// $(".output_class").text(event.data);
var data = event.data
var data = JSON.parse(event.data)
$(".output_class").text(data["label"])
$(".confidence_intervals").empty()
if ("confidences" in data) {

View File

@ -2,6 +2,7 @@ import numpy as np
import unittest
import os
from gradio import outputs
import json
PACKAGE_NAME = 'gradio'