fix: argmax of obj_logits, -1 if not found

This commit is contained in:
Ting-Jun Wang 2023-12-11 03:38:40 +08:00
parent 1b731a14a3
commit 5e424ede40
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -399,10 +399,19 @@ class GMapObjectNavAgent(Seq2SeqAgent):
i_vp = obs[i]['viewpoint'] i_vp = obs[i]['viewpoint']
# update i_vp: stop and object grounding scores # update i_vp: stop and object grounding scores
i_objids = obs[i]['obj_ids'] i_objids = obs[i]['obj_ids']
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
if len(i_objids) > 0:
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
og = -1
else:
og = i_objids[torch.argmax(i_obj_logits)]
else:
og = None
gmap.node_stop_scores[i_vp] = { gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(), 'stop': nav_probs[i, 0].data.item(),
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None, 'og': og,
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]}, 'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
} }