feat: complete adversarial's rollout()
This commit is contained in:
parent
b96106fa69
commit
fb82daf16a
@ -174,14 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
|
|
||||||
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
||||||
batch_size = len(obs)
|
batch_size = len(obs)
|
||||||
print("PANO shape", pano_embeds.shape)
|
# print("PANO shape", pano_embeds.shape)
|
||||||
|
|
||||||
# add [stop] token & [NOT FOUND] token
|
# add [stop] token & [NOT FOUND] token
|
||||||
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
||||||
vp_img_embeds = torch.cat(
|
vp_img_embeds = torch.cat(
|
||||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
||||||
)
|
)
|
||||||
print("SHAPE:", vp_img_embeds.shape)
|
# print("SHAPE:", vp_img_embeds.shape)
|
||||||
|
|
||||||
batch_vp_pos_fts = []
|
batch_vp_pos_fts = []
|
||||||
for i, gmap in enumerate(gmaps):
|
for i, gmap in enumerate(gmaps):
|
||||||
@ -195,15 +195,15 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
)
|
)
|
||||||
# add [stop] token at beginning
|
# add [stop] token at beginning
|
||||||
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
||||||
print("vp_pos_fts:", vp_pos_fts.shape)
|
# print("vp_pos_fts:", vp_pos_fts.shape)
|
||||||
|
|
||||||
vp_pos_fts[:, :7] = cur_start_pos_fts
|
vp_pos_fts[:, :7] = cur_start_pos_fts
|
||||||
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||||
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||||
|
|
||||||
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
||||||
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||||
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||||
|
|
||||||
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
||||||
|
|
||||||
@ -213,10 +213,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||||
print('vp_nav_masks:', vp_nav_masks.shape)
|
# print('vp_nav_masks:', vp_nav_masks.shape)
|
||||||
print('vp_obj_masks:', vp_obj_masks.shape)
|
# print('vp_obj_masks:', vp_obj_masks.shape)
|
||||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||||
print()
|
# print()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'vp_img_embeds': vp_img_embeds,
|
'vp_img_embeds': vp_img_embeds,
|
||||||
@ -331,7 +331,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
'instr_id': ob['instr_id'],
|
'instr_id': ob['instr_id'],
|
||||||
'path': [[ob['viewpoint']]],
|
'path': [[ob['viewpoint']]],
|
||||||
'pred_objid': None,
|
'pred_objid': None,
|
||||||
|
'gt_objid': None,
|
||||||
'found': None,
|
'found': None,
|
||||||
|
'gt_found': None,
|
||||||
'details': {},
|
'details': {},
|
||||||
} for ob in obs]
|
} for ob in obs]
|
||||||
|
|
||||||
@ -441,8 +443,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||||
# objec grounding
|
# objec grounding
|
||||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||||
print("TARGET OBJECT:", obj_targets)
|
|
||||||
print('obj logits:', obj_logits)
|
|
||||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||||
og_loss += self.criterion(obj_logits, obj_targets)
|
og_loss += self.criterion(obj_logits, obj_targets)
|
||||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||||
@ -490,9 +490,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
else:
|
else:
|
||||||
cpu_a_t.append(nav_vpids[i][a_t[i]])
|
cpu_a_t.append(nav_vpids[i][a_t[i]])
|
||||||
|
|
||||||
|
original_gt_founds = [ ob['gt_found'] for ob in obs ]
|
||||||
# Make action and get the new state
|
# Make action and get the new state
|
||||||
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
|
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
traj[i]['gt_found'] = original_gt_founds[i]
|
||||||
if (not ended[i]) and just_ended[i]:
|
if (not ended[i]) and just_ended[i]:
|
||||||
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
|
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
|
||||||
for k, v in gmaps[i].node_stop_scores.items():
|
for k, v in gmaps[i].node_stop_scores.items():
|
||||||
@ -502,6 +504,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
||||||
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
||||||
traj[i]['pred_objid'] = stop_score['og']
|
traj[i]['pred_objid'] = stop_score['og']
|
||||||
|
if stop_score['og'] == -1:
|
||||||
|
traj[i]['found'] = False
|
||||||
|
else:
|
||||||
|
traj[i]['found'] = True
|
||||||
if self.args.detailed_output:
|
if self.args.detailed_output:
|
||||||
for k, v in gmaps[i].node_stop_scores.items():
|
for k, v in gmaps[i].node_stop_scores.items():
|
||||||
traj[i]['details'][k] = {
|
traj[i]['details'][k] = {
|
||||||
@ -531,4 +537,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
self.logs['IL_loss'].append(ml_loss.item())
|
self.logs['IL_loss'].append(ml_loss.item())
|
||||||
self.logs['OG_loss'].append(og_loss.item())
|
self.logs['OG_loss'].append(og_loss.item())
|
||||||
|
|
||||||
|
'''
|
||||||
|
print("TRAJ:")
|
||||||
|
for i in traj:
|
||||||
|
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
|
||||||
|
|
||||||
|
'''
|
||||||
return traj
|
return traj
|
||||||
|
|||||||
@ -112,7 +112,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# first evaluation
|
# first evaluation
|
||||||
'''
|
|
||||||
if args.eval_first:
|
if args.eval_first:
|
||||||
loss_str = "validation before training"
|
loss_str = "validation before training"
|
||||||
for env_name, env in val_envs.items():
|
for env_name, env in val_envs.items():
|
||||||
@ -136,7 +135,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
|
|||||||
write_to_record_file(
|
write_to_record_file(
|
||||||
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
'\nListener training starts, start iteration: %s' % str(start_iter), record_file
|
||||||
)
|
)
|
||||||
'''
|
|
||||||
|
|
||||||
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user