2
0
mirror of https://github.com/gradio-app/gradio.git synced 2024-12-21 02:19:59 +08:00
gradio/demo/image_segmentation/run.py
Ali Abdalla 597337dcb8
Adding a Playground Tab to the Website ()
* added playground with 12 demos

* change name to recipes, restyle navbar

* add explanatory text to page

* fix demo mapping

* categorize demos, clean up design

* styling

* cateogry naming and emojis

* refactor and add text demos

* add view code button

* remove opening slash in embed

* styling

* add image demos

* adding plot demos

* remove see code button

* removed submodules

* changes

* add audio models

* remove fun section

* remove tests in image semgentation demo repo

* requested changes

* add outbreak_forecast

* fix broken demos

* remove images and models, add new demos

* remove readmes, change to run.py, add description as comment

* move to /demos folder, clean up dict

* add upload_to_spaces script

* fix script, clean repos, and add to docker file

* fix python versioning issue

* env variable

* fix

* env fixes

* spaces instead of tabs

* revert to original networking.py

* fix rate limiting in asr and autocomplete

* change name to demos

* clean up navbar

* move url and description, remove code comments

* add tabs to demos

* remove margins and footer from embedded demo

* font consistency

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2022-09-15 08:24:10 -07:00

42 lines
1.5 KiB
Python

import gradio as gr
import torch
import random
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
device = torch.device("cpu")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
def visualize_instance_seg_mask(mask):
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
target_size = (img.shape[0], img.shape[1])
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
outputs.class_queries_logits = outputs.class_queries_logits.cpu()
outputs.masks_queries_logits = outputs.masks_queries_logits.cpu()
results = preprocessor.post_process_segmentation(outputs=outputs, target_size=target_size)[0].cpu().detach()
results = torch.argmax(results, dim=0).numpy()
results = visualize_instance_seg_mask(results)
return results
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="MaskFormer Demo",
examples=[["example_2.png"]]
)
demo.launch()