From f6c4a4f87e9c3013a9eb50c0ac7d3eba4e95c8be Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 11 Dec 2023 04:33:37 +0800 Subject: [PATCH] feat: change _teacher_object() to allow not found token --- map_nav_src/reverie/agent_obj.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index c006378..0f30bf1 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -213,10 +213,9 @@ class GMapObjectNavAgent(Seq2SeqAgent): # 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起 vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1) vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1) - print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks) - print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks) + print('vp_nav_masks:', vp_nav_masks.shape) + print('vp_obj_masks:', vp_obj_masks.shape) vp_masks = gen_seq_masks(view_lens+obj_lens+2) - print('vp_masks:', vp_masks) print() return { @@ -260,7 +259,7 @@ class GMapObjectNavAgent(Seq2SeqAgent): return torch.from_numpy(a).cuda() - def _teacher_object(self, obs, ended, view_lens): + def _teacher_object(self, obs, ended, view_lens, obj_logits): targets = np.zeros(len(obs), dtype=np.int64) for i, ob in enumerate(obs): if ended[i]: @@ -270,12 +269,18 @@ class GMapObjectNavAgent(Seq2SeqAgent): if i_vp not in ob['gt_end_vps']: targets[i] = self.args.ignoreid else: + i_objids = ob['obj_ids'] targets[i] = self.args.ignoreid for j, obj_id in enumerate(i_objids): if str(obj_id) == str(ob['gt_obj_id']): - targets[i] = j + view_lens[i] + 1 + + if ob['gt_found'] == True: # 可以找得到 + targets[i] = j + view_lens[i] + 1 + else: + targets[i] = len(obj_logits[i])-1 # 不能找到, break + return torch.from_numpy(targets).cuda() def make_equiv_action(self, a_t, gmaps, obs, traj=None): @@ -409,6 +414,9 @@ class GMapObjectNavAgent(Seq2SeqAgent): else: og = None + # 如果有找到,og 會是 object id + # 如果是 not found,og 會是 -1 + # 如果這個 viewpoint 看不到物件,og 會是 None gmap.node_stop_scores[i_vp] = { 'stop': nav_probs[i, 0].data.item(), 'og': og, @@ -432,7 +440,9 @@ class GMapObjectNavAgent(Seq2SeqAgent): ) ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local # objec grounding - obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens']) + obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits) + print("TARGET OBJECT:", obj_targets) + print('obj logits:', obj_logits) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) og_loss += self.criterion(obj_logits, obj_targets) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))