feat: add NOT_FOUND token
This commit is contained in:
parent
7329f7fa0a
commit
857c7e8e10
@ -147,7 +147,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||
candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
@ -155,6 +155,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
for i, ob in enumerate(obs):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
candidate_feat[i, -1, :] = np.ones((self.feature_size + args.angle_feat_size))
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
|
||||
@ -186,7 +187,10 @@ class Seq2SeqAgent(BaseAgent):
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = len(ob['candidate'])
|
||||
if ob['swap']: # instruction 有被換過,所以要 not found
|
||||
a[i] = len(ob['candidate'])
|
||||
else: # STOP
|
||||
a[i] = len(ob['candidate'])-1
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
||||
@ -205,7 +209,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
for i, idx in enumerate(perm_idx):
|
||||
action = a_t[i]
|
||||
if action != -1: # -1 is the <stop> action
|
||||
print('action: ', action)
|
||||
if action != -1 and action != -2: # -1 is the <stop> action
|
||||
select_candidate = perm_obs[i]['candidate'][action]
|
||||
src_point = perm_obs[i]['viewIndex']
|
||||
trg_point = select_candidate['pointId']
|
||||
@ -228,6 +233,11 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||
if traj is not None:
|
||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||
elif action == -1:
|
||||
print('<STOP>')
|
||||
elif action == -2:
|
||||
print('<NOT_FOUND>')
|
||||
|
||||
|
||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||
"""
|
||||
@ -253,7 +263,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
sentence, language_attention_mask, token_type_ids, \
|
||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||
|
||||
print("perm_index:", perm_idx)
|
||||
perm_obs = obs[perm_idx]
|
||||
|
||||
|
||||
@ -297,7 +306,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||
@ -328,7 +336,17 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended)
|
||||
print("target: ", target.shape)
|
||||
for i, d in enumerate(target):
|
||||
print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
|
||||
print(d)
|
||||
_, at_t = logit.max(1)
|
||||
if at_t[i].item() == candidate_leng[i]-1:
|
||||
print("-2")
|
||||
elif at_t[i].item() == candidate_leng[i]-2:
|
||||
print("-1")
|
||||
else:
|
||||
print(at_t[i].item())
|
||||
print()
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
@ -349,12 +367,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
print(self.feedback)
|
||||
sys.exit('Invalid feedback option')
|
||||
|
||||
# Prepare environment action
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||
if next_id == (candidate_leng[i]-2) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
elif next_id == (candidate_leng[i]-1):
|
||||
cpu_a_t[i] = -2
|
||||
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
||||
@ -381,8 +402,22 @@ class Seq2SeqAgent(BaseAgent):
|
||||
if action_idx == -1: # If the action now is end
|
||||
if dist[i] < 3.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['swap']:
|
||||
reward[i] -= 2
|
||||
else:
|
||||
reward[i] += 1
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
elif action_idx == -2: # NOT_FOUND reward 設定在這裏
|
||||
if dist[i] < 3.0:
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['swap']:
|
||||
reward[i] += 3 # 偵測到錯誤 instruction,多加一分
|
||||
else:
|
||||
reward[i] -= 2
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一
|
||||
else: # The action is not end
|
||||
# Path fidelity rewards (distance & nDTW)
|
||||
reward[i] = - (dist[i] - last_dist[i])
|
||||
@ -404,6 +439,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Update the finished actions
|
||||
# -1 means ended or ignored (already ended)
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
||||
|
||||
# Early exit if all ended
|
||||
if ended.all():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user