mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
597337dcb8
* added playground with 12 demos * change name to recipes, restyle navbar * add explanatory text to page * fix demo mapping * categorize demos, clean up design * styling * cateogry naming and emojis * refactor and add text demos * add view code button * remove opening slash in embed * styling * add image demos * adding plot demos * remove see code button * removed submodules * changes * add audio models * remove fun section * remove tests in image semgentation demo repo * requested changes * add outbreak_forecast * fix broken demos * remove images and models, add new demos * remove readmes, change to run.py, add description as comment * move to /demos folder, clean up dict * add upload_to_spaces script * fix script, clean repos, and add to docker file * fix python versioning issue * env variable * fix * env fixes * spaces instead of tabs * revert to original networking.py * fix rate limiting in asr and autocomplete * change name to demos * clean up navbar * move url and description, remove code comments * add tabs to demos * remove margins and footer from embedded demo * font consistency Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
import gradio as gr
|
|
import torch
|
|
import random
|
|
import numpy as np
|
|
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
|
|
|
|
device = torch.device("cpu")
|
|
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
|
|
model.eval()
|
|
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
|
|
|
|
def visualize_instance_seg_mask(mask):
|
|
image = np.zeros((mask.shape[0], mask.shape[1], 3))
|
|
labels = np.unique(mask)
|
|
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
|
for i in range(image.shape[0]):
|
|
for j in range(image.shape[1]):
|
|
image[i, j, :] = label2color[mask[i, j]]
|
|
image = image / 255
|
|
return image
|
|
|
|
def query_image(img):
|
|
target_size = (img.shape[0], img.shape[1])
|
|
inputs = preprocessor(images=img, return_tensors="pt")
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
|
|
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
|
|
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
|
|
results = torch.argmax(results, dim=0).numpy()
|
|
results = visualize_instance_seg_mask(results)
|
|
return results
|
|
|
|
demo = gr.Interface(
|
|
query_image,
|
|
inputs=[gr.Image()],
|
|
outputs="image",
|
|
title="MaskFormer Demo",
|
|
examples=[["example_2.png"]]
|
|
)
|
|
|
|
demo.launch() |