feat: test torchvision's resnet model
This commit is contained in:
parent
d64cffa63f
commit
9c09c2bb3e
44
src/test_backbone.py
Normal file
44
src/test_backbone.py
Normal file
@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as models
|
||||
import time
|
||||
from torchinfo import summary
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
# 創建 resnet50 模型
|
||||
model = models.resnet50(weights=torchvision.models.ResNet50_Weights).to(device)
|
||||
model = torch.nn.Sequential(*list(model.children())[:-3]) # 去除 FC & global pooling && 一層
|
||||
|
||||
# 定義 transform
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = False
|
||||
# print(name, param.requires_grad)
|
||||
# 將模型設置為 evaluation 模式
|
||||
model.eval()
|
||||
|
||||
# summary(model, input_size=(32, 3, 224, 224))
|
||||
# time.sleep(2)
|
||||
|
||||
with torch.no_grad():
|
||||
image = Image.open('../data/textcaps/train_images/{}'.format('0a9d5e54fcf25702.jpg'))
|
||||
image = transform(image).to(device).unsqueeze(0)
|
||||
|
||||
output = model(image)
|
||||
output = output[0].to('cpu')
|
||||
for index, feature in enumerate(output):
|
||||
plt.imshow(feature)
|
||||
plt.savefig("result/tmp/{}.jpg".format(index))
|
||||
print(output.shape)
|
||||
# print(model)
|
||||
Loading…
Reference in New Issue
Block a user