fix: <not found> in in obj_ref logit

This commit is contained in:
Ting-Jun Wang 2023-11-19 21:31:27 +08:00
parent dfe586b9ab
commit 2fba923103
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354

View File

@ -156,7 +156,7 @@ class Seq2SeqAgent(BaseAgent):
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
def _candidate_variable(self, obs):
candidate_leng = [len(ob['candidate'])+1 for ob in obs]
candidate_leng = [len(ob['candidate']) for ob in obs]
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
@ -165,8 +165,10 @@ class Seq2SeqAgent(BaseAgent):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
'''
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
'''
result = result.cuda()
@ -216,10 +218,13 @@ class Seq2SeqAgent(BaseAgent):
break
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = cand_size - 1
'''
if ob['found']:
a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
'''
return torch.from_numpy(a).cuda()
@ -231,9 +236,13 @@ class Seq2SeqAgent(BaseAgent):
else:
candidate_objs = ob['candidate_obj'][2]
for k, kid in enumerate(candidate_objs):
if kid == ob['objId'] and ob['found']:
a[i] = k
break
if kid == ob['objId']:
if ob['found']:
a[i] = k
break
else:
a[i] = len(candidate_objs)
break
else:
a[i] = args.ignoreid
return torch.from_numpy(a).cuda()
@ -430,17 +439,34 @@ class Seq2SeqAgent(BaseAgent):
_, ref_t = logit_REF[i].max(0)
if ref_t != obj_leng[i]-1: # decide not to do REF
traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t]
else:
traj[i]['ref'] = 'NOT_FOUND'
else:
just_ended[i] = False
if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
'''
if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0)
if ref_t != obj_leng[i]-1:
cpu_a_t[i] = -1
found[i] = -1
else:
cpu_a_t[i] = -2
found[i] = -2
else:
'''
cpu_a_t[i] = -1
found[i] = -1
elif (next_id == (candidate_leng[i]-1)):
cpu_a_t[i] = -2
found[i] = -2
if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0)
if ref_t == obj_leng[i]-1:
found[i] = -2
else:
found[i] = -1
'''
print("MODE: ", self.feedback)
print("logit: ", logit)
@ -515,19 +541,6 @@ class Seq2SeqAgent(BaseAgent):
# reward[i] = -2.0
if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] += 1
else:
reward[i] -= 2
else: # Incorrect
reward[i] = -2.0
elif action_idx == -2:
if dist[i] < 1.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']:
reward[i] -= 2
else:
reward[i] += 1
else: # Incorrect
reward[i] = -2.0
else: # The action is not end
@ -551,10 +564,8 @@ class Seq2SeqAgent(BaseAgent):
# Update the finished actions
# -1 means ended or ignored (already ended)
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
# Early exit if all ended
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
if ended.all():
'''
if train_ml is None: