mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-24 13:01:18 +08:00
Added support for TokenClassificationPipeline (#8888)
* fix for 8628 * formatting update * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
f8ccb5e663
commit
70a0c56200
5
.changeset/shaggy-carrots-drop.md
Normal file
5
.changeset/shaggy-carrots-drop.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
"gradio": minor
|
||||||
|
---
|
||||||
|
|
||||||
|
feat:Added support for TokenClassificationPipeline
|
@ -48,16 +48,20 @@ def load_from_pipeline(
|
|||||||
pipelines.text_classification.TextClassificationPipeline,
|
pipelines.text_classification.TextClassificationPipeline,
|
||||||
pipelines.text2text_generation.Text2TextGenerationPipeline,
|
pipelines.text2text_generation.Text2TextGenerationPipeline,
|
||||||
pipelines.text2text_generation.TranslationPipeline,
|
pipelines.text2text_generation.TranslationPipeline,
|
||||||
|
pipelines.token_classification.TokenClassificationPipeline,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
data = pipeline(*data)
|
data = pipeline(*data)
|
||||||
else:
|
else:
|
||||||
data = pipeline(**data) # type: ignore
|
data = pipeline(**data) # type: ignore
|
||||||
# special case for object-detection
|
# special case for object-detection and token-classification pipelines
|
||||||
# original input image sent to postprocess function
|
# original input image / text sent to postprocess function
|
||||||
if isinstance(
|
if isinstance(
|
||||||
pipeline,
|
pipeline,
|
||||||
pipelines.object_detection.ObjectDetectionPipeline,
|
(
|
||||||
|
pipelines.object_detection.ObjectDetectionPipeline,
|
||||||
|
pipelines.token_classification.TokenClassificationPipeline,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
output = pipeline_info["postprocess"](data, params[0])
|
output = pipeline_info["postprocess"](data, params[0])
|
||||||
else:
|
else:
|
||||||
|
@ -89,6 +89,16 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
|
|||||||
"preprocess": lambda x: [x],
|
"preprocess": lambda x: [x],
|
||||||
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
|
"postprocess": lambda r: {i["label"]: i["score"] for i in r},
|
||||||
}
|
}
|
||||||
|
if is_transformers_pipeline_type(pipeline, "TokenClassificationPipeline"):
|
||||||
|
return {
|
||||||
|
"inputs": components.Textbox(label="Input", render=False),
|
||||||
|
"outputs": components.HighlightedText(label="Entities", render=False),
|
||||||
|
"preprocess": lambda x: [x],
|
||||||
|
"postprocess": lambda r, text: {
|
||||||
|
"text": text,
|
||||||
|
"entities": r,
|
||||||
|
},
|
||||||
|
}
|
||||||
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
|
if is_transformers_pipeline_type(pipeline, "TextGenerationPipeline"):
|
||||||
return {
|
return {
|
||||||
"inputs": components.Textbox(label="Input", render=False),
|
"inputs": components.Textbox(label="Input", render=False),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user