fix: <not found> in in obj_ref logit
This commit is contained in:
parent
dfe586b9ab
commit
2fba923103
@ -156,7 +156,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate'])+1 for ob in obs]
|
||||
candidate_leng = [len(ob['candidate']) for ob in obs]
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
@ -165,8 +165,10 @@ class Seq2SeqAgent(BaseAgent):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
result = torch.from_numpy(candidate_feat)
|
||||
'''
|
||||
for i, ob in enumerate(obs):
|
||||
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
||||
'''
|
||||
|
||||
result = result.cuda()
|
||||
|
||||
@ -216,10 +218,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = cand_size - 1
|
||||
'''
|
||||
if ob['found']:
|
||||
a[i] = cand_size - 1
|
||||
else:
|
||||
a[i] = candidate_leng[i] - 1
|
||||
'''
|
||||
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
@ -231,9 +236,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
candidate_objs = ob['candidate_obj'][2]
|
||||
for k, kid in enumerate(candidate_objs):
|
||||
if kid == ob['objId'] and ob['found']:
|
||||
a[i] = k
|
||||
break
|
||||
if kid == ob['objId']:
|
||||
if ob['found']:
|
||||
a[i] = k
|
||||
break
|
||||
else:
|
||||
a[i] = len(candidate_objs)
|
||||
break
|
||||
else:
|
||||
a[i] = args.ignoreid
|
||||
return torch.from_numpy(a).cuda()
|
||||
@ -430,17 +439,34 @@ class Seq2SeqAgent(BaseAgent):
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
if ref_t != obj_leng[i]-1: # decide not to do REF
|
||||
traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
||||
else:
|
||||
traj[i]['ref'] = 'NOT_FOUND'
|
||||
else:
|
||||
just_ended[i] = False
|
||||
|
||||
if (next_id == args.ignoreid) or (ended[i]):
|
||||
cpu_a_t[i] = found[i]
|
||||
elif (next_id == visual_temp_mask.size(1)):
|
||||
'''
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
if ref_t != obj_leng[i]-1:
|
||||
cpu_a_t[i] = -1
|
||||
found[i] = -1
|
||||
else:
|
||||
cpu_a_t[i] = -2
|
||||
found[i] = -2
|
||||
else:
|
||||
'''
|
||||
cpu_a_t[i] = -1
|
||||
found[i] = -1
|
||||
elif (next_id == (candidate_leng[i]-1)):
|
||||
cpu_a_t[i] = -2
|
||||
found[i] = -2
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
if ref_t == obj_leng[i]-1:
|
||||
found[i] = -2
|
||||
else:
|
||||
found[i] = -1
|
||||
|
||||
'''
|
||||
print("MODE: ", self.feedback)
|
||||
print("logit: ", logit)
|
||||
@ -515,19 +541,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# reward[i] = -2.0
|
||||
if dist[i] < 1.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] += 1
|
||||
else:
|
||||
reward[i] -= 2
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
elif action_idx == -2:
|
||||
if dist[i] < 1.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] -= 2
|
||||
else:
|
||||
reward[i] += 1
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
else: # The action is not end
|
||||
@ -551,10 +564,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Update the finished actions
|
||||
# -1 means ended or ignored (already ended)
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
||||
|
||||
# Early exit if all ended
|
||||
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
|
||||
if ended.all():
|
||||
'''
|
||||
if train_ml is None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user