fix: <not found> in in obj_ref logit
This commit is contained in:
parent
dfe586b9ab
commit
2fba923103
@ -156,7 +156,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||||
|
|
||||||
def _candidate_variable(self, obs):
|
def _candidate_variable(self, obs):
|
||||||
candidate_leng = [len(ob['candidate'])+1 for ob in obs]
|
candidate_leng = [len(ob['candidate']) for ob in obs]
|
||||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||||
|
|
||||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||||
@ -165,8 +165,10 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
result = torch.from_numpy(candidate_feat)
|
result = torch.from_numpy(candidate_feat)
|
||||||
|
'''
|
||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
||||||
|
'''
|
||||||
|
|
||||||
result = result.cuda()
|
result = result.cuda()
|
||||||
|
|
||||||
@ -216,10 +218,13 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
break
|
break
|
||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
|
a[i] = cand_size - 1
|
||||||
|
'''
|
||||||
if ob['found']:
|
if ob['found']:
|
||||||
a[i] = cand_size - 1
|
a[i] = cand_size - 1
|
||||||
else:
|
else:
|
||||||
a[i] = candidate_leng[i] - 1
|
a[i] = candidate_leng[i] - 1
|
||||||
|
'''
|
||||||
|
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
@ -231,9 +236,13 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
candidate_objs = ob['candidate_obj'][2]
|
candidate_objs = ob['candidate_obj'][2]
|
||||||
for k, kid in enumerate(candidate_objs):
|
for k, kid in enumerate(candidate_objs):
|
||||||
if kid == ob['objId'] and ob['found']:
|
if kid == ob['objId']:
|
||||||
a[i] = k
|
if ob['found']:
|
||||||
break
|
a[i] = k
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
a[i] = len(candidate_objs)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
a[i] = args.ignoreid
|
a[i] = args.ignoreid
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
@ -430,17 +439,34 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
_, ref_t = logit_REF[i].max(0)
|
_, ref_t = logit_REF[i].max(0)
|
||||||
if ref_t != obj_leng[i]-1: # decide not to do REF
|
if ref_t != obj_leng[i]-1: # decide not to do REF
|
||||||
traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
||||||
|
else:
|
||||||
|
traj[i]['ref'] = 'NOT_FOUND'
|
||||||
else:
|
else:
|
||||||
just_ended[i] = False
|
just_ended[i] = False
|
||||||
|
|
||||||
if (next_id == args.ignoreid) or (ended[i]):
|
if (next_id == args.ignoreid) or (ended[i]):
|
||||||
cpu_a_t[i] = found[i]
|
cpu_a_t[i] = found[i]
|
||||||
elif (next_id == visual_temp_mask.size(1)):
|
elif (next_id == visual_temp_mask.size(1)):
|
||||||
|
'''
|
||||||
|
if self.feedback == 'argmax':
|
||||||
|
_, ref_t = logit_REF[i].max(0)
|
||||||
|
if ref_t != obj_leng[i]-1:
|
||||||
|
cpu_a_t[i] = -1
|
||||||
|
found[i] = -1
|
||||||
|
else:
|
||||||
|
cpu_a_t[i] = -2
|
||||||
|
found[i] = -2
|
||||||
|
else:
|
||||||
|
'''
|
||||||
cpu_a_t[i] = -1
|
cpu_a_t[i] = -1
|
||||||
found[i] = -1
|
found[i] = -1
|
||||||
elif (next_id == (candidate_leng[i]-1)):
|
if self.feedback == 'argmax':
|
||||||
cpu_a_t[i] = -2
|
_, ref_t = logit_REF[i].max(0)
|
||||||
found[i] = -2
|
if ref_t == obj_leng[i]-1:
|
||||||
|
found[i] = -2
|
||||||
|
else:
|
||||||
|
found[i] = -1
|
||||||
|
|
||||||
'''
|
'''
|
||||||
print("MODE: ", self.feedback)
|
print("MODE: ", self.feedback)
|
||||||
print("logit: ", logit)
|
print("logit: ", logit)
|
||||||
@ -515,19 +541,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# reward[i] = -2.0
|
# reward[i] = -2.0
|
||||||
if dist[i] < 1.0: # Correct
|
if dist[i] < 1.0: # Correct
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
if ob['found']:
|
|
||||||
reward[i] += 1
|
|
||||||
else:
|
|
||||||
reward[i] -= 2
|
|
||||||
else: # Incorrect
|
|
||||||
reward[i] = -2.0
|
|
||||||
elif action_idx == -2:
|
|
||||||
if dist[i] < 1.0: # Correct
|
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
|
||||||
if ob['found']:
|
|
||||||
reward[i] -= 2
|
|
||||||
else:
|
|
||||||
reward[i] += 1
|
|
||||||
else: # Incorrect
|
else: # Incorrect
|
||||||
reward[i] = -2.0
|
reward[i] = -2.0
|
||||||
else: # The action is not end
|
else: # The action is not end
|
||||||
@ -551,10 +564,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Update the finished actions
|
# Update the finished actions
|
||||||
# -1 means ended or ignored (already ended)
|
# -1 means ended or ignored (already ended)
|
||||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
|
||||||
|
|
||||||
# Early exit if all ended
|
# Early exit if all ended
|
||||||
target_found = [ (-1 if i['found'] else -2) for i in perm_obs ]
|
|
||||||
if ended.all():
|
if ended.all():
|
||||||
'''
|
'''
|
||||||
if train_ml is None:
|
if train_ml is None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user