feat: add not found token in pano features & masks

This commit is contained in:
Ting-Jun Wang 2023-12-11 03:18:19 +08:00
parent bc8bc1b9d4
commit 1b731a14a3
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -174,11 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
batch_size = len(obs)
print("PANO shape", pano_embeds.shape)
# add [stop] token
# add [stop] token & [NOT FOUND] token
# [STOP] 在最前面, [NOT FOUND] 在最後面
vp_img_embeds = torch.cat(
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
)
print("SHAPE:", vp_img_embeds.shape)
batch_vp_pos_fts = []
for i, gmap in enumerate(gmaps):
@ -192,19 +195,34 @@ class GMapObjectNavAgent(Seq2SeqAgent):
)
# add [stop] token at beginning
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
print("vp_pos_fts:", vp_pos_fts.shape)
vp_pos_fts[:, :7] = cur_start_pos_fts
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1)
# 要把 stop 和 not found 的 mask 補上去
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks)
print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks)
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
print('vp_masks:', vp_masks)
print()
return {
'vp_img_embeds': vp_img_embeds,
'vp_pos_fts': batch_vp_pos_fts,
'vp_masks': gen_seq_masks(view_lens+obj_lens+1),
'vp_masks': vp_masks,
'vp_nav_masks': vp_nav_masks,
'vp_obj_masks': vp_obj_masks,
'vp_cand_vpids': [[None]+x for x in cand_vpids],
@ -298,7 +316,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
batch_size = len(obs)
# build graph: keep the start viewpoint
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
for i, ob in enumerate(obs):
gmaps[i].update_graph(ob)
@ -307,6 +326,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'pred_objid': None,
'found': None,
'details': {},
} for ob in obs]