-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
74 lines (61 loc) · 2.34 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from torchstat import stat
# import torchvision.models as models
# from model import mobile_vit_xx_small as create_model
# from model import mobile_vit_x_small as create_model
from model import mobile_vit_small as create_model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(torch.cuda.is_available())
img_size = 224
data_transform = transforms.Compose(
[transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = r"test_data/pic(1).JPEG"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
#
# # read class_indict
# json_path = './class_indices.json'
# assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
#
# with open(json_path, "r") as f:
# class_indict = json.load(f)
# create model
model = create_model(num_classes=1000).to(device)
# load model weights
model_weight_path = r"./weights/mobilevit_s.pt"
# model_weight_path = r"./weights/mobilevit_xs.pt"
# model_weight_path = r"./weights/mobilevit_xxs.pt"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
print(img.to(device))
print(stat(model.to("cpu"),(3, 224, 224)))
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
# print(predict)
print(predict_cla)
# print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
# predict[predict_cla].numpy())
# plt.title(print_res)
# for i in range(len(predict)):
# print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
# predict[i].numpy()))
# plt.show()
if __name__ == '__main__':
main()