diff --git a/src/test_backbone.py b/src/test_backbone.py new file mode 100644 index 0000000..17eb3aa --- /dev/null +++ b/src/test_backbone.py @@ -0,0 +1,44 @@ +import torch +import torchvision +import torchvision.models as models +import time +from torchinfo import summary +import torchvision.transforms as transforms +from PIL import Image +import matplotlib.pyplot as plt + +device = torch.device('cuda') + +# 創建 resnet50 模型 +model = models.resnet50(weights=torchvision.models.ResNet50_Weights).to(device) +model = torch.nn.Sequential(*list(model.children())[:-3]) # 去除 FC & global pooling && 一層 + +# 定義 transform +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) +]) + + +for name, param in model.named_parameters(): + param.requires_grad = False + # print(name, param.requires_grad) +# 將模型設置為 evaluation 模式 +model.eval() + +# summary(model, input_size=(32, 3, 224, 224)) +# time.sleep(2) + +with torch.no_grad(): + image = Image.open('../data/textcaps/train_images/{}'.format('0a9d5e54fcf25702.jpg')) + image = transform(image).to(device).unsqueeze(0) + + output = model(image) + output = output[0].to('cpu') + for index, feature in enumerate(output): + plt.imshow(feature) + plt.savefig("result/tmp/{}.jpg".format(index)) + print(output.shape) +# print(model)