1165 lines
42 KiB
Plaintext
1165 lines
42 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "3619bc9d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json, os\n",
|
|
"import openai\n",
|
|
"from openai import OpenAI\n",
|
|
"from datetime import datetime\n",
|
|
"from typing import Tuple, List\n",
|
|
"import spacy\n",
|
|
"from PIL import Image\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"\n",
|
|
"# ==============================================\n",
|
|
"# !!!!! OPEN AI API KEY HERE !!!!!\n",
|
|
"# Please delete your API key before the commit.\n",
|
|
"# ==============================================\n",
|
|
"OPENAI_API_KEY = ''\n",
|
|
"\n",
|
|
"DATASET = 'val_seen'\n",
|
|
"#!this version is proccess on the adversarial instruction file generated by H.T., not original REVERIE data\n",
|
|
"REVERIE_TRAIN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'\n",
|
|
"REVERIE_VAL_UNSEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_unseen.json'\n",
|
|
"REVERIE_VAL_SEEN_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_val_seen.json'\n",
|
|
"BBOXES_JSON_FILE = '/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/BBoxes.json'\n",
|
|
"NAVIGABLE_PATH = '/data/NavGPT_data/navigable'\n",
|
|
"SKYBOX_PATH = '/home/snsd0805/code/research/VLN/base_dir/v1/scans'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "f33686ff",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'/data/Matterport3DSimulator-duet/VLN-DUET/datasets/REVERIE/annotations/REVERIE_train.json'"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"REVERIE_TRAIN_JSON_FILE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ce31f8d8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting en-core-web-sm==3.7.1\n",
|
|
" Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n",
|
|
"\u001b[K |████████████████████████████████| 12.8 MB 1.6 MB/s eta 0:00:01\n",
|
|
"\u001b[?25hRequirement already satisfied: spacy<3.8.0,>=3.7.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from en-core-web-sm==3.7.1) (3.7.2)\n",
|
|
"Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n",
|
|
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.5.2)\n",
|
|
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n",
|
|
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n",
|
|
"Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n",
|
|
"Requirement already satisfied: jinja2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.2)\n",
|
|
"Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n",
|
|
"Requirement already satisfied: numpy>=1.15.0; python_version < \"3.9\" in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.24.4)\n",
|
|
"Requirement already satisfied: packaging>=20.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (23.2)\n",
|
|
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.1)\n",
|
|
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n",
|
|
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n",
|
|
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n",
|
|
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n",
|
|
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n",
|
|
"Requirement already satisfied: thinc<8.3.0,>=8.1.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.2)\n",
|
|
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n",
|
|
"Requirement already satisfied: typer<0.10.0,>=0.3.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.0)\n",
|
|
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/snsd0805/.local/lib/python3.8/site-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n",
|
|
"Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (45.2.0)\n",
|
|
"Requirement already satisfied: typing-extensions>=4.6.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.8.0)\n",
|
|
"Requirement already satisfied: pydantic-core==2.14.5 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.14.5)\n",
|
|
"Requirement already satisfied: annotated-types>=0.4.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.6.0)\n",
|
|
"Requirement already satisfied: confection<0.2.0,>=0.0.4 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n",
|
|
"Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n",
|
|
"Requirement already satisfied: MarkupSafe>=2.0 in /home/snsd0805/.local/lib/python3.8/site-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.3)\n",
|
|
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/snsd0805/.local/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n",
|
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.25.8)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2019.11.28)\n",
|
|
"Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.8)\n",
|
|
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in /home/snsd0805/.local/lib/python3.8/site-packages (from thinc<8.3.0,>=8.1.8->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n",
|
|
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /home/snsd0805/.local/lib/python3.8/site-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n",
|
|
"Installing collected packages: en-core-web-sm\n",
|
|
"Successfully installed en-core-web-sm-3.7.1\n",
|
|
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
|
|
"You can now load the package via spacy.load('en_core_web_sm')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"! python3 -m spacy download en_core_web_sm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "781790cd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_json(fn):\n",
|
|
" with open(fn) as f:\n",
|
|
" ret = json.load(f)\n",
|
|
" return ret\n",
|
|
"\n",
|
|
"def dump_json(data, fn, force=False):\n",
|
|
" if not force:\n",
|
|
" assert not os.path.exists(fn)\n",
|
|
" with open(fn, 'w') as f:\n",
|
|
" json.dump(data, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b2159db6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ============================================================\n",
|
|
"# ChatGPT to extract the room & the target object\n",
|
|
"# ============================================================"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "29fe59c8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"TEMPLATE = '''\n",
|
|
"Please extract the target room, the goal object, and the relations between the goal object and other reference objects.\n",
|
|
"Example:\n",
|
|
"inputs:\n",
|
|
"{}\n",
|
|
"outputs:\n",
|
|
"{}\n",
|
|
"Now it is your turn:\n",
|
|
"inputs: \n",
|
|
"___inputs___\n",
|
|
"outputs:\n",
|
|
"'''\n",
|
|
"\n",
|
|
"def get_template() -> str:\n",
|
|
" inputs = 'In the kitchen above the brown shelves and above the basket on the top shelf there is a ceiling beam Please firm this beam'\n",
|
|
" outputs = {\n",
|
|
" 'room': 'kitchen',\n",
|
|
" 'goal': 'ceiling beam',\n",
|
|
" 'goal_relations':[\n",
|
|
" {'relation': 'above', 'reference': 'brown shelves'},\n",
|
|
" {'relation': 'above', 'reference': 'basket'},\n",
|
|
" {'relation': 'on the top', 'reference': 'shelf'},\n",
|
|
" ]\n",
|
|
" }\n",
|
|
" template = TEMPLATE.format(inputs, json.dumps(outputs, indent=4))\n",
|
|
" return template"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "3e146e99",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def query(openai: OpenAI, prompt: str) -> Tuple[str, int]:\n",
|
|
" response = client.chat.completions.create(\n",
|
|
" model=\"gpt-3.5-turbo-1106\",\n",
|
|
" response_format={ \"type\": \"json_object\" },\n",
|
|
" messages=[\n",
|
|
" {\"role\": \"system\", \"content\": \"Please output JSON.\"},\n",
|
|
" {\"role\": \"user\", \"content\": prompt}\n",
|
|
" ]\n",
|
|
" )\n",
|
|
" \n",
|
|
" return (\n",
|
|
" json.loads(response.choices[0].message.content),\n",
|
|
" response.usage.total_tokens\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "b334e925",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2996\n",
|
|
"1096\n",
|
|
"750\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n",
|
|
"reverie_val_unseen = load_json(REVERIE_VAL_UNSEEN_JSON_FILE)\n",
|
|
"reverie_val_seen = load_json(REVERIE_VAL_SEEN_JSON_FILE)\n",
|
|
"print(len(reverie_train))\n",
|
|
"print(len(reverie_val_unseen))\n",
|
|
"print(len(reverie_val_seen))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "0917c229",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"val_seen\n",
|
|
"\n",
|
|
"2024-01-28 16:45:49.797527 0 / 750\n",
|
|
" use 0 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:46:04.932742 10 / 750\n",
|
|
" use 2480 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:46:23.465098 20 / 750\n",
|
|
" use 5206 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:46:38.780172 30 / 750\n",
|
|
" use 7705 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:46:52.773380 40 / 750\n",
|
|
" use 10161 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:47:05.971142 50 / 750\n",
|
|
" use 12556 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:47:19.135185 60 / 750\n",
|
|
" use 14983 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:47:31.982835 70 / 750\n",
|
|
" use 17418 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:47:44.867015 80 / 750\n",
|
|
" use 19848 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:47:58.097011 90 / 750\n",
|
|
" use 22322 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:48:09.810830 100 / 750\n",
|
|
" use 24670 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:48:25.051226 110 / 750\n",
|
|
" use 27178 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:48:39.153634 120 / 750\n",
|
|
" use 29675 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:48:52.691056 130 / 750\n",
|
|
" use 32131 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:49:08.061178 140 / 750\n",
|
|
" use 34721 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:49:20.385052 150 / 750\n",
|
|
" use 37036 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:49:35.084404 160 / 750\n",
|
|
" use 39550 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:49:48.060591 170 / 750\n",
|
|
" use 41995 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:50:00.565658 180 / 750\n",
|
|
" use 44445 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:50:13.490748 190 / 750\n",
|
|
" use 46862 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:50:26.500824 200 / 750\n",
|
|
" use 49205 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:50:38.342902 210 / 750\n",
|
|
" use 51623 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:50:50.995957 220 / 750\n",
|
|
" use 54082 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:51:02.467171 230 / 750\n",
|
|
" use 56477 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:51:15.902792 240 / 750\n",
|
|
" use 58961 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:51:28.502589 250 / 750\n",
|
|
" use 61385 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:51:41.633834 260 / 750\n",
|
|
" use 63868 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:51:54.141903 270 / 750\n",
|
|
" use 66339 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:52:08.683939 280 / 750\n",
|
|
" use 68949 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:52:22.242175 290 / 750\n",
|
|
" use 71424 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:52:35.615204 300 / 750\n",
|
|
" use 73965 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:52:47.855405 310 / 750\n",
|
|
" use 76411 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:53:01.057109 320 / 750\n",
|
|
" use 78988 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:53:14.308833 330 / 750\n",
|
|
" use 81496 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:53:27.172938 340 / 750\n",
|
|
" use 83995 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:53:41.768159 350 / 750\n",
|
|
" use 86587 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:53:54.263125 360 / 750\n",
|
|
" use 89018 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:54:05.971150 370 / 750\n",
|
|
" use 91426 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:54:18.277461 380 / 750\n",
|
|
" use 93908 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:54:30.239469 390 / 750\n",
|
|
" use 96333 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:54:43.432008 400 / 750\n",
|
|
" use 98827 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:54:54.753545 410 / 750\n",
|
|
" use 101234 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:55:08.248909 420 / 750\n",
|
|
" use 103824 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:55:20.266048 430 / 750\n",
|
|
" use 106241 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:55:31.412194 440 / 750\n",
|
|
" use 108579 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:55:44.964667 450 / 750\n",
|
|
" use 111007 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:55:56.905568 460 / 750\n",
|
|
" use 113448 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:56:09.677100 470 / 750\n",
|
|
" use 115910 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:56:23.082325 480 / 750\n",
|
|
" use 118380 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:56:35.981859 490 / 750\n",
|
|
" use 120884 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:56:50.095850 500 / 750\n",
|
|
" use 123367 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:57:01.456319 510 / 750\n",
|
|
" use 125768 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:57:13.588327 520 / 750\n",
|
|
" use 128230 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:57:27.429472 530 / 750\n",
|
|
" use 130816 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:57:39.262664 540 / 750\n",
|
|
" use 133245 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:57:52.110883 550 / 750\n",
|
|
" use 135746 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:58:05.209857 560 / 750\n",
|
|
" use 138337 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:58:17.967249 570 / 750\n",
|
|
" use 140782 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:58:29.888118 580 / 750\n",
|
|
" use 143165 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:58:42.603735 590 / 750\n",
|
|
" use 145588 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:58:55.585690 600 / 750\n",
|
|
" use 148139 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:59:07.646169 610 / 750\n",
|
|
" use 150567 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:59:21.987738 620 / 750\n",
|
|
" use 153205 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:59:34.235770 630 / 750\n",
|
|
" use 155655 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:59:47.266655 640 / 750\n",
|
|
" use 158165 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 16:59:58.403109 650 / 750\n",
|
|
" use 160519 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:00:10.445611 660 / 750\n",
|
|
" use 162942 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:00:24.864396 670 / 750\n",
|
|
" use 165361 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:00:37.851239 680 / 750\n",
|
|
" use 167854 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:00:50.721513 690 / 750\n",
|
|
" use 170355 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:01:04.352591 700 / 750\n",
|
|
" use 172922 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:01:15.395867 710 / 750\n",
|
|
" use 175265 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:01:29.026544 720 / 750\n",
|
|
" use 177693 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:01:42.652462 730 / 750\n",
|
|
" use 180203 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:01:57.032790 740 / 750\n",
|
|
" use 182718 tokens\n",
|
|
".........."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# OpenAI GPT to extract the relationship between objects\n",
|
|
"client = OpenAI(api_key=OPENAI_API_KEY)\n",
|
|
"template = get_template()\n",
|
|
"\n",
|
|
"logs = {}\n",
|
|
"tokens = 0\n",
|
|
"\n",
|
|
"if DATASET == 'train':\n",
|
|
" print('train')\n",
|
|
" dataset = reverie_train\n",
|
|
"elif DATASET == 'val_unseen':\n",
|
|
" print('val_unseen')\n",
|
|
" dataset = reverie_val_unseen\n",
|
|
"else:\n",
|
|
" print('val_seen')\n",
|
|
" dataset = reverie_val_seen\n",
|
|
"\n",
|
|
"for idx, r in enumerate(dataset):\n",
|
|
" if idx%10==0:\n",
|
|
" dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n",
|
|
" print('\\n', end='')\n",
|
|
" print(datetime.now(), idx, '/', len(dataset))\n",
|
|
" print(f\" use {tokens} tokens\")\n",
|
|
" \n",
|
|
" instruction = r['instructions'][0]\n",
|
|
" prompt = template.replace(\"___inputs___\", instruction)\n",
|
|
" response, total_tokens = query(client, prompt)\n",
|
|
" tokens += total_tokens\n",
|
|
" \n",
|
|
" query_name = f'reverie__{DATASET}__{idx}__0'\n",
|
|
" logs[query_name] = response\n",
|
|
" print('.', end='')\n",
|
|
" \n",
|
|
"dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "175cc6fb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Used 185368 tokens\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Used {tokens} tokens')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d6596246",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ============================================================\n",
|
|
"# Method #1 : ChatGPT api to replace the target object\n",
|
|
"# ============================================================"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "0c83e389",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def show_skybox(scan: str, viewpoint:str) -> None:\n",
|
|
" img_path = f'{SKYBOX_PATH}/{scan}/matterport_skybox_images/{viewpoint}_skybox_small.jpg'\n",
|
|
" im = Image.open(img_path)\n",
|
|
" from matplotlib import rcParams\n",
|
|
"\n",
|
|
" plt.imshow(im)\n",
|
|
" plt.show()\n",
|
|
" print(img_path)\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_navigable_viewpoints(scan: str, viewpoint: str) -> list:\n",
|
|
" '''\n",
|
|
" Get all neighbor vps around it.\n",
|
|
" '''\n",
|
|
" data = load_json(f'{NAVIGABLE_PATH}/{scan}_navigable.json') \n",
|
|
" navigable_viewpoints = []\n",
|
|
" for k, v in data[viewpoint].items():\n",
|
|
" navigable_viewpoints.append(k)\n",
|
|
" \n",
|
|
" return navigable_viewpoints\n",
|
|
"\n",
|
|
"def get_objects_in_the_rooms(bboxes: dict, scan: str, viewpoint: str) -> list:\n",
|
|
" '''\n",
|
|
" Get all touchable objects around this viewpoint.\n",
|
|
" \n",
|
|
" Touchable: define by REVERIE datasets, means the objects is close to this point (maybe 1m).\n",
|
|
" '''\n",
|
|
" objs = set()\n",
|
|
" for k, v in bboxes[f'{scan}_{viewpoint}'].items():\n",
|
|
" objs.add(v['name'].replace('#', ' '))\n",
|
|
" return list(objs)\n",
|
|
"\n",
|
|
"def get_avoid_objs(bboxes: dict, scan: str, viewpoint: str) -> list:\n",
|
|
" '''\n",
|
|
" Get objects around this viewpoint\n",
|
|
" \n",
|
|
" First, it call get_navigable_viewpoints() to get the neighbor viewpoints.\n",
|
|
" Then, it get all the objects around its neighbor, we assume these objects is all visible bbox in this room\n",
|
|
" We need this list to avoid generating the objects that exist in this room\n",
|
|
"\n",
|
|
" '''\n",
|
|
" vps = get_navigable_viewpoints(scan, viewpoint)\n",
|
|
" objs = get_objects_in_the_rooms(bboxes, scan, viewpoint)\n",
|
|
" for i in vps:\n",
|
|
" tmp_objs = get_objects_in_the_rooms(bboxes, scan, i)\n",
|
|
" objs += tmp_objs\n",
|
|
" \n",
|
|
" return list(set(objs))\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ae77b862",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"You should replace the target object and return me a new instruction.\n",
|
|
"Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n",
|
|
"Some times, you can change the verb which suitable for the new target objects.\n",
|
|
"\n",
|
|
"Example:\n",
|
|
"inputs:\n",
|
|
"{\n",
|
|
" \"instruction\": \"Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door\",\n",
|
|
" \"room_name\": \"bedroom\",\n",
|
|
" \"avoid_objects\": [\n",
|
|
" \"window\",\n",
|
|
" \"lamp\",\n",
|
|
" \"picture\",\n",
|
|
" \"bed\"\n",
|
|
" ],\n",
|
|
" \"target_object\": \"lamp\"\n",
|
|
"}\n",
|
|
"outputs:\n",
|
|
"{\n",
|
|
" \"new_instruction\": \"Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door\"\n",
|
|
"}\n",
|
|
"Now it is your turn:\n",
|
|
"inputs: \n",
|
|
"___inputs___\n",
|
|
"outputs:\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"REPLACE_OBJECT_TEMPLATE = '''\n",
|
|
"You should replace the target object and return me a new instruction.\n",
|
|
"Notice: the new target object must be suitable for this room (room name), and it must doesn't look like any objects(different type) in avoid_objects list.\n",
|
|
"Some times, you can change the verb which suitable for the new target objects.\n",
|
|
"\n",
|
|
"Example:\n",
|
|
"inputs:\n",
|
|
"{}\n",
|
|
"outputs:\n",
|
|
"{}\n",
|
|
"Now it is your turn:\n",
|
|
"inputs: \n",
|
|
"___inputs___\n",
|
|
"outputs:\n",
|
|
"'''\n",
|
|
"\n",
|
|
"def get_replace_object_template() -> str:\n",
|
|
" inputs = {\n",
|
|
" 'instruction': 'Go to bedroom at the back left side of the house and turn on the lamp nearest the bedroom door',\n",
|
|
" 'room_name': 'bedroom',\n",
|
|
" 'avoid_objects': ['window', 'lamp', 'picture', 'bed'],\n",
|
|
" 'target_object': 'lamp'\n",
|
|
" }\n",
|
|
" outputs = {\n",
|
|
" 'new_instruction': 'Go to bedroom at the back left side of the house and take the mirror nearest the bedroom door',\n",
|
|
" }\n",
|
|
" template = REPLACE_OBJECT_TEMPLATE.format(json.dumps(inputs, indent=4), json.dumps(outputs, indent=4))\n",
|
|
" return template\n",
|
|
"print(get_replace_object_template())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "70c9888b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"2024-01-28 17:02:12.982758 0 / 750\n",
|
|
" use 0 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:02:21.516631 10 / 750\n",
|
|
" use 3207 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:02:30.687271 20 / 750\n",
|
|
" use 6593 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:02:38.764408 30 / 750\n",
|
|
" use 9909 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:02:46.047457 40 / 750\n",
|
|
" use 13275 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:02:53.616187 50 / 750\n",
|
|
" use 16683 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:00.754883 60 / 750\n",
|
|
" use 19998 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:10.039162 70 / 750\n",
|
|
" use 23367 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:18.122595 80 / 750\n",
|
|
" use 26633 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:26.190372 90 / 750\n",
|
|
" use 29905 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:33.605275 100 / 750\n",
|
|
" use 33258 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:41.508207 110 / 750\n",
|
|
" use 36639 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:49.187187 120 / 750\n",
|
|
" use 39933 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:03:57.650307 130 / 750\n",
|
|
" use 43265 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:05.802962 140 / 750\n",
|
|
" use 46670 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:14.862087 150 / 750\n",
|
|
" use 49887 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:24.654696 160 / 750\n",
|
|
" use 53189 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:32.348524 170 / 750\n",
|
|
" use 56649 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:39.992408 180 / 750\n",
|
|
" use 60088 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:47.581791 190 / 750\n",
|
|
" use 63313 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:04:54.968194 200 / 750\n",
|
|
" use 66443 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:03.081295 210 / 750\n",
|
|
" use 69914 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:11.042498 220 / 750\n",
|
|
" use 73115 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:18.786897 230 / 750\n",
|
|
" use 76566 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:27.032178 240 / 750\n",
|
|
" use 79859 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:34.505171 250 / 750\n",
|
|
" use 83234 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:42.915753 260 / 750\n",
|
|
" use 86662 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:51.055077 270 / 750\n",
|
|
" use 90100 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:05:59.541581 280 / 750\n",
|
|
" use 93532 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:07.890632 290 / 750\n",
|
|
" use 96856 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:16.479621 300 / 750\n",
|
|
" use 100353 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:25.275359 310 / 750\n",
|
|
" use 103854 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:33.737390 320 / 750\n",
|
|
" use 107260 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:41.555370 330 / 750\n",
|
|
" use 110745 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:49.479090 340 / 750\n",
|
|
" use 114082 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:06:58.041339 350 / 750\n",
|
|
" use 117357 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:05.942328 360 / 750\n",
|
|
" use 120673 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:14.103199 370 / 750\n",
|
|
" use 123899 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:22.556955 380 / 750\n",
|
|
" use 127261 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:30.841584 390 / 750\n",
|
|
" use 130624 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:40.961980 400 / 750\n",
|
|
" use 134033 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:49.035868 410 / 750\n",
|
|
" use 137208 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:07:58.045275 420 / 750\n",
|
|
" use 140563 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:06.704814 430 / 750\n",
|
|
" use 143926 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:14.719195 440 / 750\n",
|
|
" use 147137 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:23.215680 450 / 750\n",
|
|
" use 150332 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:31.857775 460 / 750\n",
|
|
" use 153765 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:39.738408 470 / 750\n",
|
|
" use 156996 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:47.790493 480 / 750\n",
|
|
" use 160213 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:08:55.185835 490 / 750\n",
|
|
" use 163563 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:03.000253 500 / 750\n",
|
|
" use 166932 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:10.783133 510 / 750\n",
|
|
" use 170240 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:18.652092 520 / 750\n",
|
|
" use 173502 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:26.422400 530 / 750\n",
|
|
" use 176843 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:34.636908 540 / 750\n",
|
|
" use 180363 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:43.233021 550 / 750\n",
|
|
" use 183593 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:51.823002 560 / 750\n",
|
|
" use 187014 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:09:59.527840 570 / 750\n",
|
|
" use 190399 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:08.005529 580 / 750\n",
|
|
" use 193703 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:17.602830 590 / 750\n",
|
|
" use 197048 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:26.569711 600 / 750\n",
|
|
" use 200428 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:34.537832 610 / 750\n",
|
|
" use 203882 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:43.276852 620 / 750\n",
|
|
" use 207415 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:51.461752 630 / 750\n",
|
|
" use 210626 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:10:59.469175 640 / 750\n",
|
|
" use 213857 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:07.635625 650 / 750\n",
|
|
" use 217043 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:17.050346 660 / 750\n",
|
|
" use 220181 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:24.829593 670 / 750\n",
|
|
" use 223295 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:32.432651 680 / 750\n",
|
|
" use 226488 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:40.744801 690 / 750\n",
|
|
" use 229875 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:48.741426 700 / 750\n",
|
|
" use 233400 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:11:56.509583 710 / 750\n",
|
|
" use 236675 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:12:03.689453 720 / 750\n",
|
|
" use 239984 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:12:12.271545 730 / 750\n",
|
|
" use 243264 tokens\n",
|
|
"..........\n",
|
|
"2024-01-28 17:12:20.567358 740 / 750\n",
|
|
" use 246521 tokens\n",
|
|
".........."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"logs = load_json(f'gpt_outputs_{DATASET}.json')\n",
|
|
"nlp = spacy.load(\"en_core_web_sm\")\n",
|
|
"bboxes = load_json(BBOXES_JSON_FILE)\n",
|
|
"\n",
|
|
"client = OpenAI(api_key=OPENAI_API_KEY)\n",
|
|
"template = get_replace_object_template()\n",
|
|
"\n",
|
|
"tokens = 0\n",
|
|
"for idx, r in enumerate(dataset):\n",
|
|
" if idx%10==0:\n",
|
|
" dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n",
|
|
" print('\\n', end='')\n",
|
|
" print(datetime.now(), idx, '/', len(dataset))\n",
|
|
" print(f\" use {tokens} tokens\")\n",
|
|
"\n",
|
|
" log = logs[f'reverie__{DATASET}__{idx}__0']\n",
|
|
" scan = r['scan']\n",
|
|
" target_vp = r['path'][-1]\n",
|
|
" avoid_objs = get_avoid_objs(bboxes, scan, target_vp)\n",
|
|
"# print(log['room'])\n",
|
|
"# print(log['goal'])\n",
|
|
"# print(r['instructions'][0])\n",
|
|
"# print(f'avoid {avoid_objs}')\n",
|
|
" try:\n",
|
|
" inputs = {\n",
|
|
" 'instruction': r['instructions'][0],\n",
|
|
" 'room_name': log['room'],\n",
|
|
" 'avoid_objects': avoid_objs,\n",
|
|
" 'target_object': log['goal']\n",
|
|
" } \n",
|
|
" prompt = template.replace('___inputs___', json.dumps(inputs, indent=4))\n",
|
|
" response, total_tokens = query(client, prompt)\n",
|
|
" tokens += total_tokens\n",
|
|
"\n",
|
|
" log['instruction'] = r['instructions'][0]\n",
|
|
" log['new_instruction'] = response['new_instruction']\n",
|
|
" print('.', end='')\n",
|
|
" except:\n",
|
|
" print(log)\n",
|
|
"\n",
|
|
"dump_json(logs, f'gpt_outputs_{DATASET}.json', force=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4b713517",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ============================================================\n",
|
|
"# Transform the log into REVERIE data \n",
|
|
"# ============================================================"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f8f42d90",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for index in range(len(dataset)):\n",
|
|
" original_data = dataset[index]\n",
|
|
" log = logs[f'reverie__{DATASET}__{index}__0']\n",
|
|
" print(original_data['scan'])\n",
|
|
" print(original_data['path'][-1])\n",
|
|
" print(original_data['instructions'][0])\n",
|
|
" print(original_data['instructions'][1])\n",
|
|
" print(log)\n",
|
|
" if 'new_instruction' in log:\n",
|
|
" original_data['instructions'][1] = log['new_instruction']\n",
|
|
" print(original_data['instructions'][1])\n",
|
|
" print(get_avoid_objs(bboxes, original_data['scan'], original_data['path'][-1]))\n",
|
|
" else:\n",
|
|
" del log\n",
|
|
" print()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b169a40c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dump_json(dataset, f'REVERIE_{DATASET}.json.adversarial', force=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b45d85c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ============================================================\n",
|
|
"# Method #2 : NLP tool to build the swap pool\n",
|
|
"# ( Haven't implement )\n",
|
|
"# ============================================================"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c5c751d7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_subject(nlp, text) -> str:\n",
|
|
" doc = nlp(text)\n",
|
|
" \n",
|
|
" # find subject\n",
|
|
" subject = None\n",
|
|
" for index, token in enumerate(doc):\n",
|
|
"# print(\"--->\", token, token.dep_)\n",
|
|
" if \"ROOT\" in token.dep_ or 'obj' in token.dep_:\n",
|
|
" if doc[index-1].dep_ == 'compound' or \\\n",
|
|
" (doc[index-1].dep_ == 'amod' and doc[index-1].text == 'living' ):\n",
|
|
" \n",
|
|
" if doc[index-1].text != 'level' and doc[index-1].text != 'floor':\n",
|
|
" subject = doc[index-1].text + \" \" + token.text\n",
|
|
" else:\n",
|
|
" subject = token.text\n",
|
|
" else:\n",
|
|
" subject = token.text\n",
|
|
" if subject:\n",
|
|
" subject = subject.replace('area', '')\n",
|
|
" subject = subject.replace('\\n', '')\n",
|
|
" subject = subject.replace(' ', '')\n",
|
|
" return subject"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "413d2456",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"NO ROOM: {'room': None, 'goal': 'open double doors', 'goal_relations': [{'relation': 'to the left', 'reference': None}]}\n",
|
|
" Go through the open double doors to the left\n",
|
|
"NO ROOM: {'room': None, 'goal': 'cabinet handle', 'goal_relations': [{'relation': 'bottom left', 'reference': 'washer and dryer'}, {'relation': 'beside', 'reference': 'washer and dryer'}, {'relation': 'on the second level', 'reference': None}]}\n",
|
|
" tighten the screws in the cabinet handle that is the bottom left beside the washer and dryer on the second level\n",
|
|
"NO ROOM: {'error': 'Unable to extract the goal object and its relations. Please provide a valid input.'}\n",
|
|
" Move to the closet and take the dress of the rack\n",
|
|
"NO ROOM: {'room': ['kitchen', 'family room'], 'goal': 'table', 'goal_relations': [{'relation': 'between', 'reference': 'white chairs'}]}\n",
|
|
" Walk through the kitchen into the family room and touch the table between the white chairs\n",
|
|
"stairs 22\n",
|
|
"foyer 10\n",
|
|
"laundryroom 182\n",
|
|
"lounge 126\n",
|
|
"floor 22\n",
|
|
"porch 22\n",
|
|
"masterbathroom 10\n",
|
|
"dining 5\n",
|
|
"utilityroom 29\n",
|
|
"deck 6\n",
|
|
"terrace 10\n",
|
|
"livingroom 201\n",
|
|
"office 169\n",
|
|
"lobby 12\n",
|
|
"masterbedroom 6\n",
|
|
"spa 13\n",
|
|
"staircase 9\n",
|
|
"kitchen 172\n",
|
|
"entryway 40\n",
|
|
"gym 8\n",
|
|
"bathroom 693\n",
|
|
"room 7\n",
|
|
"balcony 25\n",
|
|
"level 92\n",
|
|
"meetingroom 25\n",
|
|
"toilet 5\n",
|
|
"hall 8\n",
|
|
"closet 98\n",
|
|
"bar 7\n",
|
|
"hallway 232\n",
|
|
"familyroom 120\n",
|
|
"bedroom 356\n",
|
|
"sparoom 9\n",
|
|
"diningroom 172\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"logs = load_json('gpt_outputs.json')\n",
|
|
"nlp = spacy.load(\"en_core_web_sm\")\n",
|
|
"bboxes = load_json(BBOXES_JSON_FILE)\n",
|
|
"\n",
|
|
"rooms = {}\n",
|
|
"\n",
|
|
"for idx, r in enumerate(reverie_train):\n",
|
|
" try:\n",
|
|
" room_descr = logs[f'reverie__train__{idx}__0']['room']\n",
|
|
" room = get_subject(nlp, room_descr)\n",
|
|
" if room in rooms:\n",
|
|
" rooms[room] += 1\n",
|
|
" else:\n",
|
|
" rooms[room] = 1\n",
|
|
" except:\n",
|
|
" print(\"NO ROOM:\", logs[f'reverie__train__{idx}__0'])\n",
|
|
" print(\" \", r['instructions'][0])\n",
|
|
" \n",
|
|
"selected_rooms = set()\n",
|
|
"for k, v in rooms.items():\n",
|
|
" if v >= 5:\n",
|
|
" selected_rooms.add(k)\n",
|
|
"\n",
|
|
"for i in selected_rooms:\n",
|
|
" print(i, rooms[i])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "08fc5b71",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ============================================================\n",
|
|
"# DEBUG \n",
|
|
"# ============================================================"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a1524ce5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"reverie_train = load_json(REVERIE_TRAIN_JSON_FILE)\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|