Add support for object-detection models in gr.load() (#7716)

* add support for object-detection models in gr.load

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-15 13:09:44 -07:00 committed by GitHub
parent 28342a2040
commit 188b86b766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 0 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Add support for object-detection models in `gr.load()`

View File

@ -343,6 +343,11 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
label="Predictions", type="array", headers=["prediction"]
)
fn = external_utils.tabular_wrapper
# example model: microsoft/table-transformer-detection
elif p == "object-detection":
inputs = components.Image(type="filepath", label="Input Image")
outputs = components.AnnotatedImage(label="Annotations")
fn = external_utils.object_detection_wrapper(client)
else:
raise ValueError(f"Unsupported pipeline type: {p}")

View File

@ -169,6 +169,26 @@ def token_classification_wrapper(client: InferenceClient):
return token_classification_inner
def object_detection_wrapper(client: InferenceClient):
def object_detection_inner(input: str):
annotations = client.object_detection(input)
formatted_annotations = [
(
(
a["box"]["xmin"],
a["box"]["ymin"],
a["box"]["xmax"],
a["box"]["ymax"],
),
a["label"],
)
for a in annotations
]
return (input, formatted_annotations)
return object_detection_inner
def chatbot_preprocess(text, state):
if not state:
return text, [], []