feat: complete adversarial version
This commit is contained in:
parent
7fab347934
commit
ad72df7970
@ -36,7 +36,7 @@ class BaseAgent(object):
|
||||
json.dump(output, f)
|
||||
|
||||
def get_results(self):
|
||||
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()]
|
||||
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()]
|
||||
return output
|
||||
|
||||
def rollout(self, **args):
|
||||
@ -57,17 +57,19 @@ class BaseAgent(object):
|
||||
if iters is not None:
|
||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||
for i in range(iters):
|
||||
for traj in self.rollout(**kwargs):
|
||||
trajs, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(trajs):
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
|
||||
else: # Do a full round
|
||||
while True:
|
||||
for traj in self.rollout(**kwargs):
|
||||
trajs, found = self.rollout(**kwargs)
|
||||
for index, traj in enumerate(trajs):
|
||||
if traj['instr_id'] in self.results:
|
||||
looped = True
|
||||
else:
|
||||
self.loss = 0
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
|
||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
|
||||
if looped:
|
||||
break
|
||||
|
||||
@ -169,8 +171,14 @@ class Seq2SeqAgent(BaseAgent):
|
||||
for i, ob in enumerate(obs):
|
||||
for j, cc in enumerate(ob['candidate']):
|
||||
candidate_feat[i, j, :] = cc['feature']
|
||||
result = torch.from_numpy(candidate_feat)
|
||||
'''
|
||||
for i, ob in enumerate(obs):
|
||||
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
||||
'''
|
||||
result = result.cuda()
|
||||
|
||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||
return result, candidate_leng
|
||||
|
||||
def _object_variable(self, obs):
|
||||
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
|
||||
@ -202,7 +210,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
|
||||
|
||||
def _teacher_action(self, obs, ended, cand_size):
|
||||
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
|
||||
"""
|
||||
Extract teacher actions into variable.
|
||||
:param obs: The observation.
|
||||
@ -221,6 +229,12 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else: # Stop here
|
||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||
a[i] = cand_size - 1
|
||||
'''
|
||||
if ob['found']:
|
||||
a[i] = cand_size - 1
|
||||
else:
|
||||
a[i] = candidate_leng[i] - 1
|
||||
'''
|
||||
return torch.from_numpy(a).cuda()
|
||||
|
||||
def _teacher_REF(self, obs, just_ended):
|
||||
@ -232,8 +246,12 @@ class Seq2SeqAgent(BaseAgent):
|
||||
candidate_objs = ob['candidate_obj'][2]
|
||||
for k, kid in enumerate(candidate_objs):
|
||||
if kid == ob['objId']:
|
||||
a[i] = k
|
||||
break
|
||||
if ob['found']:
|
||||
a[i] = k
|
||||
break
|
||||
else:
|
||||
a[i] = len(candidate_objs)
|
||||
break
|
||||
else:
|
||||
a[i] = args.ignoreid
|
||||
return torch.from_numpy(a).cuda()
|
||||
@ -256,7 +274,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
for i, idx in enumerate(perm_idx):
|
||||
action = a_t[i]
|
||||
if action != -1: # -1 is the <stop> action
|
||||
if action != -1 and action != -2: # -1 is the <stop> action
|
||||
select_candidate = perm_obs[i]['candidate'][action]
|
||||
src_point = perm_obs[i]['viewIndex']
|
||||
trg_point = select_candidate['pointId']
|
||||
@ -296,6 +314,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
obs = np.array(self.env._get_obs())
|
||||
|
||||
|
||||
batch_size = len(obs)
|
||||
|
||||
# Reorder the language input for the encoder (do not ruin the original code)
|
||||
@ -334,6 +353,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Initialization the tracking state
|
||||
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
||||
just_ended = np.array([False] * batch_size)
|
||||
found = np.array([None] * batch_size)
|
||||
|
||||
# Init the logs
|
||||
rewards = []
|
||||
@ -398,7 +418,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
if train_ml is not None:
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
|
||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
@ -424,12 +444,15 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# NOTE: Env action is in the perm_obs space
|
||||
cpu_a_t = a_t.cpu().numpy()
|
||||
for i, next_id in enumerate(cpu_a_t):
|
||||
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
|
||||
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
|
||||
and (not ended[i]): # just stoppped and forced stopped
|
||||
just_ended[i] = True
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[i].max(0)
|
||||
if ref_t != obj_leng[i]-1: # decide not to do REF
|
||||
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
||||
else:
|
||||
traj[i]['ref'] = 'NOT_FOUND'
|
||||
|
||||
if args.submit:
|
||||
if obj_leng[i] == 1:
|
||||
@ -443,8 +466,18 @@ class Seq2SeqAgent(BaseAgent):
|
||||
else:
|
||||
just_ended[i] = False
|
||||
|
||||
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
|
||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||
|
||||
if (next_id == args.ignoreid) or (ended[i]):
|
||||
cpu_a_t[i] = found[i]
|
||||
elif (next_id == visual_temp_mask.size(1)):
|
||||
cpu_a_t[i] = -1
|
||||
found[i] = -1
|
||||
if self.feedback == 'argmax':
|
||||
_, ref_t = logit_REF[1].max(0)
|
||||
if ref_t == obj_leng[i]-1:
|
||||
found[i] = -2
|
||||
else:
|
||||
found[i] = -1
|
||||
|
||||
''' Supervised training for REF '''
|
||||
if train_ml is not None:
|
||||
@ -600,7 +633,7 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return traj
|
||||
return traj, found
|
||||
|
||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||
''' Evaluate once on each instruction in the current environment '''
|
||||
|
||||
@ -127,6 +127,7 @@ class R2RBatch():
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['id'], j)
|
||||
new_item['instructions'] = instr
|
||||
new_item['found'] = item['found'][j]
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
@ -339,7 +340,8 @@ class R2RBatch():
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['id'],
|
||||
'objId': str(item['objId']) if 'objId' in item else str(None), # target objId
|
||||
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject])
|
||||
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject]),
|
||||
'found': item['found']
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
|
||||
@ -50,11 +50,12 @@ class Evaluation(object):
|
||||
near_d = d
|
||||
return near_id
|
||||
|
||||
def _score_item(self, instr_id, path, ref_objId):
|
||||
def _score_item(self, instr_id, path, ref_objId, predict_found):
|
||||
''' Calculate error based on the final position in trajectory, and also
|
||||
the closest position (oracle stopping rule).
|
||||
The path contains [view_id, angle, vofv] '''
|
||||
gt = self.gt[instr_id[:-2]] # pathId_objId
|
||||
index = int(instr_id.split('_')[-1])
|
||||
start = gt['path'][0]
|
||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||
goal = gt['path'][-1]
|
||||
@ -74,6 +75,19 @@ class Evaluation(object):
|
||||
self.distances[gt['scan']][start][goal]
|
||||
)
|
||||
|
||||
if gt['found'][index] == True:
|
||||
if predict_found == -1:
|
||||
self.scores['found_count'] += 1
|
||||
self.scores['foundable'].append(1)
|
||||
else:
|
||||
self.scores['foundable'].append(0)
|
||||
else:
|
||||
if predict_found == -2:
|
||||
self.scores['found_count'] += 1
|
||||
self.scores['foundable'].append(1)
|
||||
else:
|
||||
self.scores['foundable'].append(0)
|
||||
|
||||
# REF success or not
|
||||
if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)):
|
||||
self.scores['rgs'].append(1)
|
||||
@ -104,6 +118,8 @@ class Evaluation(object):
|
||||
def score(self, output_file):
|
||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||
self.scores = defaultdict(list)
|
||||
self.scores['found_count'] = 0
|
||||
self.scores['foundable'] = []
|
||||
instr_ids = set(self.instr_ids)
|
||||
if type(output_file) is str:
|
||||
with open(output_file) as f:
|
||||
@ -112,11 +128,13 @@ class Evaluation(object):
|
||||
results = output_file
|
||||
|
||||
print('result length', len(results))
|
||||
path_counter = 0
|
||||
for item in results:
|
||||
# Check against expected ids
|
||||
if item['instr_id'] in instr_ids:
|
||||
instr_ids.remove(item['instr_id'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'])
|
||||
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'], item['found'])
|
||||
path_counter += 1
|
||||
|
||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||
@ -125,7 +143,8 @@ class Evaluation(object):
|
||||
|
||||
score_summary = {
|
||||
'steps': np.average(self.scores['trajectory_steps']),
|
||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||
'lengths': np.average(self.scores['trajectory_lengths']),
|
||||
'found_score': self.scores['found_count'] / path_counter
|
||||
}
|
||||
end_successes = sum(self.scores['visible'])
|
||||
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
|
||||
@ -137,8 +156,18 @@ class Evaluation(object):
|
||||
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
]
|
||||
score_summary['spl'] = np.average(spl)
|
||||
# sspl
|
||||
sspl = [float( foundable == 1) * float( visible == 1 ) * l / max(l, p, 0.01)
|
||||
for foundable, visible, p, l in
|
||||
zip(self.scores['foundable'], self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||
]
|
||||
score_summary['sspl'] = np.average(sspl)
|
||||
|
||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||
try:
|
||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||
except:
|
||||
print(len(self.scores['rgs']), len(self.instr_ids))
|
||||
num_rgs = sum(self.scores['rgs'])
|
||||
score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs']))
|
||||
|
||||
|
||||
@ -110,14 +110,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
||||
score_summary, _ = evaluator.score(result)
|
||||
loss_str += ", %s " % env_name
|
||||
for metric, val in score_summary.items():
|
||||
if metric in ['spl']:
|
||||
writer.add_scalar("spl/%s" % env_name, val, idx)
|
||||
if metric in ['sspl']:
|
||||
writer.add_scalar("sspl/%s" % env_name, val, idx)
|
||||
if env_name in best_val:
|
||||
if val > best_val[env_name]['spl']:
|
||||
best_val[env_name]['spl'] = val
|
||||
if val > best_val[env_name]['sspl']:
|
||||
best_val[env_name]['sspl'] = val
|
||||
best_val[env_name]['update'] = True
|
||||
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
|
||||
best_val[env_name]['spl'] = val
|
||||
elif (val == best_val[env_name]['sspl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
|
||||
best_val[env_name]['sspl'] = val
|
||||
best_val[env_name]['update'] = True
|
||||
loss_str += ', %s: %.4f' % (metric, val)
|
||||
|
||||
@ -236,6 +236,7 @@ def train_val(test_only=False):
|
||||
|
||||
if args.train == 'listener':
|
||||
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
|
||||
# train(train_env, tok, args.iters, log_every=100, val_envs=val_envs)
|
||||
elif args.train == 'validlistener':
|
||||
valid(train_env, tok, val_envs=val_envs)
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
export AIRBERT_ROOT=$(pwd)
|
||||
export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build
|
||||
|
||||
name=REVERIE-RC-VLN-BERT-original/train-init.airbert
|
||||
name=REVERIE-RC-VLN-BERT-original/train-init.airbert-ver2
|
||||
|
||||
flag="--vlnbert vilbert
|
||||
|
||||
@ -13,7 +13,7 @@ flag="--vlnbert vilbert
|
||||
--features places365
|
||||
--maxAction 15
|
||||
--maxInput 50
|
||||
--batchSize 4
|
||||
--batchSize 8
|
||||
--feedback sample
|
||||
--lr 1e-5
|
||||
--iters 200000
|
||||
|
||||
Loading…
Reference in New Issue
Block a user