feat: change _teacher_object() to allow not found token
This commit is contained in:
parent
5e424ede40
commit
f6c4a4f87e
@ -213,10 +213,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||||
print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks)
|
print('vp_nav_masks:', vp_nav_masks.shape)
|
||||||
print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks)
|
print('vp_obj_masks:', vp_obj_masks.shape)
|
||||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||||
print('vp_masks:', vp_masks)
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -260,7 +259,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
|
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def _teacher_object(self, obs, ended, view_lens):
|
def _teacher_object(self, obs, ended, view_lens, obj_logits):
|
||||||
targets = np.zeros(len(obs), dtype=np.int64)
|
targets = np.zeros(len(obs), dtype=np.int64)
|
||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
if ended[i]:
|
if ended[i]:
|
||||||
@ -270,12 +269,18 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
if i_vp not in ob['gt_end_vps']:
|
if i_vp not in ob['gt_end_vps']:
|
||||||
targets[i] = self.args.ignoreid
|
targets[i] = self.args.ignoreid
|
||||||
else:
|
else:
|
||||||
|
|
||||||
i_objids = ob['obj_ids']
|
i_objids = ob['obj_ids']
|
||||||
targets[i] = self.args.ignoreid
|
targets[i] = self.args.ignoreid
|
||||||
for j, obj_id in enumerate(i_objids):
|
for j, obj_id in enumerate(i_objids):
|
||||||
if str(obj_id) == str(ob['gt_obj_id']):
|
if str(obj_id) == str(ob['gt_obj_id']):
|
||||||
targets[i] = j + view_lens[i] + 1
|
|
||||||
|
if ob['gt_found'] == True: # 可以找得到
|
||||||
|
targets[i] = j + view_lens[i] + 1
|
||||||
|
else:
|
||||||
|
targets[i] = len(obj_logits[i])-1 # 不能找到,
|
||||||
break
|
break
|
||||||
|
|
||||||
return torch.from_numpy(targets).cuda()
|
return torch.from_numpy(targets).cuda()
|
||||||
|
|
||||||
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
|
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
|
||||||
@ -409,6 +414,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
else:
|
else:
|
||||||
og = None
|
og = None
|
||||||
|
|
||||||
|
# 如果有找到,og 會是 object id
|
||||||
|
# 如果是 not found,og 會是 -1
|
||||||
|
# 如果這個 viewpoint 看不到物件,og 會是 None
|
||||||
gmap.node_stop_scores[i_vp] = {
|
gmap.node_stop_scores[i_vp] = {
|
||||||
'stop': nav_probs[i, 0].data.item(),
|
'stop': nav_probs[i, 0].data.item(),
|
||||||
'og': og,
|
'og': og,
|
||||||
@ -432,7 +440,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
)
|
)
|
||||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||||
# objec grounding
|
# objec grounding
|
||||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
|
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||||
|
print("TARGET OBJECT:", obj_targets)
|
||||||
|
print('obj logits:', obj_logits)
|
||||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||||
og_loss += self.criterion(obj_logits, obj_targets)
|
og_loss += self.criterion(obj_logits, obj_targets)
|
||||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user