fix: remove obj loss

This commit is contained in:
Ting-Jun Wang 2024-07-16 13:49:22 +08:00
parent 2a561bcf01
commit 287a35965e
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -443,9 +443,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
) )
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding # objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits) # obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id']) # print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets) # og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none')) # print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item()) # print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
@ -532,11 +532,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
if train_ml is not None: if train_ml is not None:
ml_loss = ml_loss * train_ml / batch_size ml_loss = ml_loss * train_ml / batch_size
og_loss = og_loss * train_ml / batch_size # og_loss = og_loss * train_ml / batch_size
self.loss += ml_loss self.loss += ml_loss
self.loss += og_loss # self.loss += og_loss
self.logs['IL_loss'].append(ml_loss.item()) self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item()) # self.logs['OG_loss'].append(og_loss.item())
''' '''
print("TRAJ:") print("TRAJ:")