Instance Segmentation
Sample inference script for torchscript exported instance segmentation model
Mask R-CNN
FBNetV3
sample_instance_segmentation.py
1
import torch
2
import numpy as np
3
from PIL import Image
4
import torchvision
5
import json
6
import matplotlib.pyplot as plt
7
import cv2
8
9
with open('class_mapping.json') as data:
10
mappings = json.load(data)
11
12
class_mapping = {item['model_idx']: item['class_name'] for item in mappings}
13
14
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
16
model = torch.jit.load('model.pt').to(device)
17
18
image_path = '/path/to/your/image'
19
image = Image.open(image_path)
20
# Transform your image according to the transforms.json as in
21
# https://help.hasty.ai/model-playground/image-transformations
22
image = np.array(image)
23
h, w = image.shape[:2]
24
# Convert to torch tensor
25
x = torch.from_numpy(image).to(device)
26
with torch.no_grad():
27
# Convert to channels first, convert to float datatype
28
x = x.permute(2, 0, 1).float()
29
y = model(x)
30
# Some optional postprocessing, you can change the 0.5 iou
31
# overlap as needed
32
to_keep = torchvision.ops.nms(y['pred_boxes'], y['scores'], 0.5)
33
y['pred_boxes'] = y['pred_boxes'][to_keep]
34
y['pred_classes'] = y['pred_classes'][to_keep]
35
y['pred_masks'] = y['pred_masks'][to_keep]
36
37
# Draw you box predictions:
38
all_masks = np.zeros((h, w), dtype=np.int8)
39
instance_idx = 1
40
for mask, bbox, label in zip(reversed(y['pred_masks']),
41
y['pred_boxes'],
42
y['pred_classes']):
43
bbox = list(map(int, bbox))
44
x1, y1, x2, y2 = bbox
45
class_idx = label.item()
46
class_name = class_mapping[class_idx]
47
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
48
cv2.putText(
49
image,
50
class_name,
51
(x1, y1),
52
cv2.FONT_HERSHEY_SIMPLEX,
53
4,
54
(255, 0, 0)
55
)
56
all_masks[mask == 1] = instance_idx
57
instance_idx += 1
58
# Display predicted masks, boxes and classes on your image
59
plt.imshow(image)
60
plt.imshow(all_masks, alpha=0.5)
61
plt.show()
62
Copied!
sample_instance_segmentation.py
1
import torch
2
import numpy as np
3
from PIL import Image
4
import torchvision
5
import json
6
import matplotlib.pyplot as plt
7
import cv2
8
9
with open('class_mapping.json') as data:
10
mappings = json.load(data)
11
12
class_mapping = {item['model_idx']: item['class_name'] for item in mappings}
13
14
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
16
model = torch.jit.load('model.pt').to(device)
17
18
image_path = '/path/to/your/image'
19
image = Image.open(image_path)
20
# Transform your image according to the transforms.json as in
21
# https://help.hasty.ai/model-playground/image-transformations
22
image = np.array(image)
23
h, w = image.shape[:2]
24
# Convert to torch tensor
25
x = torch.from_numpy(image).to(device)
26
with torch.no_grad():
27
# Convert to channels first, convert to float datatype
28
x = x.permute(2, 0, 1).float()
29
pred_boxes, pred_classes, pred_masks, scores, _ = model(x)
30
# Some optional postprocessing, you can change the 0.5 iou
31
# overlap as needed
32
to_keep = torchvision.ops.nms(pred_boxes, scores, 0.5)
33
pred_boxes = pred_boxes[to_keep]
34
pred_classes = pred_classes[to_keep]
35
pred_masks = pred_masks[to_keep]
36
37
# Draw you box predictions:
38
all_masks = np.zeros((h, w), dtype=np.int8)
39
instance_idx = 1
40
for mask, bbox, label in zip(reversed(pred_masks),
41
pred_boxes,
42
pred_classes):
43
bbox = list(map(int, bbox))
44
x1, y1, x2, y2 = bbox
45
class_idx = label.item()
46
class_name = class_mapping[class_idx]
47
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
48
cv2.putText(
49
image,
50
class_name,
51
(x1, y1),
52
cv2.FONT_HERSHEY_SIMPLEX,
53
4,
54
(255, 0, 0)
55
)
56
mask = cv2.resize(mask.squeeze().numpy(), dsize=(w, h),
57
interpolation=cv2.INTER_LINEAR)
58
all_masks[mask > 0.5] = instance_idx
59
instance_idx += 1
60
# Display predicted masks, boxes and classes on your image
61
plt.imshow(image)
62
plt.imshow(all_masks, alpha=0.5)
63
plt.show()
64
Copied!
Example output from the instance segmentation sample inference script
Copy link