From 5e424ede40aedbcce5d4d84cf1607d060bd95f76 Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 11 Dec 2023 03:38:40 +0800 Subject: [PATCH] fix: argmax of obj_logits, -1 if not found --- map_nav_src/reverie/agent_obj.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index 05a3eb6..c006378 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -399,10 +399,19 @@ class GMapObjectNavAgent(Seq2SeqAgent): i_vp = obs[i]['viewpoint'] # update i_vp: stop and object grounding scores i_objids = obs[i]['obj_ids'] - i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] + i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found + + if len(i_objids) > 0: + if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格) + og = -1 + else: + og = i_objids[torch.argmax(i_obj_logits)] + else: + og = None + gmap.node_stop_scores[i_vp] = { 'stop': nav_probs[i, 0].data.item(), - 'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None, + 'og': og, 'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]}, }