mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
reformat
This commit is contained in:
parent
155c403432
commit
10bb240e23
@ -232,7 +232,9 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
response.status_code
|
||||
)
|
||||
)
|
||||
if p == "token-classification": # Handle as a special case since HF API only returns the named entities and we need the input as well
|
||||
if (
|
||||
p == "token-classification"
|
||||
): # Handle as a special case since HF API only returns the named entities and we need the input as well
|
||||
ner_groups = response.json()
|
||||
input_string = params[0]
|
||||
response = utils.format_ner_list(input_string, ner_groups)
|
||||
|
@ -294,18 +294,18 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
]
|
||||
|
||||
|
||||
def format_ner_list(input_string: str, ner_groups: Dict[str: str | int]):
|
||||
def format_ner_list(input_string: str, ner_groups: Dict[str : str | int]):
|
||||
if len(ner_groups) == 0:
|
||||
return [(input_string, None)]
|
||||
|
||||
|
||||
output = []
|
||||
prev_end = 0
|
||||
|
||||
|
||||
for group in ner_groups:
|
||||
entity, start, end = group["entity_group"], group["start"], group["end"]
|
||||
output.append((input_string[prev_end:start], None))
|
||||
output.append((input_string[start:end], entity))
|
||||
prev_end = end
|
||||
|
||||
output.append((input_string[end:], None))
|
||||
return output
|
||||
|
||||
output.append((input_string[end:], None))
|
||||
return output
|
||||
|
@ -4,7 +4,7 @@ if [ -z "$(ls | grep CONTRIBUTING.md)" ]; then
|
||||
exit -1
|
||||
else
|
||||
echo "Formatting backend and tests with black and isort, also checking for standards with flake8"
|
||||
python -m black gradio test
|
||||
python -m isort --profile=black gradio test
|
||||
python -m flake8 --ignore=E731,E501,E722,W503,E126,F401,E203 gradio test
|
||||
python3 -m black gradio test
|
||||
python3 -m isort --profile=black gradio test
|
||||
python3 -m flake8 --ignore=E731,E501,E722,W503,E126,F401,E203 gradio test
|
||||
fi
|
||||
|
@ -10,13 +10,13 @@ import requests
|
||||
from gradio.utils import (
|
||||
colab_check,
|
||||
error_analytics,
|
||||
format_ner_list,
|
||||
get_local_ip_address,
|
||||
ipython_check,
|
||||
json,
|
||||
launch_analytics,
|
||||
readme_to_html,
|
||||
version_check,
|
||||
format_ner_list,
|
||||
)
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -120,26 +120,24 @@ class TestIPAddress(unittest.TestCase):
|
||||
class TestFormatNERList(unittest.TestCase):
|
||||
def test_format_ner_list_standard(self):
|
||||
string = "Wolfgang lives in Berlin"
|
||||
groups = [{"entity_group": "PER", "start": 0, "end": 8},
|
||||
{"entity_group": "LOC", "start": 18, "end": 24}]
|
||||
result = [('', None),
|
||||
("Wolfgang", "PER"),
|
||||
(" lives in ", None),
|
||||
("Berlin", "LOC"),
|
||||
('', None)]
|
||||
self.assertEqual(
|
||||
format_ner_list(string, groups),
|
||||
result
|
||||
)
|
||||
groups = [
|
||||
{"entity_group": "PER", "start": 0, "end": 8},
|
||||
{"entity_group": "LOC", "start": 18, "end": 24},
|
||||
]
|
||||
result = [
|
||||
("", None),
|
||||
("Wolfgang", "PER"),
|
||||
(" lives in ", None),
|
||||
("Berlin", "LOC"),
|
||||
("", None),
|
||||
]
|
||||
self.assertEqual(format_ner_list(string, groups), result)
|
||||
|
||||
def test_format_ner_list_empty(self):
|
||||
string = "I live in a city"
|
||||
groups = []
|
||||
result = [("I live in a city", None)]
|
||||
self.assertEqual(
|
||||
format_ner_list(string, groups),
|
||||
result
|
||||
)
|
||||
self.assertEqual(format_ner_list(string, groups), result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user