diff --git a/r2r_src/agent.py b/r2r_src/agent.py index c1e3eb6..064c673 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -156,7 +156,7 @@ class Seq2SeqAgent(BaseAgent): return Variable(torch.from_numpy(features), requires_grad=False).cuda() def _candidate_variable(self, obs): - candidate_leng = [len(ob['candidate'])+1 for ob in obs] + candidate_leng = [len(ob['candidate']) for ob in obs] candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) # Note: The candidate_feat at len(ob['candidate']) is the feature for the END @@ -165,8 +165,10 @@ class Seq2SeqAgent(BaseAgent): for j, cc in enumerate(ob['candidate']): candidate_feat[i, j, :] = cc['feature'] result = torch.from_numpy(candidate_feat) + ''' for i, ob in enumerate(obs): result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32) + ''' result = result.cuda() @@ -216,10 +218,13 @@ class Seq2SeqAgent(BaseAgent): break else: # Stop here assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" + a[i] = cand_size - 1 + ''' if ob['found']: a[i] = cand_size - 1 else: a[i] = candidate_leng[i] - 1 + ''' return torch.from_numpy(a).cuda() @@ -231,9 +236,13 @@ class Seq2SeqAgent(BaseAgent): else: candidate_objs = ob['candidate_obj'][2] for k, kid in enumerate(candidate_objs): - if kid == ob['objId'] and ob['found']: - a[i] = k - break + if kid == ob['objId']: + if ob['found']: + a[i] = k + break + else: + a[i] = len(candidate_objs) + break else: a[i] = args.ignoreid return torch.from_numpy(a).cuda() @@ -430,17 +439,34 @@ class Seq2SeqAgent(BaseAgent): _, ref_t = logit_REF[i].max(0) if ref_t != obj_leng[i]-1: # decide not to do REF traj[i]['ref'] = perm_obs[i]['candidate_obj'][2][ref_t] + else: + traj[i]['ref'] = 'NOT_FOUND' else: just_ended[i] = False if (next_id == args.ignoreid) or (ended[i]): cpu_a_t[i] = found[i] elif (next_id == visual_temp_mask.size(1)): + ''' + if self.feedback == 'argmax': + _, ref_t = logit_REF[i].max(0) + if ref_t != obj_leng[i]-1: + cpu_a_t[i] = -1 + found[i] = -1 + else: + cpu_a_t[i] = -2 + found[i] = -2 + else: + ''' cpu_a_t[i] = -1 found[i] = -1 - elif (next_id == (candidate_leng[i]-1)): - cpu_a_t[i] = -2 - found[i] = -2 + if self.feedback == 'argmax': + _, ref_t = logit_REF[i].max(0) + if ref_t == obj_leng[i]-1: + found[i] = -2 + else: + found[i] = -1 + ''' print("MODE: ", self.feedback) print("logit: ", logit) @@ -515,19 +541,6 @@ class Seq2SeqAgent(BaseAgent): # reward[i] = -2.0 if dist[i] < 1.0: # Correct reward[i] = 2.0 + ndtw_score[i] * 2.0 - if ob['found']: - reward[i] += 1 - else: - reward[i] -= 2 - else: # Incorrect - reward[i] = -2.0 - elif action_idx == -2: - if dist[i] < 1.0: # Correct - reward[i] = 2.0 + ndtw_score[i] * 2.0 - if ob['found']: - reward[i] -= 2 - else: - reward[i] += 1 else: # Incorrect reward[i] = -2.0 else: # The action is not end @@ -551,10 +564,8 @@ class Seq2SeqAgent(BaseAgent): # Update the finished actions # -1 means ended or ignored (already ended) ended[:] = np.logical_or(ended, (cpu_a_t == -1)) - ended[:] = np.logical_or(ended, (cpu_a_t == -2)) # Early exit if all ended - target_found = [ (-1 if i['found'] else -2) for i in perm_obs ] if ended.all(): ''' if train_ml is None: