Compare commits
6 Commits
main
...
adversaria
| Author | SHA1 | Date | |
|---|---|---|---|
| 4073c52bb8 | |||
| 595866c2f4 | |||
| 03a3e5b489 | |||
| 4936098b5e | |||
| a5db597de5 | |||
| ab5010d32d |
42
adversarial_summary.py
Normal file
42
adversarial_summary.py
Normal file
@ -0,0 +1,42 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
def remove_non_ascii(text):
|
||||
return re.sub(r'[^\x00-\x7F]', ' ', text)
|
||||
|
||||
|
||||
for file in ['train', 'val_unseen', 'val_seen', 'train_seen', 'test', 'val_train_seen']:
|
||||
print(file)
|
||||
if os.path.isfile('data/adversarial/reverie_{}_fnf.json'.format(file)):
|
||||
with open('data/adversarial/reverie_{}_fnf.json'.format(file)) as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
|
||||
result = {}
|
||||
for i in data:
|
||||
instruction_id = i['path_id']
|
||||
if instruction_id not in result:
|
||||
result[instruction_id] = {
|
||||
'distance': float(i['distance']),
|
||||
'scan': i['scan'],
|
||||
'path_id': int(i['path_id']),
|
||||
'path': i['path'],
|
||||
'heading': float(i['heading']),
|
||||
'instructions': [ remove_non_ascii(i['instruction'])],
|
||||
'found': [ i['found'] ],
|
||||
'id': i['id'],
|
||||
'objId': i['objId']
|
||||
}
|
||||
else:
|
||||
result[instruction_id]['instructions'].append(remove_non_ascii(i['instruction']))
|
||||
result[instruction_id]['found'].append( i['found'] )
|
||||
|
||||
output = []
|
||||
for k, item in result.items():
|
||||
output.append(item)
|
||||
else:
|
||||
output = []
|
||||
|
||||
with open('data/adversarial/R2R_{}.json'.format(file), 'w') as fp:
|
||||
json.dump(output, fp)
|
||||
101
r2r_src/agent.py
101
r2r_src/agent.py
@ -35,12 +35,12 @@ class BaseAgent(object):
|
||||
self.losses = [] # For learning agents
|
||||
|
||||
def write_results(self):
|
||||
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
|
||||
output = [{'instr_id':k, 'trajectory': v[0], 'found': v[1]} for k,v in self.results.items()]
|
||||
with open(self.results_path, 'w') as f:
|
||||
json.dump(output, f)
|
||||
|
||||
def get_results(self):
|
||||
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
|
||||
output = [{'instr_id': k, 'trajectory': v[0], 'found': v[1]} for k, v in self.results.items()]
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
@ -61,17 +61,19 @@ class BaseAgent(object):
|
||||
if iters is not None:
|
||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||
for i in range(iters):
|
||||
for traj in self.rollout(**kwargs):
|
||||
traj, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(traj):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
traj, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(traj):
|
||||
if traj['instr_id'] in self.results:
|
||||
looped = True
|
||||
else:
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = traj['path']
|
||||
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||
if looped:
|
||||
break
|
||||
|
||||
@ -147,7 +149,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||
|
||||
def _candidate_variable(self, obs):
|
||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||
candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end
|
||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||
|
||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||
@ -155,6 +157,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
for i, ob in enumerate(obs):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
candidate_feat[i, len(ob['candidate']), :] = np.zeros(self.feature_size+args.angle_feat_size, dtype=np.float32) # <STOP>
|
||||
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones(self.feature_size+args.angle_feat_size, dtype=np.float32) # <NOT_FOUND>
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
|
||||
@ -186,10 +190,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
break
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = len(ob['candidate'])
|
||||
if ob['found']:
|
||||
a[i] = len(ob['candidate'])
|
||||
else:
|
||||
a[i] = len(ob['candidate'])+1
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
|
||||
"""
|
||||
Interface between Panoramic view and Egocentric view
|
||||
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
||||
@ -205,7 +212,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
for i, idx in enumerate(perm_idx):
|
||||
action = a_t[i]
|
||||
if action != -1: # -1 is the <stop> action
|
||||
if action != -1 and action != -2: # -1 is the <stop> action
|
||||
select_candidate = perm_obs[i]['candidate'][action]
|
||||
src_point = perm_obs[i]['viewIndex']
|
||||
trg_point = select_candidate['pointId']
|
||||
@ -228,6 +235,10 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||
if traj is not None:
|
||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||
elif action == -1 or action == -2:
|
||||
if found is not None:
|
||||
found[i] = action
|
||||
|
||||
|
||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||
"""
|
||||
@ -271,6 +282,8 @@ class Seq2SeqAgent(BaseAgent):
|
||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||
} for ob in perm_obs]
|
||||
|
||||
found = [ None for _ in range(len(perm_obs)) ]
|
||||
|
||||
# Init the reward shaping
|
||||
last_dist = np.zeros(batch_size, np.float32)
|
||||
last_ndtw = np.zeros(batch_size, np.float32)
|
||||
@ -294,9 +307,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
print("input_a_t: ", input_a_t.shape)
|
||||
print("candidate_feat: ", candidate_feat.shape)
|
||||
print("candidate_leng: ", candidate_leng)
|
||||
'''
|
||||
# show feature
|
||||
for index, feat in enumerate(candidate_feat):
|
||||
for ff in feat:
|
||||
print(ff)
|
||||
print(candidate_leng[index])
|
||||
print()
|
||||
'''
|
||||
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
@ -329,7 +348,18 @@ class Seq2SeqAgent(BaseAgent):
|
||||
target = self._teacher_action(perm_obs, ended)
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
|
||||
'''
|
||||
for index, mask in enumerate(candidate_mask):
|
||||
print(mask)
|
||||
print(candidate_leng[index])
|
||||
print(logit[index])
|
||||
print(target[index])
|
||||
print("\n\n")
|
||||
'''
|
||||
|
||||
# Determine next model inputs
|
||||
|
||||
if self.feedback == 'teacher':
|
||||
a_t = target # teacher forcing
|
||||
elif self.feedback == 'argmax':
|
||||
@ -347,18 +377,34 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
print(self.feedback)
|
||||
sys.exit('Invalid feedback option')
|
||||
|
||||
# Prepare environment action
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
if next_id == (args.ignoreid) or ended[i]:
|
||||
cpu_a_t[i] = found[i]
|
||||
elif next_id == (candidate_leng[i]-2):
|
||||
cpu_a_t[i] = -1
|
||||
elif next_id == (candidate_leng[i]-1):
|
||||
cpu_a_t[i] = -2
|
||||
|
||||
|
||||
# Make action and get the new state
|
||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
|
||||
|
||||
'''
|
||||
print(self.feedback, end=' ')
|
||||
print(cpu_a_t, end=' ')
|
||||
for i in perm_obs:
|
||||
print(i['found'], end=' ')
|
||||
print(found)
|
||||
print()
|
||||
'''
|
||||
obs = np.array(self.env._get_obs())
|
||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||
|
||||
'''
|
||||
if train_rl:
|
||||
# Calculate the mask and reward
|
||||
dist = np.zeros(batch_size, np.float32)
|
||||
@ -379,6 +425,20 @@ class Seq2SeqAgent(BaseAgent):
|
||||
if action_idx == -1: # If the action now is end
|
||||
if dist[i] < 3.0: # Correct
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] += 1
|
||||
else:
|
||||
reward[i] -= 2
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
|
||||
elif action_idx == -2:
|
||||
if dist[i] < 3.0:
|
||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||
if ob['found']:
|
||||
reward[i] -= 2
|
||||
else:
|
||||
reward[i] += 1
|
||||
else: # Incorrect
|
||||
reward[i] = -2.0
|
||||
else: # The action is not end
|
||||
@ -398,17 +458,18 @@ class Seq2SeqAgent(BaseAgent):
|
||||
masks.append(mask)
|
||||
last_dist[:] = dist
|
||||
last_ndtw[:] = ndtw_score
|
||||
'''
|
||||
|
||||
# Update the finished actions
|
||||
# -1 means ended or ignored (already ended)
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
||||
|
||||
# Early exit if all ended
|
||||
if ended.all():
|
||||
break
|
||||
|
||||
print()
|
||||
|
||||
'''
|
||||
if train_rl:
|
||||
# Last action in A2C
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
@ -419,7 +480,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||
|
||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||
''' Visual BERT '''
|
||||
visual_inputs = {'mode': 'visual',
|
||||
'sentence': language_features,
|
||||
'attention_mask': visual_attention_mask,
|
||||
@ -470,6 +530,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
self.loss += rl_loss
|
||||
self.logs['RL_loss'].append(rl_loss.item())
|
||||
'''
|
||||
|
||||
if train_ml is not None:
|
||||
self.loss += ml_loss * train_ml / batch_size
|
||||
@ -480,7 +541,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||
|
||||
return traj
|
||||
return traj, found
|
||||
|
||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||
''' Evaluate once on each instruction in the current environment '''
|
||||
|
||||
@ -127,6 +127,7 @@ class R2RBatch():
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||
new_item['instructions'] = instr
|
||||
new_item['found'] = item['found'][j]
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
@ -328,6 +329,7 @@ class R2RBatch():
|
||||
# [visual_feature, angle_feature] for views
|
||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||
|
||||
|
||||
obs.append({
|
||||
'instr_id' : item['instr_id'],
|
||||
'scan' : state.scanId,
|
||||
@ -341,7 +343,8 @@ class R2RBatch():
|
||||
'instructions' : item['instructions'],
|
||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id']
|
||||
'path_id' : item['path_id'],
|
||||
'found': item['found']
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
|
||||
@ -55,11 +55,16 @@ class Evaluation(object):
|
||||
near_d = d
|
||||
return near_id
|
||||
|
||||
def _score_item(self, instr_id, path):
|
||||
def _score_item(self, instr_id, path, predict_found):
|
||||
''' Calculate error based on the final position in trajectory, and also
|
||||
the closest position (oracle stopping rule).
|
||||
The path contains [view_id, angle, vofv] '''
|
||||
gt = self.gt[instr_id.split('_')[-2]]
|
||||
index = int(instr_id.split('_')[-1])
|
||||
|
||||
gt_instruction = gt['instructions'][index]
|
||||
gt_found = gt['found'][index]
|
||||
|
||||
start = gt['path'][0]
|
||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||
goal = gt['path'][-1]
|
||||
@ -68,6 +73,19 @@ class Evaluation(object):
|
||||
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||
self.scores['trajectory_steps'].append(len(path)-1)
|
||||
|
||||
# <STOP> <NOT_FOUND> score
|
||||
score = 0
|
||||
if gt_found == True:
|
||||
if predict_found == -1:
|
||||
score = 1
|
||||
else:
|
||||
if predict_found == -2:
|
||||
score = 1
|
||||
self.scores['found_count'] += score
|
||||
|
||||
|
||||
|
||||
distance = 0 # length of the path in meters
|
||||
prev = path[0]
|
||||
for curr in path[1:]:
|
||||
@ -81,6 +99,7 @@ class Evaluation(object):
|
||||
def score(self, output_file):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||
self.scores = defaultdict(list)
|
||||
self.scores['found_count'] = 0
|
||||
instr_ids = set(self.instr_ids)
|
||||
if type(output_file) is str:
|
||||
with open(output_file) as f:
|
||||
@ -90,12 +109,14 @@ class Evaluation(object):
|
||||
|
||||
# print('result length', len(results))
|
||||
# print("RESULT:", results)
|
||||
path_counter = 0
|
||||
for item in results:
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
# print("{} exist".format(item['instr_id']))
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['found'])
|
||||
path_counter += 1
|
||||
else:
|
||||
print("{} not exist".format(item['instr_id']))
|
||||
print(item)
|
||||
@ -108,7 +129,8 @@ class Evaluation(object):
|
||||
'nav_error': np.average(self.scores['nav_errors']),
|
||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||
'steps': np.average(self.scores['trajectory_steps']),
|
||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||
'lengths': np.average(self.scores['trajectory_lengths']),
|
||||
'found_score': self.scores['found_count'] / path_counter
|
||||
}
|
||||
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
||||
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
||||
|
||||
@ -105,6 +105,9 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
|
||||
# Run validation
|
||||
loss_str = "iter {}".format(iter)
|
||||
|
||||
|
||||
save_results = []
|
||||
for env_name, (env, evaluator) in val_envs.items():
|
||||
listner.env = env
|
||||
|
||||
@ -112,6 +115,8 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
||||
result = listner.get_results()
|
||||
score_summary, _ = evaluator.score(result)
|
||||
|
||||
print(score_summary)
|
||||
loss_str += ", %s " % env_name
|
||||
for metric, val in score_summary.items():
|
||||
if metric in ['spl']:
|
||||
@ -195,11 +200,11 @@ def train_val(test_only=False):
|
||||
|
||||
if test_only:
|
||||
featurized_scans = None
|
||||
val_env_names = ['val_train_seen']
|
||||
val_env_names = ['val_unseen']
|
||||
else:
|
||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||
val_env_names = ['val_train_seen']
|
||||
val_env_names = ['train','val_unseen']
|
||||
|
||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||
from collections import OrderedDict
|
||||
|
||||
Loading…
Reference in New Issue
Block a user