From f2e343ff5064757925241e1be246f4d87c5dea69 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 5 Feb 2024 14:51:23 +0800 Subject: [PATCH] feat: advanced adversarial instruction (chatgpt) --- .../instruction_editing_with_chatgpt.ipynb | 1164 +++++++++++++++++ 1 file changed, 1164 insertions(+) create mode 100644 instruction_generation/instruction_editing_with_chatgpt.ipynb diff --git a/instruction_generation/instruction_editing_with_chatgpt.ipynb b/instruction_generation/instruction_editing_with_chatgpt.ipynb new file mode 100644 index 0000000..27c814f --- /dev/null +++ b/instruction_generation/instruction_editing_with_chatgpt.ipynb @@ -0,0 +1,1164 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "id": "3619bc9d", + "metadata": {}, + "outputs": [], + "source": [ + "import json, os\n", + "import openai\n", + "from openai import OpenAI\n", + "from datetime import datetime\n", + "from typing import Tuple, List\n", + "import spacy\n", + "from PIL import Image\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "# ==============================================\n", + "# !!!!! OPEN AI API KEY HERE !!!!!\n", + "# Please delete your API key before the commit.\n", + "# ==============================================\n", + "OPENAI_API_KEY = ''\n", + "\n", + "DATASET = 'val_seen'\n", + "#!this version is proccess on the adversarial instruction file generated by H.T., not original REVERIE data\n", + "REVERIE_TRAIN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'\n", + "REVERIE_VAL_UNSEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_unseen.json'\n", + "REVERIE_VAL_SEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_seen.json'\n", + "BBOXES_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json'\n", + "NAVIGABLE_PATH = '/data/NavGPT_data/navigable'\n", + "SKYBOX_PATH = '/home/snsd0805/code/research/VLN/base_dir/v1/scans'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f33686ff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "REVERIE_TRAIN_JSON_FILE" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ce31f8d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting en-core-web-sm==3.7.1\n", + " Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n", + "\u001b[K |████████████████████████████████| 12.8 MB 1.6 MB/s eta 0:00:01\n", + "\u001b[?25hRequirement already satisfied: spacy<3.8.0,>=3.7.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from en-core-web-sm==3.7.1) (3.7.2)\n", + "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n", + "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.5.2)\n", + "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n", + "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n", + "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n", + "Requirement already satisfied: jinja2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.2)\n", + "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n", + "Requirement already satisfied: numpy>=1.15.0; python_version < \"3.9\" in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.24.4)\n", + "Requirement already satisfied: packaging>=20.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (23.2)\n", + "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.1)\n", + "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n", + "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n", + "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n", + "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n", + "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n", + "Requirement already satisfied: thinc<8.3.0,>=8.1.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.2)\n", + "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n", + "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.0)\n", + "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n", + "Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (45.2.0)\n", + "Requirement already satisfied: typing-extensions>=4.6.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.8.0)\n", + "Requirement already satisfied: pydantic-core==2.14.5 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.14.5)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.6.0)\n", + "Requirement already satisfied: confection<0.2.0,>=0.0.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n", + "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/snsd0805/.local/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.25.8)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2019.11.28)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.8)\n", + "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from thinc<8.3.0,>=8.1.8->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n", + "Requirement already satisfied: click<9.0.0,>=7.1.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n", + "Installing collected packages: en-core-web-sm\n", + "Successfully installed en-core-web-sm-3.7.1\n", + "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", + "You can now load the package via spacy.load('en_core_web_sm')\n" + ] + } + ], + "source": [ + "! python3 -m spacy download en_core_web_sm" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "781790cd", + "metadata": {}, + "outputs": [], + "source": [ + "def load_json(fn):\n", + " with open(fn) as f:\n", + " ret = json.load(f)\n", + " return ret\n", + "\n", + "def dump_json(data, fn, force=False):\n", + " if not force:\n", + " assert not os.path.exists(fn)\n", + " with open(fn, 'w') as f:\n", + " json.dump(data, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2159db6", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# ChatGPT to extract the room & the target object\n", + "# ============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "29fe59c8", + "metadata": {}, + "outputs": [], + "source": [ + "TEMPLATE = '''\n", + "Please extract the target room, the goal object, and the relations between the goal object and other reference objects.\n", + "Example:\n", + "inputs:\n", + "{}\n", + "outputs:\n", + "{}\n", + "Now it is your turn:\n", + "inputs: \n", + "___inputs___\n", + "outputs:\n", + "'''\n", + "\n", + "def get_template() -> str:\n", + " inputs = 'In the kitchen above the brown shelves and above the basket on the top shelf there is a ceiling beam Please firm this beam'\n", + " outputs = {\n", + " 'room': 'kitchen',\n", + " 'goal': 'ceiling beam',\n", + " 'goal_relations':[\n", + " {'relation': 'above', 'reference': 'brown shelves'},\n", + " {'relation': 'above', 'reference': 'basket'},\n", + " {'relation': 'on the top', 'reference': 'shelf'},\n", + " ]\n", + " }\n", + " template = TEMPLATE.format(inputs, json.dumps(outputs, indent=4))\n", + " return template" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3e146e99", + "metadata": {}, + "outputs": [], + "source": [ + "def query(openai: OpenAI, prompt: str) -> Tuple[str, int]:\n", + " response = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-1106\",\n", + " response_format={ \"type\": \"json_object\" },\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"Please output JSON.\"},\n", + " {\"role\": \"user\", \"content\": prompt}\n", + " ]\n", + " )\n", + " \n", + " return (\n", + " json.loads(response.choices[0].message.content),\n", + " response.usage.total_tokens\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b334e925", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2996\n", + "1096\n", + "750\n" + ] + } + ], + "source": [ + "reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n", + "reverie_val_unseen = load_json(REVERIE_VAL_UNSEEN_JSON_FILE)\n", + "reverie_val_seen = load_json(REVERIE_VAL_SEEN_JSON_FILE)\n", + "print(len(reverie_train))\n", + "print(len(reverie_val_unseen))\n", + "print(len(reverie_val_seen))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0917c229", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val_seen\n", + "\n", + "2024-01-28 16:45:49.797527 0 / 750\n", + " use 0 tokens\n", + "..........\n", + "2024-01-28 16:46:04.932742 10 / 750\n", + " use 2480 tokens\n", + "..........\n", + "2024-01-28 16:46:23.465098 20 / 750\n", + " use 5206 tokens\n", + "..........\n", + "2024-01-28 16:46:38.780172 30 / 750\n", + " use 7705 tokens\n", + "..........\n", + "2024-01-28 16:46:52.773380 40 / 750\n", + " use 10161 tokens\n", + "..........\n", + "2024-01-28 16:47:05.971142 50 / 750\n", + " use 12556 tokens\n", + "..........\n", + "2024-01-28 16:47:19.135185 60 / 750\n", + " use 14983 tokens\n", + "..........\n", + "2024-01-28 16:47:31.982835 70 / 750\n", + " use 17418 tokens\n", + "..........\n", + "2024-01-28 16:47:44.867015 80 / 750\n", + " use 19848 tokens\n", + "..........\n", + "2024-01-28 16:47:58.097011 90 / 750\n", + " use 22322 tokens\n", + "..........\n", + "2024-01-28 16:48:09.810830 100 / 750\n", + " use 24670 tokens\n", + "..........\n", + "2024-01-28 16:48:25.051226 110 / 750\n", + " use 27178 tokens\n", + "..........\n", + "2024-01-28 16:48:39.153634 120 / 750\n", + " use 29675 tokens\n", + "..........\n", + "2024-01-28 16:48:52.691056 130 / 750\n", + " use 32131 tokens\n", + "..........\n", + "2024-01-28 16:49:08.061178 140 / 750\n", + " use 34721 tokens\n", + "..........\n", + "2024-01-28 16:49:20.385052 150 / 750\n", + " use 37036 tokens\n", + "..........\n", + "2024-01-28 16:49:35.084404 160 / 750\n", + " use 39550 tokens\n", + "..........\n", + "2024-01-28 16:49:48.060591 170 / 750\n", + " use 41995 tokens\n", + "..........\n", + "2024-01-28 16:50:00.565658 180 / 750\n", + " use 44445 tokens\n", + "..........\n", + "2024-01-28 16:50:13.490748 190 / 750\n", + " use 46862 tokens\n", + "..........\n", + "2024-01-28 16:50:26.500824 200 / 750\n", + " use 49205 tokens\n", + "..........\n", + "2024-01-28 16:50:38.342902 210 / 750\n", + " use 51623 tokens\n", + "..........\n", + "2024-01-28 16:50:50.995957 220 / 750\n", + " use 54082 tokens\n", + "..........\n", + "2024-01-28 16:51:02.467171 230 / 750\n", + " use 56477 tokens\n", + "..........\n", + "2024-01-28 16:51:15.902792 240 / 750\n", + " use 58961 tokens\n", + "..........\n", + "2024-01-28 16:51:28.502589 250 / 750\n", + " use 61385 tokens\n", + "..........\n", + "2024-01-28 16:51:41.633834 260 / 750\n", + " use 63868 tokens\n", + "..........\n", + "2024-01-28 16:51:54.141903 270 / 750\n", + " use 66339 tokens\n", + "..........\n", + "2024-01-28 16:52:08.683939 280 / 750\n", + " use 68949 tokens\n", + "..........\n", + "2024-01-28 16:52:22.242175 290 / 750\n", + " use 71424 tokens\n", + "..........\n", + "2024-01-28 16:52:35.615204 300 / 750\n", + " use 73965 tokens\n", + "..........\n", + "2024-01-28 16:52:47.855405 310 / 750\n", + " use 76411 tokens\n", + "..........\n", + "2024-01-28 16:53:01.057109 320 / 750\n", + " use 78988 tokens\n", + "..........\n", + "2024-01-28 16:53:14.308833 330 / 750\n", + " use 81496 tokens\n", + "..........\n", + "2024-01-28 16:53:27.172938 340 / 750\n", + " use 83995 tokens\n", + "..........\n", + "2024-01-28 16:53:41.768159 350 / 750\n", + " use 86587 tokens\n", + "..........\n", + "2024-01-28 16:53:54.263125 360 / 750\n", + " use 89018 tokens\n", + "..........\n", + "2024-01-28 16:54:05.971150 370 / 750\n", + " use 91426 tokens\n", + "..........\n", + "2024-01-28 16:54:18.277461 380 / 750\n", + " use 93908 tokens\n", + "..........\n", + "2024-01-28 16:54:30.239469 390 / 750\n", + " use 96333 tokens\n", + "..........\n", + "2024-01-28 16:54:43.432008 400 / 750\n", + " use 98827 tokens\n", + "..........\n", + "2024-01-28 16:54:54.753545 410 / 750\n", + " use 101234 tokens\n", + "..........\n", + "2024-01-28 16:55:08.248909 420 / 750\n", + " use 103824 tokens\n", + "..........\n", + "2024-01-28 16:55:20.266048 430 / 750\n", + " use 106241 tokens\n", + "..........\n", + "2024-01-28 16:55:31.412194 440 / 750\n", + " use 108579 tokens\n", + "..........\n", + "2024-01-28 16:55:44.964667 450 / 750\n", + " use 111007 tokens\n", + "..........\n", + "2024-01-28 16:55:56.905568 460 / 750\n", + " use 113448 tokens\n", + "..........\n", + "2024-01-28 16:56:09.677100 470 / 750\n", + " use 115910 tokens\n", + "..........\n", + "2024-01-28 16:56:23.082325 480 / 750\n", + " use 118380 tokens\n", + "..........\n", + "2024-01-28 16:56:35.981859 490 / 750\n", + " use 120884 tokens\n", + "..........\n", + "2024-01-28 16:56:50.095850 500 / 750\n", + " use 123367 tokens\n", + "..........\n", + "2024-01-28 16:57:01.456319 510 / 750\n", + " use 125768 tokens\n", + "..........\n", + "2024-01-28 16:57:13.588327 520 / 750\n", + " use 128230 tokens\n", + "..........\n", + "2024-01-28 16:57:27.429472 530 / 750\n", + " use 130816 tokens\n", + "..........\n", + "2024-01-28 16:57:39.262664 540 / 750\n", + " use 133245 tokens\n", + "..........\n", + "2024-01-28 16:57:52.110883 550 / 750\n", + " use 135746 tokens\n", + "..........\n", + "2024-01-28 16:58:05.209857 560 / 750\n", + " use 138337 tokens\n", + "..........\n", + "2024-01-28 16:58:17.967249 570 / 750\n", + " use 140782 tokens\n", + "..........\n", + "2024-01-28 16:58:29.888118 580 / 750\n", + " use 143165 tokens\n", + "..........\n", + "2024-01-28 16:58:42.603735 590 / 750\n", + " use 145588 tokens\n", + "..........\n", + "2024-01-28 16:58:55.585690 600 / 750\n", + " use 148139 tokens\n", + "..........\n", + "2024-01-28 16:59:07.646169 610 / 750\n", + " use 150567 tokens\n", + "..........\n", + "2024-01-28 16:59:21.987738 620 / 750\n", + " use 153205 tokens\n", + "..........\n", + "2024-01-28 16:59:34.235770 630 / 750\n", + " use 155655 tokens\n", + "..........\n", + "2024-01-28 16:59:47.266655 640 / 750\n", + " use 158165 tokens\n", + "..........\n", + "2024-01-28 16:59:58.403109 650 / 750\n", + " use 160519 tokens\n", + "..........\n", + "2024-01-28 17:00:10.445611 660 / 750\n", + " use 162942 tokens\n", + "..........\n", + "2024-01-28 17:00:24.864396 670 / 750\n", + " use 165361 tokens\n", + "..........\n", + "2024-01-28 17:00:37.851239 680 / 750\n", + " use 167854 tokens\n", + "..........\n", + "2024-01-28 17:00:50.721513 690 / 750\n", + " use 170355 tokens\n", + "..........\n", + "2024-01-28 17:01:04.352591 700 / 750\n", + " use 172922 tokens\n", + "..........\n", + "2024-01-28 17:01:15.395867 710 / 750\n", + " use 175265 tokens\n", + "..........\n", + "2024-01-28 17:01:29.026544 720 / 750\n", + " use 177693 tokens\n", + "..........\n", + "2024-01-28 17:01:42.652462 730 / 750\n", + " use 180203 tokens\n", + "..........\n", + "2024-01-28 17:01:57.032790 740 / 750\n", + " use 182718 tokens\n", + ".........." + ] + } + ], + "source": [ + "# OpenAI GPT to extract the relationship between objects\n", + "client = OpenAI(api_key=OPENAI_API_KEY)\n", + "template = get_template()\n", + "\n", + "logs = {}\n", + "tokens = 0\n", + "\n", + "if DATASET == 'train':\n", + " print('train')\n", + " dataset = reverie_train\n", + "elif DATASET == 'val_unseen':\n", + " print('val_unseen')\n", + " dataset = reverie_val_unseen\n", + "else:\n", + " print('val_seen')\n", + " dataset = reverie_val_seen\n", + "\n", + "for idx, r in enumerate(dataset):\n", + " if idx%10==0:\n", + " dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n", + " print('\\n', end='')\n", + " print(datetime.now(), idx, '/', len(dataset))\n", + " print(f\" use {tokens} tokens\")\n", + " \n", + " instruction = r['instructions'][0]\n", + " prompt = template.replace(\"___inputs___\", instruction)\n", + " response, total_tokens = query(client, prompt)\n", + " tokens += total_tokens\n", + " \n", + " query_name = f'reverie__{DATASET}__{idx}__0'\n", + " logs[query_name] = response\n", + " print('.', end='')\n", + " \n", + "dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "175cc6fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Used 185368 tokens\n" + ] + } + ], + "source": [ + "print(f'Used {tokens} tokens')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6596246", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# Method #1 : ChatGPT api to replace the target object\n", + "# ============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0c83e389", + "metadata": {}, + "outputs": [], + "source": [ + "def show_skybox(scan: str, viewpoint:str) -> None:\n", + " img_path = f'{SKYBOX_PATH}/{scan}/matterport_skybox_images/{viewpoint}_skybox_small.jpg'\n", + " im = Image.open(img_path)\n", + " from matplotlib import rcParams\n", + "\n", + " plt.imshow(im)\n", + " plt.show()\n", + " print(img_path)\n", + "\n", + "\n", + "def get_navigable_viewpoints(scan: str, viewpoint: str) -> list:\n", + " '''\n", + " Get all neighbor vps around it.\n", + " '''\n", + " data = load_json(f'{NAVIGABLE_PATH}/{scan}_navigable.json') \n", + " navigable_viewpoints = []\n", + " for k, v in data[viewpoint].items():\n", + " navigable_viewpoints.append(k)\n", + " \n", + " return navigable_viewpoints\n", + "\n", + "def get_objects_in_the_rooms(bboxes: dict, scan: str, viewpoint: str) -> list:\n", + " '''\n", + " Get all touchable objects around this viewpoint.\n", + " \n", + " Touchable: define by REVERIE datasets, means the objects is close to this point (maybe 1m).\n", + " '''\n", + " objs = set()\n", + " for k, v in bboxes[f'{scan}_{viewpoint}'].items():\n", + " objs.add(v['name'].replace('#', ' '))\n", + " return list(objs)\n", + "\n", + "def get_avoid_objs(bboxes: dict, scan: str, viewpoint: str) -> list:\n", + " '''\n", + " Get objects around this viewpoint\n", + " \n", + " First, it call get_navigable_viewpoints() to get the neighbor viewpoints.\n", + " Then, it get all the objects around its neighbor, we assume these objects is all visible bbox in this room\n", + " We need this list to avoid generating the objects that exist in this room\n", + "\n", + " '''\n", + " vps = get_navigable_viewpoints(scan, viewpoint)\n", + " objs = get_objects_in_the_rooms(bboxes, scan, viewpoint)\n", + " for i in vps:\n", + " tmp_objs = get_objects_in_the_rooms(bboxes, scan, i)\n", + " objs += tmp_objs\n", + " \n", + " return list(set(objs))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ae77b862", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "You should replace the target object and return me a new instruction.\n", + "Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n", + "Some times, you can change the verb which suitable for the new target objects.\n", + "\n", + "Example:\n", + "inputs:\n", + "{\n", + " \"instruction\": \"Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door\",\n", + " \"room_name\": \"bedroom\",\n", + " \"avoid_objects\": [\n", + " \"window\",\n", + " \"lamp\",\n", + " \"picture\",\n", + " \"bed\"\n", + " ],\n", + " \"target_object\": \"lamp\"\n", + "}\n", + "outputs:\n", + "{\n", + " \"new_instruction\": \"Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door\"\n", + "}\n", + "Now it is your turn:\n", + "inputs: \n", + "___inputs___\n", + "outputs:\n", + "\n" + ] + } + ], + "source": [ + "REPLACE_OBJECT_TEMPLATE = '''\n", + "You should replace the target object and return me a new instruction.\n", + "Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n", + "Some times, you can change the verb which suitable for the new target objects.\n", + "\n", + "Example:\n", + "inputs:\n", + "{}\n", + "outputs:\n", + "{}\n", + "Now it is your turn:\n", + "inputs: \n", + "___inputs___\n", + "outputs:\n", + "'''\n", + "\n", + "def get_replace_object_template() -> str:\n", + " inputs = {\n", + " 'instruction': 'Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door',\n", + " 'room_name': 'bedroom',\n", + " 'avoid_objects': ['window', 'lamp', 'picture', 'bed'],\n", + " 'target_object': 'lamp'\n", + " }\n", + " outputs = {\n", + " 'new_instruction': 'Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door',\n", + " }\n", + " template = REPLACE_OBJECT_TEMPLATE.format(json.dumps(inputs, indent=4), json.dumps(outputs, indent=4))\n", + " return template\n", + "print(get_replace_object_template())" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "70c9888b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "2024-01-28 17:02:12.982758 0 / 750\n", + " use 0 tokens\n", + "..........\n", + "2024-01-28 17:02:21.516631 10 / 750\n", + " use 3207 tokens\n", + "..........\n", + "2024-01-28 17:02:30.687271 20 / 750\n", + " use 6593 tokens\n", + "..........\n", + "2024-01-28 17:02:38.764408 30 / 750\n", + " use 9909 tokens\n", + "..........\n", + "2024-01-28 17:02:46.047457 40 / 750\n", + " use 13275 tokens\n", + "..........\n", + "2024-01-28 17:02:53.616187 50 / 750\n", + " use 16683 tokens\n", + "..........\n", + "2024-01-28 17:03:00.754883 60 / 750\n", + " use 19998 tokens\n", + "..........\n", + "2024-01-28 17:03:10.039162 70 / 750\n", + " use 23367 tokens\n", + "..........\n", + "2024-01-28 17:03:18.122595 80 / 750\n", + " use 26633 tokens\n", + "..........\n", + "2024-01-28 17:03:26.190372 90 / 750\n", + " use 29905 tokens\n", + "..........\n", + "2024-01-28 17:03:33.605275 100 / 750\n", + " use 33258 tokens\n", + "..........\n", + "2024-01-28 17:03:41.508207 110 / 750\n", + " use 36639 tokens\n", + "..........\n", + "2024-01-28 17:03:49.187187 120 / 750\n", + " use 39933 tokens\n", + "..........\n", + "2024-01-28 17:03:57.650307 130 / 750\n", + " use 43265 tokens\n", + "..........\n", + "2024-01-28 17:04:05.802962 140 / 750\n", + " use 46670 tokens\n", + "..........\n", + "2024-01-28 17:04:14.862087 150 / 750\n", + " use 49887 tokens\n", + "..........\n", + "2024-01-28 17:04:24.654696 160 / 750\n", + " use 53189 tokens\n", + "..........\n", + "2024-01-28 17:04:32.348524 170 / 750\n", + " use 56649 tokens\n", + "..........\n", + "2024-01-28 17:04:39.992408 180 / 750\n", + " use 60088 tokens\n", + "..........\n", + "2024-01-28 17:04:47.581791 190 / 750\n", + " use 63313 tokens\n", + "..........\n", + "2024-01-28 17:04:54.968194 200 / 750\n", + " use 66443 tokens\n", + "..........\n", + "2024-01-28 17:05:03.081295 210 / 750\n", + " use 69914 tokens\n", + "..........\n", + "2024-01-28 17:05:11.042498 220 / 750\n", + " use 73115 tokens\n", + "..........\n", + "2024-01-28 17:05:18.786897 230 / 750\n", + " use 76566 tokens\n", + "..........\n", + "2024-01-28 17:05:27.032178 240 / 750\n", + " use 79859 tokens\n", + "..........\n", + "2024-01-28 17:05:34.505171 250 / 750\n", + " use 83234 tokens\n", + "..........\n", + "2024-01-28 17:05:42.915753 260 / 750\n", + " use 86662 tokens\n", + "..........\n", + "2024-01-28 17:05:51.055077 270 / 750\n", + " use 90100 tokens\n", + "..........\n", + "2024-01-28 17:05:59.541581 280 / 750\n", + " use 93532 tokens\n", + "..........\n", + "2024-01-28 17:06:07.890632 290 / 750\n", + " use 96856 tokens\n", + "..........\n", + "2024-01-28 17:06:16.479621 300 / 750\n", + " use 100353 tokens\n", + "..........\n", + "2024-01-28 17:06:25.275359 310 / 750\n", + " use 103854 tokens\n", + "..........\n", + "2024-01-28 17:06:33.737390 320 / 750\n", + " use 107260 tokens\n", + "..........\n", + "2024-01-28 17:06:41.555370 330 / 750\n", + " use 110745 tokens\n", + "..........\n", + "2024-01-28 17:06:49.479090 340 / 750\n", + " use 114082 tokens\n", + "..........\n", + "2024-01-28 17:06:58.041339 350 / 750\n", + " use 117357 tokens\n", + "..........\n", + "2024-01-28 17:07:05.942328 360 / 750\n", + " use 120673 tokens\n", + "..........\n", + "2024-01-28 17:07:14.103199 370 / 750\n", + " use 123899 tokens\n", + "..........\n", + "2024-01-28 17:07:22.556955 380 / 750\n", + " use 127261 tokens\n", + "..........\n", + "2024-01-28 17:07:30.841584 390 / 750\n", + " use 130624 tokens\n", + "..........\n", + "2024-01-28 17:07:40.961980 400 / 750\n", + " use 134033 tokens\n", + "..........\n", + "2024-01-28 17:07:49.035868 410 / 750\n", + " use 137208 tokens\n", + "..........\n", + "2024-01-28 17:07:58.045275 420 / 750\n", + " use 140563 tokens\n", + "..........\n", + "2024-01-28 17:08:06.704814 430 / 750\n", + " use 143926 tokens\n", + "..........\n", + "2024-01-28 17:08:14.719195 440 / 750\n", + " use 147137 tokens\n", + "..........\n", + "2024-01-28 17:08:23.215680 450 / 750\n", + " use 150332 tokens\n", + "..........\n", + "2024-01-28 17:08:31.857775 460 / 750\n", + " use 153765 tokens\n", + "..........\n", + "2024-01-28 17:08:39.738408 470 / 750\n", + " use 156996 tokens\n", + "..........\n", + "2024-01-28 17:08:47.790493 480 / 750\n", + " use 160213 tokens\n", + "..........\n", + "2024-01-28 17:08:55.185835 490 / 750\n", + " use 163563 tokens\n", + "..........\n", + "2024-01-28 17:09:03.000253 500 / 750\n", + " use 166932 tokens\n", + "..........\n", + "2024-01-28 17:09:10.783133 510 / 750\n", + " use 170240 tokens\n", + "..........\n", + "2024-01-28 17:09:18.652092 520 / 750\n", + " use 173502 tokens\n", + "..........\n", + "2024-01-28 17:09:26.422400 530 / 750\n", + " use 176843 tokens\n", + "..........\n", + "2024-01-28 17:09:34.636908 540 / 750\n", + " use 180363 tokens\n", + "..........\n", + "2024-01-28 17:09:43.233021 550 / 750\n", + " use 183593 tokens\n", + "..........\n", + "2024-01-28 17:09:51.823002 560 / 750\n", + " use 187014 tokens\n", + "..........\n", + "2024-01-28 17:09:59.527840 570 / 750\n", + " use 190399 tokens\n", + "..........\n", + "2024-01-28 17:10:08.005529 580 / 750\n", + " use 193703 tokens\n", + "..........\n", + "2024-01-28 17:10:17.602830 590 / 750\n", + " use 197048 tokens\n", + "..........\n", + "2024-01-28 17:10:26.569711 600 / 750\n", + " use 200428 tokens\n", + "..........\n", + "2024-01-28 17:10:34.537832 610 / 750\n", + " use 203882 tokens\n", + "..........\n", + "2024-01-28 17:10:43.276852 620 / 750\n", + " use 207415 tokens\n", + "..........\n", + "2024-01-28 17:10:51.461752 630 / 750\n", + " use 210626 tokens\n", + "..........\n", + "2024-01-28 17:10:59.469175 640 / 750\n", + " use 213857 tokens\n", + "..........\n", + "2024-01-28 17:11:07.635625 650 / 750\n", + " use 217043 tokens\n", + "..........\n", + "2024-01-28 17:11:17.050346 660 / 750\n", + " use 220181 tokens\n", + "..........\n", + "2024-01-28 17:11:24.829593 670 / 750\n", + " use 223295 tokens\n", + "..........\n", + "2024-01-28 17:11:32.432651 680 / 750\n", + " use 226488 tokens\n", + "..........\n", + "2024-01-28 17:11:40.744801 690 / 750\n", + " use 229875 tokens\n", + "..........\n", + "2024-01-28 17:11:48.741426 700 / 750\n", + " use 233400 tokens\n", + "..........\n", + "2024-01-28 17:11:56.509583 710 / 750\n", + " use 236675 tokens\n", + "..........\n", + "2024-01-28 17:12:03.689453 720 / 750\n", + " use 239984 tokens\n", + "..........\n", + "2024-01-28 17:12:12.271545 730 / 750\n", + " use 243264 tokens\n", + "..........\n", + "2024-01-28 17:12:20.567358 740 / 750\n", + " use 246521 tokens\n", + ".........." + ] + } + ], + "source": [ + "logs = load_json(f'gpt_outputs_{DATASET}.json')\n", + "nlp = spacy.load(\"en_core_web_sm\")\n", + "bboxes = load_json(BBOXES_JSON_FILE)\n", + "\n", + "client = OpenAI(api_key=OPENAI_API_KEY)\n", + "template = get_replace_object_template()\n", + "\n", + "tokens = 0\n", + "for idx, r in enumerate(dataset):\n", + " if idx%10==0:\n", + " dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n", + " print('\\n', end='')\n", + " print(datetime.now(), idx, '/', len(dataset))\n", + " print(f\" use {tokens} tokens\")\n", + "\n", + " log = logs[f'reverie__{DATASET}__{idx}__0']\n", + " scan = r['scan']\n", + " target_vp = r['path'][-1]\n", + " avoid_objs = get_avoid_objs(bboxes, scan, target_vp)\n", + "# print(log['room'])\n", + "# print(log['goal'])\n", + "# print(r['instructions'][0])\n", + "# print(f'avoid {avoid_objs}')\n", + " try:\n", + " inputs = {\n", + " 'instruction': r['instructions'][0],\n", + " 'room_name': log['room'],\n", + " 'avoid_objects': avoid_objs,\n", + " 'target_object': log['goal']\n", + " } \n", + " prompt = template.replace('___inputs___', json.dumps(inputs, indent=4))\n", + " response, total_tokens = query(client, prompt)\n", + " tokens += total_tokens\n", + "\n", + " log['instruction'] = r['instructions'][0]\n", + " log['new_instruction'] = response['new_instruction']\n", + " print('.', end='')\n", + " except:\n", + " print(log)\n", + "\n", + "dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b713517", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# Transform the log into REVERIE data \n", + "# ============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8f42d90", + "metadata": {}, + "outputs": [], + "source": [ + "for index in range(len(dataset)):\n", + " original_data = dataset[index]\n", + " log = logs[f'reverie__{DATASET}__{index}__0']\n", + " print(original_data['scan'])\n", + " print(original_data['path'][-1])\n", + " print(original_data['instructions'][0])\n", + " print(original_data['instructions'][1])\n", + " print(log)\n", + " if 'new_instruction' in log:\n", + " original_data['instructions'][1] = log['new_instruction']\n", + " print(original_data['instructions'][1])\n", + " print(get_avoid_objs(bboxes, original_data['scan'], original_data['path'][-1]))\n", + " else:\n", + " del log\n", + " print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b169a40c", + "metadata": {}, + "outputs": [], + "source": [ + "dump_json(dataset, f'REVERIE_{DATASET}.json.adversarial', force=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b45d85c", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# Method #2 : NLP tool to build the swap pool\n", + "# ( Haven't implement )\n", + "# ============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5c751d7", + "metadata": {}, + "outputs": [], + "source": [ + "def get_subject(nlp, text) -> str:\n", + " doc = nlp(text)\n", + " \n", + " # find subject\n", + " subject = None\n", + " for index, token in enumerate(doc):\n", + "# print(\"--->\", token, token.dep_)\n", + " if \"ROOT\" in token.dep_ or 'obj' in token.dep_:\n", + " if doc[index-1].dep_ == 'compound' or \\\n", + " (doc[index-1].dep_ == 'amod' and doc[index-1].text == 'living' ):\n", + " \n", + " if doc[index-1].text != 'level' and doc[index-1].text != 'floor':\n", + " subject = doc[index-1].text + \" \" + token.text\n", + " else:\n", + " subject = token.text\n", + " else:\n", + " subject = token.text\n", + " if subject:\n", + " subject = subject.replace('area', '')\n", + " subject = subject.replace('\\n', '')\n", + " subject = subject.replace(' ', '')\n", + " return subject" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "413d2456", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NO ROOM: {'room': None, 'goal': 'open double doors', 'goal_relations': [{'relation': 'to the left', 'reference': None}]}\n", + " Go through the open double doors to the left\n", + "NO ROOM: {'room': None, 'goal': 'cabinet handle', 'goal_relations': [{'relation': 'bottom left', 'reference': 'washer and dryer'}, {'relation': 'beside', 'reference': 'washer and dryer'}, {'relation': 'on the second level', 'reference': None}]}\n", + " tighten the screws in the cabinet handle that is the bottom left beside the washer and dryer on the second level\n", + "NO ROOM: {'error': 'Unable to extract the goal object and its relations. Please provide a valid input.'}\n", + " Move to the closet and take the dress of the rack\n", + "NO ROOM: {'room': ['kitchen', 'family room'], 'goal': 'table', 'goal_relations': [{'relation': 'between', 'reference': 'white chairs'}]}\n", + " Walk through the kitchen into the family room and touch the table between the white chairs\n", + "stairs 22\n", + "foyer 10\n", + "laundryroom 182\n", + "lounge 126\n", + "floor 22\n", + "porch 22\n", + "masterbathroom 10\n", + "dining 5\n", + "utilityroom 29\n", + "deck 6\n", + "terrace 10\n", + "livingroom 201\n", + "office 169\n", + "lobby 12\n", + "masterbedroom 6\n", + "spa 13\n", + "staircase 9\n", + "kitchen 172\n", + "entryway 40\n", + "gym 8\n", + "bathroom 693\n", + "room 7\n", + "balcony 25\n", + "level 92\n", + "meetingroom 25\n", + "toilet 5\n", + "hall 8\n", + "closet 98\n", + "bar 7\n", + "hallway 232\n", + "familyroom 120\n", + "bedroom 356\n", + "sparoom 9\n", + "diningroom 172\n" + ] + } + ], + "source": [ + "logs = load_json('gpt_outputs.json')\n", + "nlp = spacy.load(\"en_core_web_sm\")\n", + "bboxes = load_json(BBOXES_JSON_FILE)\n", + "\n", + "rooms = {}\n", + "\n", + "for idx, r in enumerate(reverie_train):\n", + " try:\n", + " room_descr = logs[f'reverie__train__{idx}__0']['room']\n", + " room = get_subject(nlp, room_descr)\n", + " if room in rooms:\n", + " rooms[room] += 1\n", + " else:\n", + " rooms[room] = 1\n", + " except:\n", + " print(\"NO ROOM:\", logs[f'reverie__train__{idx}__0'])\n", + " print(\" \", r['instructions'][0])\n", + " \n", + "selected_rooms = set()\n", + "for k, v in rooms.items():\n", + " if v >= 5:\n", + " selected_rooms.add(k)\n", + "\n", + "for i in selected_rooms:\n", + " print(i, rooms[i])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08fc5b71", + "metadata": {}, + "outputs": [], + "source": [ + "# ============================================================\n", + "# DEBUG \n", + "# ============================================================" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1524ce5", + "metadata": {}, + "outputs": [], + "source": [ + "reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}