fix: argmax of obj_logits, -1 if not found
This commit is contained in:
parent
1b731a14a3
commit
5e424ede40
@ -399,10 +399,19 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
i_vp = obs[i]['viewpoint']
|
i_vp = obs[i]['viewpoint']
|
||||||
# update i_vp: stop and object grounding scores
|
# update i_vp: stop and object grounding scores
|
||||||
i_objids = obs[i]['obj_ids']
|
i_objids = obs[i]['obj_ids']
|
||||||
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
|
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
|
||||||
|
|
||||||
|
if len(i_objids) > 0:
|
||||||
|
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
|
||||||
|
og = -1
|
||||||
|
else:
|
||||||
|
og = i_objids[torch.argmax(i_obj_logits)]
|
||||||
|
else:
|
||||||
|
og = None
|
||||||
|
|
||||||
gmap.node_stop_scores[i_vp] = {
|
gmap.node_stop_scores[i_vp] = {
|
||||||
'stop': nav_probs[i, 0].data.item(),
|
'stop': nav_probs[i, 0].data.item(),
|
||||||
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
|
'og': og,
|
||||||
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
|
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user