Compare commits
2 Commits
bc8bc1b9d4
...
5e424ede40
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e424ede40 | |||
| 1b731a14a3 |
@ -174,11 +174,14 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
|
||||
batch_size = len(obs)
|
||||
print("PANO shape", pano_embeds.shape)
|
||||
|
||||
# add [stop] token
|
||||
# add [stop] token & [NOT FOUND] token
|
||||
# [STOP] 在最前面, [NOT FOUND] 在最後面
|
||||
vp_img_embeds = torch.cat(
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds], 1
|
||||
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
|
||||
)
|
||||
print("SHAPE:", vp_img_embeds.shape)
|
||||
|
||||
batch_vp_pos_fts = []
|
||||
for i, gmap in enumerate(gmaps):
|
||||
@ -192,19 +195,34 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
# add [stop] token at beginning
|
||||
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
|
||||
print("vp_pos_fts:", vp_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[:, :7] = cur_start_pos_fts
|
||||
print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
|
||||
print("cur_start_pos_fts:", cur_start_pos_fts.shape)
|
||||
|
||||
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
|
||||
print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
|
||||
print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
|
||||
|
||||
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
|
||||
|
||||
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
|
||||
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2], 1)
|
||||
# 要把 stop 和 not found 的 mask 補上去
|
||||
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
|
||||
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
|
||||
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
|
||||
print('vp_nav_masks:', vp_nav_masks.shape, vp_nav_masks)
|
||||
print('vp_obj_masks:', vp_obj_masks.shape, vp_obj_masks)
|
||||
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
|
||||
print('vp_masks:', vp_masks)
|
||||
print()
|
||||
|
||||
return {
|
||||
'vp_img_embeds': vp_img_embeds,
|
||||
'vp_pos_fts': batch_vp_pos_fts,
|
||||
'vp_masks': gen_seq_masks(view_lens+obj_lens+1),
|
||||
'vp_masks': vp_masks,
|
||||
'vp_nav_masks': vp_nav_masks,
|
||||
'vp_obj_masks': vp_obj_masks,
|
||||
'vp_cand_vpids': [[None]+x for x in cand_vpids],
|
||||
@ -298,7 +316,8 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
batch_size = len(obs)
|
||||
# build graph: keep the start viewpoint
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
|
||||
|
||||
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
|
||||
for i, ob in enumerate(obs):
|
||||
gmaps[i].update_graph(ob)
|
||||
|
||||
@ -307,6 +326,7 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
'instr_id': ob['instr_id'],
|
||||
'path': [[ob['viewpoint']]],
|
||||
'pred_objid': None,
|
||||
'found': None,
|
||||
'details': {},
|
||||
} for ob in obs]
|
||||
|
||||
@ -379,10 +399,19 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
i_vp = obs[i]['viewpoint']
|
||||
# update i_vp: stop and object grounding scores
|
||||
i_objids = obs[i]['obj_ids']
|
||||
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
|
||||
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
|
||||
|
||||
if len(i_objids) > 0:
|
||||
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
|
||||
og = -1
|
||||
else:
|
||||
og = i_objids[torch.argmax(i_obj_logits)]
|
||||
else:
|
||||
og = None
|
||||
|
||||
gmap.node_stop_scores[i_vp] = {
|
||||
'stop': nav_probs[i, 0].data.item(),
|
||||
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
|
||||
'og': og,
|
||||
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user