mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
got the class probs to be working!
This commit is contained in:
parent
e25730f542
commit
8403c86949
File diff suppressed because one or more lines are too long
@ -117,7 +117,10 @@ class ImageUpload(AbstractInput):
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
if self.num_channels is None:
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
|
||||
else:
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -102,7 +102,6 @@ class Interface:
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
print('>>>>>>>>>msg', msg)
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface.postprocess(prediction)
|
||||
|
@ -6,7 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
|
||||
import json
|
||||
|
||||
class AbstractOutput(ABC):
|
||||
"""
|
||||
@ -60,20 +60,21 @@ class Label(AbstractOutput):
|
||||
if prediction.size == 1: # if it's single value
|
||||
response[Label.LABEL_KEY] = np.asscalar(prediction)
|
||||
elif len(prediction.shape) == 1: # if a 1D
|
||||
response[Label.LABEL_KEY] = prediction.argmax()
|
||||
response[Label.LABEL_KEY] = int(prediction.argmax())
|
||||
if self.show_confidences:
|
||||
response[Label.CONFIDENCES_KEY] = []
|
||||
for i in range(self.num_top_classes):
|
||||
response[Label.CONFIDENCES_KEY].append({
|
||||
Label.LABEL_KEY: prediction.argmax(),
|
||||
Label.CONFIDENCE_KEY: prediction.max(),
|
||||
Label.LABEL_KEY: int(prediction.argmax()),
|
||||
Label.CONFIDENCE_KEY: float(prediction.max()),
|
||||
})
|
||||
prediction[prediction.argmax()] = 0
|
||||
elif isinstance(prediction, str):
|
||||
response[Label.LABEL_KEY] = prediction
|
||||
else:
|
||||
raise ValueError("Unable to post-process model prediction.")
|
||||
return response
|
||||
print(response)
|
||||
return json.dumps(response)
|
||||
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
|
@ -26,7 +26,7 @@ try {
|
||||
ws.onmessage = function (event) {
|
||||
sleep(300).then(() => {
|
||||
// $(".output_class").text(event.data);
|
||||
var data = event.data
|
||||
var data = JSON.parse(event.data)
|
||||
$(".output_class").text(data["label"])
|
||||
$(".confidence_intervals").empty()
|
||||
if ("confidences" in data) {
|
||||
|
@ -2,6 +2,7 @@ import numpy as np
|
||||
import unittest
|
||||
import os
|
||||
from gradio import outputs
|
||||
import json
|
||||
|
||||
PACKAGE_NAME = 'gradio'
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user