feat: load confidence & clip target
This commit is contained in:
parent
cd2e0a30e4
commit
1547974692
@ -26,6 +26,8 @@ from langchain.schema import (
|
||||
)
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
||||
from data_utils import load_json
|
||||
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from prompt.planner_prompt import (
|
||||
ACTION_PROMPT,
|
||||
@ -52,6 +54,8 @@ TEMP_STEPS_COUNTER = 0
|
||||
STEPS_COUNTER = 0
|
||||
NOW_LOCATION = None
|
||||
|
||||
THRESHOLD = 0.2812
|
||||
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
||||
"Invalid Format: Missing 'Action:' after 'Thought:"
|
||||
)
|
||||
@ -62,6 +66,18 @@ FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||
)
|
||||
|
||||
print("Load CLIP confidence file...")
|
||||
confidences = load_json('../datasets/REVERIE/annotations/confidence.json')
|
||||
print("Loaded")
|
||||
|
||||
def is_found(scan, vp, clip_target):
|
||||
found = False
|
||||
for obj in confidences[scan][vp]:
|
||||
prob = confidences[scan][vp][obj][clip_target]
|
||||
|
||||
if prob >= THRESHOLD:
|
||||
found = True
|
||||
return found
|
||||
|
||||
class NavGPTOutputParser(AgentOutputParser):
|
||||
"""MRKL Output parser for the chat agent."""
|
||||
@ -759,9 +775,11 @@ class NavGPTAgent(BaseAgent):
|
||||
print(obs[0]['start'])
|
||||
print(obs[0]['target'])
|
||||
print(obs[0]['new_reverie_id'])
|
||||
print(obs[0]['clip_target'])
|
||||
NOW_LOCATION = obs[0]['start']
|
||||
|
||||
|
||||
|
||||
print("==")
|
||||
|
||||
# Initialize the trajectory
|
||||
|
||||
@ -45,3 +45,8 @@ def construct_reverie_instrs(anno_dir, dataset, splits):
|
||||
del new_item['instr_encodings']
|
||||
data.append(new_item)
|
||||
return data
|
||||
|
||||
def load_json(f):
|
||||
with open(f) as fp:
|
||||
data = json.load(fp)
|
||||
return data
|
||||
|
||||
@ -144,7 +144,8 @@ class Simulator(object):
|
||||
heading: int,
|
||||
elevation: int,
|
||||
start: str,
|
||||
target: str
|
||||
target: str,
|
||||
clip_target: str,
|
||||
):
|
||||
self.heading = heading
|
||||
self.elevation = elevation
|
||||
@ -152,6 +153,7 @@ class Simulator(object):
|
||||
self.viewpoint_ID = viewpoint_ID
|
||||
self.start = start
|
||||
self.target = target
|
||||
self.clip_target = clip_target
|
||||
# Load navigable dict
|
||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||
with open(navigable_path, 'r') as f:
|
||||
@ -185,7 +187,8 @@ class Simulator(object):
|
||||
'elevation': self.elevation,
|
||||
'candidate': self.candidate,
|
||||
'start': self.start,
|
||||
'target': self.target
|
||||
'target': self.target,
|
||||
'clip_target': self.clip_target,
|
||||
}
|
||||
return self.state
|
||||
|
||||
@ -230,9 +233,9 @@ class EnvBatch(object):
|
||||
def _make_id(self, scanId, viewpointId):
|
||||
return scanId + '_' + viewpointId
|
||||
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets):
|
||||
for i, (scanId, viewpointId, heading, start, target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target)
|
||||
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets, clip_targets):
|
||||
for i, (scanId, viewpointId, heading, start, target, clip_target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets, clip_targets)):
|
||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target, clip_target)
|
||||
|
||||
def getStates(self):
|
||||
"""
|
||||
@ -358,7 +361,8 @@ class REVERIENavBatch(object):
|
||||
'path_id' : item['path_id'],
|
||||
'start': item['start'],
|
||||
'new_reverie_id': item['new_reverie_id'],
|
||||
'target': item['target']
|
||||
'target': item['target'],
|
||||
'clip_target': item['clip_target']
|
||||
}
|
||||
# RL reward. The negative distance between the state and the final state
|
||||
# There are multiple gt end viewpoints on REVERIE.
|
||||
@ -382,7 +386,8 @@ class REVERIENavBatch(object):
|
||||
headings = [item['heading'] for item in self.batch]
|
||||
starts = [item['start'] for item in self.batch]
|
||||
targets = [item['target'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, starts, headings, starts, targets)
|
||||
clip_targets = [item['clip_target'] for item in self.batch]
|
||||
self.env.newEpisodes(scanIds, starts, headings, starts, targets, clip_targets)
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, next_viewpoint_IDs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user