Merge branch 'dawood/hf-token-class' of https://github.com/gradio-app/gradio into dawood/hf-token-class

This commit is contained in:
dawoodkhan82 2022-03-21 14:01:24 -04:00
commit 155c403432
3 changed files with 49 additions and 2 deletions

View File

@ -4,7 +4,7 @@ import re
import requests import requests
from gradio import inputs, outputs from gradio import inputs, outputs, utils
def get_huggingface_interface(model_name, api_key, alias): def get_huggingface_interface(model_name, api_key, alias):
@ -208,7 +208,7 @@ def get_huggingface_interface(model_name, api_key, alias):
"inputs": inputs.Textbox(label="Input"), "inputs": inputs.Textbox(label="Input"),
"outputs": outputs.HighlightedText(label="Output"), "outputs": outputs.HighlightedText(label="Output"),
"preprocess": lambda x: {"inputs": x}, "preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: [(i["word"], i["entity_group"]) for i in r.json()], "postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
}, },
} }
@ -232,6 +232,10 @@ def get_huggingface_interface(model_name, api_key, alias):
response.status_code response.status_code
) )
) )
if p == "token-classification": # Handle as a special case since HF API only returns the named entities and we need the input as well
ner_groups = response.json()
input_string = params[0]
response = utils.format_ner_list(input_string, ner_groups)
output = pipeline["postprocess"](response) output = pipeline["postprocess"](response)
return output return output

View File

@ -292,3 +292,20 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
v.default if v.default is not inspect.Parameter.empty else None v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values() for v in signature.parameters.values()
] ]
def format_ner_list(input_string: str, ner_groups: Dict[str: str | int]):
if len(ner_groups) == 0:
return [(input_string, None)]
output = []
prev_end = 0
for group in ner_groups:
entity, start, end = group["entity_group"], group["start"], group["end"]
output.append((input_string[prev_end:start], None))
output.append((input_string[start:end], entity))
prev_end = end
output.append((input_string[end:], None))
return output

View File

@ -16,6 +16,7 @@ from gradio.utils import (
launch_analytics, launch_analytics,
readme_to_html, readme_to_html,
version_check, version_check,
format_ner_list,
) )
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -116,5 +117,30 @@ class TestIPAddress(unittest.TestCase):
self.assertEqual(ip, "No internet connection") self.assertEqual(ip, "No internet connection")
class TestFormatNERList(unittest.TestCase):
def test_format_ner_list_standard(self):
string = "Wolfgang lives in Berlin"
groups = [{"entity_group": "PER", "start": 0, "end": 8},
{"entity_group": "LOC", "start": 18, "end": 24}]
result = [('', None),
("Wolfgang", "PER"),
(" lives in ", None),
("Berlin", "LOC"),
('', None)]
self.assertEqual(
format_ner_list(string, groups),
result
)
def test_format_ner_list_empty(self):
string = "I live in a city"
groups = []
result = [("I live in a city", None)]
self.assertEqual(
format_ner_list(string, groups),
result
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()