2022-01-21 21:44:12 +08:00
|
|
|
import os
|
2022-04-05 18:08:53 +08:00
|
|
|
from copy import deepcopy
|
2022-01-21 21:44:12 +08:00
|
|
|
|
|
|
|
import numpy as np
|
2022-08-11 06:29:14 +08:00
|
|
|
import pytest
|
|
|
|
import pytest_asyncio
|
2022-01-21 21:44:12 +08:00
|
|
|
|
2020-10-05 20:43:42 +08:00
|
|
|
import gradio.interpretation
|
2022-04-20 02:27:32 +08:00
|
|
|
from gradio import Interface, media_data
|
2022-02-10 03:52:26 +08:00
|
|
|
from gradio.processing_utils import decode_base64_to_image
|
2020-10-05 20:43:42 +08:00
|
|
|
|
2021-11-10 02:30:59 +08:00
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
2020-10-05 20:43:42 +08:00
|
|
|
|
2022-01-21 21:44:12 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
class TestDefault:
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_default_text(self):
|
2020-10-05 20:43:42 +08:00
|
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
2022-01-21 21:44:12 +08:00
|
|
|
text_interface = Interface(
|
|
|
|
max_word_len, "textbox", "label", interpretation="default"
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
interpretation = (await text_interface.interpret(["quickest brown fox"]))[0][
|
2022-04-05 06:47:51 +08:00
|
|
|
"interpretation"
|
|
|
|
]
|
2022-08-11 06:29:14 +08:00
|
|
|
assert interpretation[0][1] > 0 # Checks to see if the first word has >0 score.
|
|
|
|
assert 0 == interpretation[-1][1] # Checks to see if the last word has 0 score.
|
2022-01-21 21:44:12 +08:00
|
|
|
|
2021-10-27 06:36:12 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
class TestShapley:
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_shapley_text(self):
|
2021-10-27 06:36:12 +08:00
|
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
2022-01-21 21:44:12 +08:00
|
|
|
text_interface = Interface(
|
|
|
|
max_word_len, "textbox", "label", interpretation="shapley"
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
interpretation = (await text_interface.interpret(["quickest brown fox"]))[0][
|
2022-04-05 06:47:51 +08:00
|
|
|
"interpretation"
|
|
|
|
][0]
|
2022-08-11 06:29:14 +08:00
|
|
|
assert interpretation[1] > 0 # Checks to see if the first word has >0 score.
|
2022-01-21 21:44:12 +08:00
|
|
|
|
2020-10-05 20:43:42 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
class TestCustom:
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_custom_text(self):
|
2020-10-05 20:43:42 +08:00
|
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
|
|
|
custom = lambda text: [(char, 1) for char in text]
|
2022-01-21 21:44:12 +08:00
|
|
|
text_interface = Interface(
|
|
|
|
max_word_len, "textbox", "label", interpretation=custom
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
result = (await text_interface.interpret(["quickest brown fox"]))[0][
|
|
|
|
"interpretation"
|
|
|
|
][0]
|
|
|
|
assert result[1] == 1 # Checks to see if the first letter has score of 1.
|
2020-10-05 20:43:42 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_custom_img(self):
|
2020-10-05 20:43:42 +08:00
|
|
|
max_pixel_value = lambda img: img.max()
|
|
|
|
custom = lambda img: img.tolist()
|
2022-01-21 21:44:12 +08:00
|
|
|
img_interface = Interface(
|
|
|
|
max_pixel_value, "image", "label", interpretation=custom
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
result = (await img_interface.interpret([deepcopy(media_data.BASE64_IMAGE)]))[
|
|
|
|
0
|
|
|
|
]["interpretation"]
|
2022-01-21 21:44:12 +08:00
|
|
|
expected_result = np.asarray(
|
2022-04-05 18:08:53 +08:00
|
|
|
decode_base64_to_image(deepcopy(media_data.BASE64_IMAGE)).convert("RGB")
|
2022-01-21 21:44:12 +08:00
|
|
|
).tolist()
|
2022-08-11 06:29:14 +08:00
|
|
|
assert result == expected_result
|
2022-01-21 21:44:12 +08:00
|
|
|
|
2020-10-05 20:43:42 +08:00
|
|
|
|
2022-08-11 06:29:14 +08:00
|
|
|
class TestHelperMethods:
|
2021-10-27 06:36:12 +08:00
|
|
|
def test_diff(self):
|
|
|
|
diff = gradio.interpretation.diff(13, "2")
|
2022-08-11 06:29:14 +08:00
|
|
|
assert diff == 11
|
2021-10-27 06:36:12 +08:00
|
|
|
diff = gradio.interpretation.diff("cat", "dog")
|
2022-08-11 06:29:14 +08:00
|
|
|
assert diff == 1
|
2021-10-27 06:36:12 +08:00
|
|
|
diff = gradio.interpretation.diff("cat", "cat")
|
2022-08-11 06:29:14 +08:00
|
|
|
assert diff == 0
|
2021-10-27 06:36:12 +08:00
|
|
|
|
2022-03-19 23:20:22 +08:00
|
|
|
def test_quantify_difference_with_number(self):
|
|
|
|
iface = Interface(lambda text: text, ["textbox"], ["number"])
|
2022-03-23 06:40:36 +08:00
|
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, [4], [6])
|
2022-08-11 06:29:14 +08:00
|
|
|
assert diff == -2
|
2021-10-27 06:36:12 +08:00
|
|
|
|
|
|
|
def test_quantify_difference_with_label(self):
|
|
|
|
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
|
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
|
2022-08-11 06:29:14 +08:00
|
|
|
assert -7 == diff
|
2021-10-27 06:36:12 +08:00
|
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
|
2022-08-11 06:29:14 +08:00
|
|
|
assert -100 == diff
|
2021-10-27 06:36:12 +08:00
|
|
|
|
|
|
|
def test_quantify_difference_with_confidences(self):
|
|
|
|
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
2022-01-21 21:44:12 +08:00
|
|
|
output_1 = {"cat": 0.9, "dog": 0.1}
|
|
|
|
output_2 = {"cat": 0.6, "dog": 0.4}
|
|
|
|
output_3 = {"cat": 0.1, "dog": 0.6}
|
|
|
|
diff = gradio.interpretation.quantify_difference_in_label(
|
|
|
|
iface, [output_1], [output_2]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 0.3 == pytest.approx(diff)
|
2022-01-21 21:44:12 +08:00
|
|
|
diff = gradio.interpretation.quantify_difference_in_label(
|
|
|
|
iface, [output_1], [output_3]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 0.8 == pytest.approx(diff)
|
2021-10-27 06:36:12 +08:00
|
|
|
|
|
|
|
def test_get_regression_value(self):
|
|
|
|
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
2022-01-21 21:44:12 +08:00
|
|
|
output_1 = {"cat": 0.9, "dog": 0.1}
|
|
|
|
output_2 = {"cat": float("nan"), "dog": 0.4}
|
|
|
|
output_3 = {"cat": 0.1, "dog": 0.6}
|
|
|
|
diff = gradio.interpretation.get_regression_or_classification_value(
|
|
|
|
iface, [output_1], [output_2]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 0 == diff
|
2022-01-21 21:44:12 +08:00
|
|
|
diff = gradio.interpretation.get_regression_or_classification_value(
|
|
|
|
iface, [output_1], [output_3]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 0.1 == pytest.approx(diff)
|
2021-10-27 06:36:12 +08:00
|
|
|
|
|
|
|
def test_get_classification_value(self):
|
|
|
|
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
2022-01-21 21:44:12 +08:00
|
|
|
diff = gradio.interpretation.get_regression_or_classification_value(
|
|
|
|
iface, ["cat"], ["test"]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 1 == diff
|
2022-01-21 21:44:12 +08:00
|
|
|
diff = gradio.interpretation.get_regression_or_classification_value(
|
|
|
|
iface, ["test"], ["test"]
|
|
|
|
)
|
2022-08-11 06:29:14 +08:00
|
|
|
assert 0 == diff
|