feat: add not found token in pano features & masks
This commit is contained in:
parent
bc8bc1b9d4
commit
1b731a14a3
@ -174,11 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
||||
batch_size = len(obs)
|
||||
print("PANO shape", pano_embeds.shape)
|
||||
|
||||
# add [stop] token
|
||||
# add [stop] token & [NOT FOUND] token
|
||||
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
||||
vp_img_embeds = torch.cat(
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
||||
)
|
||||
print("SHAPE:", vp_img_embeds.shape)
|
||||
|
||||
batch_vp_pos_fts = []
|
||||
for i, gmap in enumerate(gmaps):
|
||||
@ -192,19 +195,34 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
# add [stop] token at beginning
|
||||
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
||||
print("vp_pos_fts:", vp_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[:, :7] = cur_start_pos_fts
|
||||
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
||||
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||
|
||||
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
||||
|
||||
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
|
||||
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1)
|
||||
# 要把 stop 和 not found 的 mask 補上去
|
||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||
print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks)
|
||||
print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks)
|
||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||
print('vp_masks:', vp_masks)
|
||||
print()
|
||||
|
||||
return {
|
||||
'vp_img_embeds': vp_img_embeds,
|
||||
'vp_pos_fts': batch_vp_pos_fts,
|
||||
'vp_masks': gen_seq_masks(view_lens+obj_lens+1),
|
||||
'vp_masks': vp_masks,
|
||||
'vp_nav_masks': vp_nav_masks,
|
||||
'vp_obj_masks': vp_obj_masks,
|
||||
'vp_cand_vpids': [[None]+x for x in cand_vpids],
|
||||
@ -298,7 +316,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
batch_size = len(obs)
|
||||
# build graph: keep the start viewpoint
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
|
||||
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
|
||||
for i, ob in enumerate(obs):
|
||||
gmaps[i].update_graph(ob)
|
||||
|
||||
@ -307,6 +326,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [[ob['viewpoint']]],
|
||||
'pred_objid': None,
|
||||
'found': None,
|
||||
'details': {},
|
||||
} for ob in obs]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user