From 1b731a14a3d815903b605c93bf44aacbcafc300e Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 11 Dec 2023 03:18:19 +0800 Subject: [PATCH] feat: add not found token in pano features & masks --- map_nav_src/reverie/agent_obj.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index d343d9d..05a3eb6 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -174,11 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent): def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types): batch_size = len(obs) + print("PANO shape", pano_embeds.shape) - # add [stop] token + # add [stop] token & [NOT FOUND] token + # [STOP] 在最前面, [NOT FOUND] 在最後面 vp_img_embeds = torch.cat( - [torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1 + [torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1 ) + print("SHAPE:", vp_img_embeds.shape) batch_vp_pos_fts = [] for i, gmap in enumerate(gmaps): @@ -192,19 +195,34 @@ class GMapObjectNavAgent(Seq2SeqAgent): ) # add [stop] token at beginning vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32) + print("vp_pos_fts:", vp_pos_fts.shape) + vp_pos_fts[:, :7] = cur_start_pos_fts + print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape) + print("cur_start_pos_fts:", cur_start_pos_fts.shape) + vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts + print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape) + print("cur_cand_pos_fts:", cur_cand_pos_fts.shape) + batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts)) batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda() - vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1) - vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1) + # 要把 stop 和 not found 的 mask 補上去 + # 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起 + vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1) + vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1) + print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks) + print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks) + vp_masks = gen_seq_masks(view_lens+obj_lens+2) + print('vp_masks:', vp_masks) + print() return { 'vp_img_embeds': vp_img_embeds, 'vp_pos_fts': batch_vp_pos_fts, - 'vp_masks': gen_seq_masks(view_lens+obj_lens+1), + 'vp_masks': vp_masks, 'vp_nav_masks': vp_nav_masks, 'vp_obj_masks': vp_obj_masks, 'vp_cand_vpids': [[None]+x for x in cand_vpids], @@ -298,7 +316,8 @@ class GMapObjectNavAgent(Seq2SeqAgent): batch_size = len(obs) # build graph: keep the start viewpoint - gmaps = [GraphMap(ob['viewpoint']) for ob in obs] + + gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point for i, ob in enumerate(obs): gmaps[i].update_graph(ob) @@ -307,6 +326,7 @@ class GMapObjectNavAgent(Seq2SeqAgent): 'instr_id': ob['instr_id'], 'path': [[ob['viewpoint']]], 'pred_objid': None, + 'found': None, 'details': {}, } for ob in obs]