From 287a35965eef5e2b6af3a555ed74c82a2b89cf0d Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Tue, 16 Jul 2024 13:49:22 +0800 Subject: [PATCH] fix: remove obj loss --- map_nav_src/reverie/agent_obj.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/map_nav_src/reverie/agent_obj.py b/map_nav_src/reverie/agent_obj.py index 13b092a..0322ad8 100644 --- a/map_nav_src/reverie/agent_obj.py +++ b/map_nav_src/reverie/agent_obj.py @@ -443,9 +443,9 @@ class GMapObjectNavAgent(Seq2SeqAgent): ) ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local # objec grounding - obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits) + # obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) - og_loss += self.criterion(obj_logits, obj_targets) + # og_loss += self.criterion(obj_logits, obj_targets) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none')) # print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item()) @@ -532,11 +532,11 @@ class GMapObjectNavAgent(Seq2SeqAgent): if train_ml is not None: ml_loss = ml_loss * train_ml / batch_size - og_loss = og_loss * train_ml / batch_size + # og_loss = og_loss * train_ml / batch_size self.loss += ml_loss - self.loss += og_loss + # self.loss += og_loss self.logs['IL_loss'].append(ml_loss.item()) - self.logs['OG_loss'].append(og_loss.item()) + # self.logs['OG_loss'].append(og_loss.item()) ''' print("TRAJ:")