feat: test torchvision's resnet model
This commit is contained in:
parent
d64cffa63f
commit
9c09c2bb3e
44
src/test_backbone.py
Normal file
44
src/test_backbone.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.models as models
|
||||||
|
import time
|
||||||
|
from torchinfo import summary
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
device = torch.device('cuda')
|
||||||
|
|
||||||
|
# 創建 resnet50 模型
|
||||||
|
model = models.resnet50(weights=torchvision.models.ResNet50_Weights).to(device)
|
||||||
|
model = torch.nn.Sequential(*list(model.children())[:-3]) # 去除 FC & global pooling && 一層
|
||||||
|
|
||||||
|
# 定義 transform
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
# print(name, param.requires_grad)
|
||||||
|
# 將模型設置為 evaluation 模式
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# summary(model, input_size=(32, 3, 224, 224))
|
||||||
|
# time.sleep(2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image = Image.open('../data/textcaps/train_images/{}'.format('0a9d5e54fcf25702.jpg'))
|
||||||
|
image = transform(image).to(device).unsqueeze(0)
|
||||||
|
|
||||||
|
output = model(image)
|
||||||
|
output = output[0].to('cpu')
|
||||||
|
for index, feature in enumerate(output):
|
||||||
|
plt.imshow(feature)
|
||||||
|
plt.savefig("result/tmp/{}.jpg".format(index))
|
||||||
|
print(output.shape)
|
||||||
|
# print(model)
|
||||||
Loading…
Reference in New Issue
Block a user