2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-04-24 13:01:18 +08:00

Added support for TokenClassificationPipeline ()

* fix for 8628

* formatting update

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
cswamy 2024-07-24 21:08:15 +01:00 committed by GitHub
parent f8ccb5e663
commit 70a0c56200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 3 deletions

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Added support for TokenClassificationPipeline

@ -48,16 +48,20 @@ def load_from_pipeline(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
pipelines.token_classification.TokenClassificationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data) # type: ignore
# special case for object-detection
# original input image sent to postprocess function
# special case for object-detection and token-classification pipelines
# original input image / text sent to postprocess function
if isinstance(
pipeline,
pipelines.object_detection.ObjectDetectionPipeline,
(
pipelines.object_detection.ObjectDetectionPipeline,
pipelines.token_classification.TokenClassificationPipeline,
),
):
output = pipeline_info["postprocess"](data, params[0])
else:

@ -89,6 +89,16 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
}
if is_transformers_pipeline_type(pipeline, "TokenClassificationPipeline"):
return {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.HighlightedText(label="Entities", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r, text: {
"text": text,
"entities": r,
},
}
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
return {
"inputs": components.Textbox(label="Input", render=False),