mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-24 13:01:18 +08:00
Added support for TokenClassificationPipeline (#8888)
* fix for 8628 * formatting update * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
f8ccb5e663
commit
70a0c56200
5
.changeset/shaggy-carrots-drop.md
Normal file
5
.changeset/shaggy-carrots-drop.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Added support for TokenClassificationPipeline
|
@ -48,16 +48,20 @@ def load_from_pipeline(
|
||||
pipelines.text_classification.TextClassificationPipeline,
|
||||
pipelines.text2text_generation.Text2TextGenerationPipeline,
|
||||
pipelines.text2text_generation.TranslationPipeline,
|
||||
pipelines.token_classification.TokenClassificationPipeline,
|
||||
),
|
||||
):
|
||||
data = pipeline(*data)
|
||||
else:
|
||||
data = pipeline(**data) # type: ignore
|
||||
# special case for object-detection
|
||||
# original input image sent to postprocess function
|
||||
# special case for object-detection and token-classification pipelines
|
||||
# original input image / text sent to postprocess function
|
||||
if isinstance(
|
||||
pipeline,
|
||||
pipelines.object_detection.ObjectDetectionPipeline,
|
||||
(
|
||||
pipelines.object_detection.ObjectDetectionPipeline,
|
||||
pipelines.token_classification.TokenClassificationPipeline,
|
||||
),
|
||||
):
|
||||
output = pipeline_info["postprocess"](data, params[0])
|
||||
else:
|
||||
|
@ -89,6 +89,16 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
|
||||
}
|
||||
if is_transformers_pipeline_type(pipeline, "TokenClassificationPipeline"):
|
||||
return {
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
"outputs": components.HighlightedText(label="Entities", render=False),
|
||||
"preprocess": lambda x: [x],
|
||||
"postprocess": lambda r, text: {
|
||||
"text": text,
|
||||
"entities": r,
|
||||
},
|
||||
}
|
||||
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
|
||||
return {
|
||||
"inputs": components.Textbox(label="Input", render=False),
|
||||
|
Loading…
x
Reference in New Issue
Block a user