diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 259ea81..93d6a72 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -38,7 +38,7 @@ class BaseAgent(object): json.dump(output, f) def get_results(self): - output = [{'instr_id': k, 'trajectory': v, 'ref': r} for k, (v,r) in self.results.items()] + output = [{'instr_id': k, 'trajectory': v, 'ref': r, 'found': found} for k, (v,r, found) in self.results.items()] return output def rollout(self, **args): @@ -59,17 +59,19 @@ class BaseAgent(object): if iters is not None: # For each time, it will run the first 'iters' iterations. (It was shuffled before) for i in range(iters): - for traj in self.rollout(**kwargs): + trajs, found = self.rollout(**kwargs) + for index, traj in enumerate(trajs): self.loss = 0 - self.results[traj['instr_id']] = (traj['path'], traj['ref']) + self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index]) else: # Do a full round while True: - for traj in self.rollout(**kwargs): + trajs, found = self.rollout(**kwargs) + for index, traj in enumerate(trajs): if traj['instr_id'] in self.results: looped = True else: self.loss = 0 - self.results[traj['instr_id']] = (traj['path'], traj['ref']) + self.results[traj['instr_id']] = (traj['path'], traj['ref'], found[index]) if looped: break @@ -154,15 +156,21 @@ class Seq2SeqAgent(BaseAgent): return Variable(torch.from_numpy(features), requires_grad=False).cuda() def _candidate_variable(self, obs): - candidate_leng = [len(ob['candidate']) for ob in obs] + candidate_leng = [len(ob['candidate'])+1 for ob in obs] candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32) + # Note: The candidate_feat at len(ob['candidate']) is the feature for the END # which is zero in my implementation for i, ob in enumerate(obs): for j, cc in enumerate(ob['candidate']): candidate_feat[i, j, :] = cc['feature'] + result = torch.from_numpy(candidate_feat) + for i, ob in enumerate(obs): + result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32) + + result = result.cuda() - return torch.from_numpy(candidate_feat).cuda(), candidate_leng + return result, candidate_leng def _object_variable(self, obs): cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF @@ -190,7 +198,7 @@ class Seq2SeqAgent(BaseAgent): return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng - def _teacher_action(self, obs, ended, cand_size): + def _teacher_action(self, obs, ended, cand_size, candidate_leng): """ Extract teacher actions into variable. :param obs: The observation. @@ -208,7 +216,11 @@ class Seq2SeqAgent(BaseAgent): break else: # Stop here assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" - a[i] = cand_size - 1 + if ob['found']: + a[i] = cand_size - 1 + else: + a[i] = candidate_leng[i] - 1 + return torch.from_numpy(a).cuda() def _teacher_REF(self, obs, just_ended): @@ -242,7 +254,7 @@ class Seq2SeqAgent(BaseAgent): for i, idx in enumerate(perm_idx): action = a_t[i] - if action != -1: # -1 is the action + if action != -1 and action != -2: # -1 is the action select_candidate = perm_obs[i]['candidate'][action] src_point = perm_obs[i]['viewIndex'] trg_point = select_candidate['pointId'] @@ -315,6 +327,7 @@ class Seq2SeqAgent(BaseAgent): # Initialization the tracking state ended = np.array([False] * batch_size) # Indices match permuation of the model, not env just_ended = np.array([False] * batch_size) + found = np.array([None] * batch_size) # Init the logs rewards = [] @@ -330,6 +343,15 @@ class Seq2SeqAgent(BaseAgent): input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng = self.get_input_feat(perm_obs) + + ''' + for i in candidate_feat: + print(candidate_leng) + for j in i: + print(j) + print() + ''' + # the first [CLS] token, initialized by the language BERT, servers # as the agent's state passing through time steps language_features = torch.cat((h_t.unsqueeze(1), language_features[:,1:,:]), dim=1) @@ -358,6 +380,7 @@ class Seq2SeqAgent(BaseAgent): h_t, logit, logit_REF = self.vln_bert(**visual_inputs) hidden_states.append(h_t) + # print('time step', t) # import pdb; pdb.set_trace() @@ -372,7 +395,7 @@ class Seq2SeqAgent(BaseAgent): logit_REF.masked_fill_(candidate_mask_obj, -float('inf')) # Supervised training - target = self._teacher_action(perm_obs, ended, candidate_mask.size(1)) + target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng) ml_loss += self.criterion(logit, target) # Determine next model inputs @@ -400,7 +423,8 @@ class Seq2SeqAgent(BaseAgent): # NOTE: Env action is in the perm_obs space cpu_a_t = a_t.cpu().numpy() for i, next_id in enumerate(cpu_a_t): - if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped + if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \ + and (not ended[i]): # just stopped and forced stopped just_ended[i] = True if self.feedback == 'argmax': _, ref_t = logit_REF[i].max(0) @@ -409,8 +433,25 @@ class Seq2SeqAgent(BaseAgent): else: just_ended[i] = False - if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is - cpu_a_t[i] = -1 # Change the and ignore action to -1 + if (next_id == args.ignoreid) or (ended[i]): + cpu_a_t[i] = found[i] + elif (next_id == visual_temp_mask.size(1)): + cpu_a_t[i] = -1 + found[i] = -1 + elif (next_id == (candidate_leng[i]-1)): + cpu_a_t[i] = -2 + found[i] = -2 + ''' + print("MODE: ", self.feedback) + print("logit: ", logit) + print("leng:", candidate_leng) + print("cpu_a_t: ", cpu_a_t) + if train_ml is not None: + print("target: ", target) + for i in perm_obs: + print(i['found'], i['instructions']) + print() + ''' ''' Supervised training for REF ''' if train_ml is not None: @@ -456,6 +497,19 @@ class Seq2SeqAgent(BaseAgent): # reward[i] = -2.0 if dist[i] < 1.0: # Correct reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['found']: + reward[i] += 1 + else: + reward[i] -= 2 + else: # Incorrect + reward[i] = -2.0 + elif action_idx == -2: + if dist[i] < 1.0: # Correct + reward[i] = 2.0 + ndtw_score[i] * 2.0 + if ob['found']: + reward[i] -= 2 + else: + reward[i] += 1 else: # Incorrect reward[i] = -2.0 else: # The action is not end @@ -479,9 +533,15 @@ class Seq2SeqAgent(BaseAgent): # Update the finished actions # -1 means ended or ignored (already ended) ended[:] = np.logical_or(ended, (cpu_a_t == -1)) + ended[:] = np.logical_or(ended, (cpu_a_t == -2)) # Early exit if all ended + target_found = [ (-1 if i['found'] else -2) for i in perm_obs ] if ended.all(): + ''' + if train_ml is None: + print(target_found, found) + ''' break if train_rl: @@ -563,7 +623,8 @@ class Seq2SeqAgent(BaseAgent): # import pdb; pdb.set_trace() - return traj + + return traj, found def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): ''' Evaluate once on each instruction in the current environment ''' diff --git a/r2r_src/env.py b/r2r_src/env.py index 364548e..ef2b72c 100644 --- a/r2r_src/env.py +++ b/r2r_src/env.py @@ -115,6 +115,7 @@ class R2RBatch(): new_item = dict(item) new_item['instr_id'] = '%s_%d' % (item['id'], j) new_item['instructions'] = instr + new_item['found'] = item['found'][j] ''' BERT tokenizer ''' instr_tokens = tokenizer.tokenize(instr) @@ -332,7 +333,8 @@ class R2RBatch(): 'gt_path' : item['path'], 'path_id' : item['id'], 'objId': str(item['objId']), # target objId - 'candidate_obj': (obj_local_pos, obj_features, candidate_objId) + 'candidate_obj': (obj_local_pos, obj_features, candidate_objId), + 'found': item['found'] }) if 'instr_encoding' in item: obs[-1]['instr_encoding'] = item['instr_encoding'] diff --git a/r2r_src/eval.py b/r2r_src/eval.py index 44ecf2d..1311e10 100644 --- a/r2r_src/eval.py +++ b/r2r_src/eval.py @@ -62,11 +62,12 @@ class Evaluation(object): near_d = d return near_id - def _score_item(self, instr_id, path, ref_objId): + def _score_item(self, instr_id, path, ref_objId, predict_found): ''' Calculate error based on the final position in trajectory, and also the closest position (oracle stopping rule). The path contains [view_id, angle, vofv] ''' gt = self.gt[instr_id[:-2]] + index = int(instr_id.split('_')[-1]) start = gt['path'][0] assert start == path[0][0], 'Result trajectories should include the start position' goal = gt['path'][-1] @@ -86,6 +87,16 @@ class Evaluation(object): self.distances[gt['scan']][start][goal] ) + # print(predict_found, gt['found'], gt['found'][index]) + + if gt['found'][index] == True: + if predict_found == -1: + self.scores['found_count'] += 1 + else: + if predict_found == -2: + self.scores['found_count'] += 1 + + # REF sucess or not if ref_objId == str(gt['objId']): self.scores['rgs'].append(1) @@ -111,33 +122,11 @@ class Evaluation(object): self.scores['oracle_visible'].append(oracle_succ) - # # if self.scores['nav_errors'][-1] < self.error_margin: - # # print('item', item) - # ndtw_path = [k[0] for k in item['trajectory']] - # # print('path', ndtw_path) - # - # path_id = item['instr_id'][:-2] - # # print('path id', path_id) - # path_scan_id, path_ref = self.scan_gts[path_id] - # # print('path_scan_id', path_scan_id) - # # print('path_ref', path_ref) - # - # path_act = [] - # for jdx, pid in enumerate(ndtw_path): - # if jdx != 0: - # if pid != path_act[-1]: - # path_act.append(pid) - # else: - # path_act.append(pid) - # # print('path act', path_act) - # - # ndtw_score = self.ndtw_criterion[path_scan_id](path_act, path_ref, metric='ndtw') - # ndtw_scores.append(ndtw_score) - # print('nDTW score: ', np.average(ndtw_scores)) def score(self, output_file): ''' Evaluate each agent trajectory based on how close it got to the goal location ''' self.scores = defaultdict(list) + self.scores['found_count'] = 0 instr_ids = set(self.instr_ids) if type(output_file) is str: with open(output_file) as f: @@ -145,12 +134,15 @@ class Evaluation(object): else: results = output_file + print('result length', len(results)) + path_counter = 0 for item in results: # Check against expected ids if item['instr_id'] in instr_ids: instr_ids.remove(item['instr_id']) - self._score_item(item['instr_id'], item['trajectory'], item['ref']) + self._score_item(item['instr_id'], item['trajectory'], item['ref'], item['found']) + path_counter += 1 if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ @@ -159,9 +151,11 @@ class Evaluation(object): score_summary = { 'steps': np.average(self.scores['trajectory_steps']), - 'lengths': np.average(self.scores['trajectory_lengths']) + 'lengths': np.average(self.scores['trajectory_lengths']), + 'found_score': self.scores['found_count'] / path_counter } end_successes = sum(self.scores['visible']) + score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible'])) oracle_successes = sum(self.scores['oracle_visible']) score_summary['oracle_rate'] = float(oracle_successes) / float(len(self.scores['oracle_visible'])) @@ -172,7 +166,10 @@ class Evaluation(object): ] score_summary['spl'] = np.average(spl) - assert len(self.scores['rgs']) == len(self.instr_ids) + try: + assert len(self.scores['rgs']) == len(self.instr_ids) + except: + print(len(self.scores['rgs']), len(self.instr_ids)) num_rgs = sum(self.scores['rgs']) score_summary['rgs'] = float(num_rgs)/float(len(self.scores['rgs'])) diff --git a/r2r_src/train.py b/r2r_src/train.py index 722aba0..fc7e151 100644 --- a/r2r_src/train.py +++ b/r2r_src/train.py @@ -218,7 +218,8 @@ def train_val(test_only=False): val_env_names = ['val_train_seen'] else: featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) - val_env_names = ['val_seen', 'val_unseen'] + # val_env_names = ['val_seen', 'val_unseen'] + val_env_names = ['train', 'val_unseen'] # val_env_names = ['val_unseen'] train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) @@ -242,6 +243,7 @@ def train_val(test_only=False): if args.train == 'listener': train(train_env, tok, args.iters, val_envs=val_envs) + # train(train_env, tok, 1000, val_envs=val_envs, log_every=10) elif args.train == 'validlistener': valid(train_env, tok, val_envs=val_envs) else: diff --git a/r2r_src/utils.py b/r2r_src/utils.py index f964082..00394df 100644 --- a/r2r_src/utils.py +++ b/r2r_src/utils.py @@ -582,7 +582,7 @@ def print_progress(iteration, total, prefix='', suffix='', decimals=1, bar_lengt str_format = "{0:." + str(decimals) + "f}" percents = str_format.format(100 * (iteration / float(total))) filled_length = int(round(bar_length * iteration / float(total))) - bar = 'LL' * filled_length + '-' * (bar_length - filled_length) + bar = 'L' * filled_length + '-' * (bar_length - filled_length) sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)),