feat: complete NOT_FOUND but always 50:50

- Notice: comment out RL
- Notice: always 50:50, seems there exist some bugs
This commit is contained in:
Ting-Jun Wang 2023-11-07 01:21:15 +08:00
parent 03a3e5b489
commit 595866c2f4
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 53 additions and 19 deletions

View File

@ -35,12 +35,12 @@ class BaseAgent(object):
self.losses = [] # For learning agents self.losses = [] # For learning agents
def write_results(self): def write_results(self):
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()] output = [{'instr_id':k, 'trajectory': v[0], 'found': v[1]} for k,v in self.results.items()]
with open(self.results_path, 'w') as f: with open(self.results_path, 'w') as f:
json.dump(output, f) json.dump(output, f)
def get_results(self): def get_results(self):
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()] output = [{'instr_id': k, 'trajectory': v[0], 'found': v[1]} for k, v in self.results.items()]
return output return output
def rollout(self, **args): def rollout(self, **args):
@ -61,17 +61,19 @@ class BaseAgent(object):
if iters is not None: if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before) # For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters): for i in range(iters):
for traj in self.rollout(**kwargs): traj, found = self.rollout(**kwargs)
for index, traj in enumerate(traj):
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = traj['path'] self.results[traj['instr_id']] = (traj['path'], found[index])
else: # Do a full round else: # Do a full round
while True: while True:
for traj in self.rollout(**kwargs): traj, found = self.rollout(**kwargs)
for index, traj in enumerate(traj):
if traj['instr_id'] in self.results: if traj['instr_id'] in self.results:
looped = True looped = True
else: else:
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = traj['path'] self.results[traj['instr_id']] = (traj['path'], found[index])
if looped: if looped:
break break
@ -344,8 +346,6 @@ class Seq2SeqAgent(BaseAgent):
# Supervised training # Supervised training
target = self._teacher_action(perm_obs, ended) target = self._teacher_action(perm_obs, ended)
for i in perm_obs:
print(i['found'], end=' ')
ml_loss += self.criterion(logit, target) ml_loss += self.criterion(logit, target)
@ -390,14 +390,21 @@ class Seq2SeqAgent(BaseAgent):
cpu_a_t[i] = -2 cpu_a_t[i] = -2
print(cpu_a_t)
# Make action and get the new state # Make action and get the new state
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found) self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
print(self.feedback, found)
'''
print(self.feedback, end=' ')
print(cpu_a_t, end=' ')
for i in perm_obs:
print(i['found'], end=' ')
print(found)
print()
'''
obs = np.array(self.env._get_obs()) obs = np.array(self.env._get_obs())
perm_obs = obs[perm_idx] # Perm the obs for the resu perm_obs = obs[perm_idx] # Perm the obs for the resu
'''
if train_rl: if train_rl:
# Calculate the mask and reward # Calculate the mask and reward
dist = np.zeros(batch_size, np.float32) dist = np.zeros(batch_size, np.float32)
@ -451,6 +458,7 @@ class Seq2SeqAgent(BaseAgent):
masks.append(mask) masks.append(mask)
last_dist[:] = dist last_dist[:] = dist
last_ndtw[:] = ndtw_score last_ndtw[:] = ndtw_score
'''
# Update the finished actions # Update the finished actions
# -1 means ended or ignored (already ended) # -1 means ended or ignored (already ended)
@ -461,7 +469,7 @@ class Seq2SeqAgent(BaseAgent):
if ended.all(): if ended.all():
break break
'''
if train_rl: if train_rl:
# Last action in A2C # Last action in A2C
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
@ -472,7 +480,6 @@ class Seq2SeqAgent(BaseAgent):
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1) visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng) self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual', visual_inputs = {'mode': 'visual',
'sentence': language_features, 'sentence': language_features,
'attention_mask': visual_attention_mask, 'attention_mask': visual_attention_mask,
@ -523,6 +530,7 @@ class Seq2SeqAgent(BaseAgent):
self.loss += rl_loss self.loss += rl_loss
self.logs['RL_loss'].append(rl_loss.item()) self.logs['RL_loss'].append(rl_loss.item())
'''
if train_ml is not None: if train_ml is not None:
self.loss += ml_loss * train_ml / batch_size self.loss += ml_loss * train_ml / batch_size
@ -533,8 +541,7 @@ class Seq2SeqAgent(BaseAgent):
else: else:
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless. self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
print('\n') return traj, found
return traj
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment ''' ''' Evaluate once on each instruction in the current environment '''

View File

@ -55,11 +55,16 @@ class Evaluation(object):
near_d = d near_d = d
return near_id return near_id
def _score_item(self, instr_id, path): def _score_item(self, instr_id, path, predict_found):
''' Calculate error based on the final position in trajectory, and also ''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule). the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] ''' The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id.split('_')[-2]] gt = self.gt[instr_id.split('_')[-2]]
index = int(instr_id.split('_')[-1])
gt_instruction = gt['instructions'][index]
gt_found = gt['found'][index]
start = gt['path'][0] start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position' assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1] goal = gt['path'][-1]
@ -68,6 +73,19 @@ class Evaluation(object):
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal]) self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal]) self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
self.scores['trajectory_steps'].append(len(path)-1) self.scores['trajectory_steps'].append(len(path)-1)
# <STOP> <NOT_FOUND> score
score = 0
if gt_found == True:
if predict_found == -1:
score = 1
else:
if predict_found == -2:
score = 1
self.scores['found_count'] += score
distance = 0 # length of the path in meters distance = 0 # length of the path in meters
prev = path[0] prev = path[0]
for curr in path[1:]: for curr in path[1:]:
@ -81,6 +99,7 @@ class Evaluation(object):
def score(self, output_file): def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location ''' ''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list) self.scores = defaultdict(list)
self.scores['found_count'] = 0
instr_ids = set(self.instr_ids) instr_ids = set(self.instr_ids)
if type(output_file) is str: if type(output_file) is str:
with open(output_file) as f: with open(output_file) as f:
@ -90,12 +109,14 @@ class Evaluation(object):
# print('result length', len(results)) # print('result length', len(results))
# print("RESULT:", results) # print("RESULT:", results)
path_counter = 0
for item in results: for item in results:
# Check against expected ids # Check against expected ids
if item['instr_id'] in instr_ids: if item['instr_id'] in instr_ids:
# print("{} exist".format(item['instr_id'])) # print("{} exist".format(item['instr_id']))
instr_ids.remove(item['instr_id']) instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory']) self._score_item(item['instr_id'], item['trajectory'], item['found'])
path_counter += 1
else: else:
print("{} not exist".format(item['instr_id'])) print("{} not exist".format(item['instr_id']))
print(item) print(item)
@ -108,7 +129,8 @@ class Evaluation(object):
'nav_error': np.average(self.scores['nav_errors']), 'nav_error': np.average(self.scores['nav_errors']),
'oracle_error': np.average(self.scores['oracle_errors']), 'oracle_error': np.average(self.scores['oracle_errors']),
'steps': np.average(self.scores['trajectory_steps']), 'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths']) 'lengths': np.average(self.scores['trajectory_lengths']),
'found_score': self.scores['found_count'] / path_counter
} }
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin]) num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors'])) score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))

View File

@ -105,6 +105,9 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
# Run validation # Run validation
loss_str = "iter {}".format(iter) loss_str = "iter {}".format(iter)
save_results = []
for env_name, (env, evaluator) in val_envs.items(): for env_name, (env, evaluator) in val_envs.items():
listner.env = env listner.env = env
@ -112,6 +115,8 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
listner.test(use_dropout=False, feedback='argmax', iters=None) listner.test(use_dropout=False, feedback='argmax', iters=None)
result = listner.get_results() result = listner.get_results()
score_summary, _ = evaluator.score(result) score_summary, _ = evaluator.score(result)
print(score_summary)
loss_str += ", %s " % env_name loss_str += ", %s " % env_name
for metric, val in score_summary.items(): for metric, val in score_summary.items():
if metric in ['spl']: if metric in ['spl']:
@ -199,7 +204,7 @@ def train_val(test_only=False):
else: else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = ['val_unseen'] val_env_names = ['train','val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
from collections import OrderedDict from collections import OrderedDict