Compare commits

..

3 Commits

6 changed files with 116 additions and 141 deletions

View File

@ -1,42 +0,0 @@
import json
import os
import re
def remove_non_ascii(text):
return re.sub(r'[^\x00-\x7F]', ' ', text)
for file in ['train', 'val_unseen', 'val_seen', 'train_seen', 'test', 'val_train_seen']:
print(file)
if os.path.isfile('data/adversarial/reverie_{}_fnf.json'.format(file)):
with open('data/adversarial/reverie_{}_fnf.json'.format(file)) as fp:
data = json.load(fp)
result = {}
for i in data:
instruction_id = i['path_id']
if instruction_id not in result:
result[instruction_id] = {
'distance': float(i['distance']),
'scan': i['scan'],
'path_id': int(i['path_id']),
'path': i['path'],
'heading': float(i['heading']),
'instructions': [ remove_non_ascii(i['instruction'])],
'found': [ i['found'] ],
'id': i['id'],
'objId': i['objId']
}
else:
result[instruction_id]['instructions'].append(remove_non_ascii(i['instruction']))
result[instruction_id]['found'].append( i['found'] )
output = []
for k, item in result.items():
output.append(item)
else:
output = []
with open('data/adversarial/R2R_{}.json'.format(file), 'w') as fp:
json.dump(output, fp)

21
data/adversarial.py Normal file
View File

@ -0,0 +1,21 @@
import json
import sys
import random
with open(sys.argv[1]) as fp:
data = json.load(fp)
for _, d in enumerate(data):
swaps = []
for index, ins in enumerate(d['instructions']):
p = random.random()
if p > 0.5:
swaps.append(True)
d['instructions'][index] += 'This is swap.'
else:
swaps.append(False)
d['swap'] = swaps
print(data)
with open(sys.argv[1], 'w') as fp:
json.dump(data, fp)

View File

@ -61,14 +61,16 @@ class BaseAgent(object):
if iters is not None: if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before) # For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters): for i in range(iters):
traj, found = self.rollout(**kwargs) trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(traj): print(found)
for index, traj in enumerate(trajs):
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = (traj['path'], found[index]) self.results[traj['instr_id']] = (traj['path'], found[index])
else: # Do a full round else: # Do a full round
while True: while True:
traj, found = self.rollout(**kwargs) trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(traj): print("FOUND: ", found)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results: if traj['instr_id'] in self.results:
looped = True looped = True
else: else:
@ -157,8 +159,9 @@ class Seq2SeqAgent(BaseAgent):
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']): for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature'] candidate_feat[i, j, :] = cc['feature']
candidate_feat[i, len(ob['candidate']), :] = np.zeros(self.feature_size+args.angle_feat_size, dtype=np.float32) # <STOP>
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones(self.feature_size+args.angle_feat_size, dtype=np.float32) # <NOT_FOUND> # 補上 not fount token
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones((self.feature_size + args.angle_feat_size))
return torch.from_numpy(candidate_feat).cuda(), candidate_leng return torch.from_numpy(candidate_feat).cuda(), candidate_leng
@ -190,10 +193,11 @@ class Seq2SeqAgent(BaseAgent):
break break
else: # Stop here else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
if ob['found']: if ob['swap']: # instruction 有被換過,所以要 not found
a[i] = len(ob['candidate']) a[i] = len(ob['candidate'])-1
else: else: # STOP
a[i] = len(ob['candidate'])+1 a[i] = len(ob['candidate'])-2
print(" ", a)
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None): def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
@ -212,6 +216,7 @@ class Seq2SeqAgent(BaseAgent):
for i, idx in enumerate(perm_idx): for i, idx in enumerate(perm_idx):
action = a_t[i] action = a_t[i]
# print('action: ', action)
if action != -1 and action != -2: # -1 is the <stop> action if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action] select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex'] src_point = perm_obs[i]['viewIndex']
@ -235,11 +240,18 @@ class Seq2SeqAgent(BaseAgent):
# print("action: {} view_index: {}".format(action, state.viewIndex)) # print("action: {} view_index: {}".format(action, state.viewIndex))
if traj is not None: if traj is not None:
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation)) traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
elif action == -1 or action == -2: else:
if found is not None: found[i] = action
found[i] = action
'''
elif action == -1:
print('<STOP>')
elif action == -2:
print('<NOT_FOUND>')
'''
def rollout(self, train_ml=None, train_rl=True, reset=True): def rollout(self, train_ml=None, train_rl=True, reset=True):
""" """
:param train_ml: The weight to train with maximum likelihood :param train_ml: The weight to train with maximum likelihood
@ -248,6 +260,7 @@ class Seq2SeqAgent(BaseAgent):
:return: :return:
""" """
print("ROLLOUT!!!")
if self.feedback == 'teacher' or self.feedback == 'argmax': if self.feedback == 'teacher' or self.feedback == 'argmax':
train_rl = False train_rl = False
@ -263,8 +276,10 @@ class Seq2SeqAgent(BaseAgent):
# Language input # Language input
sentence, language_attention_mask, token_type_ids, \ sentence, language_attention_mask, token_type_ids, \
seq_lengths, perm_idx = self._sort_batch(obs) seq_lengths, perm_idx = self._sort_batch(obs)
perm_obs = obs[perm_idx] perm_obs = obs[perm_idx]
''' Language BERT ''' ''' Language BERT '''
language_inputs = {'mode': 'language', language_inputs = {'mode': 'language',
'sentence': sentence, 'sentence': sentence,
@ -282,7 +297,8 @@ class Seq2SeqAgent(BaseAgent):
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])], 'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
} for ob in perm_obs] } for ob in perm_obs]
found = [ None for _ in range(len(perm_obs)) ] found = [None for _ in range(len(perm_obs))]
# Init the reward shaping # Init the reward shaping
last_dist = np.zeros(batch_size, np.float32) last_dist = np.zeros(batch_size, np.float32)
@ -307,15 +323,6 @@ class Seq2SeqAgent(BaseAgent):
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
'''
# show feature
for index, feat in enumerate(candidate_feat):
for ff in feat:
print(ff)
print(candidate_leng[index])
print()
'''
# the first [CLS] token, initialized by the language BERT, serves # the first [CLS] token, initialized by the language BERT, serves
# as the agent's state passing through time steps # as the agent's state passing through time steps
@ -341,30 +348,37 @@ class Seq2SeqAgent(BaseAgent):
# Mask outputs where agent can't move forward # Mask outputs where agent can't move forward
# Here the logit is [b, max_candidate] # Here the logit is [b, max_candidate]
# (8, max(candidate))
candidate_mask = utils.length2mask(candidate_leng) candidate_mask = utils.length2mask(candidate_leng)
logit.masked_fill_(candidate_mask, -float('inf')) logit.masked_fill_(candidate_mask, -float('inf'))
# Supervised training # Supervised training
target = self._teacher_action(perm_obs, ended) target = self._teacher_action(perm_obs, ended)
for i, d in enumerate(target):
# print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
# print(d)
_, at_t = logit.max(1)
'''
if at_t[i].item() == candidate_leng[i]-1:
print("-2")
elif at_t[i].item() == candidate_leng[i]-2:
print("-1")
else:
print(at_t[i].item())
print()
'''
ml_loss += self.criterion(logit, target) ml_loss += self.criterion(logit, target)
a_predict = None
'''
for index, mask in enumerate(candidate_mask):
print(mask)
print(candidate_leng[index])
print(logit[index])
print(target[index])
print("\n\n")
'''
# Determine next model inputs # Determine next model inputs
if self.feedback == 'teacher': if self.feedback == 'teacher':
a_t = target # teacher forcing a_t = target # teacher forcing
_, a_predict = logit.max(1)
a_predict = a_predict.detach()
elif self.feedback == 'argmax': elif self.feedback == 'argmax':
_, a_t = logit.max(1) # student forcing - argmax _, a_t = logit.max(1) # student forcing - argmax
a_t = a_t.detach() a_t = a_t.detach()
a_predict = a_t.detach()
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
elif self.feedback == 'sample': elif self.feedback == 'sample':
@ -372,39 +386,42 @@ class Seq2SeqAgent(BaseAgent):
c = torch.distributions.Categorical(probs) c = torch.distributions.Categorical(probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization entropys.append(c.entropy()) # For optimization
a_t = c.sample().detach() new_c = c.sample()
a_t = new_c.detach()
a_predict = new_c.detach()
policy_log_probs.append(c.log_prob(a_t)) policy_log_probs.append(c.log_prob(a_t))
else: else:
print(self.feedback) # print(self.feedback)
sys.exit('Invalid feedback option') sys.exit('Invalid feedback option')
# Prepare environment action # Prepare environment action
# NOTE: Env action is in the perm_obs space # NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy() cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t): for i, next_id in enumerate(cpu_a_t):
if next_id == (args.ignoreid) or ended[i]: if next_id == args.ignoreid or ended[i]:
cpu_a_t[i] = found[i] if found[i] == True:
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
else:
cpu_a_t[i] = -2
elif next_id == (candidate_leng[i]-2): elif next_id == (candidate_leng[i]-2):
cpu_a_t[i] = -1 cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
elif next_id == (candidate_leng[i]-1): elif next_id == (candidate_leng[i]-1):
cpu_a_t[i] = -2 cpu_a_t[i] = -2
# Make action and get the new state cpu_a_predict = a_predict.cpu().numpy()
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found) for i, next_id in enumerate(cpu_a_predict):
if next_id == (candidate_leng[i]-2):
cpu_a_predict[i] = -1 # Change the <end> and ignore action to -1
elif next_id == (candidate_leng[i]-1):
cpu_a_predict[i] = -2
''' # Make action and get the new state
print(self.feedback, end=' ') print(cpu_a_t)
print(cpu_a_t, end=' ') self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found=found)
for i in perm_obs:
print(i['found'], end=' ')
print(found)
print()
'''
obs = np.array(self.env._get_obs()) obs = np.array(self.env._get_obs())
perm_obs = obs[perm_idx] # Perm the obs for the resu perm_obs = obs[perm_idx] # Perm the obs for the resu
'''
if train_rl: if train_rl:
# Calculate the mask and reward # Calculate the mask and reward
dist = np.zeros(batch_size, np.float32) dist = np.zeros(batch_size, np.float32)
@ -425,22 +442,22 @@ class Seq2SeqAgent(BaseAgent):
if action_idx == -1: # If the action now is end if action_idx == -1: # If the action now is end
if dist[i] < 3.0: # Correct if dist[i] < 3.0: # Correct
reward[i] = 2.0 + ndtw_score[i] * 2.0 reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']: if ob['swap']:
reward[i] += 1
else:
reward[i] -= 2 reward[i] -= 2
else:
reward[i] += 1
else: # Incorrect else: # Incorrect
reward[i] = -2.0 reward[i] = -2.0
elif action_idx == -2: # NOT_FOUND reward 設定在這裏
elif action_idx == -2:
if dist[i] < 3.0: if dist[i] < 3.0:
reward[i] = 2.0 + ndtw_score[i] * 2.0 reward[i] = 2.0 + ndtw_score[i] * 2.0
if ob['found']: if ob['swap']:
reward[i] -= 2 reward[i] += 3 # 偵測到錯誤 instruction,多加一分
else: else:
reward[i] += 1 reward[i] -= 2
else: # Incorrect else: # Incorrect
reward[i] = -2.0 reward[i] = -2.0
reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一
else: # The action is not end else: # The action is not end
# Path fidelity rewards (distance & nDTW) # Path fidelity rewards (distance & nDTW)
reward[i] = - (dist[i] - last_dist[i]) reward[i] = - (dist[i] - last_dist[i])
@ -458,7 +475,6 @@ class Seq2SeqAgent(BaseAgent):
masks.append(mask) masks.append(mask)
last_dist[:] = dist last_dist[:] = dist
last_ndtw[:] = ndtw_score last_ndtw[:] = ndtw_score
'''
# Update the finished actions # Update the finished actions
# -1 means ended or ignored (already ended) # -1 means ended or ignored (already ended)
@ -469,7 +485,8 @@ class Seq2SeqAgent(BaseAgent):
if ended.all(): if ended.all():
break break
''' # print()
if train_rl: if train_rl:
# Last action in A2C # Last action in A2C
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
@ -480,6 +497,7 @@ class Seq2SeqAgent(BaseAgent):
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1) visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
self.vln_bert.vln_bert.config.directions = max(candidate_leng) self.vln_bert.vln_bert.config.directions = max(candidate_leng)
''' Visual BERT '''
visual_inputs = {'mode': 'visual', visual_inputs = {'mode': 'visual',
'sentence': language_features, 'sentence': language_features,
'attention_mask': visual_attention_mask, 'attention_mask': visual_attention_mask,
@ -530,7 +548,6 @@ class Seq2SeqAgent(BaseAgent):
self.loss += rl_loss self.loss += rl_loss
self.logs['RL_loss'].append(rl_loss.item()) self.logs['RL_loss'].append(rl_loss.item())
'''
if train_ml is not None: if train_ml is not None:
self.loss += ml_loss * train_ml / batch_size self.loss += ml_loss * train_ml / batch_size
@ -540,6 +557,7 @@ class Seq2SeqAgent(BaseAgent):
self.losses.append(0.) self.losses.append(0.)
else: else:
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless. self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
print("\n\n")
return traj, found return traj, found

View File

@ -1,6 +1,8 @@
''' Batched Room-to-Room navigation environment ''' ''' Batched Room-to-Room navigation environment '''
import sys import sys
from networkx.algorithms import swap
sys.path.append('buildpy36') sys.path.append('buildpy36')
sys.path.append('Matterport_Simulator/build/') sys.path.append('Matterport_Simulator/build/')
import MatterSim import MatterSim
@ -14,6 +16,7 @@ import os
import random import random
import networkx as nx import networkx as nx
from param import args from param import args
import time
from utils import load_datasets, load_nav_graphs, pad_instr_tokens from utils import load_datasets, load_nav_graphs, pad_instr_tokens
from IPython import embed from IPython import embed
@ -127,7 +130,7 @@ class R2RBatch():
new_item = dict(item) new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['path_id'], j) new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
new_item['instructions'] = instr new_item['instructions'] = instr
new_item['found'] = item['found'][j] new_item['swap'] = item['swap'][j]
''' BERT tokenizer ''' ''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr) instr_tokens = tokenizer.tokenize(instr)
@ -137,10 +140,12 @@ class R2RBatch():
if new_item['instr_encoding'] is not None: # Filter the wrong data if new_item['instr_encoding'] is not None: # Filter the wrong data
self.data.append(new_item) self.data.append(new_item)
scans.append(item['scan']) scans.append(item['scan'])
except: except:
continue continue
print("split {} has {} datas in the file.".format(split, max_len)) print("split {} has {} datas in the file.".format(split, max_len))
if name is None: if name is None:
self.name = splits[0] if len(splits) > 0 else "FAKE" self.name = splits[0] if len(splits) > 0 else "FAKE"
else: else:
@ -329,7 +334,6 @@ class R2RBatch():
# [visual_feature, angle_feature] for views # [visual_feature, angle_feature] for views
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1) feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
obs.append({ obs.append({
'instr_id' : item['instr_id'], 'instr_id' : item['instr_id'],
'scan' : state.scanId, 'scan' : state.scanId,
@ -344,7 +348,7 @@ class R2RBatch():
'teacher' : self._shortest_path_action(state, item['path'][-1]), 'teacher' : self._shortest_path_action(state, item['path'][-1]),
'gt_path' : item['path'], 'gt_path' : item['path'],
'path_id' : item['path_id'], 'path_id' : item['path_id'],
'found': item['found'] 'swap': item['swap']
}) })
if 'instr_encoding' in item: if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding'] obs[-1]['instr_encoding'] = item['instr_encoding']

View File

@ -55,16 +55,11 @@ class Evaluation(object):
near_d = d near_d = d
return near_id return near_id
def _score_item(self, instr_id, path, predict_found): def _score_item(self, instr_id, path):
''' Calculate error based on the final position in trajectory, and also ''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule). the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] ''' The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id.split('_')[-2]] gt = self.gt[instr_id.split('_')[-2]]
index = int(instr_id.split('_')[-1])
gt_instruction = gt['instructions'][index]
gt_found = gt['found'][index]
start = gt['path'][0] start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position' assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1] goal = gt['path'][-1]
@ -73,19 +68,6 @@ class Evaluation(object):
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal]) self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal]) self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
self.scores['trajectory_steps'].append(len(path)-1) self.scores['trajectory_steps'].append(len(path)-1)
# <STOP> <NOT_FOUND> score
score = 0
if gt_found == True:
if predict_found == -1:
score = 1
else:
if predict_found == -2:
score = 1
self.scores['found_count'] += score
distance = 0 # length of the path in meters distance = 0 # length of the path in meters
prev = path[0] prev = path[0]
for curr in path[1:]: for curr in path[1:]:
@ -99,7 +81,6 @@ class Evaluation(object):
def score(self, output_file): def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location ''' ''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list) self.scores = defaultdict(list)
self.scores['found_count'] = 0
instr_ids = set(self.instr_ids) instr_ids = set(self.instr_ids)
if type(output_file) is str: if type(output_file) is str:
with open(output_file) as f: with open(output_file) as f:
@ -109,14 +90,12 @@ class Evaluation(object):
# print('result length', len(results)) # print('result length', len(results))
# print("RESULT:", results) # print("RESULT:", results)
path_counter = 0
for item in results: for item in results:
# Check against expected ids # Check against expected ids
if item['instr_id'] in instr_ids: if item['instr_id'] in instr_ids:
# print("{} exist".format(item['instr_id'])) # print("{} exist".format(item['instr_id']))
instr_ids.remove(item['instr_id']) instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory'], item['found']) self._score_item(item['instr_id'], item['trajectory'])
path_counter += 1
else: else:
print("{} not exist".format(item['instr_id'])) print("{} not exist".format(item['instr_id']))
print(item) print(item)
@ -129,8 +108,7 @@ class Evaluation(object):
'nav_error': np.average(self.scores['nav_errors']), 'nav_error': np.average(self.scores['nav_errors']),
'oracle_error': np.average(self.scores['oracle_errors']), 'oracle_error': np.average(self.scores['oracle_errors']),
'steps': np.average(self.scores['trajectory_steps']), 'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths']), 'lengths': np.average(self.scores['trajectory_lengths'])
'found_score': self.scores['found_count'] / path_counter
} }
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin]) num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors'])) score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))

View File

@ -105,9 +105,6 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
# Run validation # Run validation
loss_str = "iter {}".format(iter) loss_str = "iter {}".format(iter)
save_results = []
for env_name, (env, evaluator) in val_envs.items(): for env_name, (env, evaluator) in val_envs.items():
listner.env = env listner.env = env
@ -115,8 +112,6 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
listner.test(use_dropout=False, feedback='argmax', iters=None) listner.test(use_dropout=False, feedback='argmax', iters=None)
result = listner.get_results() result = listner.get_results()
score_summary, _ = evaluator.score(result) score_summary, _ = evaluator.score(result)
print(score_summary)
loss_str += ", %s " % env_name loss_str += ", %s " % env_name
for metric, val in score_summary.items(): for metric, val in score_summary.items():
if metric in ['spl']: if metric in ['spl']:
@ -200,11 +195,12 @@ def train_val(test_only=False):
if test_only: if test_only:
featurized_scans = None featurized_scans = None
val_env_names = ['val_unseen'] val_env_names = ['val_train_seen']
else: else:
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())]) featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen'] # val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
val_env_names = ['train','val_unseen'] # val_env_names = ['val_train_seen']
val_env_names = ['val_unseen']
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok) train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
from collections import OrderedDict from collections import OrderedDict