commit c23893bfaba9c557de11373dbfd924d44f98ce4c Author: snsd0805 Date: Wed Apr 12 21:22:43 2023 +0800 feat: use object detection to plot boxes on pictures diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3722a7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data/* +src/enviroment + diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000..21428b2 --- /dev/null +++ b/src/test.py @@ -0,0 +1,72 @@ +import torchvision +import torch +import torchvision.transforms as transforms +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights +import os +import json +from PIL import Image +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +device = torch.device('cuda') + +classes = {1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'} + +# classes is from here +# +# +# with open('../data/coco/annotations/instances_train2017.json') as fp: +# data = json.load(fp) +# for i in data['categories']: +# classes[i['id']] = i['name'] +# print(classes) + +model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1).to(device) +model.eval() + +# 定義轉換 +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor() +]) + +images = [] +paths = [] +count = 0 +for filename in os.listdir("../data/textcaps/train_images"): + if len(images) < 8: + image = Image.open('../data/textcaps/train_images/{}'.format(filename)) + image = transform(image).to(device) + c, h, w = image.shape + if c == 3: + images.append(image) + paths.append(filename) + else: + with torch.no_grad(): + outputs = model(images) + + for fileindex, output in enumerate(outputs): + + fig, ax = plt.subplots(1) + image = images[fileindex].to('cpu') + ax.imshow(image.permute(1, 2, 0)) + + for index in range(len(output['boxes'])): + boxes = output['boxes'][index].tolist() + score = output['scores'][index].item() + label = output['labels'][index].item() + if score > 0.5: # 只畫出信心度大於 0.5 的 bounding box + rect = patches.Rectangle((boxes[0], boxes[1]), boxes[2]-boxes[0], boxes[3]-boxes[1], linewidth=2, edgecolor='r', facecolor='none') + ax.add_patch(rect) + ax.text(boxes[0], boxes[1]-5, classes[label] + f' {score:.2f}', color='r', fontsize=12) + plt.savefig(paths[fileindex]) + + images = [] + paths = [] + count += 1 + if count == 25: + break + + + +