From d63e1f0cd03e0d53093507ffbc53f71f2d36759b Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 15 Jan 2024 17:29:07 +0800 Subject: [PATCH] feat: visualizer --- .../visualizer-checkpoint.ipynb | 427 +++++++++++++++++- visualization/visualizer.ipynb | 427 +++++++++++++++++- 2 files changed, 848 insertions(+), 6 deletions(-) diff --git a/visualization/.ipynb_checkpoints/visualizer-checkpoint.ipynb b/visualization/.ipynb_checkpoints/visualizer-checkpoint.ipynb index 00995d0..32dd74d 100644 --- a/visualization/.ipynb_checkpoints/visualizer-checkpoint.ipynb +++ b/visualization/.ipynb_checkpoints/visualizer-checkpoint.ipynb @@ -124,13 +124,434 @@ "#0 ok\n", "#1 error 2 (not reasonable)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualizer\n", + "# data loader part\n", + "from PIL import Image\n", + "import networkx as nx\n", + "import json\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Rectangle\n", + "from io import BytesIO\n", + "import IPython.display\n", + "\n", + "# ================================== DATA DOWNLOAD ==================================\n", + "#\n", + "# DATASET: default to \"val unseen\"\n", + "# IMG_DIR_PATH: where to store the skyboxes images\n", + " # which the Matterport Sim used.\n", + "# ANNOTATION_DIR_PATH: where to store the \"adversarial instruction ver\" annotations\n", + " # download here: https://snsd0805.com/data/adversarial_annotations.zip\n", + "# PREDICTS_PATH: predict output (download here (duet version): )\n", + " # download here(duet's prediction): https://snsd0805.com/data/submit_val_unseen_dynamic.json\n", + "# CONNECTIVITY_PATH: where to store the connectivity file\n", + " # download here (provided by NavGPT): https://www.dropbox.com/sh/i8ng3iq5kpa68nu/AAB53bvCFY_ihYx1mkLlOB-ea?dl=1\n", + "# NAVIGABLE_PATH: \n", + " # the download link is same as CONNECTIVITY_PATH's\n", + "# BBOXES_FILE_PATH: \n", + " # download here: https://snsd0805.com/data/BBoxes.json\n", + "#\n", + "# ===================================================================================\n", + "DATASET = \"val_unseen\"\n", + "IMG_DIR_PATH = \"/home/snsd0805/code/research/VLN/base_dir/v1/scans\"\n", + "ANNOTATION_DIR_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_{DATASET}.json\"\n", + "PREDICTS_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/exprs_map/finetune/dagger-vitbase-seed.0/preds/submit_{DATASET}_dynamic.json\"\n", + "CONNECTIVITY_PATH = \"/data/Matterport3DSimulator-duet/connectivity\"\n", + "NAVIGABLE_PATH = \"/data/NavGPT_data/navigable\"\n", + "BBOXES_FILE_PATH = \"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json\"\n", + "\n", + "def load_predicts() -> dict:\n", + " with open(PREDICTS_PATH) as fp:\n", + " data = json.load(fp)\n", + " return data\n", + "\n", + "def load_annotations() -> dict:\n", + " with open(ANNOTATION_DIR_PATH) as fp:\n", + " origin_data = json.load(fp)\n", + " data = {}\n", + " for item in origin_data:\n", + " data[item['id']] = item\n", + " return data\n", + "\n", + "def get_annotation(annotations: dict, adversarial_instr_id: str) -> dict:\n", + " origin_instr_id, index = adversarial_instr_id[:-2], int(adversarial_instr_id[-1])\n", + " \n", + " ans = annotations[origin_instr_id]\n", + " if 'instructions_l' in ans:\n", + " del ans['instructions_l']\n", + " ans['original_instruction'] = ans['instructions'][0]\n", + " ans['instructions'] = ans['instructions'][index]\n", + " ans['found'] = ans['found'][index]\n", + "\n", + " return ans\n", + "\n", + "def load_scan_viewpoints():\n", + " '''\n", + " get all scan's viewpoints.\n", + " '''\n", + " scans_viewpoints = {}\n", + " \n", + " # load scan list\n", + " with open(f'{CONNECTIVITY_PATH}/scans.txt') as fp:\n", + " scans = [ scan.replace('\\n', '') for scan in fp.readlines() ]\n", + " \n", + " \n", + " # load all viewpoints from scan list\n", + " for scan in scans:\n", + " with open(f'{CONNECTIVITY_PATH}/{scan}_connectivity.json') as fp:\n", + " data = json.load(fp)\n", + " \n", + " # load navigable list\n", + " with open(f'{NAVIGABLE_PATH}/{scan}_navigable.json') as fp:\n", + " navigable_viewpoints = json.load(fp)\n", + " \n", + " # save all viewpoint in the scan\n", + " viewpoints = []\n", + " for viewpoint in data:\n", + " if viewpoint['included']:\n", + " viewpoints.append({\n", + " 'viewpoint': viewpoint['image_id'],\n", + " 'location': {\n", + " 'x': viewpoint['pose'][3],\n", + " 'y': viewpoint['pose'][7],\n", + " 'z': viewpoint['pose'][11],\n", + " },\n", + " 'navigable_viewpoints': list(navigable_viewpoints[viewpoint['image_id']].keys())\n", + " })\n", + " scans_viewpoints[scan] = viewpoints\n", + " return scans_viewpoints\n", + "\n", + "def get_viewpoints(scans_viewpoints: dict, scan: str) -> list:\n", + " return scans_viewpoints[scan]\n", + "\n", + "def load_bboxes_data() -> dict:\n", + " with open(BBOXES_FILE_PATH) as fp:\n", + " data = json.load(fp)\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot trajactory part\n", + "\n", + "def plot_all_scan_viewpoints(viewpoints: list, plot: bool=True):\n", + " nodes, edges = [], []\n", + " G = nx.Graph()\n", + " positions = {}\n", + "\n", + " for viewpoint in viewpoints:\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " viewpoint_id = viewpoint['viewpoint']\n", + " G.add_node(viewpoint_id)\n", + " positions[viewpoint_id] = (x, y)\n", + " for neighbor in viewpoint['navigable_viewpoints']:\n", + " G.add_edge(viewpoint_id, neighbor)\n", + " if plot:\n", + " nx.draw(G, positions, node_size=50)\n", + " plt.show() \n", + " else:\n", + " return G, positions\n", + "\n", + " \n", + " \n", + " \n", + "# overview\n", + "def plot_trajs_overview(viewpoints: list, gt_anno: dict, predict: dict):\n", + " G, positions = plot_all_scan_viewpoints(viewpoints, plot=False)\n", + " \n", + " edge_colors = [ 'black' for _ in G.edges ]\n", + "\n", + " max_x, max_y = -100, -100\n", + " min_x, min_y = 100, 100\n", + "\n", + " \n", + " # GT path\n", + " previous_node_name = None\n", + " for node_name in gt_anno['path']:\n", + " pos = positions[node_name]\n", + " max_x = max(max_x, pos[0])\n", + " max_y = max(max_y, pos[1])\n", + " min_x = min(min_x, pos[0])\n", + " min_y = min(min_y, pos[1])\n", + "\n", + " if previous_node_name != None:\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " edge_colors[index] = 'red'\n", + " previous_node_name = node_name\n", + " \n", + " # Predicted Path\n", + " previous_node_name = None\n", + " for node_name in predict['trajectory']:\n", + " node_name = node_name[0]\n", + " pos = positions[node_name]\n", + " max_x = max(max_x, pos[0])\n", + " max_y = max(max_y, pos[1])\n", + " min_x = min(min_x, pos[0])\n", + " min_y = min(min_y, pos[1])\n", + "\n", + " if previous_node_name != None:\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " if edge_colors[index] == 'red':\n", + " edge_colors[index] = 'orange'\n", + " else:\n", + " edge_colors[index] = 'blue'\n", + " previous_node_name = node_name\n", + "\n", + " current_axis = plt.gca()\n", + " current_axis.add_patch(Rectangle((min_x, min_y), max_x-min_x, max_y-min_y, fill=None, edgecolor='red', linewidth=0.3))\n", + "\n", + " \n", + " nx.draw(G, positions, node_size=20, edge_color=edge_colors)\n", + " plt.show()\n", + " \n", + " \n", + " \n", + " \n", + "# local traj\n", + "def plot_trajs_local(viewpoints: list, gt_anno: dict, predict: dict):\n", + " def get_viewpoints_dict(viewpoints) -> dict:\n", + " ans = {}\n", + " for viewpoint in viewpoints:\n", + " viewpoint_id = viewpoint['viewpoint']\n", + " ans[viewpoint_id] = viewpoint\n", + " return ans\n", + "\n", + " # list to dict, to find position\n", + " viewpoints = get_viewpoints_dict(viewpoints)\n", + "\n", + " # GT traj path\n", + " G = nx.Graph()\n", + " positions = {}\n", + " edge_colors = {}\n", + " previous_node_name = None\n", + " \n", + " # GT path\n", + " for index, node_name in enumerate(gt_anno['path']):\n", + " # new node\n", + " G.add_node(node_name)\n", + "\n", + " # find node's position\n", + " viewpoint = viewpoints[node_name]\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " positions[node_name] = (x, y)\n", + " \n", + " # path edge\n", + " if previous_node_name != None: # not start\n", + " G.add_edge(previous_node_name, node_name)\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'red'\n", + " previous_node_name = node_name\n", + " \n", + " plt.text(x-0.2, y, f'{index}', fontsize=15, color='red', ha='center', va='center')\n", + "\n", + " if node_name == gt_anno['path'][0]:\n", + " plt.text(x, y, 'START', fontsize=12, color='darkred', ha='center', va='center')\n", + " if node_name == gt_anno['path'][-1]:\n", + " plt.text(x, y, 'STOP', fontsize=12, color='darkred', ha='center', va='center')\n", + " \n", + " \n", + " # Predicted path\n", + " previous_node_name = None\n", + " for node_index, node_name in enumerate(predict['trajectory']):\n", + " node_name = node_name[0]\n", + " # new node\n", + " \n", + " G.add_node(node_name)\n", + "\n", + " # find node's position\n", + " viewpoint = viewpoints[node_name]\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " positions[node_name] = (x, y)\n", + " \n", + " plt.text(x+0.2, y, f'{node_index}', fontsize=15, color='blue', ha='center', va='center')\n", + " \n", + " # path edge\n", + " if previous_node_name != None: # not start\n", + " find = False\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " if f'{previous_node_name}_{node_name}' in edge_colors:\n", + " if edge_colors[f'{previous_node_name}_{node_name}'] == 'red':\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'orange'\n", + " else:\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'darkblue'\n", + " else:\n", + " if edge_colors[f'{node_name}_{previous_node_name}'] == 'red':\n", + " edge_colors[f'{node_name}_{previous_node_name}'] = 'orange'\n", + " else:\n", + " edge_colors[f'{node_name}_{previous_node_name}'] = 'darkblue'\n", + " find = True\n", + " break\n", + " if not find:\n", + " G.add_edge(previous_node_name, node_name)\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'blue'\n", + " previous_node_name = node_name\n", + " \n", + " # edge colors\n", + " colors = []\n", + " for edge in G.edges:\n", + " if f'{edge[0]}_{edge[1]}' in edge_colors:\n", + " colors.append(edge_colors[f'{edge[0]}_{edge[1]}'])\n", + " else:\n", + " colors.append(edge_colors[f'{edge[1]}_{edge[0]}'])\n", + " # STOP & START label\n", + " start, stop = predict['trajectory'][0][0], predict['trajectory'][-1][0]\n", + " x, y = viewpoints[start]['location']['x'], viewpoints[start]['location']['y']\n", + " plt.text(x, y, 'START', fontsize=12, color='darkblue', ha='center', va='center')\n", + " x, y = viewpoints[stop]['location']['x'], viewpoints[stop]['location']['y']\n", + " plt.text(x, y, 'STOP', fontsize=12, color='darkblue', ha='center', va='center')\n", + "\n", + " nx.draw(G, positions, node_size=20, edge_color=colors, with_labels=False)\n", + " plt.show()\n", + "\n", + "\n", + "\n", + "def plot_trajs(viewpoints: list, gt_anno: dict, predict: dict, overview: bool):\n", + " if overview:\n", + " plot_trajs_overview(viewpoints, gt_anno, predict)\n", + " else:\n", + " plot_trajs_local(viewpoints, gt_anno, predict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# skybox part\n", + "def concat_images(images):\n", + " #images = [Image.open(x) for x in ['Test1.jpg', 'Test2.jpg', 'Test3.jpg']]\n", + " widths, heights = zip(*(i.size for i in images))\n", + "\n", + " total_width = sum(widths)+(len(images)-1)*10\n", + " max_height = max(heights)\n", + "\n", + " new_im = Image.new('RGB', (total_width, max_height))\n", + "\n", + " x_offset = 0\n", + " for im in images:\n", + " new_im.paste(im, (x_offset,0))\n", + " x_offset += im.size[0]+10\n", + " return new_im\n", + "\n", + "def get_skyboxes(scan: str, viewpoint: list) -> list:\n", + " ans = []\n", + " path = f\"{IMG_DIR_PATH}/{scan}/matterport_skybox_images\"\n", + " for i in range(6):\n", + " ans.append(Image.open(f'{path}/{viewpoint}_skybox{i}_sami.jpg', 'r'))\n", + " return ans\n", + "\n", + "def get_pano(scan: str, viewpoint: dict) -> Image:\n", + " pano = concat_images(get_skyboxes(scan, viewpoint))\n", + " pano = pano.resize([pano.size[0]//4, pano.size[1]//4])\n", + " return pano\n", + "\n", + "def display_img(im):\n", + " bio = BytesIO()\n", + " im.save(bio, format='png')\n", + " display(IPython.display.Image(bio.getvalue(), format='png'))\n", + "\n", + "def get_obj_name(bboxes_data: dict, scan: str, node: str, obj_id: str):\n", + " return bboxes_data[f'{scan}_{node}'][str(obj_id)]['name']\n", + " \n", + "def plot_gt_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n", + " for index, node in enumerate(traj):\n", + " print(f\"GT path ({index}): {node}\")\n", + " pano = get_pano(scan, node)\n", + " display_img(pano)\n", + " if index == len(traj)-1:\n", + " try:\n", + " obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n", + " print(f\"GT object: {obj_name}({obj_id})\")\n", + " except:\n", + " print(\"KEY ERROR\", scan, node, obj_id)\n", + " \n", + "def plot_predict_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n", + " for index, node in enumerate(traj):\n", + " node = node[0]\n", + " print(f\"Predict path ({index}): {node}\")\n", + " pano = get_pano(scan, node)\n", + " display_img(pano)\n", + " if index == len(traj)-1:\n", + " if obj_id != -1:\n", + " try:\n", + " obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n", + " print(f\"Predict object: {obj_name}({obj_id})\")\n", + " except:\n", + " print(\"KEY ERROR\", scan, node, obj_id)\n", + "\n", + " else:\n", + " print(f\"Predict object: NOT_FOUND\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# main \n", + "\n", + "annotations = load_annotations()\n", + "predicts = load_predicts()\n", + "scans_viewpoints = load_scan_viewpoints()\n", + "bboxes_data = load_bboxes_data()\n", + "\n", + "NUMBER = 3\n", + "success = 0\n", + "for i in range(NUMBER):\n", + " predict = predicts[i]\n", + " annotation = get_annotation(annotations, predict['instr_id'])\n", + " viewpoints = get_viewpoints(scans_viewpoints, annotation['scan'])\n", + " if annotation['found'] == predict['found']:\n", + " success += 1\n", + " print(i)\n", + " print(predict['instr_id'])\n", + " print(annotation['scan'])\n", + " print()\n", + " print(annotation['instructions'])\n", + " if annotation['found'] == False:\n", + " print(\"Original: \", annotation['original_instruction'])\n", + " print()\n", + " print(\"GT found: \", predict['gt_found'], annotation['found'])\n", + " print(\"predict found: \", predict['found'])\n", + " print(\"===\")\n", + " print(\"target object ID: \", annotation['objId'])\n", + " print(\"predicted object ID: \", predict['pred_objid'])\n", + "# print(annotation.keys())\n", + "# print(predict.keys())\n", + "# plot_all_scan_viewpoints(viewpoints)\n", + " plot_trajs(viewpoints, annotation, predict, overview=True)\n", + " plot_trajs(viewpoints, annotation, predict, overview=False)\n", + " \n", + " plot_gt_panos(annotation['scan'], annotation['path'], bboxes_data, obj_id=annotation['objId'])\n", + " print(\"* \" * 30)\n", + " plot_predict_panos(annotation['scan'], predict['trajectory'], bboxes_data, obj_id=predict['pred_objid'])\n", + " print(\"=\"*100)\n", + " # print(ans)\n", + "# print(viewpoints)\n", + "print(success/NUMBER)" + ] } ], "metadata": { "kernelspec": { - "display_name": "minigpt4", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "minigpt4" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -142,7 +563,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/visualization/visualizer.ipynb b/visualization/visualizer.ipynb index 00995d0..32dd74d 100644 --- a/visualization/visualizer.ipynb +++ b/visualization/visualizer.ipynb @@ -124,13 +124,434 @@ "#0 ok\n", "#1 error 2 (not reasonable)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualizer\n", + "# data loader part\n", + "from PIL import Image\n", + "import networkx as nx\n", + "import json\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Rectangle\n", + "from io import BytesIO\n", + "import IPython.display\n", + "\n", + "# ================================== DATA DOWNLOAD ==================================\n", + "#\n", + "# DATASET: default to \"val unseen\"\n", + "# IMG_DIR_PATH: where to store the skyboxes images\n", + " # which the Matterport Sim used.\n", + "# ANNOTATION_DIR_PATH: where to store the \"adversarial instruction ver\" annotations\n", + " # download here: https://snsd0805.com/data/adversarial_annotations.zip\n", + "# PREDICTS_PATH: predict output (download here (duet version): )\n", + " # download here(duet's prediction): https://snsd0805.com/data/submit_val_unseen_dynamic.json\n", + "# CONNECTIVITY_PATH: where to store the connectivity file\n", + " # download here (provided by NavGPT): https://www.dropbox.com/sh/i8ng3iq5kpa68nu/AAB53bvCFY_ihYx1mkLlOB-ea?dl=1\n", + "# NAVIGABLE_PATH: \n", + " # the download link is same as CONNECTIVITY_PATH's\n", + "# BBOXES_FILE_PATH: \n", + " # download here: https://snsd0805.com/data/BBoxes.json\n", + "#\n", + "# ===================================================================================\n", + "DATASET = \"val_unseen\"\n", + "IMG_DIR_PATH = \"/home/snsd0805/code/research/VLN/base_dir/v1/scans\"\n", + "ANNOTATION_DIR_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_{DATASET}.json\"\n", + "PREDICTS_PATH = f\"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/exprs_map/finetune/dagger-vitbase-seed.0/preds/submit_{DATASET}_dynamic.json\"\n", + "CONNECTIVITY_PATH = \"/data/Matterport3DSimulator-duet/connectivity\"\n", + "NAVIGABLE_PATH = \"/data/NavGPT_data/navigable\"\n", + "BBOXES_FILE_PATH = \"/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json\"\n", + "\n", + "def load_predicts() -> dict:\n", + " with open(PREDICTS_PATH) as fp:\n", + " data = json.load(fp)\n", + " return data\n", + "\n", + "def load_annotations() -> dict:\n", + " with open(ANNOTATION_DIR_PATH) as fp:\n", + " origin_data = json.load(fp)\n", + " data = {}\n", + " for item in origin_data:\n", + " data[item['id']] = item\n", + " return data\n", + "\n", + "def get_annotation(annotations: dict, adversarial_instr_id: str) -> dict:\n", + " origin_instr_id, index = adversarial_instr_id[:-2], int(adversarial_instr_id[-1])\n", + " \n", + " ans = annotations[origin_instr_id]\n", + " if 'instructions_l' in ans:\n", + " del ans['instructions_l']\n", + " ans['original_instruction'] = ans['instructions'][0]\n", + " ans['instructions'] = ans['instructions'][index]\n", + " ans['found'] = ans['found'][index]\n", + "\n", + " return ans\n", + "\n", + "def load_scan_viewpoints():\n", + " '''\n", + " get all scan's viewpoints.\n", + " '''\n", + " scans_viewpoints = {}\n", + " \n", + " # load scan list\n", + " with open(f'{CONNECTIVITY_PATH}/scans.txt') as fp:\n", + " scans = [ scan.replace('\\n', '') for scan in fp.readlines() ]\n", + " \n", + " \n", + " # load all viewpoints from scan list\n", + " for scan in scans:\n", + " with open(f'{CONNECTIVITY_PATH}/{scan}_connectivity.json') as fp:\n", + " data = json.load(fp)\n", + " \n", + " # load navigable list\n", + " with open(f'{NAVIGABLE_PATH}/{scan}_navigable.json') as fp:\n", + " navigable_viewpoints = json.load(fp)\n", + " \n", + " # save all viewpoint in the scan\n", + " viewpoints = []\n", + " for viewpoint in data:\n", + " if viewpoint['included']:\n", + " viewpoints.append({\n", + " 'viewpoint': viewpoint['image_id'],\n", + " 'location': {\n", + " 'x': viewpoint['pose'][3],\n", + " 'y': viewpoint['pose'][7],\n", + " 'z': viewpoint['pose'][11],\n", + " },\n", + " 'navigable_viewpoints': list(navigable_viewpoints[viewpoint['image_id']].keys())\n", + " })\n", + " scans_viewpoints[scan] = viewpoints\n", + " return scans_viewpoints\n", + "\n", + "def get_viewpoints(scans_viewpoints: dict, scan: str) -> list:\n", + " return scans_viewpoints[scan]\n", + "\n", + "def load_bboxes_data() -> dict:\n", + " with open(BBOXES_FILE_PATH) as fp:\n", + " data = json.load(fp)\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot trajactory part\n", + "\n", + "def plot_all_scan_viewpoints(viewpoints: list, plot: bool=True):\n", + " nodes, edges = [], []\n", + " G = nx.Graph()\n", + " positions = {}\n", + "\n", + " for viewpoint in viewpoints:\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " viewpoint_id = viewpoint['viewpoint']\n", + " G.add_node(viewpoint_id)\n", + " positions[viewpoint_id] = (x, y)\n", + " for neighbor in viewpoint['navigable_viewpoints']:\n", + " G.add_edge(viewpoint_id, neighbor)\n", + " if plot:\n", + " nx.draw(G, positions, node_size=50)\n", + " plt.show() \n", + " else:\n", + " return G, positions\n", + "\n", + " \n", + " \n", + " \n", + "# overview\n", + "def plot_trajs_overview(viewpoints: list, gt_anno: dict, predict: dict):\n", + " G, positions = plot_all_scan_viewpoints(viewpoints, plot=False)\n", + " \n", + " edge_colors = [ 'black' for _ in G.edges ]\n", + "\n", + " max_x, max_y = -100, -100\n", + " min_x, min_y = 100, 100\n", + "\n", + " \n", + " # GT path\n", + " previous_node_name = None\n", + " for node_name in gt_anno['path']:\n", + " pos = positions[node_name]\n", + " max_x = max(max_x, pos[0])\n", + " max_y = max(max_y, pos[1])\n", + " min_x = min(min_x, pos[0])\n", + " min_y = min(min_y, pos[1])\n", + "\n", + " if previous_node_name != None:\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " edge_colors[index] = 'red'\n", + " previous_node_name = node_name\n", + " \n", + " # Predicted Path\n", + " previous_node_name = None\n", + " for node_name in predict['trajectory']:\n", + " node_name = node_name[0]\n", + " pos = positions[node_name]\n", + " max_x = max(max_x, pos[0])\n", + " max_y = max(max_y, pos[1])\n", + " min_x = min(min_x, pos[0])\n", + " min_y = min(min_y, pos[1])\n", + "\n", + " if previous_node_name != None:\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " if edge_colors[index] == 'red':\n", + " edge_colors[index] = 'orange'\n", + " else:\n", + " edge_colors[index] = 'blue'\n", + " previous_node_name = node_name\n", + "\n", + " current_axis = plt.gca()\n", + " current_axis.add_patch(Rectangle((min_x, min_y), max_x-min_x, max_y-min_y, fill=None, edgecolor='red', linewidth=0.3))\n", + "\n", + " \n", + " nx.draw(G, positions, node_size=20, edge_color=edge_colors)\n", + " plt.show()\n", + " \n", + " \n", + " \n", + " \n", + "# local traj\n", + "def plot_trajs_local(viewpoints: list, gt_anno: dict, predict: dict):\n", + " def get_viewpoints_dict(viewpoints) -> dict:\n", + " ans = {}\n", + " for viewpoint in viewpoints:\n", + " viewpoint_id = viewpoint['viewpoint']\n", + " ans[viewpoint_id] = viewpoint\n", + " return ans\n", + "\n", + " # list to dict, to find position\n", + " viewpoints = get_viewpoints_dict(viewpoints)\n", + "\n", + " # GT traj path\n", + " G = nx.Graph()\n", + " positions = {}\n", + " edge_colors = {}\n", + " previous_node_name = None\n", + " \n", + " # GT path\n", + " for index, node_name in enumerate(gt_anno['path']):\n", + " # new node\n", + " G.add_node(node_name)\n", + "\n", + " # find node's position\n", + " viewpoint = viewpoints[node_name]\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " positions[node_name] = (x, y)\n", + " \n", + " # path edge\n", + " if previous_node_name != None: # not start\n", + " G.add_edge(previous_node_name, node_name)\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'red'\n", + " previous_node_name = node_name\n", + " \n", + " plt.text(x-0.2, y, f'{index}', fontsize=15, color='red', ha='center', va='center')\n", + "\n", + " if node_name == gt_anno['path'][0]:\n", + " plt.text(x, y, 'START', fontsize=12, color='darkred', ha='center', va='center')\n", + " if node_name == gt_anno['path'][-1]:\n", + " plt.text(x, y, 'STOP', fontsize=12, color='darkred', ha='center', va='center')\n", + " \n", + " \n", + " # Predicted path\n", + " previous_node_name = None\n", + " for node_index, node_name in enumerate(predict['trajectory']):\n", + " node_name = node_name[0]\n", + " # new node\n", + " \n", + " G.add_node(node_name)\n", + "\n", + " # find node's position\n", + " viewpoint = viewpoints[node_name]\n", + " x, y = viewpoint['location']['x'], viewpoint['location']['y']\n", + " positions[node_name] = (x, y)\n", + " \n", + " plt.text(x+0.2, y, f'{node_index}', fontsize=15, color='blue', ha='center', va='center')\n", + " \n", + " # path edge\n", + " if previous_node_name != None: # not start\n", + " find = False\n", + " for index, edge in enumerate(G.edges):\n", + " if edge == (previous_node_name, node_name) or edge == (node_name, previous_node_name):\n", + " if f'{previous_node_name}_{node_name}' in edge_colors:\n", + " if edge_colors[f'{previous_node_name}_{node_name}'] == 'red':\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'orange'\n", + " else:\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'darkblue'\n", + " else:\n", + " if edge_colors[f'{node_name}_{previous_node_name}'] == 'red':\n", + " edge_colors[f'{node_name}_{previous_node_name}'] = 'orange'\n", + " else:\n", + " edge_colors[f'{node_name}_{previous_node_name}'] = 'darkblue'\n", + " find = True\n", + " break\n", + " if not find:\n", + " G.add_edge(previous_node_name, node_name)\n", + " edge_colors[f'{previous_node_name}_{node_name}'] = 'blue'\n", + " previous_node_name = node_name\n", + " \n", + " # edge colors\n", + " colors = []\n", + " for edge in G.edges:\n", + " if f'{edge[0]}_{edge[1]}' in edge_colors:\n", + " colors.append(edge_colors[f'{edge[0]}_{edge[1]}'])\n", + " else:\n", + " colors.append(edge_colors[f'{edge[1]}_{edge[0]}'])\n", + " # STOP & START label\n", + " start, stop = predict['trajectory'][0][0], predict['trajectory'][-1][0]\n", + " x, y = viewpoints[start]['location']['x'], viewpoints[start]['location']['y']\n", + " plt.text(x, y, 'START', fontsize=12, color='darkblue', ha='center', va='center')\n", + " x, y = viewpoints[stop]['location']['x'], viewpoints[stop]['location']['y']\n", + " plt.text(x, y, 'STOP', fontsize=12, color='darkblue', ha='center', va='center')\n", + "\n", + " nx.draw(G, positions, node_size=20, edge_color=colors, with_labels=False)\n", + " plt.show()\n", + "\n", + "\n", + "\n", + "def plot_trajs(viewpoints: list, gt_anno: dict, predict: dict, overview: bool):\n", + " if overview:\n", + " plot_trajs_overview(viewpoints, gt_anno, predict)\n", + " else:\n", + " plot_trajs_local(viewpoints, gt_anno, predict)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# skybox part\n", + "def concat_images(images):\n", + " #images = [Image.open(x) for x in ['Test1.jpg', 'Test2.jpg', 'Test3.jpg']]\n", + " widths, heights = zip(*(i.size for i in images))\n", + "\n", + " total_width = sum(widths)+(len(images)-1)*10\n", + " max_height = max(heights)\n", + "\n", + " new_im = Image.new('RGB', (total_width, max_height))\n", + "\n", + " x_offset = 0\n", + " for im in images:\n", + " new_im.paste(im, (x_offset,0))\n", + " x_offset += im.size[0]+10\n", + " return new_im\n", + "\n", + "def get_skyboxes(scan: str, viewpoint: list) -> list:\n", + " ans = []\n", + " path = f\"{IMG_DIR_PATH}/{scan}/matterport_skybox_images\"\n", + " for i in range(6):\n", + " ans.append(Image.open(f'{path}/{viewpoint}_skybox{i}_sami.jpg', 'r'))\n", + " return ans\n", + "\n", + "def get_pano(scan: str, viewpoint: dict) -> Image:\n", + " pano = concat_images(get_skyboxes(scan, viewpoint))\n", + " pano = pano.resize([pano.size[0]//4, pano.size[1]//4])\n", + " return pano\n", + "\n", + "def display_img(im):\n", + " bio = BytesIO()\n", + " im.save(bio, format='png')\n", + " display(IPython.display.Image(bio.getvalue(), format='png'))\n", + "\n", + "def get_obj_name(bboxes_data: dict, scan: str, node: str, obj_id: str):\n", + " return bboxes_data[f'{scan}_{node}'][str(obj_id)]['name']\n", + " \n", + "def plot_gt_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n", + " for index, node in enumerate(traj):\n", + " print(f\"GT path ({index}): {node}\")\n", + " pano = get_pano(scan, node)\n", + " display_img(pano)\n", + " if index == len(traj)-1:\n", + " try:\n", + " obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n", + " print(f\"GT object: {obj_name}({obj_id})\")\n", + " except:\n", + " print(\"KEY ERROR\", scan, node, obj_id)\n", + " \n", + "def plot_predict_panos(scan: str, traj: list, bboxes_data: dict, obj_id: str):\n", + " for index, node in enumerate(traj):\n", + " node = node[0]\n", + " print(f\"Predict path ({index}): {node}\")\n", + " pano = get_pano(scan, node)\n", + " display_img(pano)\n", + " if index == len(traj)-1:\n", + " if obj_id != -1:\n", + " try:\n", + " obj_name = get_obj_name(bboxes_data, scan, node, obj_id)\n", + " print(f\"Predict object: {obj_name}({obj_id})\")\n", + " except:\n", + " print(\"KEY ERROR\", scan, node, obj_id)\n", + "\n", + " else:\n", + " print(f\"Predict object: NOT_FOUND\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# main \n", + "\n", + "annotations = load_annotations()\n", + "predicts = load_predicts()\n", + "scans_viewpoints = load_scan_viewpoints()\n", + "bboxes_data = load_bboxes_data()\n", + "\n", + "NUMBER = 3\n", + "success = 0\n", + "for i in range(NUMBER):\n", + " predict = predicts[i]\n", + " annotation = get_annotation(annotations, predict['instr_id'])\n", + " viewpoints = get_viewpoints(scans_viewpoints, annotation['scan'])\n", + " if annotation['found'] == predict['found']:\n", + " success += 1\n", + " print(i)\n", + " print(predict['instr_id'])\n", + " print(annotation['scan'])\n", + " print()\n", + " print(annotation['instructions'])\n", + " if annotation['found'] == False:\n", + " print(\"Original: \", annotation['original_instruction'])\n", + " print()\n", + " print(\"GT found: \", predict['gt_found'], annotation['found'])\n", + " print(\"predict found: \", predict['found'])\n", + " print(\"===\")\n", + " print(\"target object ID: \", annotation['objId'])\n", + " print(\"predicted object ID: \", predict['pred_objid'])\n", + "# print(annotation.keys())\n", + "# print(predict.keys())\n", + "# plot_all_scan_viewpoints(viewpoints)\n", + " plot_trajs(viewpoints, annotation, predict, overview=True)\n", + " plot_trajs(viewpoints, annotation, predict, overview=False)\n", + " \n", + " plot_gt_panos(annotation['scan'], annotation['path'], bboxes_data, obj_id=annotation['objId'])\n", + " print(\"* \" * 30)\n", + " plot_predict_panos(annotation['scan'], predict['trajectory'], bboxes_data, obj_id=predict['pred_objid'])\n", + " print(\"=\"*100)\n", + " # print(ans)\n", + "# print(viewpoints)\n", + "print(success/NUMBER)" + ] } ], "metadata": { "kernelspec": { - "display_name": "minigpt4", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "minigpt4" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -142,7 +563,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.8.10" } }, "nbformat": 4,