feat: complete but still exists bugs
This commit is contained in:
parent
d02bb2332c
commit
da0640ab06
@ -38,7 +38,7 @@ class BaseAgent(object):
|
||||
json.dump(output, f)
|
||||
|
||||
def get_results(self):
|
||||
output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()]
|
||||
output = [{'instr_id': k, 'trajectory': v, 'ref': r, 'found': found} for k, (v,r, found) in self.results.items()]
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
@ -59,17 +59,19 @@ class BaseAgent(object):
|
||||
if iters is not None:
|
||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||
for i in range(iters):
|
||||
for traj in self.rollout(**kwargs):
|
||||
trajs, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(trajs):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
trajs, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(trajs):
|
||||
if traj['instr_id'] in self.results:
|
||||
looped = True
|
||||
else:
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'])
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index])
|
||||
if looped:
|
||||
break
|
||||
|
||||
@ -154,15 +156,21 @@ class Seq2SeqAgent(BaseAgent):
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) for ob in obs]
|
||||
candidate_leng = [len(ob['candidate'])+1 for ob in obs]
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
# which is zero in my implementation
|
||||
for i, ob in enumerate(obs):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
result = torch.from_numpy(candidate_feat)
|
||||
for i, ob in enumerate(obs):
|
||||
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
||||
|
||||
result = result.cuda()
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
return result, candidate_leng
|
||||
|
||||
def _object_variable(self, obs):
|
||||
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
|
||||
@ -190,7 +198,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
|
||||
|
||||
def _teacher_action(self, obs, ended, cand_size):
|
||||
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
|
||||
"""
|
||||
Extract teacher actions into variable.
|
||||
:param obs: The observation.
|
||||
@ -208,7 +216,11 @@ class Seq2SeqAgent(BaseAgent):
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = cand_size - 1
|
||||
if ob['found']:
|
||||
a[i] = cand_size - 1
|
||||
else:
|
||||
a[i] = candidate_leng[i] - 1
|
||||
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def _teacher_REF(self, obs, just_ended):
|
||||
@ -242,7 +254,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
for i, idx in enumerate(perm_idx):
|
||||
action = a_t[i]
|
||||
if action != -1: # -1 is the <stop> action
|
||||
if action != -1 and action != -2: # -1 is the <stop> action
|
||||
select_candidate = perm_obs[i]['candidate'][action]
|
||||
src_point = perm_obs[i]['viewIndex']
|
||||
trg_point = select_candidate['pointId']
|
||||
@ -315,6 +327,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Initialization the tracking state
|
||||
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
||||
just_ended = np.array([False] * batch_size)
|
||||
found = np.array([None] * batch_size)
|
||||
|
||||
# Init the logs
|
||||
rewards = []
|
||||
@ -330,6 +343,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
|
||||
'''
|
||||
for i in candidate_feat:
|
||||
print(candidate_leng)
|
||||
for j in i:
|
||||
print(j)
|
||||
print()
|
||||
'''
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, servers
|
||||
# as the agent's state passing through time steps
|
||||
language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1)
|
||||
@ -358,6 +380,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
h_t, logit, logit_REF = self.vln_bert(**visual_inputs)
|
||||
hidden_states.append(h_t)
|
||||
|
||||
|
||||
# print('time step', t)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
@ -372,7 +395,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
logit_REF.masked_fill_(candidate_mask_obj, -float('inf'))
|
||||
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
|
||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
@ -400,7 +423,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
|
||||
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
|
||||
and (not ended[i]): # just stopped and forced stopped
|
||||
just_ended[i] = True
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
@ -409,8 +433,25 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
just_ended[i] = False
|
||||
|
||||
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
if (next_id == args.ignoreid) or (ended[i]):
|
||||
cpu_a_t[i] = found[i]
|
||||
elif (next_id == visual_temp_mask.size(1)):
|
||||
cpu_a_t[i] = -1
|
||||
found[i] = -1
|
||||
elif (next_id == (candidate_leng[i]-1)):
|
||||
cpu_a_t[i] = -2
|
||||
found[i] = -2
|
||||
'''
|
||||
print("MODE: ", self.feedback)
|
||||
print("logit: ", logit)
|
||||
print("leng:", candidate_leng)
|
||||
print("cpu_a_t: ", cpu_a_t)
|
||||
if train_ml is not None:
|
||||
print("target: ", target)
|
||||
for i in perm_obs:
|
||||
print(i['found'], i['instructions'])
|
||||
print()
|
||||
'''
|
||||
|
||||
''' Supervised training for REF '''
|
||||
if train_ml is not None:
|
||||
@ -456,6 +497,19 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# reward[i] = -2.0
|
||||
if dist[i] < 1.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] += 1
|
||||
else:
|
||||
reward[i] -= 2
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
elif action_idx == -2:
|
||||
if dist[i] < 1.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] -= 2
|
||||
else:
|
||||
reward[i] += 1
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
else: # The action is not end
|
||||
@ -479,9 +533,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Update the finished actions
|
||||
# -1 means ended or ignored (already ended)
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
||||
|
||||
# Early exit if all ended
|
||||
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
|
||||
if ended.all():
|
||||
'''
|
||||
if train_ml is None:
|
||||
print(target_found, found)
|
||||
'''
|
||||
break
|
||||
|
||||
if train_rl:
|
||||
@ -563,7 +623,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return traj
|
||||
|
||||
return traj, found
|
||||
|
||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||
''' Evaluate once on each instruction in the current environment '''
|
||||
|
||||
@ -115,6 +115,7 @@ class R2RBatch():
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['id'], j)
|
||||
new_item['instructions'] = instr
|
||||
new_item['found'] = item['found'][j]
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
@ -332,7 +333,8 @@ class R2RBatch():
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['id'],
|
||||
'objId': str(item['objId']), # target objId
|
||||
'candidate_obj': (obj_local_pos, obj_features, candidate_objId)
|
||||
'candidate_obj': (obj_local_pos, obj_features, candidate_objId),
|
||||
'found': item['found']
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
|
||||
@ -62,11 +62,12 @@ class Evaluation(object):
|
||||
near_d = d
|
||||
return near_id
|
||||
|
||||
def _score_item(self, instr_id, path, ref_objId):
|
||||
def _score_item(self, instr_id, path, ref_objId, predict_found):
|
||||
''' Calculate error based on the final position in trajectory, and also
|
||||
the closest position (oracle stopping rule).
|
||||
The path contains [view_id, angle, vofv] '''
|
||||
gt = self.gt[instr_id[:-2]]
|
||||
index = int(instr_id.split('_')[-1])
|
||||
start = gt['path'][0]
|
||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||
goal = gt['path'][-1]
|
||||
@ -86,6 +87,16 @@ class Evaluation(object):
|
||||
self.distances[gt['scan']][start][goal]
|
||||
)
|
||||
|
||||
# print(predict_found, gt['found'], gt['found'][index])
|
||||
|
||||
if gt['found'][index] == True:
|
||||
if predict_found == -1:
|
||||
self.scores['found_count'] += 1
|
||||
else:
|
||||
if predict_found == -2:
|
||||
self.scores['found_count'] += 1
|
||||
|
||||
|
||||
# REF sucess or not
|
||||
if ref_objId == str(gt['objId']):
|
||||
self.scores['rgs'].append(1)
|
||||
@ -111,33 +122,11 @@ class Evaluation(object):
|
||||
self.scores['oracle_visible'].append(oracle_succ)
|
||||
|
||||
|
||||
# # if self.scores['nav_errors'][-1] < self.error_margin:
|
||||
# # print('item', item)
|
||||
# ndtw_path = [k[0] for k in item['trajectory']]
|
||||
# # print('path', ndtw_path)
|
||||
#
|
||||
# path_id = item['instr_id'][:-2]
|
||||
# # print('path id', path_id)
|
||||
# path_scan_id, path_ref = self.scan_gts[path_id]
|
||||
# # print('path_scan_id', path_scan_id)
|
||||
# # print('path_ref', path_ref)
|
||||
#
|
||||
# path_act = []
|
||||
# for jdx, pid in enumerate(ndtw_path):
|
||||
# if jdx != 0:
|
||||
# if pid != path_act[-1]:
|
||||
# path_act.append(pid)
|
||||
# else:
|
||||
# path_act.append(pid)
|
||||
# # print('path act', path_act)
|
||||
#
|
||||
# ndtw_score = self.ndtw_criterion[path_scan_id](path_act, path_ref, metric='ndtw')
|
||||
# ndtw_scores.append(ndtw_score)
|
||||
# print('nDTW score: ', np.average(ndtw_scores))
|
||||
|
||||
def score(self, output_file):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||
self.scores = defaultdict(list)
|
||||
self.scores['found_count'] = 0
|
||||
instr_ids = set(self.instr_ids)
|
||||
if type(output_file) is str:
|
||||
with open(output_file) as f:
|
||||
@ -145,12 +134,15 @@ class Evaluation(object):
|
||||
else:
|
||||
results = output_file
|
||||
|
||||
|
||||
print('result length', len(results))
|
||||
path_counter = 0
|
||||
for item in results:
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['ref'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['ref'], item['found'])
|
||||
path_counter += 1
|
||||
|
||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
@ -159,9 +151,11 @@ class Evaluation(object):
|
||||
|
||||
score_summary = {
|
||||
'steps': np.average(self.scores['trajectory_steps']),
|
||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||
'lengths': np.average(self.scores['trajectory_lengths']),
|
||||
'found_score': self.scores['found_count'] / path_counter
|
||||
}
|
||||
end_successes = sum(self.scores['visible'])
|
||||
|
||||
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
|
||||
oracle_successes = sum(self.scores['oracle_visible'])
|
||||
score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible']))
|
||||
@ -172,7 +166,10 @@ class Evaluation(object):
|
||||
]
|
||||
score_summary['spl'] = np.average(spl)
|
||||
|
||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||
try:
|
||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||
except:
|
||||
print(len(self.scores['rgs']), len(self.instr_ids))
|
||||
num_rgs = sum(self.scores['rgs'])
|
||||
score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs']))
|
||||
|
||||
|
||||
@ -218,7 +218,8 @@ def train_val(test_only=False):
|
||||
val_env_names = ['val_train_seen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
val_env_names = ['val_seen', 'val_unseen']
|
||||
# val_env_names = ['val_seen', 'val_unseen']
|
||||
val_env_names = ['train', 'val_unseen']
|
||||
# val_env_names = ['val_unseen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
@ -242,6 +243,7 @@ def train_val(test_only=False):
|
||||
|
||||
if args.train == 'listener':
|
||||
train(train_env, tok, args.iters, val_envs=val_envs)
|
||||
# train(train_env, tok, 1000, val_envs=val_envs, log_every=10)
|
||||
elif args.train == 'validlistener':
|
||||
valid(train_env, tok, val_envs=val_envs)
|
||||
else:
|
||||
|
||||
@ -582,7 +582,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt
|
||||
str_format = "{0:." + str(decimals) + "f}"
|
||||
percents = str_format.format(100 * (iteration / float(total)))
|
||||
filled_length = int(round(bar_length * iteration / float(total)))
|
||||
bar = 'LL' * filled_length + '-' * (bar_length - filled_length)
|
||||
bar = 'L' * filled_length + '-' * (bar_length - filled_length)
|
||||
|
||||
sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user