feat: complete NOT_FOUND but always 50:50
- Notice: comment out RL - Notice: always 50:50, seems there exist some bugs
This commit is contained in:
parent
03a3e5b489
commit
595866c2f4
@ -35,12 +35,12 @@ class BaseAgent(object):
|
|||||||
self.losses = [] # For learning agents
|
self.losses = [] # For learning agents
|
||||||
|
|
||||||
def write_results(self):
|
def write_results(self):
|
||||||
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
|
output = [{'instr_id':k, 'trajectory': v[0], 'found': v[1]} for k,v in self.results.items()]
|
||||||
with open(self.results_path, 'w') as f:
|
with open(self.results_path, 'w') as f:
|
||||||
json.dump(output, f)
|
json.dump(output, f)
|
||||||
|
|
||||||
def get_results(self):
|
def get_results(self):
|
||||||
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
|
output = [{'instr_id': k, 'trajectory': v[0], 'found': v[1]} for k, v in self.results.items()]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
@ -61,17 +61,19 @@ class BaseAgent(object):
|
|||||||
if iters is not None:
|
if iters is not None:
|
||||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
for traj in self.rollout(**kwargs):
|
traj, found = self.rollout(**kwargs)
|
||||||
|
for index, traj in enumerate(traj):
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = traj['path']
|
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
for traj in self.rollout(**kwargs):
|
traj, found = self.rollout(**kwargs)
|
||||||
|
for index, traj in enumerate(traj):
|
||||||
if traj['instr_id'] in self.results:
|
if traj['instr_id'] in self.results:
|
||||||
looped = True
|
looped = True
|
||||||
else:
|
else:
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = traj['path']
|
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||||
if looped:
|
if looped:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -344,8 +346,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended)
|
target = self._teacher_action(perm_obs, ended)
|
||||||
for i in perm_obs:
|
|
||||||
print(i['found'], end=' ')
|
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
|
|
||||||
@ -390,14 +390,21 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
cpu_a_t[i] = -2
|
cpu_a_t[i] = -2
|
||||||
|
|
||||||
|
|
||||||
print(cpu_a_t)
|
|
||||||
|
|
||||||
# Make action and get the new state
|
# Make action and get the new state
|
||||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
|
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
|
||||||
print(self.feedback, found)
|
|
||||||
|
'''
|
||||||
|
print(self.feedback, end=' ')
|
||||||
|
print(cpu_a_t, end=' ')
|
||||||
|
for i in perm_obs:
|
||||||
|
print(i['found'], end=' ')
|
||||||
|
print(found)
|
||||||
|
print()
|
||||||
|
'''
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||||
|
|
||||||
|
'''
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Calculate the mask and reward
|
# Calculate the mask and reward
|
||||||
dist = np.zeros(batch_size, np.float32)
|
dist = np.zeros(batch_size, np.float32)
|
||||||
@ -451,6 +458,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
last_dist[:] = dist
|
last_dist[:] = dist
|
||||||
last_ndtw[:] = ndtw_score
|
last_ndtw[:] = ndtw_score
|
||||||
|
'''
|
||||||
|
|
||||||
# Update the finished actions
|
# Update the finished actions
|
||||||
# -1 means ended or ignored (already ended)
|
# -1 means ended or ignored (already ended)
|
||||||
@ -461,7 +469,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if ended.all():
|
if ended.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
'''
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Last action in A2C
|
# Last action in A2C
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
@ -472,7 +480,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||||
|
|
||||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||||
''' Visual BERT '''
|
|
||||||
visual_inputs = {'mode': 'visual',
|
visual_inputs = {'mode': 'visual',
|
||||||
'sentence': language_features,
|
'sentence': language_features,
|
||||||
'attention_mask': visual_attention_mask,
|
'attention_mask': visual_attention_mask,
|
||||||
@ -523,6 +530,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
self.loss += rl_loss
|
self.loss += rl_loss
|
||||||
self.logs['RL_loss'].append(rl_loss.item())
|
self.logs['RL_loss'].append(rl_loss.item())
|
||||||
|
'''
|
||||||
|
|
||||||
if train_ml is not None:
|
if train_ml is not None:
|
||||||
self.loss += ml_loss * train_ml / batch_size
|
self.loss += ml_loss * train_ml / batch_size
|
||||||
@ -533,8 +541,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||||
|
|
||||||
print('\n')
|
return traj, found
|
||||||
return traj
|
|
||||||
|
|
||||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||||
''' Evaluate once on each instruction in the current environment '''
|
''' Evaluate once on each instruction in the current environment '''
|
||||||
|
|||||||
@ -55,11 +55,16 @@ class Evaluation(object):
|
|||||||
near_d = d
|
near_d = d
|
||||||
return near_id
|
return near_id
|
||||||
|
|
||||||
def _score_item(self, instr_id, path):
|
def _score_item(self, instr_id, path, predict_found):
|
||||||
''' Calculate error based on the final position in trajectory, and also
|
''' Calculate error based on the final position in trajectory, and also
|
||||||
the closest position (oracle stopping rule).
|
the closest position (oracle stopping rule).
|
||||||
The path contains [view_id, angle, vofv] '''
|
The path contains [view_id, angle, vofv] '''
|
||||||
gt = self.gt[instr_id.split('_')[-2]]
|
gt = self.gt[instr_id.split('_')[-2]]
|
||||||
|
index = int(instr_id.split('_')[-1])
|
||||||
|
|
||||||
|
gt_instruction = gt['instructions'][index]
|
||||||
|
gt_found = gt['found'][index]
|
||||||
|
|
||||||
start = gt['path'][0]
|
start = gt['path'][0]
|
||||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||||
goal = gt['path'][-1]
|
goal = gt['path'][-1]
|
||||||
@ -68,6 +73,19 @@ class Evaluation(object):
|
|||||||
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||||
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||||
self.scores['trajectory_steps'].append(len(path)-1)
|
self.scores['trajectory_steps'].append(len(path)-1)
|
||||||
|
|
||||||
|
# <STOP> <NOT_FOUND> score
|
||||||
|
score = 0
|
||||||
|
if gt_found == True:
|
||||||
|
if predict_found == -1:
|
||||||
|
score = 1
|
||||||
|
else:
|
||||||
|
if predict_found == -2:
|
||||||
|
score = 1
|
||||||
|
self.scores['found_count'] += score
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
distance = 0 # length of the path in meters
|
distance = 0 # length of the path in meters
|
||||||
prev = path[0]
|
prev = path[0]
|
||||||
for curr in path[1:]:
|
for curr in path[1:]:
|
||||||
@ -81,6 +99,7 @@ class Evaluation(object):
|
|||||||
def score(self, output_file):
|
def score(self, output_file):
|
||||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||||
self.scores = defaultdict(list)
|
self.scores = defaultdict(list)
|
||||||
|
self.scores['found_count'] = 0
|
||||||
instr_ids = set(self.instr_ids)
|
instr_ids = set(self.instr_ids)
|
||||||
if type(output_file) is str:
|
if type(output_file) is str:
|
||||||
with open(output_file) as f:
|
with open(output_file) as f:
|
||||||
@ -90,12 +109,14 @@ class Evaluation(object):
|
|||||||
|
|
||||||
# print('result length', len(results))
|
# print('result length', len(results))
|
||||||
# print("RESULT:", results)
|
# print("RESULT:", results)
|
||||||
|
path_counter = 0
|
||||||
for item in results:
|
for item in results:
|
||||||
# Check against expected ids
|
# Check against expected ids
|
||||||
if item['instr_id'] in instr_ids:
|
if item['instr_id'] in instr_ids:
|
||||||
# print("{} exist".format(item['instr_id']))
|
# print("{} exist".format(item['instr_id']))
|
||||||
instr_ids.remove(item['instr_id'])
|
instr_ids.remove(item['instr_id'])
|
||||||
self._score_item(item['instr_id'], item['trajectory'])
|
self._score_item(item['instr_id'], item['trajectory'], item['found'])
|
||||||
|
path_counter += 1
|
||||||
else:
|
else:
|
||||||
print("{} not exist".format(item['instr_id']))
|
print("{} not exist".format(item['instr_id']))
|
||||||
print(item)
|
print(item)
|
||||||
@ -108,7 +129,8 @@ class Evaluation(object):
|
|||||||
'nav_error': np.average(self.scores['nav_errors']),
|
'nav_error': np.average(self.scores['nav_errors']),
|
||||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||||
'steps': np.average(self.scores['trajectory_steps']),
|
'steps': np.average(self.scores['trajectory_steps']),
|
||||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
'lengths': np.average(self.scores['trajectory_lengths']),
|
||||||
|
'found_score': self.scores['found_count'] / path_counter
|
||||||
}
|
}
|
||||||
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
||||||
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
||||||
|
|||||||
@ -105,6 +105,9 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
|||||||
|
|
||||||
# Run validation
|
# Run validation
|
||||||
loss_str = "iter {}".format(iter)
|
loss_str = "iter {}".format(iter)
|
||||||
|
|
||||||
|
|
||||||
|
save_results = []
|
||||||
for env_name, (env, evaluator) in val_envs.items():
|
for env_name, (env, evaluator) in val_envs.items():
|
||||||
listner.env = env
|
listner.env = env
|
||||||
|
|
||||||
@ -112,6 +115,8 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
|||||||
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
||||||
result = listner.get_results()
|
result = listner.get_results()
|
||||||
score_summary, _ = evaluator.score(result)
|
score_summary, _ = evaluator.score(result)
|
||||||
|
|
||||||
|
print(score_summary)
|
||||||
loss_str += ", %s " % env_name
|
loss_str += ", %s " % env_name
|
||||||
for metric, val in score_summary.items():
|
for metric, val in score_summary.items():
|
||||||
if metric in ['spl']:
|
if metric in ['spl']:
|
||||||
@ -199,7 +204,7 @@ def train_val(test_only=False):
|
|||||||
else:
|
else:
|
||||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||||
val_env_names = ['val_unseen']
|
val_env_names = ['train','val_unseen']
|
||||||
|
|
||||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user