feat: change _teacher_object() to allow not found token

This commit is contained in:
Ting-Jun Wang 2023-12-11 04:33:37 +08:00
parent 5e424ede40
commit f6c4a4f87e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -213,10 +213,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起 # 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1) vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1) vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks) print('vp_nav_masks:', vp_nav_masks.shape)
print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks) print('vp_obj_masks:', vp_obj_masks.shape)
vp_masks = gen_seq_masks(view_lens+obj_lens+2) vp_masks = gen_seq_masks(view_lens+obj_lens+2)
print('vp_masks:', vp_masks)
print() print()
return { return {
@ -260,7 +259,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def _teacher_object(self, obs, ended, view_lens): def _teacher_object(self, obs, ended, view_lens, obj_logits):
targets = np.zeros(len(obs), dtype=np.int64) targets = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
if ended[i]: if ended[i]:
@ -270,12 +269,18 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if i_vp not in ob['gt_end_vps']: if i_vp not in ob['gt_end_vps']:
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
else: else:
i_objids = ob['obj_ids'] i_objids = ob['obj_ids']
targets[i] = self.args.ignoreid targets[i] = self.args.ignoreid
for j, obj_id in enumerate(i_objids): for j, obj_id in enumerate(i_objids):
if str(obj_id) == str(ob['gt_obj_id']): if str(obj_id) == str(ob['gt_obj_id']):
targets[i] = j + view_lens[i] + 1
if ob['gt_found'] == True: # 可以找得到
targets[i] = j + view_lens[i] + 1
else:
targets[i] = len(obj_logits[i])-1 # 不能找到,
break break
return torch.from_numpy(targets).cuda() return torch.from_numpy(targets).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None): def make_equiv_action(self, a_t, gmaps, obs, traj=None):
@ -409,6 +414,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
else: else:
og = None og = None
# 如果有找到og 會是 object id
# 如果是 not foundog 會是 -1
# 如果這個 viewpoint 看不到物件og 會是 None
gmap.node_stop_scores[i_vp] = { gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(), 'stop': nav_probs[i, 0].data.item(),
'og': og, 'og': og,
@ -432,7 +440,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
) )
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding # objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens']) obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
print("TARGET OBJECT:", obj_targets)
print('obj logits:', obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets) og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none')) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))