feat: confusion matrix

This commit is contained in:
Ting-Jun Wang 2024-01-15 23:12:53 +08:00
parent d63e1f0cd0
commit 35031e918b
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
2 changed files with 934 additions and 578 deletions

View File

@ -1,571 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#1. load skybox image of the path[-3:]\n",
"#2. display instruction, adversarial instruction, (navgpt captions and objects)\n",
"\n",
"from PIL import Image\n",
"import json, os\n",
"\n",
"def load_json(fn):\n",
" with open(fn) as f:\n",
" ret = json.load(f)\n",
" return ret\n",
"\n",
"def dump_json(data, fn, force=False):\n",
" if not force:\n",
" assert not os.path.exists(fn)\n",
" with open(fn, 'w') as f:\n",
" json.dump(data, f)\n",
" \n",
"def concat_images(images):\n",
" #images = [Image.open(x) for x in ['Test1.jpg', 'Test2.jpg', 'Test3.jpg']]\n",
" widths, heights = zip(*(i.size for i in images))\n",
"\n",
" total_width = sum(widths)+(len(images)-1)*10\n",
" max_height = max(heights)\n",
"\n",
" new_im = Image.new('RGB', (total_width, max_height))\n",
"\n",
" x_offset = 0\n",
" for im in images:\n",
" new_im.paste(im, (x_offset,0))\n",
" x_offset += im.size[0]+10\n",
" return new_im"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"reverie_val_unseen = load_json('../REVERIE/tasks/REVERIE/data/REVERIE_val_unseen_first_ins.json')\n",
"reverie_val_unseen_fnf = load_json(\"..//fnf/reverie_val_unseen_fnf.json\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"fnf_pairs = []\n",
"for idx, r in enumerate(reverie_val_unseen_fnf):\n",
" if not r['found']:\n",
" fnf_pairs.append((reverie_val_unseen_fnf[idx-1], r))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"fnf_pair = fnf_pairs[idx]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"scan_id, path = fnf_pair[0]['scan'], fnf_pair[0]['path']\n",
"fn_template = '/work/ganymede9487/mp3d/unzipped/{}/matterport_skybox_images/{}_skybox{}_sami.jpg'\n",
"skybox_image_files = [fn_template.format(scan_id, path[-1], skybox_idx) for skybox_idx in range(6)]\n",
"skybox_images = []\n",
"for skybox_image_file in skybox_image_files:\n",
" skybox_images.append(Image.open(skybox_image_file))\n",
"\n",
"concat_images(skybox_images)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Go to the lounge on level 1 with the fire extinguisher and push the rope around the table closer to the walls\n",
"Go to the lounge on level 1 with the fire extinguisher and push the towel around the table closer to the walls\n"
]
}
],
"source": [
"print(fnf_pair[0]['instruction'])\n",
"print(fnf_pair[1]['instruction'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#0 ok\n",
"#1 error 2 (not reasonable)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualizer\n",
"# data loader part\n",
"from PIL import Image\n",
"import networkx as nx\n",
"import json\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"from io import BytesIO\n",
"import IPython.display\n",
"\n",
"# ================================== DATA DOWNLOAD ==================================\n",
"#\n",
"# DATASET: default to \"val unseen\"\n",
"# IMG_DIR_PATH: where to store the skyboxes images\n",
" # which the Matterport Sim used.\n",
"# ANNOTATION_DIR_PATH: where to store the \"adversarial instruction ver\" annotations\n",
" # download here: https://snsd0805.com/data/adversarial_annotations.zip\n",
"# PREDICTS_PATH: predict output (download here (duet version): )\n",
" # download here(duet's prediction): https://snsd0805.com/data/submit_val_unseen_dynamic.json\n",
"# CONNECTIVITY_PATH: where to store the connectivity file\n",
" # download here (provided by NavGPT): https://www.dropbox.com/sh/i8ng3iq5kpa68nu/AAB53bvCFY_ihYx1mkLlOB-ea?dl=1\n",
"# NAVIGABLE_PATH: \n",
" # the download link is same as CONNECTIVITY_PATH's\n",
"# BBOXES_FILE_PATH: \n",
" # download here: https://snsd0805.com/data/BBoxes.json\n",
"#\n",
"# ===================================================================================\n",
"DATASET = \"val_unseen\"\n",
"IMG_DIR_PATH = \"/home/snsd0805/code/research/VLN/base_dir/v1/scans\"\n",
"ANNOTATION_DIR_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_{DATASET}.json\"\n",
"PREDICTS_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/exprs_map/finetune/dagger-vitbase-seed.0/preds/submit_{DATASET}_dynamic.json\"\n",
"CONNECTIVITY_PATH = \"/data/Matterport3DSimulator-duet/connectivity\"\n",
"NAVIGABLE_PATH = \"/data/NavGPT_data/navigable\"\n",
"BBOXES_FILE_PATH = \"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json\"\n",
"\n",
"def load_predicts() -> dict:\n",
" with open(PREDICTS_PATH) as fp:\n",
" data = json.load(fp)\n",
" return data\n",
"\n",
"def load_annotations() -> dict:\n",
" with open(ANNOTATION_DIR_PATH) as fp:\n",
" origin_data = json.load(fp)\n",
" data = {}\n",
" for item in origin_data:\n",
" data[item['id']] = item\n",
" return data\n",
"\n",
"def get_annotation(annotations: dict, adversarial_instr_id: str) -> dict:\n",
" origin_instr_id, index = adversarial_instr_id[:-2], int(adversarial_instr_id[-1])\n",
" \n",
" ans = annotations[origin_instr_id]\n",
" if 'instructions_l' in ans:\n",
" del ans['instructions_l']\n",
" ans['original_instruction'] = ans['instructions'][0]\n",
" ans['instructions'] = ans['instructions'][index]\n",
" ans['found'] = ans['found'][index]\n",
"\n",
" return ans\n",
"\n",
"def load_scan_viewpoints():\n",
" '''\n",
" get all scan's viewpoints.\n",
" '''\n",
" scans_viewpoints = {}\n",
" \n",
" # load scan list\n",
" with open(f'{CONNECTIVITY_PATH}/scans.txt') as fp:\n",
" scans = [ scan.replace('\\n', '') for scan in fp.readlines() ]\n",
" \n",
" \n",
" # load all viewpoints from scan list\n",
" for scan in scans:\n",
" with open(f'{CONNECTIVITY_PATH}/{scan}_connectivity.json') as fp:\n",
" data = json.load(fp)\n",
" \n",
" # load navigable list\n",
" with open(f'{NAVIGABLE_PATH}/{scan}_navigable.json') as fp:\n",
" navigable_viewpoints = json.load(fp)\n",
" \n",
" # save all viewpoint in the scan\n",
" viewpoints = []\n",
" for viewpoint in data:\n",
" if viewpoint['included']:\n",
" viewpoints.append({\n",
" 'viewpoint': viewpoint['image_id'],\n",
" 'location': {\n",
" 'x': viewpoint['pose'][3],\n",
" 'y': viewpoint['pose'][7],\n",
" 'z': viewpoint['pose'][11],\n",
" },\n",
" 'navigable_viewpoints': list(navigable_viewpoints[viewpoint['image_id']].keys())\n",
" })\n",
" scans_viewpoints[scan] = viewpoints\n",
" return scans_viewpoints\n",
"\n",
"def get_viewpoints(scans_viewpoints: dict, scan: str) -> list:\n",
" return scans_viewpoints[scan]\n",
"\n",
"def load_bboxes_data() -> dict:\n",
" with open(BBOXES_FILE_PATH) as fp:\n",
" data = json.load(fp)\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot trajactory part\n",
"\n",
"def plot_all_scan_viewpoints(viewpoints: list, plot: bool=True):\n",
" nodes, edges = [], []\n",
" G = nx.Graph()\n",
" positions = {}\n",
"\n",
" for viewpoint in viewpoints:\n",
" x, y = viewpoint['location']['x'], viewpoint['location']['y']\n",
" viewpoint_id = viewpoint['viewpoint']\n",
" G.add_node(viewpoint_id)\n",
" positions[viewpoint_id] = (x, y)\n",
" for neighbor in viewpoint['navigable_viewpoints']:\n",
" G.add_edge(viewpoint_id, neighbor)\n",
" if plot:\n",
" nx.draw(G, positions, node_size=50)\n",
" plt.show() \n",
" else:\n",
" return G, positions\n",
"\n",
" \n",
" \n",
" \n",
"# overview\n",
"def plot_trajs_overview(viewpoints: list, gt_anno: dict, predict: dict):\n",
" G, positions = plot_all_scan_viewpoints(viewpoints, plot=False)\n",
" \n",
" edge_colors = [ 'black' for _ in G.edges ]\n",
"\n",
" max_x, max_y = -100, -100\n",
" min_x, min_y = 100, 100\n",
"\n",
" \n",
" # GT path\n",
" previous_node_name = None\n",
" for node_name in gt_anno['path']:\n",
" pos = positions[node_name]\n",
" max_x = max(max_x, pos[0])\n",
" max_y = max(max_y, pos[1])\n",
" min_x = min(min_x, pos[0])\n",
" min_y = min(min_y, pos[1])\n",
"\n",
" if previous_node_name != None:\n",
" for index, edge in enumerate(G.edges):\n",
" if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n",
" edge_colors[index] = 'red'\n",
" previous_node_name = node_name\n",
" \n",
" # Predicted Path\n",
" previous_node_name = None\n",
" for node_name in predict['trajectory']:\n",
" node_name = node_name[0]\n",
" pos = positions[node_name]\n",
" max_x = max(max_x, pos[0])\n",
" max_y = max(max_y, pos[1])\n",
" min_x = min(min_x, pos[0])\n",
" min_y = min(min_y, pos[1])\n",
"\n",
" if previous_node_name != None:\n",
" for index, edge in enumerate(G.edges):\n",
" if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n",
" if edge_colors[index] == 'red':\n",
" edge_colors[index] = 'orange'\n",
" else:\n",
" edge_colors[index] = 'blue'\n",
" previous_node_name = node_name\n",
"\n",
" current_axis = plt.gca()\n",
" current_axis.add_patch(Rectangle((min_x, min_y), max_x-min_x, max_y-min_y, fill=None, edgecolor='red', linewidth=0.3))\n",
"\n",
" \n",
" nx.draw(G, positions, node_size=20, edge_color=edge_colors)\n",
" plt.show()\n",
" \n",
" \n",
" \n",
" \n",
"# local traj\n",
"def plot_trajs_local(viewpoints: list, gt_anno: dict, predict: dict):\n",
" def get_viewpoints_dict(viewpoints) -> dict:\n",
" ans = {}\n",
" for viewpoint in viewpoints:\n",
" viewpoint_id = viewpoint['viewpoint']\n",
" ans[viewpoint_id] = viewpoint\n",
" return ans\n",
"\n",
" # list to dict, to find position\n",
" viewpoints = get_viewpoints_dict(viewpoints)\n",
"\n",
" # GT traj path\n",
" G = nx.Graph()\n",
" positions = {}\n",
" edge_colors = {}\n",
" previous_node_name = None\n",
" \n",
" # GT path\n",
" for index, node_name in enumerate(gt_anno['path']):\n",
" # new node\n",
" G.add_node(node_name)\n",
"\n",
" # find node's position\n",
" viewpoint = viewpoints[node_name]\n",
" x, y = viewpoint['location']['x'], viewpoint['location']['y']\n",
" positions[node_name] = (x, y)\n",
" \n",
" # path edge\n",
" if previous_node_name != None: # not start\n",
" G.add_edge(previous_node_name, node_name)\n",
" edge_colors[f'{previous_node_name}_{node_name}'] = 'red'\n",
" previous_node_name = node_name\n",
" \n",
" plt.text(x-0.2, y, f'{index}', fontsize=15, color='red', ha='center', va='center')\n",
"\n",
" if node_name == gt_anno['path'][0]:\n",
" plt.text(x, y, 'START', fontsize=12, color='darkred', ha='center', va='center')\n",
" if node_name == gt_anno['path'][-1]:\n",
" plt.text(x, y, 'STOP', fontsize=12, color='darkred', ha='center', va='center')\n",
" \n",
" \n",
" # Predicted path\n",
" previous_node_name = None\n",
" for node_index, node_name in enumerate(predict['trajectory']):\n",
" node_name = node_name[0]\n",
" # new node\n",
" \n",
" G.add_node(node_name)\n",
"\n",
" # find node's position\n",
" viewpoint = viewpoints[node_name]\n",
" x, y = viewpoint['location']['x'], viewpoint['location']['y']\n",
" positions[node_name] = (x, y)\n",
" \n",
" plt.text(x+0.2, y, f'{node_index}', fontsize=15, color='blue', ha='center', va='center')\n",
" \n",
" # path edge\n",
" if previous_node_name != None: # not start\n",
" find = False\n",
" for index, edge in enumerate(G.edges):\n",
" if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n",
" if f'{previous_node_name}_{node_name}' in edge_colors:\n",
" if edge_colors[f'{previous_node_name}_{node_name}'] == 'red':\n",
" edge_colors[f'{previous_node_name}_{node_name}'] = 'orange'\n",
" else:\n",
" edge_colors[f'{previous_node_name}_{node_name}'] = 'darkblue'\n",
" else:\n",
" if edge_colors[f'{node_name}_{previous_node_name}'] == 'red':\n",
" edge_colors[f'{node_name}_{previous_node_name}'] = 'orange'\n",
" else:\n",
" edge_colors[f'{node_name}_{previous_node_name}'] = 'darkblue'\n",
" find = True\n",
" break\n",
" if not find:\n",
" G.add_edge(previous_node_name, node_name)\n",
" edge_colors[f'{previous_node_name}_{node_name}'] = 'blue'\n",
" previous_node_name = node_name\n",
" \n",
" # edge colors\n",
" colors = []\n",
" for edge in G.edges:\n",
" if f'{edge[0]}_{edge[1]}' in edge_colors:\n",
" colors.append(edge_colors[f'{edge[0]}_{edge[1]}'])\n",
" else:\n",
" colors.append(edge_colors[f'{edge[1]}_{edge[0]}'])\n",
" # STOP & START label\n",
" start, stop = predict['trajectory'][0][0], predict['trajectory'][-1][0]\n",
" x, y = viewpoints[start]['location']['x'], viewpoints[start]['location']['y']\n",
" plt.text(x, y, 'START', fontsize=12, color='darkblue', ha='center', va='center')\n",
" x, y = viewpoints[stop]['location']['x'], viewpoints[stop]['location']['y']\n",
" plt.text(x, y, 'STOP', fontsize=12, color='darkblue', ha='center', va='center')\n",
"\n",
" nx.draw(G, positions, node_size=20, edge_color=colors, with_labels=False)\n",
" plt.show()\n",
"\n",
"\n",
"\n",
"def plot_trajs(viewpoints: list, gt_anno: dict, predict: dict, overview: bool):\n",
" if overview:\n",
" plot_trajs_overview(viewpoints, gt_anno, predict)\n",
" else:\n",
" plot_trajs_local(viewpoints, gt_anno, predict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# skybox part\n",
"def concat_images(images):\n",
" #images = [Image.open(x) for x in ['Test1.jpg', 'Test2.jpg', 'Test3.jpg']]\n",
" widths, heights = zip(*(i.size for i in images))\n",
"\n",
" total_width = sum(widths)+(len(images)-1)*10\n",
" max_height = max(heights)\n",
"\n",
" new_im = Image.new('RGB', (total_width, max_height))\n",
"\n",
" x_offset = 0\n",
" for im in images:\n",
" new_im.paste(im, (x_offset,0))\n",
" x_offset += im.size[0]+10\n",
" return new_im\n",
"\n",
"def get_skyboxes(scan: str, viewpoint: list) -> list:\n",
" ans = []\n",
" path = f\"{IMG_DIR_PATH}/{scan}/matterport_skybox_images\"\n",
" for i in range(6):\n",
" ans.append(Image.open(f'{path}/{viewpoint}_skybox{i}_sami.jpg', 'r'))\n",
" return ans\n",
"\n",
"def get_pano(scan: str, viewpoint: dict) -> Image:\n",
" pano = concat_images(get_skyboxes(scan, viewpoint))\n",
" pano = pano.resize([pano.size[0]//4, pano.size[1]//4])\n",
" return pano\n",
"\n",
"def display_img(im):\n",
" bio = BytesIO()\n",
" im.save(bio, format='png')\n",
" display(IPython.display.Image(bio.getvalue(), format='png'))\n",
"\n",
"def get_obj_name(bboxes_data: dict, scan: str, node: str, obj_id: str):\n",
" return bboxes_data[f'{scan}_{node}'][str(obj_id)]['name']\n",
" \n",
"def plot_gt_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n",
" for index, node in enumerate(traj):\n",
" print(f\"GT path ({index}): {node}\")\n",
" pano = get_pano(scan, node)\n",
" display_img(pano)\n",
" if index == len(traj)-1:\n",
" try:\n",
" obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n",
" print(f\"GT object: {obj_name}({obj_id})\")\n",
" except:\n",
" print(\"KEY ERROR\", scan, node, obj_id)\n",
" \n",
"def plot_predict_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n",
" for index, node in enumerate(traj):\n",
" node = node[0]\n",
" print(f\"Predict path ({index}): {node}\")\n",
" pano = get_pano(scan, node)\n",
" display_img(pano)\n",
" if index == len(traj)-1:\n",
" if obj_id != -1:\n",
" try:\n",
" obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n",
" print(f\"Predict object: {obj_name}({obj_id})\")\n",
" except:\n",
" print(\"KEY ERROR\", scan, node, obj_id)\n",
"\n",
" else:\n",
" print(f\"Predict object: NOT_FOUND\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# main \n",
"\n",
"annotations = load_annotations()\n",
"predicts = load_predicts()\n",
"scans_viewpoints = load_scan_viewpoints()\n",
"bboxes_data = load_bboxes_data()\n",
"\n",
"NUMBER = 3\n",
"success = 0\n",
"for i in range(NUMBER):\n",
" predict = predicts[i]\n",
" annotation = get_annotation(annotations, predict['instr_id'])\n",
" viewpoints = get_viewpoints(scans_viewpoints, annotation['scan'])\n",
" if annotation['found'] == predict['found']:\n",
" success += 1\n",
" print(i)\n",
" print(predict['instr_id'])\n",
" print(annotation['scan'])\n",
" print()\n",
" print(annotation['instructions'])\n",
" if annotation['found'] == False:\n",
" print(\"Original: \", annotation['original_instruction'])\n",
" print()\n",
" print(\"GT found: \", predict['gt_found'], annotation['found'])\n",
" print(\"predict found: \", predict['found'])\n",
" print(\"===\")\n",
" print(\"target object ID: \", annotation['objId'])\n",
" print(\"predicted object ID: \", predict['pred_objid'])\n",
"# print(annotation.keys())\n",
"# print(predict.keys())\n",
"# plot_all_scan_viewpoints(viewpoints)\n",
" plot_trajs(viewpoints, annotation, predict, overview=True)\n",
" plot_trajs(viewpoints, annotation, predict, overview=False)\n",
" \n",
" plot_gt_panos(annotation['scan'], annotation['path'], bboxes_data, obj_id=annotation['objId'])\n",
" print(\"* \" * 30)\n",
" plot_predict_panos(annotation['scan'], predict['trajectory'], bboxes_data, obj_id=predict['pred_objid'])\n",
" print(\"=\"*100)\n",
" # print(ans)\n",
"# print(viewpoints)\n",
"print(success/NUMBER)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

File diff suppressed because one or more lines are too long