feat: test torchvision's resnet model

This commit is contained in:
snsd0805 2023-04-13 22:55:51 +08:00
parent d64cffa63f
commit 9c09c2bb3e
Signed by: snsd0805
GPG Key ID: 569349933C77A854

44
src/test_backbone.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torchvision
import torchvision.models as models
import time
from torchinfo import summary
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
device = torch.device('cuda')
# 創建 resnet50 模型
model = models.resnet50(weights=torchvision.models.ResNet50_Weights).to(device)
model = torch.nn.Sequential(*list(model.children())[:-3]) # 去除 FC & global pooling && 一層
# 定義 transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
for name, param in model.named_parameters():
param.requires_grad = False
# print(name, param.requires_grad)
# 將模型設置為 evaluation 模式
model.eval()
# summary(model, input_size=(32, 3, 224, 224))
# time.sleep(2)
with torch.no_grad():
image = Image.open('../data/textcaps/train_images/{}'.format('0a9d5e54fcf25702.jpg'))
image = transform(image).to(device).unsqueeze(0)
output = model(image)
output = output[0].to('cpu')
for index, feature in enumerate(output):
plt.imshow(feature)
plt.savefig("result/tmp/{}.jpg".format(index))
print(output.shape)
# print(model)