mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
fixed bugs & added custom embedding fn
This commit is contained in:
parent
0eab6a27a4
commit
efe292cda9
@ -15,7 +15,7 @@ def recognize_digit(image):
|
||||
prediction = model.predict(image).tolist()[0]
|
||||
return {str(i): prediction[i] for i in range(10)}
|
||||
|
||||
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
|
||||
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False)
|
||||
|
||||
io = gr.Interface(
|
||||
recognize_digit,
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 1.7 KiB |
Binary file not shown.
Before Width: | Height: | Size: 3.6 KiB |
Binary file not shown.
Before Width: | Height: | Size: 1.7 KiB |
Binary file not shown.
Before Width: | Height: | Size: 1.6 KiB |
File diff suppressed because one or more lines are too long
@ -123,6 +123,7 @@ class Interface:
|
||||
self.analytics_enabled=analytics_enabled
|
||||
self.save_to = None
|
||||
self.share = None
|
||||
self.embedding_fn = embedding_fn
|
||||
|
||||
data = {'fn': fn,
|
||||
'inputs': inputs,
|
||||
@ -243,14 +244,12 @@ class Interface:
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
|
||||
def embedding_fn(self, raw_input):
|
||||
if self.interpretation == "default":
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
def embed(self, processed_input):
|
||||
if self.embedding_fn == "default":
|
||||
embedding = np.concatenate([input_interface.embed(processed_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)])
|
||||
else:
|
||||
raise NotImplementedError("Only default embedding is currently supported")
|
||||
embedding = self.embedding_fn(processed_input)
|
||||
return embedding
|
||||
|
||||
def interpret(self, raw_input):
|
||||
|
@ -124,12 +124,14 @@ def predict():
|
||||
@app.route("/api/score_similarity/", methods=["POST"])
|
||||
def score_similarity():
|
||||
raw_input = request.json["data"]
|
||||
input_embedding = app.interface.embedding_fn(raw_input)
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
input_embedding = app.interface.embed(preprocessed_input)
|
||||
scores = list()
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess_example(example)
|
||||
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embedding_fn(preprocessed_example)
|
||||
example_embedding = app.interface.embed(preprocessed_example)
|
||||
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||
return jsonify({"data": scores})
|
||||
|
||||
|
@ -47,7 +47,7 @@ var io_master_template = {
|
||||
console.log(output.data)
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
let html = "<th>DIFFS</th>"
|
||||
let html = "<th>SIMILARITY</th>"
|
||||
this.target.find(".examples > table > thead > tr").append(html);
|
||||
for (let i = 0; i < output["data"].length; i++) {
|
||||
let html = "<td>" + output["data"][i] + "</td>"
|
||||
|
@ -318,7 +318,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
})
|
||||
target.find(".interpret").click(function() {
|
||||
target.find(".interpretation_explained").removeClass("invisible");
|
||||
// target.find(".interpretation_explained").removeClass("invisible");
|
||||
if (io_master.last_output) {
|
||||
io_master.score_similarity();
|
||||
// io_master.interpret(); // TODO(UNDO)
|
||||
|
Loading…
Reference in New Issue
Block a user