mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
Merge branch 'dawood/hf-token-class' of https://github.com/gradio-app/gradio into dawood/hf-token-class
This commit is contained in:
commit
155c403432
@ -4,7 +4,7 @@ import re
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from gradio import inputs, outputs
|
from gradio import inputs, outputs, utils
|
||||||
|
|
||||||
|
|
||||||
def get_huggingface_interface(model_name, api_key, alias):
|
def get_huggingface_interface(model_name, api_key, alias):
|
||||||
@ -208,7 +208,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
|||||||
"inputs": inputs.Textbox(label="Input"),
|
"inputs": inputs.Textbox(label="Input"),
|
||||||
"outputs": outputs.HighlightedText(label="Output"),
|
"outputs": outputs.HighlightedText(label="Output"),
|
||||||
"preprocess": lambda x: {"inputs": x},
|
"preprocess": lambda x: {"inputs": x},
|
||||||
"postprocess": lambda r: [(i["word"], i["entity_group"]) for i in r.json()],
|
"postprocess": lambda r: r, # Handled as a special case in query_huggingface_api()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,6 +232,10 @@ def get_huggingface_interface(model_name, api_key, alias):
|
|||||||
response.status_code
|
response.status_code
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if p == "token-classification": # Handle as a special case since HF API only returns the named entities and we need the input as well
|
||||||
|
ner_groups = response.json()
|
||||||
|
input_string = params[0]
|
||||||
|
response = utils.format_ner_list(input_string, ner_groups)
|
||||||
output = pipeline["postprocess"](response)
|
output = pipeline["postprocess"](response)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -292,3 +292,20 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
|
|||||||
v.default if v.default is not inspect.Parameter.empty else None
|
v.default if v.default is not inspect.Parameter.empty else None
|
||||||
for v in signature.parameters.values()
|
for v in signature.parameters.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_ner_list(input_string: str, ner_groups: Dict[str: str | int]):
|
||||||
|
if len(ner_groups) == 0:
|
||||||
|
return [(input_string, None)]
|
||||||
|
|
||||||
|
output = []
|
||||||
|
prev_end = 0
|
||||||
|
|
||||||
|
for group in ner_groups:
|
||||||
|
entity, start, end = group["entity_group"], group["start"], group["end"]
|
||||||
|
output.append((input_string[prev_end:start], None))
|
||||||
|
output.append((input_string[start:end], entity))
|
||||||
|
prev_end = end
|
||||||
|
|
||||||
|
output.append((input_string[end:], None))
|
||||||
|
return output
|
@ -16,6 +16,7 @@ from gradio.utils import (
|
|||||||
launch_analytics,
|
launch_analytics,
|
||||||
readme_to_html,
|
readme_to_html,
|
||||||
version_check,
|
version_check,
|
||||||
|
format_ner_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
@ -116,5 +117,30 @@ class TestIPAddress(unittest.TestCase):
|
|||||||
self.assertEqual(ip, "No internet connection")
|
self.assertEqual(ip, "No internet connection")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatNERList(unittest.TestCase):
|
||||||
|
def test_format_ner_list_standard(self):
|
||||||
|
string = "Wolfgang lives in Berlin"
|
||||||
|
groups = [{"entity_group": "PER", "start": 0, "end": 8},
|
||||||
|
{"entity_group": "LOC", "start": 18, "end": 24}]
|
||||||
|
result = [('', None),
|
||||||
|
("Wolfgang", "PER"),
|
||||||
|
(" lives in ", None),
|
||||||
|
("Berlin", "LOC"),
|
||||||
|
('', None)]
|
||||||
|
self.assertEqual(
|
||||||
|
format_ner_list(string, groups),
|
||||||
|
result
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_format_ner_list_empty(self):
|
||||||
|
string = "I live in a city"
|
||||||
|
groups = []
|
||||||
|
result = [("I live in a city", None)]
|
||||||
|
self.assertEqual(
|
||||||
|
format_ner_list(string, groups),
|
||||||
|
result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user