feat: complete but still exists bugs

This commit is contained in:
Ting-Jun Wang 2023-11-11 13:54:45 +08:00
parent d02bb2332c
commit da0640ab06
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 107 additions and 45 deletions

View File

@ -38,7 +38,7 @@ class BaseAgent(object):
json.dump(output, f) json.dump(output, f)
def get_results(self): def get_results(self):
output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()] output = [{'instr_id': k, 'trajectory': v, 'ref': r, 'found': found} for k, (v,r, found) in self.results.items()]
return output return output
def rollout(self, **args): def rollout(self, **args):
@ -59,17 +59,19 @@ class BaseAgent(object):
if iters is not None: if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before) # For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters): for i in range(iters):
for traj in self.rollout(**kwargs): trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['ref']) self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
else: # Do a full round else: # Do a full round
while True: while True:
for traj in self.rollout(**kwargs): trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results: if traj['instr_id'] in self.results:
looped = True looped = True
else: else:
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['ref']) self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
if looped: if looped:
break break
@ -154,15 +156,21 @@ class Seq2SeqAgent(BaseAgent):
return Variable(torch.from_numpy(features), requires_grad=False).cuda() return Variable(torch.from_numpy(features), requires_grad=False).cuda()
def _candidate_variable(self, obs): def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate']) for ob in obs] candidate_leng = [len(ob['candidate'])+1 for ob in obs]
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END # Note: The candidate_feat at len(ob['candidate']) is the feature for the END
# which is zero in my implementation # which is zero in my implementation
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']): for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature'] candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
return torch.from_numpy(candidate_feat).cuda(), candidate_leng result = result.cuda()
return result, candidate_leng
def _object_variable(self, obs): def _object_variable(self, obs):
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
@ -190,7 +198,7 @@ class Seq2SeqAgent(BaseAgent):
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
def _teacher_action(self, obs, ended, cand_size): def _teacher_action(self, obs, ended, cand_size, candidate_leng):
""" """
Extract teacher actions into variable. Extract teacher actions into variable.
:param obs: The observation. :param obs: The observation.
@ -208,7 +216,11 @@ class Seq2SeqAgent(BaseAgent):
break break
else: # Stop here else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
if ob['found']:
a[i] = cand_size - 1 a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def _teacher_REF(self, obs, just_ended): def _teacher_REF(self, obs, just_ended):
@ -242,7 +254,7 @@ class Seq2SeqAgent(BaseAgent):
for i, idx in enumerate(perm_idx): for i, idx in enumerate(perm_idx):
action = a_t[i] action = a_t[i]
if action != -1: # -1 is the <stop> action if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action] select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex'] src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId'] trg_point = select_candidate['pointId']
@ -315,6 +327,7 @@ class Seq2SeqAgent(BaseAgent):
# Initialization the tracking state # Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
just_ended = np.array([False] * batch_size) just_ended = np.array([False] * batch_size)
found = np.array([None] * batch_size)
# Init the logs # Init the logs
rewards = [] rewards = []
@ -330,6 +343,15 @@ class Seq2SeqAgent(BaseAgent):
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs) input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
'''
for i in candidate_feat:
print(candidate_leng)
for j in i:
print(j)
print()
'''
# the first [CLS] token, initialized by the language BERT, servers # the first [CLS] token, initialized by the language BERT, servers
# as the agent's state passing through time steps # as the agent's state passing through time steps
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
@ -358,6 +380,7 @@ class Seq2SeqAgent(BaseAgent):
h_t, logit, logit_REF = self.vln_bert(**visual_inputs) h_t, logit, logit_REF = self.vln_bert(**visual_inputs)
hidden_states.append(h_t) hidden_states.append(h_t)
# print('time step', t) # print('time step', t)
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
@ -372,7 +395,7 @@ class Seq2SeqAgent(BaseAgent):
logit_REF.masked_fill_(candidate_mask_obj, -float('inf')) logit_REF.masked_fill_(candidate_mask_obj, -float('inf'))
# Supervised training # Supervised training
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1)) target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
ml_loss += self.criterion(logit, target) ml_loss += self.criterion(logit, target)
# Determine next model inputs # Determine next model inputs
@ -400,7 +423,8 @@ class Seq2SeqAgent(BaseAgent):
# NOTE: Env action is in the perm_obs space # NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy() cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t): for i, next_id in enumerate(cpu_a_t):
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
and (not ended[i]): # just stopped and forced stopped
just_ended[i] = True just_ended[i] = True
if self.feedback == 'argmax': if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0) _, ref_t = logit_REF[i].max(0)
@ -409,8 +433,25 @@ class Seq2SeqAgent(BaseAgent):
else: else:
just_ended[i] = False just_ended[i] = False
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end> if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1 cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
cpu_a_t[i] = -1
found[i] = -1
elif (next_id == (candidate_leng[i]-1)):
cpu_a_t[i] = -2
found[i] = -2
'''
print("MODE: ", self.feedback)
print("logit: ", logit)
print("leng:", candidate_leng)
print("cpu_a_t: ", cpu_a_t)
if train_ml is not None:
print("target: ", target)
for i in perm_obs:
print(i['found'], i['instructions'])
print()
'''
''' Supervised training for REF ''' ''' Supervised training for REF '''
if train_ml is not None: if train_ml is not None:
@ -456,6 +497,19 @@ class Seq2SeqAgent(BaseAgent):
# reward[i] = -2.0 # reward[i] = -2.0
if dist[i] < 1.0: # Correct if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0 reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] += 1
else:
reward[i] -= 2
else: # Incorrect
reward[i] = -2.0
elif action_idx == -2:
if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] -= 2
else:
reward[i] += 1
else: # Incorrect else: # Incorrect
reward[i] = -2.0 reward[i] = -2.0
else: # The action is not end else: # The action is not end
@ -479,9 +533,15 @@ class Seq2SeqAgent(BaseAgent):
# Update the finished actions # Update the finished actions
# -1 means ended or ignored (already ended) # -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1)) ended[:] = np.logical_or(ended, (cpu_a_t == -1))
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
# Early exit if all ended # Early exit if all ended
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
if ended.all(): if ended.all():
'''
if train_ml is None:
print(target_found, found)
'''
break break
if train_rl: if train_rl:
@ -563,7 +623,8 @@ class Seq2SeqAgent(BaseAgent):
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
return traj
return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment ''' ''' Evaluate once on each instruction in the current environment '''

View File

@ -115,6 +115,7 @@ class R2RBatch():
new_item = dict(item) new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['id'], j) new_item['instr_id'] = '%s_%d' % (item['id'], j)
new_item['instructions'] = instr new_item['instructions'] = instr
new_item['found'] = item['found'][j]
''' BERT tokenizer ''' ''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr) instr_tokens = tokenizer.tokenize(instr)
@ -332,7 +333,8 @@ class R2RBatch():
'gt_path' : item['path'], 'gt_path' : item['path'],
'path_id' : item['id'], 'path_id' : item['id'],
'objId': str(item['objId']), # target objId 'objId': str(item['objId']), # target objId
'candidate_obj': (obj_local_pos, obj_features, candidate_objId) 'candidate_obj': (obj_local_pos, obj_features, candidate_objId),
'found': item['found']
}) })
if 'instr_encoding' in item: if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding'] obs[-1]['instr_encoding'] = item['instr_encoding']

View File

@ -62,11 +62,12 @@ class Evaluation(object):
near_d = d near_d = d
return near_id return near_id
def _score_item(self, instr_id, path, ref_objId): def _score_item(self, instr_id, path, ref_objId, predict_found):
''' Calculate error based on the final position in trajectory, and also ''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule). the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] ''' The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id[:-2]] gt = self.gt[instr_id[:-2]]
index = int(instr_id.split('_')[-1])
start = gt['path'][0] start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position' assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1] goal = gt['path'][-1]
@ -86,6 +87,16 @@ class Evaluation(object):
self.distances[gt['scan']][start][goal] self.distances[gt['scan']][start][goal]
) )
# print(predict_found, gt['found'], gt['found'][index])
if gt['found'][index] == True:
if predict_found == -1:
self.scores['found_count'] += 1
else:
if predict_found == -2:
self.scores['found_count'] += 1
# REF sucess or not # REF sucess or not
if ref_objId == str(gt['objId']): if ref_objId == str(gt['objId']):
self.scores['rgs'].append(1) self.scores['rgs'].append(1)
@ -111,33 +122,11 @@ class Evaluation(object):
self.scores['oracle_visible'].append(oracle_succ) self.scores['oracle_visible'].append(oracle_succ)
# # if self.scores['nav_errors'][-1] < self.error_margin:
# # print('item', item)
# ndtw_path = [k[0] for k in item['trajectory']]
# # print('path', ndtw_path)
#
# path_id = item['instr_id'][:-2]
# # print('path id', path_id)
# path_scan_id, path_ref = self.scan_gts[path_id]
# # print('path_scan_id', path_scan_id)
# # print('path_ref', path_ref)
#
# path_act = []
# for jdx, pid in enumerate(ndtw_path):
# if jdx != 0:
# if pid != path_act[-1]:
# path_act.append(pid)
# else:
# path_act.append(pid)
# # print('path act', path_act)
#
# ndtw_score = self.ndtw_criterion[path_scan_id](path_act, path_ref, metric='ndtw')
# ndtw_scores.append(ndtw_score)
# print('nDTW score: ', np.average(ndtw_scores))
def score(self, output_file): def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location ''' ''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list) self.scores = defaultdict(list)
self.scores['found_count'] = 0
instr_ids = set(self.instr_ids) instr_ids = set(self.instr_ids)
if type(output_file) is str: if type(output_file) is str:
with open(output_file) as f: with open(output_file) as f:
@ -145,12 +134,15 @@ class Evaluation(object):
else: else:
results = output_file results = output_file
print('result length', len(results)) print('result length', len(results))
path_counter = 0
for item in results: for item in results:
# Check against expected ids # Check against expected ids
if item['instr_id'] in instr_ids: if item['instr_id'] in instr_ids:
instr_ids.remove(item['instr_id']) instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory'], item['ref']) self._score_item(item['instr_id'], item['trajectory'], item['ref'], item['found'])
path_counter += 1
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
@ -159,9 +151,11 @@ class Evaluation(object):
score_summary = { score_summary = {
'steps': np.average(self.scores['trajectory_steps']), 'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths']) 'lengths': np.average(self.scores['trajectory_lengths']),
'found_score': self.scores['found_count'] / path_counter
} }
end_successes = sum(self.scores['visible']) end_successes = sum(self.scores['visible'])
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible'])) score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
oracle_successes = sum(self.scores['oracle_visible']) oracle_successes = sum(self.scores['oracle_visible'])
score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible'])) score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible']))
@ -172,7 +166,10 @@ class Evaluation(object):
] ]
score_summary['spl'] = np.average(spl) score_summary['spl'] = np.average(spl)
try:
assert len(self.scores['rgs']) == len(self.instr_ids) assert len(self.scores['rgs']) == len(self.instr_ids)
except:
print(len(self.scores['rgs']), len(self.instr_ids))
num_rgs = sum(self.scores['rgs']) num_rgs = sum(self.scores['rgs'])
score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs'])) score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs']))

View File

@ -218,7 +218,8 @@ def train_val(test_only=False):
val_env_names = ['val_train_seen'] val_env_names = ['val_train_seen']
else: else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
val_env_names = ['val_seen', 'val_unseen'] # val_env_names = ['val_seen', 'val_unseen']
val_env_names = ['train', 'val_unseen']
# val_env_names = ['val_unseen'] # val_env_names = ['val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
@ -242,6 +243,7 @@ def train_val(test_only=False):
if args.train == 'listener': if args.train == 'listener':
train(train_env, tok, args.iters, val_envs=val_envs) train(train_env, tok, args.iters, val_envs=val_envs)
# train(train_env, tok, 1000, val_envs=val_envs, log_every=10)
elif args.train == 'validlistener': elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs) valid(train_env, tok, val_envs=val_envs)
else: else:

View File

@ -582,7 +582,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
str_format = "{0:." + str(decimals) + "f}" str_format = "{0:." + str(decimals) + "f}"
percents = str_format.format(100 * (iteration / float(total))) percents = str_format.format(100 * (iteration / float(total)))
filled_length = int(round(bar_length * iteration / float(total))) filled_length = int(round(bar_length * iteration / float(total)))
bar = 'LL' * filled_length + '-' * (bar_length - filled_length) bar = 'L' * filled_length + '-' * (bar_length - filled_length)
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),