mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
fixed type checking for interpretation
This commit is contained in:
parent
389ee78cb3
commit
1af481f238
@ -13,7 +13,7 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
|
|||||||
post_original_output = output_component.postprocess(original_output[0])
|
post_original_output = output_component.postprocess(original_output[0])
|
||||||
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
||||||
|
|
||||||
if type(output_component) == Label:
|
if isinstance(output_component, Label):
|
||||||
original_label = post_original_output["label"]
|
original_label = post_original_output["label"]
|
||||||
perturbed_label = post_perturbed_output["label"]
|
perturbed_label = post_perturbed_output["label"]
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
|
|||||||
score = diff(original_label, perturbed_label)
|
score = diff(original_label, perturbed_label)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
elif type(output_component) == Textbox:
|
elif isinstance(output_component, Textbox):
|
||||||
score = diff(post_original_output, post_perturbed_output)
|
score = diff(post_original_output, post_perturbed_output)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user