feat: load confidence & clip target

This commit is contained in:
Ting-Jun Wang 2024-06-30 19:44:31 +08:00
parent cd2e0a30e4
commit 1547974692
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 35 additions and 7 deletions

View File

@ -26,6 +26,8 @@ from langchain.schema import (
) )
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from data_utils import load_json
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from prompt.planner_prompt import ( from prompt.planner_prompt import (
ACTION_PROMPT, ACTION_PROMPT,
@ -52,6 +54,8 @@ TEMP_STEPS_COUNTER = 0
STEPS_COUNTER = 0 STEPS_COUNTER = 0
NOW_LOCATION = None NOW_LOCATION = None
THRESHOLD = 0.2812
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = ( MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
"Invalid Format: Missing 'Action:' after 'Thought:" "Invalid Format: Missing 'Action:' after 'Thought:"
) )
@ -62,6 +66,18 @@ FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
"Parsing LLM output produced both a final answer and a parse-able action:" "Parsing LLM output produced both a final answer and a parse-able action:"
) )
print("Load CLIP confidence file...")
confidences = load_json('../datasets/REVERIE/annotations/confidence.json')
print("Loaded")
def is_found(scan, vp, clip_target):
found = False
for obj in confidences[scan][vp]:
prob = confidences[scan][vp][obj][clip_target]
if prob >= THRESHOLD:
found = True
return found
class NavGPTOutputParser(AgentOutputParser): class NavGPTOutputParser(AgentOutputParser):
"""MRKL Output parser for the chat agent.""" """MRKL Output parser for the chat agent."""
@ -759,9 +775,11 @@ class NavGPTAgent(BaseAgent):
print(obs[0]['start']) print(obs[0]['start'])
print(obs[0]['target']) print(obs[0]['target'])
print(obs[0]['new_reverie_id']) print(obs[0]['new_reverie_id'])
print(obs[0]['clip_target'])
NOW_LOCATION = obs[0]['start'] NOW_LOCATION = obs[0]['start']
print("==") print("==")
# Initialize the trajectory # Initialize the trajectory

View File

@ -45,3 +45,8 @@ def construct_reverie_instrs(anno_dir, dataset, splits):
del new_item['instr_encodings'] del new_item['instr_encodings']
data.append(new_item) data.append(new_item)
return data return data
def load_json(f):
with open(f) as fp:
data = json.load(fp)
return data

View File

@ -144,7 +144,8 @@ class Simulator(object):
heading: int, heading: int,
elevation: int, elevation: int,
start: str, start: str,
target: str target: str,
clip_target: str,
): ):
self.heading = heading self.heading = heading
self.elevation = elevation self.elevation = elevation
@ -152,6 +153,7 @@ class Simulator(object):
self.viewpoint_ID = viewpoint_ID self.viewpoint_ID = viewpoint_ID
self.start = start self.start = start
self.target = target self.target = target
self.clip_target = clip_target
# Load navigable dict # Load navigable dict
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json') navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
with open(navigable_path, 'r') as f: with open(navigable_path, 'r') as f:
@ -185,7 +187,8 @@ class Simulator(object):
'elevation': self.elevation, 'elevation': self.elevation,
'candidate': self.candidate, 'candidate': self.candidate,
'start': self.start, 'start': self.start,
'target': self.target 'target': self.target,
'clip_target': self.clip_target,
} }
return self.state return self.state
@ -230,9 +233,9 @@ class EnvBatch(object):
def _make_id(self, scanId, viewpointId): def _make_id(self, scanId, viewpointId):
return scanId + '_' + viewpointId return scanId + '_' + viewpointId
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets): def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets, clip_targets):
for i, (scanId, viewpointId, heading, start, target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets)): for i, (scanId, viewpointId, heading, start, target, clip_target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets, clip_targets)):
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target) self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target, clip_target)
def getStates(self): def getStates(self):
""" """
@ -358,7 +361,8 @@ class REVERIENavBatch(object):
'path_id' : item['path_id'], 'path_id' : item['path_id'],
'start': item['start'], 'start': item['start'],
'new_reverie_id': item['new_reverie_id'], 'new_reverie_id': item['new_reverie_id'],
'target': item['target'] 'target': item['target'],
'clip_target': item['clip_target']
} }
# RL reward. The negative distance between the state and the final state # RL reward. The negative distance between the state and the final state
# There are multiple gt end viewpoints on REVERIE. # There are multiple gt end viewpoints on REVERIE.
@ -382,7 +386,8 @@ class REVERIENavBatch(object):
headings = [item['heading'] for item in self.batch] headings = [item['heading'] for item in self.batch]
starts = [item['start'] for item in self.batch] starts = [item['start'] for item in self.batch]
targets = [item['target'] for item in self.batch] targets = [item['target'] for item in self.batch]
self.env.newEpisodes(scanIds, starts, headings, starts, targets) clip_targets = [item['clip_target'] for item in self.batch]
self.env.newEpisodes(scanIds, starts, headings, starts, targets, clip_targets)
return self._get_obs() return self._get_obs()
def step(self, next_viewpoint_IDs): def step(self, next_viewpoint_IDs):