Semantic Segmentation
Sample inference script for torchscript exported semantic segmentation model
sample_sem_seg.py
1
import torch
2
import numpy as np
3
from PIL import Image
4
import torchvision
5
import json
6
import matplotlib.pyplot as plt
7
import cv2
8
9
with open('class_mapping.json') as data:
10
mappings = json.load(data)
11
12
class_mapping = {item['model_idx']: item['class_name'] for item in mappings}
13
14
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
16
model = torch.jit.load('model.pt').to(device)
17
18
image_path = '/path/to/your/image'
19
image = Image.open(image_path)
20
# Transform your image according to the transforms.json as in
21
# https://help.hasty.ai/model-playground/image-transformations
22
image = np.array(image)
23
h, w = image.shape[:2]
24
# Convert to torch tensor
25
x = torch.from_numpy(image).to(device)
26
with torch.no_grad():
27
# Convert to channels first, convert to float datatype
28
x = x.permute(2, 0, 1).unsqueeze(dim=0).float()
29
y = model(x)
30
mask = torch.argmax(y, dim=1).squeeze()
31
32
# Overlay predicted mask on image and display
33
plt.imshow(image)
34
plt.imshow(mask, alpha=0.5)
35
plt.show()
Copied!
The script above should produce outputs that look like this:
Example output from the semseg inference script, yellow highlights the present class.
Copy link