Compare commits

..

No commits in common. "b96106fa691f243c5a8b162115d6227e30133cd0" and "5e424ede40aedbcce5d4d84cf1607d060bd95f76" have entirely different histories.

2 changed files with 6 additions and 18 deletions

View File

@ -213,9 +213,10 @@ class GMapObjectNavAgent(Seq2SeqAgent):
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起 # 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1) vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1) vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
print('vp_nav_masks:', vp_nav_masks.shape) print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks)
print('vp_obj_masks:', vp_obj_masks.shape) print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks)
vp_masks = gen_seq_masks(view_lens+obj_lens+2) vp_masks = gen_seq_masks(view_lens+obj_lens+2)
print('vp_masks:', vp_masks)
print() print()
return { return {
@ -259,7 +260,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def _teacher_object(self, obs, ended, view_lens, obj_logits): def _teacher_object(self, obs, ended, view_lens):
targets = np.zeros(len(obs), dtype=np.int64) targets = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
if ended[i]: if ended[i]:
@ -269,18 +270,12 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if i_vp not in ob['gt_end_vps']: if i_vp not in ob['gt_end_vps']:
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
else: else:
i_objids = ob['obj_ids'] i_objids = ob['obj_ids']
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
for j, obj_id in enumerate(i_objids): for j, obj_id in enumerate(i_objids):
if str(obj_id) == str(ob['gt_obj_id']): if str(obj_id) == str(ob['gt_obj_id']):
targets[i] = j + view_lens[i] + 1
if ob['gt_found'] == True: # 可以找得到
targets[i] = j + view_lens[i] + 1
else:
targets[i] = len(obj_logits[i])-1 # 不能找到,
break break
return torch.from_numpy(targets).cuda() return torch.from_numpy(targets).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None): def make_equiv_action(self, a_t, gmaps, obs, traj=None):
@ -414,9 +409,6 @@ class GMapObjectNavAgent(Seq2SeqAgent):
else: else:
og = None og = None
# 如果有找到og 會是 object id
# 如果是 not foundog 會是 -1
# 如果這個 viewpoint 看不到物件og 會是 None
gmap.node_stop_scores[i_vp] = { gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(), 'stop': nav_probs[i, 0].data.item(),
'og': og, 'og': og,
@ -440,9 +432,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
) )
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding # objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits) obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
print("TARGET OBJECT:", obj_targets)
print('obj logits:', obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets) og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none')) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))

View File

@ -112,7 +112,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
) )
# first evaluation # first evaluation
'''
if args.eval_first: if args.eval_first:
loss_str = "validation before training" loss_str = "validation before training"
for env_name, env in val_envs.items(): for env_name, env in val_envs.items():
@ -136,7 +135,6 @@ def train(args, train_env, val_envs, aug_env=None, rank=-1):
write_to_record_file( write_to_record_file(
'\nListener training starts, start iteration: %s' % str(start_iter), record_file '\nListener training starts, start iteration: %s' % str(start_iter), record_file
) )
'''
best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}} best_val = {'val_unseen': {"spl": 0., "sr": 0., "state":""}}