{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "3619bc9d", "metadata": {}, "outputs": [], "source": [ "import json, os\n", "import openai\n", "from openai import OpenAI\n", "from datetime import datetime\n", "from typing import Tuple, List\n", "import spacy\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "# ==============================================\n", "# !!!!! OPEN AI API KEY HERE !!!!!\n", "# Please delete your API key before the commit.\n", "# ==============================================\n", "OPENAI_API_KEY = ''\n", "\n", "DATASET = 'val_seen'\n", "#!this version is proccess on the adversarial instruction file generated by H.T., not original REVERIE data\n", "REVERIE_TRAIN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'\n", "REVERIE_VAL_UNSEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_unseen.json'\n", "REVERIE_VAL_SEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_seen.json'\n", "BBOXES_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json'\n", "NAVIGABLE_PATH = '/data/NavGPT_data/navigable'\n", "SKYBOX_PATH = '/home/snsd0805/code/research/VLN/base_dir/v1/scans'" ] }, { "cell_type": "code", "execution_count": 2, "id": "f33686ff", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "REVERIE_TRAIN_JSON_FILE" ] }, { "cell_type": "code", "execution_count": 3, "id": "ce31f8d8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting en-core-web-sm==3.7.1\n", " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n", "\u001b[K |████████████████████████████████| 12.8 MB 1.6 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied: spacy<3.8.0,>=3.7.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from en-core-web-sm==3.7.1) (3.7.2)\n", "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n", "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.5.2)\n", "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n", "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n", "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n", "Requirement already satisfied: jinja2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.2)\n", "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n", "Requirement already satisfied: numpy>=1.15.0; python_version < \"3.9\" in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.24.4)\n", "Requirement already satisfied: packaging>=20.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (23.2)\n", "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.1)\n", "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n", "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n", "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n", "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n", "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n", "Requirement already satisfied: thinc<8.3.0,>=8.1.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.2)\n", "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n", "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.0)\n", "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n", "Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (45.2.0)\n", "Requirement already satisfied: typing-extensions>=4.6.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.8.0)\n", "Requirement already satisfied: pydantic-core==2.14.5 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.14.5)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.6.0)\n", "Requirement already satisfied: confection<0.2.0,>=0.0.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n", "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/snsd0805/.local/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.25.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2019.11.28)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.8)\n", "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from thinc<8.3.0,>=8.1.8->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n", "Requirement already satisfied: click<9.0.0,>=7.1.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n", "Installing collected packages: en-core-web-sm\n", "Successfully installed en-core-web-sm-3.7.1\n", "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", "You can now load the package via spacy.load('en_core_web_sm')\n" ] } ], "source": [ "! python3 -m spacy download en_core_web_sm" ] }, { "cell_type": "code", "execution_count": 6, "id": "781790cd", "metadata": {}, "outputs": [], "source": [ "def load_json(fn):\n", " with open(fn) as f:\n", " ret = json.load(f)\n", " return ret\n", "\n", "def dump_json(data, fn, force=False):\n", " if not force:\n", " assert not os.path.exists(fn)\n", " with open(fn, 'w') as f:\n", " json.dump(data, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "b2159db6", "metadata": {}, "outputs": [], "source": [ "# ============================================================\n", "# ChatGPT to extract the room & the target object\n", "# ============================================================" ] }, { "cell_type": "code", "execution_count": 7, "id": "29fe59c8", "metadata": {}, "outputs": [], "source": [ "TEMPLATE = '''\n", "Please extract the target room, the goal object, and the relations between the goal object and other reference objects.\n", "Example:\n", "inputs:\n", "{}\n", "outputs:\n", "{}\n", "Now it is your turn:\n", "inputs: \n", "___inputs___\n", "outputs:\n", "'''\n", "\n", "def get_template() -> str:\n", " inputs = 'In the kitchen above the brown shelves and above the basket on the top shelf there is a ceiling beam Please firm this beam'\n", " outputs = {\n", " 'room': 'kitchen',\n", " 'goal': 'ceiling beam',\n", " 'goal_relations':[\n", " {'relation': 'above', 'reference': 'brown shelves'},\n", " {'relation': 'above', 'reference': 'basket'},\n", " {'relation': 'on the top', 'reference': 'shelf'},\n", " ]\n", " }\n", " template = TEMPLATE.format(inputs, json.dumps(outputs, indent=4))\n", " return template" ] }, { "cell_type": "code", "execution_count": 8, "id": "3e146e99", "metadata": {}, "outputs": [], "source": [ "def query(openai: OpenAI, prompt: str) -> Tuple[str, int]:\n", " response = client.chat.completions.create(\n", " model=\"gpt-3.5-turbo-1106\",\n", " response_format={ \"type\": \"json_object\" },\n", " messages=[\n", " {\"role\": \"system\", \"content\": \"Please output JSON.\"},\n", " {\"role\": \"user\", \"content\": prompt}\n", " ]\n", " )\n", " \n", " return (\n", " json.loads(response.choices[0].message.content),\n", " response.usage.total_tokens\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "id": "b334e925", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2996\n", "1096\n", "750\n" ] } ], "source": [ "reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n", "reverie_val_unseen = load_json(REVERIE_VAL_UNSEEN_JSON_FILE)\n", "reverie_val_seen = load_json(REVERIE_VAL_SEEN_JSON_FILE)\n", "print(len(reverie_train))\n", "print(len(reverie_val_unseen))\n", "print(len(reverie_val_seen))" ] }, { "cell_type": "code", "execution_count": 10, "id": "0917c229", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "val_seen\n", "\n", "2024-01-28 16:45:49.797527 0 / 750\n", " use 0 tokens\n", "..........\n", "2024-01-28 16:46:04.932742 10 / 750\n", " use 2480 tokens\n", "..........\n", "2024-01-28 16:46:23.465098 20 / 750\n", " use 5206 tokens\n", "..........\n", "2024-01-28 16:46:38.780172 30 / 750\n", " use 7705 tokens\n", "..........\n", "2024-01-28 16:46:52.773380 40 / 750\n", " use 10161 tokens\n", "..........\n", "2024-01-28 16:47:05.971142 50 / 750\n", " use 12556 tokens\n", "..........\n", "2024-01-28 16:47:19.135185 60 / 750\n", " use 14983 tokens\n", "..........\n", "2024-01-28 16:47:31.982835 70 / 750\n", " use 17418 tokens\n", "..........\n", "2024-01-28 16:47:44.867015 80 / 750\n", " use 19848 tokens\n", "..........\n", "2024-01-28 16:47:58.097011 90 / 750\n", " use 22322 tokens\n", "..........\n", "2024-01-28 16:48:09.810830 100 / 750\n", " use 24670 tokens\n", "..........\n", "2024-01-28 16:48:25.051226 110 / 750\n", " use 27178 tokens\n", "..........\n", "2024-01-28 16:48:39.153634 120 / 750\n", " use 29675 tokens\n", "..........\n", "2024-01-28 16:48:52.691056 130 / 750\n", " use 32131 tokens\n", "..........\n", "2024-01-28 16:49:08.061178 140 / 750\n", " use 34721 tokens\n", "..........\n", "2024-01-28 16:49:20.385052 150 / 750\n", " use 37036 tokens\n", "..........\n", "2024-01-28 16:49:35.084404 160 / 750\n", " use 39550 tokens\n", "..........\n", "2024-01-28 16:49:48.060591 170 / 750\n", " use 41995 tokens\n", "..........\n", "2024-01-28 16:50:00.565658 180 / 750\n", " use 44445 tokens\n", "..........\n", "2024-01-28 16:50:13.490748 190 / 750\n", " use 46862 tokens\n", "..........\n", "2024-01-28 16:50:26.500824 200 / 750\n", " use 49205 tokens\n", "..........\n", "2024-01-28 16:50:38.342902 210 / 750\n", " use 51623 tokens\n", "..........\n", "2024-01-28 16:50:50.995957 220 / 750\n", " use 54082 tokens\n", "..........\n", "2024-01-28 16:51:02.467171 230 / 750\n", " use 56477 tokens\n", "..........\n", "2024-01-28 16:51:15.902792 240 / 750\n", " use 58961 tokens\n", "..........\n", "2024-01-28 16:51:28.502589 250 / 750\n", " use 61385 tokens\n", "..........\n", "2024-01-28 16:51:41.633834 260 / 750\n", " use 63868 tokens\n", "..........\n", "2024-01-28 16:51:54.141903 270 / 750\n", " use 66339 tokens\n", "..........\n", "2024-01-28 16:52:08.683939 280 / 750\n", " use 68949 tokens\n", "..........\n", "2024-01-28 16:52:22.242175 290 / 750\n", " use 71424 tokens\n", "..........\n", "2024-01-28 16:52:35.615204 300 / 750\n", " use 73965 tokens\n", "..........\n", "2024-01-28 16:52:47.855405 310 / 750\n", " use 76411 tokens\n", "..........\n", "2024-01-28 16:53:01.057109 320 / 750\n", " use 78988 tokens\n", "..........\n", "2024-01-28 16:53:14.308833 330 / 750\n", " use 81496 tokens\n", "..........\n", "2024-01-28 16:53:27.172938 340 / 750\n", " use 83995 tokens\n", "..........\n", "2024-01-28 16:53:41.768159 350 / 750\n", " use 86587 tokens\n", "..........\n", "2024-01-28 16:53:54.263125 360 / 750\n", " use 89018 tokens\n", "..........\n", "2024-01-28 16:54:05.971150 370 / 750\n", " use 91426 tokens\n", "..........\n", "2024-01-28 16:54:18.277461 380 / 750\n", " use 93908 tokens\n", "..........\n", "2024-01-28 16:54:30.239469 390 / 750\n", " use 96333 tokens\n", "..........\n", "2024-01-28 16:54:43.432008 400 / 750\n", " use 98827 tokens\n", "..........\n", "2024-01-28 16:54:54.753545 410 / 750\n", " use 101234 tokens\n", "..........\n", "2024-01-28 16:55:08.248909 420 / 750\n", " use 103824 tokens\n", "..........\n", "2024-01-28 16:55:20.266048 430 / 750\n", " use 106241 tokens\n", "..........\n", "2024-01-28 16:55:31.412194 440 / 750\n", " use 108579 tokens\n", "..........\n", "2024-01-28 16:55:44.964667 450 / 750\n", " use 111007 tokens\n", "..........\n", "2024-01-28 16:55:56.905568 460 / 750\n", " use 113448 tokens\n", "..........\n", "2024-01-28 16:56:09.677100 470 / 750\n", " use 115910 tokens\n", "..........\n", "2024-01-28 16:56:23.082325 480 / 750\n", " use 118380 tokens\n", "..........\n", "2024-01-28 16:56:35.981859 490 / 750\n", " use 120884 tokens\n", "..........\n", "2024-01-28 16:56:50.095850 500 / 750\n", " use 123367 tokens\n", "..........\n", "2024-01-28 16:57:01.456319 510 / 750\n", " use 125768 tokens\n", "..........\n", "2024-01-28 16:57:13.588327 520 / 750\n", " use 128230 tokens\n", "..........\n", "2024-01-28 16:57:27.429472 530 / 750\n", " use 130816 tokens\n", "..........\n", "2024-01-28 16:57:39.262664 540 / 750\n", " use 133245 tokens\n", "..........\n", "2024-01-28 16:57:52.110883 550 / 750\n", " use 135746 tokens\n", "..........\n", "2024-01-28 16:58:05.209857 560 / 750\n", " use 138337 tokens\n", "..........\n", "2024-01-28 16:58:17.967249 570 / 750\n", " use 140782 tokens\n", "..........\n", "2024-01-28 16:58:29.888118 580 / 750\n", " use 143165 tokens\n", "..........\n", "2024-01-28 16:58:42.603735 590 / 750\n", " use 145588 tokens\n", "..........\n", "2024-01-28 16:58:55.585690 600 / 750\n", " use 148139 tokens\n", "..........\n", "2024-01-28 16:59:07.646169 610 / 750\n", " use 150567 tokens\n", "..........\n", "2024-01-28 16:59:21.987738 620 / 750\n", " use 153205 tokens\n", "..........\n", "2024-01-28 16:59:34.235770 630 / 750\n", " use 155655 tokens\n", "..........\n", "2024-01-28 16:59:47.266655 640 / 750\n", " use 158165 tokens\n", "..........\n", "2024-01-28 16:59:58.403109 650 / 750\n", " use 160519 tokens\n", "..........\n", "2024-01-28 17:00:10.445611 660 / 750\n", " use 162942 tokens\n", "..........\n", "2024-01-28 17:00:24.864396 670 / 750\n", " use 165361 tokens\n", "..........\n", "2024-01-28 17:00:37.851239 680 / 750\n", " use 167854 tokens\n", "..........\n", "2024-01-28 17:00:50.721513 690 / 750\n", " use 170355 tokens\n", "..........\n", "2024-01-28 17:01:04.352591 700 / 750\n", " use 172922 tokens\n", "..........\n", "2024-01-28 17:01:15.395867 710 / 750\n", " use 175265 tokens\n", "..........\n", "2024-01-28 17:01:29.026544 720 / 750\n", " use 177693 tokens\n", "..........\n", "2024-01-28 17:01:42.652462 730 / 750\n", " use 180203 tokens\n", "..........\n", "2024-01-28 17:01:57.032790 740 / 750\n", " use 182718 tokens\n", ".........." ] } ], "source": [ "# OpenAI GPT to extract the relationship between objects\n", "client = OpenAI(api_key=OPENAI_API_KEY)\n", "template = get_template()\n", "\n", "logs = {}\n", "tokens = 0\n", "\n", "if DATASET == 'train':\n", " print('train')\n", " dataset = reverie_train\n", "elif DATASET == 'val_unseen':\n", " print('val_unseen')\n", " dataset = reverie_val_unseen\n", "else:\n", " print('val_seen')\n", " dataset = reverie_val_seen\n", "\n", "for idx, r in enumerate(dataset):\n", " if idx%10==0:\n", " dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n", " print('\\n', end='')\n", " print(datetime.now(), idx, '/', len(dataset))\n", " print(f\" use {tokens} tokens\")\n", " \n", " instruction = r['instructions'][0]\n", " prompt = template.replace(\"___inputs___\", instruction)\n", " response, total_tokens = query(client, prompt)\n", " tokens += total_tokens\n", " \n", " query_name = f'reverie__{DATASET}__{idx}__0'\n", " logs[query_name] = response\n", " print('.', end='')\n", " \n", "dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "175cc6fb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Used 185368 tokens\n" ] } ], "source": [ "print(f'Used {tokens} tokens')" ] }, { "cell_type": "code", "execution_count": null, "id": "d6596246", "metadata": {}, "outputs": [], "source": [ "# ============================================================\n", "# Method #1 : ChatGPT api to replace the target object\n", "# ============================================================" ] }, { "cell_type": "code", "execution_count": 17, "id": "0c83e389", "metadata": {}, "outputs": [], "source": [ "def show_skybox(scan: str, viewpoint:str) -> None:\n", " img_path = f'{SKYBOX_PATH}/{scan}/matterport_skybox_images/{viewpoint}_skybox_small.jpg'\n", " im = Image.open(img_path)\n", " from matplotlib import rcParams\n", "\n", " plt.imshow(im)\n", " plt.show()\n", " print(img_path)\n", "\n", "\n", "def get_navigable_viewpoints(scan: str, viewpoint: str) -> list:\n", " '''\n", " Get all neighbor vps around it.\n", " '''\n", " data = load_json(f'{NAVIGABLE_PATH}/{scan}_navigable.json') \n", " navigable_viewpoints = []\n", " for k, v in data[viewpoint].items():\n", " navigable_viewpoints.append(k)\n", " \n", " return navigable_viewpoints\n", "\n", "def get_objects_in_the_rooms(bboxes: dict, scan: str, viewpoint: str) -> list:\n", " '''\n", " Get all touchable objects around this viewpoint.\n", " \n", " Touchable: define by REVERIE datasets, means the objects is close to this point (maybe 1m).\n", " '''\n", " objs = set()\n", " for k, v in bboxes[f'{scan}_{viewpoint}'].items():\n", " objs.add(v['name'].replace('#', ' '))\n", " return list(objs)\n", "\n", "def get_avoid_objs(bboxes: dict, scan: str, viewpoint: str) -> list:\n", " '''\n", " Get objects around this viewpoint\n", " \n", " First, it call get_navigable_viewpoints() to get the neighbor viewpoints.\n", " Then, it get all the objects around its neighbor, we assume these objects is all visible bbox in this room\n", " We need this list to avoid generating the objects that exist in this room\n", "\n", " '''\n", " vps = get_navigable_viewpoints(scan, viewpoint)\n", " objs = get_objects_in_the_rooms(bboxes, scan, viewpoint)\n", " for i in vps:\n", " tmp_objs = get_objects_in_the_rooms(bboxes, scan, i)\n", " objs += tmp_objs\n", " \n", " return list(set(objs))\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "ae77b862", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "You should replace the target object and return me a new instruction.\n", "Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n", "Some times, you can change the verb which suitable for the new target objects.\n", "\n", "Example:\n", "inputs:\n", "{\n", " \"instruction\": \"Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door\",\n", " \"room_name\": \"bedroom\",\n", " \"avoid_objects\": [\n", " \"window\",\n", " \"lamp\",\n", " \"picture\",\n", " \"bed\"\n", " ],\n", " \"target_object\": \"lamp\"\n", "}\n", "outputs:\n", "{\n", " \"new_instruction\": \"Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door\"\n", "}\n", "Now it is your turn:\n", "inputs: \n", "___inputs___\n", "outputs:\n", "\n" ] } ], "source": [ "REPLACE_OBJECT_TEMPLATE = '''\n", "You should replace the target object and return me a new instruction.\n", "Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n", "Some times, you can change the verb which suitable for the new target objects.\n", "\n", "Example:\n", "inputs:\n", "{}\n", "outputs:\n", "{}\n", "Now it is your turn:\n", "inputs: \n", "___inputs___\n", "outputs:\n", "'''\n", "\n", "def get_replace_object_template() -> str:\n", " inputs = {\n", " 'instruction': 'Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door',\n", " 'room_name': 'bedroom',\n", " 'avoid_objects': ['window', 'lamp', 'picture', 'bed'],\n", " 'target_object': 'lamp'\n", " }\n", " outputs = {\n", " 'new_instruction': 'Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door',\n", " }\n", " template = REPLACE_OBJECT_TEMPLATE.format(json.dumps(inputs, indent=4), json.dumps(outputs, indent=4))\n", " return template\n", "print(get_replace_object_template())" ] }, { "cell_type": "code", "execution_count": 14, "id": "70c9888b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "2024-01-28 17:02:12.982758 0 / 750\n", " use 0 tokens\n", "..........\n", "2024-01-28 17:02:21.516631 10 / 750\n", " use 3207 tokens\n", "..........\n", "2024-01-28 17:02:30.687271 20 / 750\n", " use 6593 tokens\n", "..........\n", "2024-01-28 17:02:38.764408 30 / 750\n", " use 9909 tokens\n", "..........\n", "2024-01-28 17:02:46.047457 40 / 750\n", " use 13275 tokens\n", "..........\n", "2024-01-28 17:02:53.616187 50 / 750\n", " use 16683 tokens\n", "..........\n", "2024-01-28 17:03:00.754883 60 / 750\n", " use 19998 tokens\n", "..........\n", "2024-01-28 17:03:10.039162 70 / 750\n", " use 23367 tokens\n", "..........\n", "2024-01-28 17:03:18.122595 80 / 750\n", " use 26633 tokens\n", "..........\n", "2024-01-28 17:03:26.190372 90 / 750\n", " use 29905 tokens\n", "..........\n", "2024-01-28 17:03:33.605275 100 / 750\n", " use 33258 tokens\n", "..........\n", "2024-01-28 17:03:41.508207 110 / 750\n", " use 36639 tokens\n", "..........\n", "2024-01-28 17:03:49.187187 120 / 750\n", " use 39933 tokens\n", "..........\n", "2024-01-28 17:03:57.650307 130 / 750\n", " use 43265 tokens\n", "..........\n", "2024-01-28 17:04:05.802962 140 / 750\n", " use 46670 tokens\n", "..........\n", "2024-01-28 17:04:14.862087 150 / 750\n", " use 49887 tokens\n", "..........\n", "2024-01-28 17:04:24.654696 160 / 750\n", " use 53189 tokens\n", "..........\n", "2024-01-28 17:04:32.348524 170 / 750\n", " use 56649 tokens\n", "..........\n", "2024-01-28 17:04:39.992408 180 / 750\n", " use 60088 tokens\n", "..........\n", "2024-01-28 17:04:47.581791 190 / 750\n", " use 63313 tokens\n", "..........\n", "2024-01-28 17:04:54.968194 200 / 750\n", " use 66443 tokens\n", "..........\n", "2024-01-28 17:05:03.081295 210 / 750\n", " use 69914 tokens\n", "..........\n", "2024-01-28 17:05:11.042498 220 / 750\n", " use 73115 tokens\n", "..........\n", "2024-01-28 17:05:18.786897 230 / 750\n", " use 76566 tokens\n", "..........\n", "2024-01-28 17:05:27.032178 240 / 750\n", " use 79859 tokens\n", "..........\n", "2024-01-28 17:05:34.505171 250 / 750\n", " use 83234 tokens\n", "..........\n", "2024-01-28 17:05:42.915753 260 / 750\n", " use 86662 tokens\n", "..........\n", "2024-01-28 17:05:51.055077 270 / 750\n", " use 90100 tokens\n", "..........\n", "2024-01-28 17:05:59.541581 280 / 750\n", " use 93532 tokens\n", "..........\n", "2024-01-28 17:06:07.890632 290 / 750\n", " use 96856 tokens\n", "..........\n", "2024-01-28 17:06:16.479621 300 / 750\n", " use 100353 tokens\n", "..........\n", "2024-01-28 17:06:25.275359 310 / 750\n", " use 103854 tokens\n", "..........\n", "2024-01-28 17:06:33.737390 320 / 750\n", " use 107260 tokens\n", "..........\n", "2024-01-28 17:06:41.555370 330 / 750\n", " use 110745 tokens\n", "..........\n", "2024-01-28 17:06:49.479090 340 / 750\n", " use 114082 tokens\n", "..........\n", "2024-01-28 17:06:58.041339 350 / 750\n", " use 117357 tokens\n", "..........\n", "2024-01-28 17:07:05.942328 360 / 750\n", " use 120673 tokens\n", "..........\n", "2024-01-28 17:07:14.103199 370 / 750\n", " use 123899 tokens\n", "..........\n", "2024-01-28 17:07:22.556955 380 / 750\n", " use 127261 tokens\n", "..........\n", "2024-01-28 17:07:30.841584 390 / 750\n", " use 130624 tokens\n", "..........\n", "2024-01-28 17:07:40.961980 400 / 750\n", " use 134033 tokens\n", "..........\n", "2024-01-28 17:07:49.035868 410 / 750\n", " use 137208 tokens\n", "..........\n", "2024-01-28 17:07:58.045275 420 / 750\n", " use 140563 tokens\n", "..........\n", "2024-01-28 17:08:06.704814 430 / 750\n", " use 143926 tokens\n", "..........\n", "2024-01-28 17:08:14.719195 440 / 750\n", " use 147137 tokens\n", "..........\n", "2024-01-28 17:08:23.215680 450 / 750\n", " use 150332 tokens\n", "..........\n", "2024-01-28 17:08:31.857775 460 / 750\n", " use 153765 tokens\n", "..........\n", "2024-01-28 17:08:39.738408 470 / 750\n", " use 156996 tokens\n", "..........\n", "2024-01-28 17:08:47.790493 480 / 750\n", " use 160213 tokens\n", "..........\n", "2024-01-28 17:08:55.185835 490 / 750\n", " use 163563 tokens\n", "..........\n", "2024-01-28 17:09:03.000253 500 / 750\n", " use 166932 tokens\n", "..........\n", "2024-01-28 17:09:10.783133 510 / 750\n", " use 170240 tokens\n", "..........\n", "2024-01-28 17:09:18.652092 520 / 750\n", " use 173502 tokens\n", "..........\n", "2024-01-28 17:09:26.422400 530 / 750\n", " use 176843 tokens\n", "..........\n", "2024-01-28 17:09:34.636908 540 / 750\n", " use 180363 tokens\n", "..........\n", "2024-01-28 17:09:43.233021 550 / 750\n", " use 183593 tokens\n", "..........\n", "2024-01-28 17:09:51.823002 560 / 750\n", " use 187014 tokens\n", "..........\n", "2024-01-28 17:09:59.527840 570 / 750\n", " use 190399 tokens\n", "..........\n", "2024-01-28 17:10:08.005529 580 / 750\n", " use 193703 tokens\n", "..........\n", "2024-01-28 17:10:17.602830 590 / 750\n", " use 197048 tokens\n", "..........\n", "2024-01-28 17:10:26.569711 600 / 750\n", " use 200428 tokens\n", "..........\n", "2024-01-28 17:10:34.537832 610 / 750\n", " use 203882 tokens\n", "..........\n", "2024-01-28 17:10:43.276852 620 / 750\n", " use 207415 tokens\n", "..........\n", "2024-01-28 17:10:51.461752 630 / 750\n", " use 210626 tokens\n", "..........\n", "2024-01-28 17:10:59.469175 640 / 750\n", " use 213857 tokens\n", "..........\n", "2024-01-28 17:11:07.635625 650 / 750\n", " use 217043 tokens\n", "..........\n", "2024-01-28 17:11:17.050346 660 / 750\n", " use 220181 tokens\n", "..........\n", "2024-01-28 17:11:24.829593 670 / 750\n", " use 223295 tokens\n", "..........\n", "2024-01-28 17:11:32.432651 680 / 750\n", " use 226488 tokens\n", "..........\n", "2024-01-28 17:11:40.744801 690 / 750\n", " use 229875 tokens\n", "..........\n", "2024-01-28 17:11:48.741426 700 / 750\n", " use 233400 tokens\n", "..........\n", "2024-01-28 17:11:56.509583 710 / 750\n", " use 236675 tokens\n", "..........\n", "2024-01-28 17:12:03.689453 720 / 750\n", " use 239984 tokens\n", "..........\n", "2024-01-28 17:12:12.271545 730 / 750\n", " use 243264 tokens\n", "..........\n", "2024-01-28 17:12:20.567358 740 / 750\n", " use 246521 tokens\n", ".........." ] } ], "source": [ "logs = load_json(f'gpt_outputs_{DATASET}.json')\n", "nlp = spacy.load(\"en_core_web_sm\")\n", "bboxes = load_json(BBOXES_JSON_FILE)\n", "\n", "client = OpenAI(api_key=OPENAI_API_KEY)\n", "template = get_replace_object_template()\n", "\n", "tokens = 0\n", "for idx, r in enumerate(dataset):\n", " if idx%10==0:\n", " dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n", " print('\\n', end='')\n", " print(datetime.now(), idx, '/', len(dataset))\n", " print(f\" use {tokens} tokens\")\n", "\n", " log = logs[f'reverie__{DATASET}__{idx}__0']\n", " scan = r['scan']\n", " target_vp = r['path'][-1]\n", " avoid_objs = get_avoid_objs(bboxes, scan, target_vp)\n", "# print(log['room'])\n", "# print(log['goal'])\n", "# print(r['instructions'][0])\n", "# print(f'avoid {avoid_objs}')\n", " try:\n", " inputs = {\n", " 'instruction': r['instructions'][0],\n", " 'room_name': log['room'],\n", " 'avoid_objects': avoid_objs,\n", " 'target_object': log['goal']\n", " } \n", " prompt = template.replace('___inputs___', json.dumps(inputs, indent=4))\n", " response, total_tokens = query(client, prompt)\n", " tokens += total_tokens\n", "\n", " log['instruction'] = r['instructions'][0]\n", " log['new_instruction'] = response['new_instruction']\n", " print('.', end='')\n", " except:\n", " print(log)\n", "\n", "dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4b713517", "metadata": {}, "outputs": [], "source": [ "# ============================================================\n", "# Transform the log into REVERIE data \n", "# ============================================================" ] }, { "cell_type": "code", "execution_count": null, "id": "f8f42d90", "metadata": {}, "outputs": [], "source": [ "for index in range(len(dataset)):\n", " original_data = dataset[index]\n", " log = logs[f'reverie__{DATASET}__{index}__0']\n", " print(original_data['scan'])\n", " print(original_data['path'][-1])\n", " print(original_data['instructions'][0])\n", " print(original_data['instructions'][1])\n", " print(log)\n", " if 'new_instruction' in log:\n", " original_data['instructions'][1] = log['new_instruction']\n", " print(original_data['instructions'][1])\n", " print(get_avoid_objs(bboxes, original_data['scan'], original_data['path'][-1]))\n", " else:\n", " del log\n", " print()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b169a40c", "metadata": {}, "outputs": [], "source": [ "dump_json(dataset, f'REVERIE_{DATASET}.json.adversarial', force=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "1b45d85c", "metadata": {}, "outputs": [], "source": [ "# ============================================================\n", "# Method #2 : NLP tool to build the swap pool\n", "# ( Haven't implement )\n", "# ============================================================" ] }, { "cell_type": "code", "execution_count": null, "id": "c5c751d7", "metadata": {}, "outputs": [], "source": [ "def get_subject(nlp, text) -> str:\n", " doc = nlp(text)\n", " \n", " # find subject\n", " subject = None\n", " for index, token in enumerate(doc):\n", "# print(\"--->\", token, token.dep_)\n", " if \"ROOT\" in token.dep_ or 'obj' in token.dep_:\n", " if doc[index-1].dep_ == 'compound' or \\\n", " (doc[index-1].dep_ == 'amod' and doc[index-1].text == 'living' ):\n", " \n", " if doc[index-1].text != 'level' and doc[index-1].text != 'floor':\n", " subject = doc[index-1].text + \" \" + token.text\n", " else:\n", " subject = token.text\n", " else:\n", " subject = token.text\n", " if subject:\n", " subject = subject.replace('area', '')\n", " subject = subject.replace('\\n', '')\n", " subject = subject.replace(' ', '')\n", " return subject" ] }, { "cell_type": "code", "execution_count": 7, "id": "413d2456", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NO ROOM: {'room': None, 'goal': 'open double doors', 'goal_relations': [{'relation': 'to the left', 'reference': None}]}\n", " Go through the open double doors to the left\n", "NO ROOM: {'room': None, 'goal': 'cabinet handle', 'goal_relations': [{'relation': 'bottom left', 'reference': 'washer and dryer'}, {'relation': 'beside', 'reference': 'washer and dryer'}, {'relation': 'on the second level', 'reference': None}]}\n", " tighten the screws in the cabinet handle that is the bottom left beside the washer and dryer on the second level\n", "NO ROOM: {'error': 'Unable to extract the goal object and its relations. Please provide a valid input.'}\n", " Move to the closet and take the dress of the rack\n", "NO ROOM: {'room': ['kitchen', 'family room'], 'goal': 'table', 'goal_relations': [{'relation': 'between', 'reference': 'white chairs'}]}\n", " Walk through the kitchen into the family room and touch the table between the white chairs\n", "stairs 22\n", "foyer 10\n", "laundryroom 182\n", "lounge 126\n", "floor 22\n", "porch 22\n", "masterbathroom 10\n", "dining 5\n", "utilityroom 29\n", "deck 6\n", "terrace 10\n", "livingroom 201\n", "office 169\n", "lobby 12\n", "masterbedroom 6\n", "spa 13\n", "staircase 9\n", "kitchen 172\n", "entryway 40\n", "gym 8\n", "bathroom 693\n", "room 7\n", "balcony 25\n", "level 92\n", "meetingroom 25\n", "toilet 5\n", "hall 8\n", "closet 98\n", "bar 7\n", "hallway 232\n", "familyroom 120\n", "bedroom 356\n", "sparoom 9\n", "diningroom 172\n" ] } ], "source": [ "logs = load_json('gpt_outputs.json')\n", "nlp = spacy.load(\"en_core_web_sm\")\n", "bboxes = load_json(BBOXES_JSON_FILE)\n", "\n", "rooms = {}\n", "\n", "for idx, r in enumerate(reverie_train):\n", " try:\n", " room_descr = logs[f'reverie__train__{idx}__0']['room']\n", " room = get_subject(nlp, room_descr)\n", " if room in rooms:\n", " rooms[room] += 1\n", " else:\n", " rooms[room] = 1\n", " except:\n", " print(\"NO ROOM:\", logs[f'reverie__train__{idx}__0'])\n", " print(\" \", r['instructions'][0])\n", " \n", "selected_rooms = set()\n", "for k, v in rooms.items():\n", " if v >= 5:\n", " selected_rooms.add(k)\n", "\n", "for i in selected_rooms:\n", " print(i, rooms[i])" ] }, { "cell_type": "code", "execution_count": null, "id": "08fc5b71", "metadata": {}, "outputs": [], "source": [ "# ============================================================\n", "# DEBUG \n", "# ============================================================" ] }, { "cell_type": "code", "execution_count": null, "id": "a1524ce5", "metadata": {}, "outputs": [], "source": [ "reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }