fix: remove obj loss
This commit is contained in:
parent
2a561bcf01
commit
287a35965e
@ -443,9 +443,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
)
|
||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||
# objec grounding
|
||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||
# obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||
og_loss += self.criterion(obj_logits, obj_targets)
|
||||
# og_loss += self.criterion(obj_logits, obj_targets)
|
||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
|
||||
|
||||
@ -532,11 +532,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
||||
|
||||
if train_ml is not None:
|
||||
ml_loss = ml_loss * train_ml / batch_size
|
||||
og_loss = og_loss * train_ml / batch_size
|
||||
# og_loss = og_loss * train_ml / batch_size
|
||||
self.loss += ml_loss
|
||||
self.loss += og_loss
|
||||
# self.loss += og_loss
|
||||
self.logs['IL_loss'].append(ml_loss.item())
|
||||
self.logs['OG_loss'].append(og_loss.item())
|
||||
# self.logs['OG_loss'].append(og_loss.item())
|
||||
|
||||
'''
|
||||
print("TRAJ:")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user