feat: complete adversarial's rollout()
This commit is contained in:
parent
b96106fa69
commit
fb82daf16a
@ -174,14 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
||||
batch_size = len(obs)
|
||||
print("PANO shape", pano_embeds.shape)
|
||||
# print("PANO shape", pano_embeds.shape)
|
||||
|
||||
# add [stop] token & [NOT FOUND] token
|
||||
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
||||
vp_img_embeds = torch.cat(
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
||||
)
|
||||
print("SHAPE:", vp_img_embeds.shape)
|
||||
# print("SHAPE:", vp_img_embeds.shape)
|
||||
|
||||
batch_vp_pos_fts = []
|
||||
for i, gmap in enumerate(gmaps):
|
||||
@ -195,15 +195,15 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
# add [stop] token at beginning
|
||||
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
||||
print("vp_pos_fts:", vp_pos_fts.shape)
|
||||
# print("vp_pos_fts:", vp_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[:, :7] = cur_start_pos_fts
|
||||
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
||||
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||
|
||||
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
||||
|
||||
@ -213,10 +213,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||
print('vp_nav_masks:', vp_nav_masks.shape)
|
||||
print('vp_obj_masks:', vp_obj_masks.shape)
|
||||
# print('vp_nav_masks:', vp_nav_masks.shape)
|
||||
# print('vp_obj_masks:', vp_obj_masks.shape)
|
||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||
print()
|
||||
# print()
|
||||
|
||||
return {
|
||||
'vp_img_embeds': vp_img_embeds,
|
||||
@ -331,7 +331,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [[ob['viewpoint']]],
|
||||
'pred_objid': None,
|
||||
'gt_objid': None,
|
||||
'found': None,
|
||||
'gt_found': None,
|
||||
'details': {},
|
||||
} for ob in obs]
|
||||
|
||||
@ -441,8 +443,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||
# objec grounding
|
||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||
print("TARGET OBJECT:", obj_targets)
|
||||
print('obj logits:', obj_logits)
|
||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||
og_loss += self.criterion(obj_logits, obj_targets)
|
||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||
@ -490,9 +490,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
else:
|
||||
cpu_a_t.append(nav_vpids[i][a_t[i]])
|
||||
|
||||
original_gt_founds = [ ob['gt_found'] for ob in obs ]
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
|
||||
for i in range(batch_size):
|
||||
traj[i]['gt_found'] = original_gt_founds[i]
|
||||
if (not ended[i]) and just_ended[i]:
|
||||
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
|
||||
for k, v in gmaps[i].node_stop_scores.items():
|
||||
@ -502,6 +504,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
||||
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
||||
traj[i]['pred_objid'] = stop_score['og']
|
||||
if stop_score['og'] == -1:
|
||||
traj[i]['found'] = False
|
||||
else:
|
||||
traj[i]['found'] = True
|
||||
if self.args.detailed_output:
|
||||
for k, v in gmaps[i].node_stop_scores.items():
|
||||
traj[i]['details'][k] = {
|
||||
@ -531,4 +537,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
self.logs['IL_loss'].append(ml_loss.item())
|
||||
self.logs['OG_loss'].append(og_loss.item())
|
||||
|
||||
'''
|
||||
print("TRAJ:")
|
||||
for i in traj:
|
||||
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
|
||||
|
||||
'''
|
||||
return traj
|
||||
|
||||
@ -112,7 +112,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
)
|
||||
|
||||
# first evaluation
|
||||
'''
|
||||
if args.eval_first:
|
||||
loss_str = "validation before training"
|
||||
for env_name, env in val_envs.items():
|
||||
@ -136,7 +135,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
||||
write_to_record_file(
|
||||
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
||||
)
|
||||
'''
|
||||
|
||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user