feat: load confidence & clip target
This commit is contained in:
parent
cd2e0a30e4
commit
1547974692
@ -26,6 +26,8 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
|
||||||
|
from data_utils import load_json
|
||||||
|
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||||
from prompt.planner_prompt import (
|
from prompt.planner_prompt import (
|
||||||
ACTION_PROMPT,
|
ACTION_PROMPT,
|
||||||
@ -52,6 +54,8 @@ TEMP_STEPS_COUNTER = 0
|
|||||||
STEPS_COUNTER = 0
|
STEPS_COUNTER = 0
|
||||||
NOW_LOCATION = None
|
NOW_LOCATION = None
|
||||||
|
|
||||||
|
THRESHOLD = 0.2812
|
||||||
|
|
||||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE = (
|
||||||
"Invalid Format: Missing 'Action:' after 'Thought:"
|
"Invalid Format: Missing 'Action:' after 'Thought:"
|
||||||
)
|
)
|
||||||
@ -62,6 +66,18 @@ FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
|||||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Load CLIP confidence file...")
|
||||||
|
confidences = load_json('../datasets/REVERIE/annotations/confidence.json')
|
||||||
|
print("Loaded")
|
||||||
|
|
||||||
|
def is_found(scan, vp, clip_target):
|
||||||
|
found = False
|
||||||
|
for obj in confidences[scan][vp]:
|
||||||
|
prob = confidences[scan][vp][obj][clip_target]
|
||||||
|
|
||||||
|
if prob >= THRESHOLD:
|
||||||
|
found = True
|
||||||
|
return found
|
||||||
|
|
||||||
class NavGPTOutputParser(AgentOutputParser):
|
class NavGPTOutputParser(AgentOutputParser):
|
||||||
"""MRKL Output parser for the chat agent."""
|
"""MRKL Output parser for the chat agent."""
|
||||||
@ -759,9 +775,11 @@ class NavGPTAgent(BaseAgent):
|
|||||||
print(obs[0]['start'])
|
print(obs[0]['start'])
|
||||||
print(obs[0]['target'])
|
print(obs[0]['target'])
|
||||||
print(obs[0]['new_reverie_id'])
|
print(obs[0]['new_reverie_id'])
|
||||||
|
print(obs[0]['clip_target'])
|
||||||
NOW_LOCATION = obs[0]['start']
|
NOW_LOCATION = obs[0]['start']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("==")
|
print("==")
|
||||||
|
|
||||||
# Initialize the trajectory
|
# Initialize the trajectory
|
||||||
|
|||||||
@ -45,3 +45,8 @@ def construct_reverie_instrs(anno_dir, dataset, splits):
|
|||||||
del new_item['instr_encodings']
|
del new_item['instr_encodings']
|
||||||
data.append(new_item)
|
data.append(new_item)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def load_json(f):
|
||||||
|
with open(f) as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
return data
|
||||||
|
|||||||
@ -144,7 +144,8 @@ class Simulator(object):
|
|||||||
heading: int,
|
heading: int,
|
||||||
elevation: int,
|
elevation: int,
|
||||||
start: str,
|
start: str,
|
||||||
target: str
|
target: str,
|
||||||
|
clip_target: str,
|
||||||
):
|
):
|
||||||
self.heading = heading
|
self.heading = heading
|
||||||
self.elevation = elevation
|
self.elevation = elevation
|
||||||
@ -152,6 +153,7 @@ class Simulator(object):
|
|||||||
self.viewpoint_ID = viewpoint_ID
|
self.viewpoint_ID = viewpoint_ID
|
||||||
self.start = start
|
self.start = start
|
||||||
self.target = target
|
self.target = target
|
||||||
|
self.clip_target = clip_target
|
||||||
# Load navigable dict
|
# Load navigable dict
|
||||||
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
navigable_path = os.path.join(self.navigable_dir, self.scan_ID + '_navigable.json')
|
||||||
with open(navigable_path, 'r') as f:
|
with open(navigable_path, 'r') as f:
|
||||||
@ -185,7 +187,8 @@ class Simulator(object):
|
|||||||
'elevation': self.elevation,
|
'elevation': self.elevation,
|
||||||
'candidate': self.candidate,
|
'candidate': self.candidate,
|
||||||
'start': self.start,
|
'start': self.start,
|
||||||
'target': self.target
|
'target': self.target,
|
||||||
|
'clip_target': self.clip_target,
|
||||||
}
|
}
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@ -230,9 +233,9 @@ class EnvBatch(object):
|
|||||||
def _make_id(self, scanId, viewpointId):
|
def _make_id(self, scanId, viewpointId):
|
||||||
return scanId + '_' + viewpointId
|
return scanId + '_' + viewpointId
|
||||||
|
|
||||||
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets):
|
def newEpisodes(self, scanIds, viewpointIds, headings, starts, targets, clip_targets):
|
||||||
for i, (scanId, viewpointId, heading, start, target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets)):
|
for i, (scanId, viewpointId, heading, start, target, clip_target) in enumerate(zip(scanIds, viewpointIds, headings, starts, targets, clip_targets)):
|
||||||
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target)
|
self.sims[i].newEpisode(scanId, viewpointId, heading, 0, start, target, clip_target)
|
||||||
|
|
||||||
def getStates(self):
|
def getStates(self):
|
||||||
"""
|
"""
|
||||||
@ -358,7 +361,8 @@ class REVERIENavBatch(object):
|
|||||||
'path_id' : item['path_id'],
|
'path_id' : item['path_id'],
|
||||||
'start': item['start'],
|
'start': item['start'],
|
||||||
'new_reverie_id': item['new_reverie_id'],
|
'new_reverie_id': item['new_reverie_id'],
|
||||||
'target': item['target']
|
'target': item['target'],
|
||||||
|
'clip_target': item['clip_target']
|
||||||
}
|
}
|
||||||
# RL reward. The negative distance between the state and the final state
|
# RL reward. The negative distance between the state and the final state
|
||||||
# There are multiple gt end viewpoints on REVERIE.
|
# There are multiple gt end viewpoints on REVERIE.
|
||||||
@ -382,7 +386,8 @@ class REVERIENavBatch(object):
|
|||||||
headings = [item['heading'] for item in self.batch]
|
headings = [item['heading'] for item in self.batch]
|
||||||
starts = [item['start'] for item in self.batch]
|
starts = [item['start'] for item in self.batch]
|
||||||
targets = [item['target'] for item in self.batch]
|
targets = [item['target'] for item in self.batch]
|
||||||
self.env.newEpisodes(scanIds, starts, headings, starts, targets)
|
clip_targets = [item['clip_target'] for item in self.batch]
|
||||||
|
self.env.newEpisodes(scanIds, starts, headings, starts, targets, clip_targets)
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, next_viewpoint_IDs):
|
def step(self, next_viewpoint_IDs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user