mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
fixed interpretation test
This commit is contained in:
parent
19cf1e7153
commit
29585482f7
@ -98,9 +98,6 @@ class TestTextbox(unittest.TestCase):
|
||||
"number",
|
||||
interpretation="default",
|
||||
)
|
||||
print(iface.interpret(
|
||||
["Return the length of the longest word in this sentence"]
|
||||
))
|
||||
scores = iface.interpret(
|
||||
["Return the length of the longest word in this sentence"]
|
||||
)[0]["interpretation"]
|
||||
@ -185,32 +182,17 @@ class TestNumber(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
scores = iface.interpret([2])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
[3.7636],
|
||||
[3.8415999999999997],
|
||||
[3.9204],
|
||||
[4.0804],
|
||||
[4.1616],
|
||||
[4.2436],
|
||||
]
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
],
|
||||
)
|
||||
|
||||
@ -223,32 +205,17 @@ class TestNumber(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
scores = iface.interpret([2])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
[3.7636],
|
||||
[3.8415999999999997],
|
||||
[3.9204],
|
||||
[4.0804],
|
||||
[4.1616],
|
||||
[4.2436],
|
||||
]
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
],
|
||||
)
|
||||
|
||||
@ -296,35 +263,18 @@ class TestSlider(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "slider", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
scores = iface.interpret([2])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
-4.0,
|
||||
200.08163265306123,
|
||||
812.3265306122449,
|
||||
1832.7346938775513,
|
||||
3261.3061224489797,
|
||||
5098.040816326531,
|
||||
7342.938775510205,
|
||||
9996.0,
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
[0.0],
|
||||
[204.08163265306123],
|
||||
[816.3265306122449],
|
||||
[1836.7346938775513],
|
||||
[3265.3061224489797],
|
||||
[5102.040816326531],
|
||||
[7346.938775510205],
|
||||
[10000.0],
|
||||
]
|
||||
-4.0,
|
||||
200.08163265306123,
|
||||
812.3265306122449,
|
||||
1832.7346938775513,
|
||||
3261.3061224489797,
|
||||
5098.040816326531,
|
||||
7342.938775510205,
|
||||
9996.0,
|
||||
],
|
||||
)
|
||||
|
||||
@ -365,12 +315,10 @@ class TestCheckbox(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
lambda x: 1 if x else 0, "checkbox", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([False])
|
||||
self.assertEqual(scores, [(None, 1.0)])
|
||||
self.assertEqual(alternative_outputs, [[[1]]])
|
||||
scores, alternative_outputs = iface.interpret([True])
|
||||
self.assertEqual(scores, [(-1.0, None)])
|
||||
self.assertEqual(alternative_outputs, [[[0]]])
|
||||
scores = iface.interpret([False])[0]["interpretation"]
|
||||
self.assertEqual(scores, (None, 1.0))
|
||||
scores = iface.interpret([True])[0]["interpretation"]
|
||||
self.assertEqual(scores, (-1.0, None))
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
@ -465,9 +413,8 @@ class TestRadio(unittest.TestCase):
|
||||
lambda x: 2 * x, radio_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
scores, alternative_outputs = iface.interpret(["b"])
|
||||
self.assertEqual(scores, [[-2.0, None, 2.0]])
|
||||
self.assertEqual(alternative_outputs, [[[0], [4]]])
|
||||
scores = iface.interpret(["b"])[0]["interpretation"]
|
||||
self.assertEqual(scores, [-2.0, None, 2.0])
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
@ -597,29 +544,17 @@ class TestImage(unittest.TestCase):
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
scores = iface.interpret([img])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
scores, deepcopy(media_data.SUM_PIXELS_INTERPRETATION)["scores"]
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
deepcopy(media_data.SUM_PIXELS_INTERPRETATION)["alternative_outputs"],
|
||||
scores, deepcopy(media_data.SUM_PIXELS_INTERPRETATION)["scores"][0]
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "label", interpretation="shap"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
scores = iface.interpret([img])[0]["interpretation"]
|
||||
self.assertEqual(
|
||||
len(scores[0]),
|
||||
len(deepcopy(media_data.SUM_PIXELS_SHAP_INTERPRETATION)["scores"][0]),
|
||||
)
|
||||
self.assertEqual(
|
||||
len(alternative_outputs[0]),
|
||||
len(
|
||||
deepcopy(media_data.SUM_PIXELS_SHAP_INTERPRETATION)[
|
||||
"alternative_outputs"
|
||||
][0]
|
||||
),
|
||||
len(deepcopy(media_data.SUM_PIXELS_SHAP_INTERPRETATION)["scores"][0][0]),
|
||||
)
|
||||
image_input = gr.Image(shape=(30, 10))
|
||||
iface = gr.Interface(
|
||||
|
Loading…
Reference in New Issue
Block a user