feat: complete but still exists bugs

This commit is contained in:
Ting-Jun Wang 2023-11-11 13:54:45 +08:00
parent d02bb2332c
commit da0640ab06
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 107 additions and 45 deletions

View File

@ -38,7 +38,7 @@ class BaseAgent(object):
json.dump(output, f)
def get_results(self):
output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()]
output = [{'instr_id': k, 'trajectory': v, 'ref': r, 'found': found} for k, (v,r, found) in self.results.items()]
return output
def rollout(self, **args):
@ -59,17 +59,19 @@ class BaseAgent(object):
if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters):
for traj in self.rollout(**kwargs):
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
else: # Do a full round
while True:
for traj in self.rollout(**kwargs):
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results:
looped = True
else:
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
if looped:
break
@ -154,15 +156,21 @@ class Seq2SeqAgent(BaseAgent):
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate']) for ob in obs]
candidate_leng = [len(ob['candidate'])+1 for ob in obs]
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
# which is zero in my implementation
for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
result = result.cuda()
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
return result, candidate_leng
def _object_variable(self, obs):
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
@ -190,7 +198,7 @@ class Seq2SeqAgent(BaseAgent):
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
def _teacher_action(self, obs, ended, cand_size):
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
"""
Extract teacher actions into variable.
:param obs: The observation.
@ -208,7 +216,11 @@ class Seq2SeqAgent(BaseAgent):
break
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = cand_size - 1
if ob['found']:
a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
return torch.from_numpy(a).cuda()
def _teacher_REF(self, obs, just_ended):
@ -242,7 +254,7 @@ class Seq2SeqAgent(BaseAgent):
for i, idx in enumerate(perm_idx):
action = a_t[i]
if action != -1: # -1 is the <stop> action
if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId']
@ -315,6 +327,7 @@ class Seq2SeqAgent(BaseAgent):
# Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
just_ended = np.array([False] * batch_size)
found = np.array([None] * batch_size)
# Init the logs
rewards = []
@ -330,6 +343,15 @@ class Seq2SeqAgent(BaseAgent):
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
'''
for i in candidate_feat:
print(candidate_leng)
for j in i:
print(j)
print()
'''
# the first [CLS] token, initialized by the language BERT, servers
# as the agent's state passing through time steps
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
@ -358,6 +380,7 @@ class Seq2SeqAgent(BaseAgent):
h_t, logit, logit_REF = self.vln_bert(**visual_inputs)
hidden_states.append(h_t)
# print('time step', t)
# import pdb; pdb.set_trace()
@ -372,7 +395,7 @@ class Seq2SeqAgent(BaseAgent):
logit_REF.masked_fill_(candidate_mask_obj, -float('inf'))
# Supervised training
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
ml_loss += self.criterion(logit, target)
# Determine next model inputs
@ -400,7 +423,8 @@ class Seq2SeqAgent(BaseAgent):
# NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t):
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
and (not ended[i]): # just stopped and forced stopped
just_ended[i] = True
if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0)
@ -409,8 +433,25 @@ class Seq2SeqAgent(BaseAgent):
else:
just_ended[i] = False
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
cpu_a_t[i] = -1
found[i] = -1
elif (next_id == (candidate_leng[i]-1)):
cpu_a_t[i] = -2
found[i] = -2
'''
print("MODE: ", self.feedback)
print("logit: ", logit)
print("leng:", candidate_leng)
print("cpu_a_t: ", cpu_a_t)
if train_ml is not None:
print("target: ", target)
for i in perm_obs:
print(i['found'], i['instructions'])
print()
'''
''' Supervised training for REF '''
if train_ml is not None:
@ -456,6 +497,19 @@ class Seq2SeqAgent(BaseAgent):
# reward[i] = -2.0
if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] += 1
else:
reward[i] -= 2
else: # Incorrect
reward[i] = -2.0
elif action_idx == -2:
if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] -= 2
else:
reward[i] += 1
else: # Incorrect
reward[i] = -2.0
else: # The action is not end
@ -479,9 +533,15 @@ class Seq2SeqAgent(BaseAgent):
# Update the finished actions
# -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
# Early exit if all ended
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
if ended.all():
'''
if train_ml is None:
print(target_found, found)
'''
break
if train_rl:
@ -563,7 +623,8 @@ class Seq2SeqAgent(BaseAgent):
# import pdb; pdb.set_trace()
return traj
return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment '''

View File

@ -115,6 +115,7 @@ class R2RBatch():
new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['id'], j)
new_item['instructions'] = instr
new_item['found'] = item['found'][j]
''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr)
@ -332,7 +333,8 @@ class R2RBatch():
'gt_path' : item['path'],
'path_id' : item['id'],
'objId': str(item['objId']), # target objId
'candidate_obj': (obj_local_pos, obj_features, candidate_objId)
'candidate_obj': (obj_local_pos, obj_features, candidate_objId),
'found': item['found']
})
if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding']

View File

@ -62,11 +62,12 @@ class Evaluation(object):
near_d = d
return near_id
def _score_item(self, instr_id, path, ref_objId):
def _score_item(self, instr_id, path, ref_objId, predict_found):
''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id[:-2]]
index = int(instr_id.split('_')[-1])
start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1]
@ -86,6 +87,16 @@ class Evaluation(object):
self.distances[gt['scan']][start][goal]
)
# print(predict_found, gt['found'], gt['found'][index])
if gt['found'][index] == True:
if predict_found == -1:
self.scores['found_count'] += 1
else:
if predict_found == -2:
self.scores['found_count'] += 1
# REF sucess or not
if ref_objId == str(gt['objId']):
self.scores['rgs'].append(1)
@ -111,33 +122,11 @@ class Evaluation(object):
self.scores['oracle_visible'].append(oracle_succ)
# # if self.scores['nav_errors'][-1] < self.error_margin:
# # print('item', item)
# ndtw_path = [k[0] for k in item['trajectory']]
# # print('path', ndtw_path)
#
# path_id = item['instr_id'][:-2]
# # print('path id', path_id)
# path_scan_id, path_ref = self.scan_gts[path_id]
# # print('path_scan_id', path_scan_id)
# # print('path_ref', path_ref)
#
# path_act = []
# for jdx, pid in enumerate(ndtw_path):
# if jdx != 0:
# if pid != path_act[-1]:
# path_act.append(pid)
# else:
# path_act.append(pid)
# # print('path act', path_act)
#
# ndtw_score = self.ndtw_criterion[path_scan_id](path_act, path_ref, metric='ndtw')
# ndtw_scores.append(ndtw_score)
# print('nDTW score: ', np.average(ndtw_scores))
def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list)
self.scores['found_count'] = 0
instr_ids = set(self.instr_ids)
if type(output_file) is str:
with open(output_file) as f:
@ -145,12 +134,15 @@ class Evaluation(object):
else:
results = output_file
print('result length', len(results))
path_counter = 0
for item in results:
# Check against expected ids
if item['instr_id'] in instr_ids:
instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory'], item['ref'])
self._score_item(item['instr_id'], item['trajectory'], item['ref'], item['found'])
path_counter += 1
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
@ -159,9 +151,11 @@ class Evaluation(object):
score_summary = {
'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths'])
'lengths': np.average(self.scores['trajectory_lengths']),
'found_score': self.scores['found_count'] / path_counter
}
end_successes = sum(self.scores['visible'])
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
oracle_successes = sum(self.scores['oracle_visible'])
score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible']))
@ -172,7 +166,10 @@ class Evaluation(object):
]
score_summary['spl'] = np.average(spl)
assert len(self.scores['rgs']) == len(self.instr_ids)
try:
assert len(self.scores['rgs']) == len(self.instr_ids)
except:
print(len(self.scores['rgs']), len(self.instr_ids))
num_rgs = sum(self.scores['rgs'])
score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs']))

View File

@ -218,7 +218,8 @@ def train_val(test_only=False):
val_env_names = ['val_train_seen']
else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
val_env_names = ['val_seen', 'val_unseen']
# val_env_names = ['val_seen', 'val_unseen']
val_env_names = ['train', 'val_unseen']
# val_env_names = ['val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
@ -242,6 +243,7 @@ def train_val(test_only=False):
if args.train == 'listener':
train(train_env, tok, args.iters, val_envs=val_envs)
# train(train_env, tok, 1000, val_envs=val_envs, log_every=10)
elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs)
else:

View File

@ -582,7 +582,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
str_format = "{0:." + str(decimals) + "f}"
percents = str_format.format(100 * (iteration / float(total)))
filled_length = int(round(bar_length * iteration / float(total)))
bar = 'LL' * filled_length + '-' * (bar_length - filled_length)
bar = 'L' * filled_length + '-' * (bar_length - filled_length)
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),