From 857c7e8e10502cb331e190b7dc464a039a28e51e Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Sat, 4 Nov 2023 22:53:01 +0800 Subject: [PATCH] feat: add NOT_FOUND token --- r2r_src/agent.py | 50 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 53fffb6..d8f8ef4 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -147,7 +147,7 @@ class Seq2SeqAgent(BaseAgent): return Variable(torch.from_numpy(features), requires_grad=False).cuda() def _candidate_variable(self, obs): - candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end + candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) # Note: The candidate_feat at len(ob['candidate']) is the feature for the END @@ -155,6 +155,7 @@ class Seq2SeqAgent(BaseAgent): for i, ob in enumerate(obs): for j, cc in enumerate(ob['candidate']): candidate_feat[i, j, :] = cc['feature'] + candidate_feat[i, -1, :] = np.ones((self.feature_size + args.angle_feat_size)) return torch.from_numpy(candidate_feat).cuda(), candidate_leng @@ -186,7 +187,10 @@ class Seq2SeqAgent(BaseAgent): break else: # Stop here assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" - a[i] = len(ob['candidate']) + if ob['swap']: # instruction 有被換過,所以要 not found + a[i] = len(ob['candidate']) + else: # STOP + a[i] = len(ob['candidate'])-1 return torch.from_numpy(a).cuda() def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None): @@ -205,7 +209,8 @@ class Seq2SeqAgent(BaseAgent): for i, idx in enumerate(perm_idx): action = a_t[i] - if action != -1: # -1 is the action + print('action: ', action) + if action != -1 and action != -2: # -1 is the action select_candidate = perm_obs[i]['candidate'][action] src_point = perm_obs[i]['viewIndex'] trg_point = select_candidate['pointId'] @@ -228,6 +233,11 @@ class Seq2SeqAgent(BaseAgent): # print("action: {} view_index: {}".format(action, state.viewIndex)) if traj is not None: traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) + elif action == -1: + print('') + elif action == -2: + print('') + def rollout(self, train_ml=None, train_rl=True, reset=True): """ @@ -253,7 +263,6 @@ class Seq2SeqAgent(BaseAgent): sentence, language_attention_mask, token_type_ids, \ seq_lengths, perm_idx = self._sort_batch(obs) - print("perm_index:", perm_idx) perm_obs = obs[perm_idx] @@ -297,7 +306,6 @@ class Seq2SeqAgent(BaseAgent): input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) - # the first [CLS] token, initialized by the language BERT, serves # as the agent's state passing through time steps if (t >= 1) or (args.vlnbert=='prevalent'): @@ -328,7 +336,17 @@ class Seq2SeqAgent(BaseAgent): # Supervised training target = self._teacher_action(perm_obs, ended) - print("target: ", target.shape) + for i, d in enumerate(target): + print(perm_obs[i]['swap'], perm_obs[i]['instructions']) + print(d) + _, at_t = logit.max(1) + if at_t[i].item() == candidate_leng[i]-1: + print("-2") + elif at_t[i].item() == candidate_leng[i]-2: + print("-1") + else: + print(at_t[i].item()) + print() ml_loss += self.criterion(logit, target) # Determine next model inputs @@ -349,12 +367,15 @@ class Seq2SeqAgent(BaseAgent): else: print(self.feedback) sys.exit('Invalid feedback option') + # Prepare environment action # NOTE: Env action is in the perm_obs space cpu_a_t = a_t.cpu().numpy() for i, next_id in enumerate(cpu_a_t): - if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is + if next_id == (candidate_leng[i]-2) or next_id == args.ignoreid or ended[i]: # The last action is cpu_a_t[i] = -1 # Change the and ignore action to -1 + elif next_id == (candidate_leng[i]-1): + cpu_a_t[i] = -2 # Make action and get the new state self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj) @@ -381,8 +402,22 @@ class Seq2SeqAgent(BaseAgent): if action_idx == -1: # If the action now is end if dist[i] < 3.0: # Correct reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['swap']: + reward[i] -= 2 + else: + reward[i] += 1 else: # Incorrect reward[i] = -2.0 + elif action_idx == -2: # NOT_FOUND reward 設定在這裏 + if dist[i] < 3.0: + reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['swap']: + reward[i] += 3 # 偵測到錯誤 instruction,多加一分 + else: + reward[i] -= 2 + else: # Incorrect + reward[i] = -2.0 + reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一 else: # The action is not end # Path fidelity rewards (distance & nDTW) reward[i] = - (dist[i] - last_dist[i]) @@ -404,6 +439,7 @@ class Seq2SeqAgent(BaseAgent): # Update the finished actions # -1 means ended or ignored (already ended) ended[:] = np.logical_or(ended, (cpu_a_t == -1)) + ended[:] = np.logical_or(ended, (cpu_a_t == -2)) # Early exit if all ended if ended.all():