feat: complete adversarial's rollout()

This commit is contained in:
Ting-Jun Wang 2023-12-11 05:25:23 +08:00
parent b96106fa69
commit fb82daf16a
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
2 changed files with 24 additions and 14 deletions

View File

@ -174,14 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
batch_size = len(obs)
print("PANO shape", pano_embeds.shape)
# print("PANO shape", pano_embeds.shape)
# add [stop] token & [NOT FOUND] token
# [STOP] 在最前面, [NOT FOUND] 在最後面
vp_img_embeds = torch.cat(
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
)
print("SHAPE:", vp_img_embeds.shape)
# print("SHAPE:", vp_img_embeds.shape)
batch_vp_pos_fts = []
for i, gmap in enumerate(gmaps):
@ -195,15 +195,15 @@ class GMapObjectNavAgent(Seq2SeqAgent):
)
# add [stop] token at beginning
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
print("vp_pos_fts:", vp_pos_fts.shape)
# print("vp_pos_fts:", vp_pos_fts.shape)
vp_pos_fts[:, :7] = cur_start_pos_fts
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
@ -213,10 +213,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
print('vp_nav_masks:', vp_nav_masks.shape)
print('vp_obj_masks:', vp_obj_masks.shape)
# print('vp_nav_masks:', vp_nav_masks.shape)
# print('vp_obj_masks:', vp_obj_masks.shape)
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
print()
# print()
return {
'vp_img_embeds': vp_img_embeds,
@ -331,7 +331,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'pred_objid': None,
'gt_objid': None,
'found': None,
'gt_found': None,
'details': {},
} for ob in obs]
@ -441,8 +443,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
print("TARGET OBJECT:", obj_targets)
print('obj logits:', obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
@ -490,9 +490,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
else:
cpu_a_t.append(nav_vpids[i][a_t[i]])
original_gt_founds = [ ob['gt_found'] for ob in obs ]
# Make action and get the new state
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
for i in range(batch_size):
traj[i]['gt_found'] = original_gt_founds[i]
if (not ended[i]) and just_ended[i]:
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
for k, v in gmaps[i].node_stop_scores.items():
@ -502,6 +504,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_objid'] = stop_score['og']
if stop_score['og'] == -1:
traj[i]['found'] = False
else:
traj[i]['found'] = True
if self.args.detailed_output:
for k, v in gmaps[i].node_stop_scores.items():
traj[i]['details'][k] = {
@ -531,4 +537,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item())
'''
print("TRAJ:")
for i in traj:
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
'''
return traj

View File

@ -112,7 +112,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
)
# first evaluation
'''
if args.eval_first:
loss_str = "validation before training"
for env_name, env in val_envs.items():
@ -136,7 +135,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
write_to_record_file(
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
)
'''
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}