Compare commits

...

2 Commits

Author SHA1 Message Date
6257759e64
feat: faster rcnn generate bouding box's json 2023-04-13 22:56:24 +08:00
9c09c2bb3e
feat: test torchvision's resnet model 2023-04-13 22:55:51 +08:00
2 changed files with 129 additions and 0 deletions

44
src/test_backbone.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torchvision
import torchvision.models as models
import time
from torchinfo import summary
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
device = torch.device('cuda')
# 創建 resnet50 模型
model = models.resnet50(weights=torchvision.models.ResNet50_Weights).to(device)
model = torch.nn.Sequential(*list(model.children())[:-3]) # 去除 FC & global pooling && 一層
# 定義 transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
for name, param in model.named_parameters():
param.requires_grad = False
# print(name, param.requires_grad)
# 將模型設置為 evaluation 模式
model.eval()
# summary(model, input_size=(32, 3, 224, 224))
# time.sleep(2)
with torch.no_grad():
image = Image.open('../data/textcaps/train_images/{}'.format('0a9d5e54fcf25702.jpg'))
image = transform(image).to(device).unsqueeze(0)
output = model(image)
output = output[0].to('cpu')
for index, feature in enumerate(output):
plt.imshow(feature)
plt.savefig("result/tmp/{}.jpg".format(index))
print(output.shape)
# print(model)

View File

@ -0,0 +1,85 @@
import torchvision
import torch
import torchvision.transforms as transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
import os
import json
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
device = torch.device('cuda')
classes = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}
# classes is from here
#
#
# with open('../data/coco/annotations/instances_train2017.json') as fp:
# data = json.load(fp)
# for i in data['categories']:
# classes[i['id']] = i['name']
# print(classes)
model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1).to(device)
model.eval()
# 定義轉換
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
images = []
paths = []
count = 0
l = len(os.listdir('../data/textcaps/train_images'))
for filename in os.listdir("../data/textcaps/train_images"):
if len(images) < 32:
image = Image.open('../data/textcaps/train_images/{}'.format(filename))
image = transform(image).to(device)
c, h, w = image.shape
if c == 3:
images.append(image)
paths.append(filename)
else:
with torch.no_grad():
outputs = model(images)
for fileindex, output in enumerate(outputs):
# fig, ax = plt.subplots(1)
# image = images[fileindex].to('cpu')
# ax.imshow(image.permute(1, 2, 0))
datas = {}
datas['filename'] = paths[fileindex]
datas['boxes'] = []
box_counter = 0
for index in range(len(output['boxes'])):
box = output['boxes'][index].tolist()
score = output['scores'][index].item()
label = output['labels'][index].item()
if score > 0.3: # 只畫出信心度大於 0.5 的 bounding box
# rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], linewidth=2, edgecolor='r', facecolor='none')
# ax.add_patch(rect)
# ax.text(box[0], box[1]-5, classes[label] + f' {score:.2f}', color='r', fontsize=12)
box_counter += 1
datas['boxes'].append({'box': box, 'label': classes[label], 'score': score})
# print(datas)
# if box_counter < 10:
# plt.savefig("result/{}".format(paths[fileindex]))
with open('result/object_detection/{}.json'.format(paths[fileindex]), 'w') as fp:
json.dump(datas, fp)
images = []
paths = []
print("batch finished.")
count += 1
print("{}/{}".format(count, l))