Added support for TokenClassificationPipeline (#8888)

* fix for 8628

* formatting update

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
cswamy 2024-07-24 21:08:15 +01:00 committed by GitHub
parent f8ccb5e663
commit 70a0c56200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 3 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Added support for TokenClassificationPipeline

View File

@ -48,16 +48,20 @@ def load_from_pipeline(
pipelines.text_classification.TextClassificationPipeline, pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline, pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline, pipelines.text2text_generation.TranslationPipeline,
pipelines.token_classification.TokenClassificationPipeline,
), ),
): ):
data = pipeline(*data) data = pipeline(*data)
else: else:
data = pipeline(**data) # type: ignore data = pipeline(**data) # type: ignore
# special case for object-detection # special case for object-detection and token-classification pipelines
# original input image sent to postprocess function # original input image / text sent to postprocess function
if isinstance( if isinstance(
pipeline, pipeline,
pipelines.object_detection.ObjectDetectionPipeline, (
pipelines.object_detection.ObjectDetectionPipeline,
pipelines.token_classification.TokenClassificationPipeline,
),
): ):
output = pipeline_info["postprocess"](data, params[0]) output = pipeline_info["postprocess"](data, params[0])
else: else:

View File

@ -89,6 +89,16 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
"preprocess": lambda x: [x], "preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"]: i["score"] for i in r}, "postprocess": lambda r: {i["label"]: i["score"] for i in r},
} }
if is_transformers_pipeline_type(pipeline, "TokenClassificationPipeline"):
return {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.HighlightedText(label="Entities", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r, text: {
"text": text,
"entities": r,
},
}
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"): if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
return { return {
"inputs": components.Textbox(label="Input", render=False), "inputs": components.Textbox(label="Input", render=False),