feat: load confidence & clip target

This commit is contained in:
Ting-Jun Wang 2024-06-30 19:44:31 +08:00
parent cd2e0a30e4
commit 1547974692
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 35 additions and 7 deletions

View File

@ -26,6 +26,8 @@ from langchain.schema import (
)
from langchain.base_language import BaseLanguageModel
from data_utils import load_json
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from prompt.planner_prompt import (
ACTION_PROMPT,
@ -52,6 +54,8 @@ TEMP_STEPS_COUNTER = 0
STEPS_COUNTER = 0
NOW_LOCATION = None
THRESHOLD = 0.2812
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
"Invalid Format: Missing 'Action:' after 'Thought:"
)
@ -62,6 +66,18 @@ FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
"Parsing LLM output produced both a final answer and a parse-able action:"
)
print("Load CLIP confidence file...")
confidences = load_json('../datasets/REVERIE/annotations/confidence.json')
print("Loaded")
def is_found(scan, vp, clip_target):
found = False
for obj in confidences[scan][vp]:
prob = confidences[scan][vp][obj][clip_target]
if prob >= THRESHOLD:
found = True
return found
class NavGPTOutputParser(AgentOutputParser):
"""MRKL Output parser for the chat agent."""
@ -759,9 +775,11 @@ class NavGPTAgent(BaseAgent):
print(obs[0]['start'])
print(obs[0]['target'])
print(obs[0]['new_reverie_id'])
print(obs[0]['clip_target'])
NOW_LOCATION = obs[0]['start']
print("==")
# Initialize the trajectory

View File

@ -45,3 +45,8 @@ def construct_reverie_instrs(anno_dir, dataset, splits):
del new_item['instr_encodings']
data.append(new_item)
return data
def load_json(f):
with open(f) as fp:
data = json.load(fp)
return data

View File

@ -144,7 +144,8 @@ class Simulator(object):
heading: int,
elevation: int,
start: str,
target: str
target: str,
clip_target: str,
):
self.heading = heading
self.elevation = elevation
@ -152,6 +153,7 @@ class Simulator(object):
self.viewpoint_ID = viewpoint_ID
self.start = start
self.target = target
self.clip_target = clip_target
# Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f:
@ -185,7 +187,8 @@ class Simulator(object):
'elevation': self.elevation,
'candidate': self.candidate,
'start': self.start,
'target': self.target
'target': self.target,
'clip_target': self.clip_target,
}
return self.state
@ -230,9 +233,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets):
for i, (scanId, viewpointId, heading, start, target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target)
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets, clip_targets):
for i, (scanId, viewpointId, heading, start, target, clip_target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets, clip_targets)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target, clip_target)
def getStates(self):
"""
@ -358,7 +361,8 @@ class REVERIENavBatch(object):
'path_id' : item['path_id'],
'start': item['start'],
'new_reverie_id': item['new_reverie_id'],
'target': item['target']
'target': item['target'],
'clip_target': item['clip_target']
}
# RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE.
@ -382,7 +386,8 @@ class REVERIENavBatch(object):
headings = [item['heading'] for item in self.batch]
starts = [item['start'] for item in self.batch]
targets = [item['target'] for item in self.batch]
self.env.newEpisodes(scanIds, starts, headings, starts, targets)
clip_targets = [item['clip_target'] for item in self.batch]
self.env.newEpisodes(scanIds, starts, headings, starts, targets, clip_targets)
return self._get_obs()
def step(self, next_viewpoint_IDs):