fix: remove obj loss

This commit is contained in:
Ting-Jun Wang 2024-07-16 13:49:22 +08:00
parent 2a561bcf01
commit 287a35965e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -443,9 +443,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
)
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
# obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets)
# og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
@ -532,11 +532,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if train_ml is not None:
ml_loss = ml_loss * train_ml / batch_size
og_loss = og_loss * train_ml / batch_size
# og_loss = og_loss * train_ml / batch_size
self.loss += ml_loss
self.loss += og_loss
# self.loss += og_loss
self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item())
# self.logs['OG_loss'].append(og_loss.item())
'''
print("TRAJ:")