fix: remove obj loss
This commit is contained in:
parent
2a561bcf01
commit
287a35965e
@ -443,9 +443,9 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
)
|
)
|
||||||
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
||||||
# objec grounding
|
# objec grounding
|
||||||
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
# obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
|
||||||
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
||||||
og_loss += self.criterion(obj_logits, obj_targets)
|
# og_loss += self.criterion(obj_logits, obj_targets)
|
||||||
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
||||||
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
|
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
|
||||||
|
|
||||||
@ -532,11 +532,11 @@ class GMapObjectNavAgent(Seq2SeqAgent):
|
|||||||
|
|
||||||
if train_ml is not None:
|
if train_ml is not None:
|
||||||
ml_loss = ml_loss * train_ml / batch_size
|
ml_loss = ml_loss * train_ml / batch_size
|
||||||
og_loss = og_loss * train_ml / batch_size
|
# og_loss = og_loss * train_ml / batch_size
|
||||||
self.loss += ml_loss
|
self.loss += ml_loss
|
||||||
self.loss += og_loss
|
# self.loss += og_loss
|
||||||
self.logs['IL_loss'].append(ml_loss.item())
|
self.logs['IL_loss'].append(ml_loss.item())
|
||||||
self.logs['OG_loss'].append(og_loss.item())
|
# self.logs['OG_loss'].append(og_loss.item())
|
||||||
|
|
||||||
'''
|
'''
|
||||||
print("TRAJ:")
|
print("TRAJ:")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user