gradio/demo/Echocardiogram-Segmentation/run.py
Aarni Koskela ef3862e075
Switch linting to Ruff (#3710)
* Sort requirements.in

* Switch flake8 + isort to ruff

* Apply ruff import order fixes

* Fix ruff complaints in demo/

* Fix ruff complaints in test/

* Use `x is not y`, not `not x is y`

* Remove unused listdir from website generator

* Clean up duplicate dict keys

* Add changelog entry

* Clean up unused imports (except in gradio/__init__.py)

* add space

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2023-04-03 15:48:18 -07:00

89 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch
import torchvision
import wget
destination_folder = "output"
destination_for_weights = "weights"
if os.path.exists(destination_for_weights):
print("The weights are at", destination_for_weights)
else:
print("Creating folder at ", destination_for_weights, " to store weights")
os.mkdir(destination_for_weights)
segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
else:
print("Segmentation Weights already present")
torch.cuda.empty_cache()
def collate_fn(x):
x, f = zip(*x)
i = list(map(lambda t: t.shape[1], x))
x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
return x, f, i
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
if torch.cuda.is_available():
print("cuda is available, original weights")
device = torch.device("cuda")
model = torch.nn.DataParallel(model)
model.to(device)
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
model.load_state_dict(checkpoint['state_dict'])
else:
print("cuda is not available, cpu weights")
device = torch.device("cpu")
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict_cpu)
model.eval()
def segment(input):
inp = input
x = inp.transpose([2, 0, 1]) # channels-first
x = np.expand_dims(x, axis=0) # adding a batch dimension
mean = x.mean(axis=(0, 2, 3))
std = x.std(axis=(0, 2, 3))
x = x - mean.reshape(1, 3, 1, 1)
x = x / std.reshape(1, 3, 1, 1)
with torch.no_grad():
x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
output = model(x)
y = output['out'].numpy()
y = y.squeeze()
out = y>0
mask = inp.copy()
mask[out] = np.array([0, 0, 255])
return mask
import gradio as gr
i = gr.Image(shape=(112, 112), label="Echocardiogram")
o = gr.Image(label="Segmentation Mask")
examples = [["img1.jpg"], ["img2.jpg"]]
title = None #"Left Ventricle Segmentation"
description = "This semantic segmentation model identifies the left ventricle in echocardiogram images."
# videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of Video-based AI for beat-to-beat assessment of cardiac function by Ouyang et al. in Nature, 2020."
thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False, thumbnail=thumbnail, cache_examples=False).launch()