From 03a3e5b489bbbaea619a6d667c17d744ae59bd3e Mon Sep 17 00:00:00 2001 From: Ting-Jun Wang Date: Mon, 6 Nov 2023 18:31:14 +0800 Subject: [PATCH] feat: add NOT_FOUND action in rollout --- r2r_src/agent.py | 74 ++++++++++++++++++++++++++++++++++++++++++------ r2r_src/env.py | 5 +++- 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/r2r_src/agent.py b/r2r_src/agent.py index aae150a..66d9a74 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -147,7 +147,7 @@ class Seq2SeqAgent(BaseAgent): return Variable(torch.from_numpy(features), requires_grad=False).cuda() def _candidate_variable(self, obs): - candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end + candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) # Note: The candidate_feat at len(ob['candidate']) is the feature for the END @@ -155,6 +155,8 @@ class Seq2SeqAgent(BaseAgent): for i, ob in enumerate(obs): for j, cc in enumerate(ob['candidate']): candidate_feat[i, j, :] = cc['feature'] + candidate_feat[i, len(ob['candidate']), :] = np.zeros(self.feature_size+args.angle_feat_size, dtype=np.float32) # + candidate_feat[i, len(ob['candidate'])+1, :] = np.ones(self.feature_size+args.angle_feat_size, dtype=np.float32) # return torch.from_numpy(candidate_feat).cuda(), candidate_leng @@ -186,10 +188,13 @@ class Seq2SeqAgent(BaseAgent): break else: # Stop here assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" - a[i] = len(ob['candidate']) + if ob['found']: + a[i] = len(ob['candidate']) + else: + a[i] = len(ob['candidate'])+1 return torch.from_numpy(a).cuda() - def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None): + def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None): """ Interface between Panoramic view and Egocentric view It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator @@ -205,7 +210,7 @@ class Seq2SeqAgent(BaseAgent): for i, idx in enumerate(perm_idx): action = a_t[i] - if action != -1: # -1 is the action + if action != -1 and action != -2: # -1 is the action select_candidate = perm_obs[i]['candidate'][action] src_point = perm_obs[i]['viewIndex'] trg_point = select_candidate['pointId'] @@ -228,6 +233,10 @@ class Seq2SeqAgent(BaseAgent): # print("action: {} view_index: {}".format(action, state.viewIndex)) if traj is not None: traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) + elif action == -1 or action == -2: + if found is not None: + found[i] = action + def rollout(self, train_ml=None, train_rl=True, reset=True): """ @@ -246,7 +255,7 @@ class Seq2SeqAgent(BaseAgent): obs = np.array(self.env.reset()) else: obs = np.array(self.env._get_obs()) - + batch_size = len(obs) # Language input @@ -270,6 +279,8 @@ class Seq2SeqAgent(BaseAgent): 'instr_id': ob['instr_id'], 'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])], } for ob in perm_obs] + + found = [ None for _ in range(len(perm_obs)) ] # Init the reward shaping last_dist = np.zeros(batch_size, np.float32) @@ -293,6 +304,15 @@ class Seq2SeqAgent(BaseAgent): for t in range(self.episode_len): input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) + + ''' + # show feature + for index, feat in enumerate(candidate_feat): + for ff in feat: + print(ff) + print(candidate_leng[index]) + print() + ''' # the first [CLS] token, initialized by the language BERT, serves @@ -324,9 +344,22 @@ class Seq2SeqAgent(BaseAgent): # Supervised training target = self._teacher_action(perm_obs, ended) + for i in perm_obs: + print(i['found'], end=' ') ml_loss += self.criterion(logit, target) + + ''' + for index, mask in enumerate(candidate_mask): + print(mask) + print(candidate_leng[index]) + print(logit[index]) + print(target[index]) + print("\n\n") + ''' + # Determine next model inputs + if self.feedback == 'teacher': a_t = target # teacher forcing elif self.feedback == 'argmax': @@ -344,15 +377,24 @@ class Seq2SeqAgent(BaseAgent): else: print(self.feedback) sys.exit('Invalid feedback option') + # Prepare environment action # NOTE: Env action is in the perm_obs space cpu_a_t = a_t.cpu().numpy() for i, next_id in enumerate(cpu_a_t): - if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is - cpu_a_t[i] = -1 # Change the and ignore action to -1 + if next_id == (args.ignoreid) or ended[i]: + cpu_a_t[i] = found[i] + elif next_id == (candidate_leng[i]-2): + cpu_a_t[i] = -1 + elif next_id == (candidate_leng[i]-1): + cpu_a_t[i] = -2 + + + print(cpu_a_t) # Make action and get the new state - self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj) + self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found) + print(self.feedback, found) obs = np.array(self.env._get_obs()) perm_obs = obs[perm_idx] # Perm the obs for the resu @@ -376,6 +418,20 @@ class Seq2SeqAgent(BaseAgent): if action_idx == -1: # If the action now is end if dist[i] < 3.0: # Correct reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['found']: + reward[i] += 1 + else: + reward[i] -= 2 + else: # Incorrect + reward[i] = -2.0 + + elif action_idx == -2: + if dist[i] < 3.0: + reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['found']: + reward[i] -= 2 + else: + reward[i] += 1 else: # Incorrect reward[i] = -2.0 else: # The action is not end @@ -399,6 +455,7 @@ class Seq2SeqAgent(BaseAgent): # Update the finished actions # -1 means ended or ignored (already ended) ended[:] = np.logical_or(ended, (cpu_a_t == -1)) + ended[:] = np.logical_or(ended, (cpu_a_t == -2)) # Early exit if all ended if ended.all(): @@ -476,6 +533,7 @@ class Seq2SeqAgent(BaseAgent): else: self.losses.append(self.loss.item() / self.episode_len) # This argument is useless. + print('\n') return traj def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): diff --git a/r2r_src/env.py b/r2r_src/env.py index 5b36e62..64e2b54 100644 --- a/r2r_src/env.py +++ b/r2r_src/env.py @@ -127,6 +127,7 @@ class R2RBatch(): new_item = dict(item) new_item['instr_id'] = '%s_%d' % (item['path_id'], j) new_item['instructions'] = instr + new_item['found'] = item['found'][j] ''' BERT tokenizer ''' instr_tokens = tokenizer.tokenize(instr) @@ -328,6 +329,7 @@ class R2RBatch(): # [visual_feature, angle_feature] for views feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1) + obs.append({ 'instr_id' : item['instr_id'], 'scan' : state.scanId, @@ -341,7 +343,8 @@ class R2RBatch(): 'instructions' : item['instructions'], 'teacher' : self._shortest_path_action(state, item['path'][-1]), 'gt_path' : item['path'], - 'path_id' : item['path_id'] + 'path_id' : item['path_id'], + 'found': item['found'] }) if 'instr_encoding' in item: obs[-1]['instr_encoding'] = item['instr_encoding']