mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
completed pytorch image classification guide
This commit is contained in:
parent
08ea95479a
commit
a70fe9aeff
@ -1,104 +1,91 @@
|
||||
# Image Classification in Pytorch
|
||||
# Image Classification in PyTorch
|
||||
|
||||
related_spaces: https://huggingface.co/spaces/nateraw/quickdraw
|
||||
related_spaces: abidlabs/pytorch-image-classifier
|
||||
tags: VISION, RESNET, PYTORCH
|
||||
|
||||
## Introduction
|
||||
|
||||
Image classification is a central task in computer vision. And building better classifiers to classify what object is present in a picture is an active area of research, as it has applications stretching from autonomous vehicles to medical imaging.
|
||||
Image classification is a central task in computer vision. Building better classifiers to classify what object is present in a picture is an active area of research, as it has applications stretching from autonomous vehicles to medical imaging.
|
||||
|
||||
Such models are perfect to use with Gradio's *image* input component, so in this tutorial we will build a web demo to classify images using Gradio. We will be able to build the whole web application in Python, and will look like this (try one of the examples!):
|
||||
|
||||
<iframe src="https://hf.space/gradioiframe/abidlabs/pytorch-image-classifier/+" frameBorder="0" height="450" title="Gradio app" class="container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe>
|
||||
|
||||
|
||||
Let's get started!
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Make sure you have the `gradio` Python package already [installed](/getting_started). We will be using a pretrained image classification model, so also install `torch`.
|
||||
Make sure you have the `gradio` Python package already [installed](/getting_started). We will be using a pretrained image classification model, so you should also have `torch` installed.
|
||||
|
||||
## Step 1 — Setting up the Image Classification Model
|
||||
|
||||
First, you will need a sketch recognition model. Since many researchers have already trained their own models on the Quick Draw dataset, we will use a pretrained model in this tutorial. Our model is a light 1.5 MB model trained by Nate Raw, that [you can download here](https://huggingface.co/spaces/nateraw/quickdraw/blob/main/pytorch_model.bin).
|
||||
|
||||
If you are interested, here [is the code](https://github.com/nateraw/quickdraw-pytorch) that was used to train the model. We will simply load the pretrained model in PyTorch, as follows:
|
||||
First, you will need a image classification model. For this tutorial, we will use a pretrained Resnet-18 model, as it is easily downloadable from [PyTorch Hub](https://pytorch.org/hub/pytorch_vision_resnet/). You can use a different pretrained model or train your own.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
model = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, padding='same'),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(32, 64, 3, padding='same'),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(64, 128, 3, padding='same'),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Flatten(),
|
||||
nn.Linear(1152, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, len(LABELS)),
|
||||
)
|
||||
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.eval()
|
||||
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
|
||||
```
|
||||
|
||||
Because we will be using the model for inference, we have called the `.eval()` method.
|
||||
|
||||
## Step 2 — Defining a `predict` function
|
||||
|
||||
Next, you will need to define a function that takes in the *user input*, which in this case is a sketched image, and returns the prediction. The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities. We will load the class names from this [text file](https://huggingface.co/spaces/nateraw/quickdraw/blob/main/class_names.txt).
|
||||
Next, you will need to define a function that takes in the *user input*, which in this case is an image, and returns the prediction. The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities. We will load the class names from this [text file](https://git.io/JJkYN).
|
||||
|
||||
In the case of our pretrained model, it will look like this:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
LABELS = Path('class_names.txt').read_text().splitlines()
|
||||
# Download human-readable labels for ImageNet.
|
||||
response = requests.get("https://git.io/JJkYN")
|
||||
labels = response.text.split("\n")
|
||||
|
||||
def predict(img):
|
||||
x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
|
||||
with torch.no_grad():
|
||||
out = model(x)
|
||||
probabilities = torch.nn.functional.softmax(out[0], dim=0)
|
||||
values, indices = torch.topk(probabilities, 5)
|
||||
confidences = {LABELS[i]: v.item() for i, v in zip(indices, values)}
|
||||
return confidences
|
||||
def predict(inp):
|
||||
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
||||
inp = transforms.ToTensor()(inp).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
|
||||
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
return confidences
|
||||
```
|
||||
|
||||
Let's break this down. The function takes one parameters:
|
||||
|
||||
* `img`: the input image as a `numpy` array
|
||||
* `inp`: the input image as a `numpy` array
|
||||
|
||||
Then, the function converts the image to a PyTorch `tensor`, passes it through the model, and returns:
|
||||
Then, the function converts the image to a PIL Image and then eventually a PyTorch `tensor`, passes it through the model, and returns:
|
||||
|
||||
* `confidences`: the top five predictions, as a dictionary whose keys are class labels and whose values are confidence probabilities
|
||||
* `confidences`: the predictions, as a dictionary whose keys are class labels and whose values are confidence probabilities
|
||||
|
||||
## Step 3 — Creating a Gradio Interface
|
||||
|
||||
Now that we have our predictive function set up, we can create a Gradio Interface around it.
|
||||
|
||||
In this case, the input component is a sketchpad. To create a sketchpad input, we can use the convenient string shortcut, `"sketchpad"` which creates a canvas for a user to draw on and handles the preprocessing to convert that to a numpy array.
|
||||
In this case, the input component is a drag-and-drop image component. To create this input, we can use the convenient string shortcut, `"image"` which creates the component and handles the preprocessing to convert that to a numpy array.
|
||||
|
||||
The output component will be a `"label"`, which displays the top labels in a nice form.
|
||||
The output component will be a `"label"`, which displays the top labels in a nice form. Since we don't want to show all 1,000 class labels, we will customize it to show only the top 3 images.
|
||||
|
||||
Finally, we'll add one more parameter, setting `live=True`, which allows our interface to run in real time, adjusting its predictions every time a user draws on the sketchpad. The code for Gradio looks like this:
|
||||
Finally, we'll add one more parameter, the `examples`, which allows us to prepopulate our interfaces with a few predefined examples. The code for Gradio looks like this:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
gr.Interface(fn=predict,
|
||||
inputs="sketchpad",
|
||||
outputs="label",
|
||||
live=True).launch()
|
||||
inputs="image",
|
||||
outputs=gr.outputs.Label(num_top_classes=3),
|
||||
examples=["lion.jpg", "cheetah.jpg"]).launch()
|
||||
```
|
||||
|
||||
This produces the following interface, which you can try right here in your browser (try drawing something, like a "snake" or a "laptop"):
|
||||
This produces the following interface, which you can try right here in your browser (try uploading your own examples!):
|
||||
|
||||
<iframe src="https://hf.space/gradioiframe/abidlabs/draw2/+" frameBorder="0" height="450" title="Gradio app" class="container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe>
|
||||
<iframe src="https://hf.space/gradioiframe/abidlabs/pytorch-image-classifier/+" frameBorder="0" height="450" title="Gradio app" class="container p-0 flex-grow space-iframe" allow="accelerometer; ambient-light-sensor; autoplay; battery; camera; document-domain; encrypted-media; fullscreen; geolocation; gyroscope; layout-animations; legacy-image-formats; magnetometer; microphone; midi; oversized-images; payment; picture-in-picture; publickey-credentials-get; sync-xhr; usb; vr ; wake-lock; xr-spatial-tracking" sandbox="allow-forms allow-modals allow-popups allow-popups-to-escape-sandbox allow-same-origin allow-scripts allow-downloads"></iframe>
|
||||
|
||||
----------
|
||||
|
||||
And you're done! That's all the code you need to build a Pictionary-style guessing app. Have fun and try to find some edge cases 🧐
|
||||
And you're done! That's all the code you need to build a web demo for an image classifier. If you'd like to share with others, try setting `share=True` when you `launch()` the Interface!
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user