feat: complete adversarial version
This commit is contained in:
parent
7fab347934
commit
ad72df7970
@ -36,7 +36,7 @@ class BaseAgent(object):
|
|||||||
json.dump(output, f)
|
json.dump(output, f)
|
||||||
|
|
||||||
def get_results(self):
|
def get_results(self):
|
||||||
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()]
|
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
@ -57,17 +57,19 @@ class BaseAgent(object):
|
|||||||
if iters is not None:
|
if iters is not None:
|
||||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
for traj in self.rollout(**kwargs):
|
trajs, found = self.rollout(**kwargs)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
|
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
|
||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
for traj in self.rollout(**kwargs):
|
trajs, found = self.rollout(**kwargs)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
if traj['instr_id'] in self.results:
|
if traj['instr_id'] in self.results:
|
||||||
looped = True
|
looped = True
|
||||||
else:
|
else:
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
|
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
|
||||||
if looped:
|
if looped:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -169,8 +171,14 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
|
result = torch.from_numpy(candidate_feat)
|
||||||
|
'''
|
||||||
|
for i, ob in enumerate(obs):
|
||||||
|
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
|
||||||
|
'''
|
||||||
|
result = result.cuda()
|
||||||
|
|
||||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
return result, candidate_leng
|
||||||
|
|
||||||
def _object_variable(self, obs):
|
def _object_variable(self, obs):
|
||||||
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
|
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
|
||||||
@ -202,7 +210,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
|
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
|
||||||
|
|
||||||
def _teacher_action(self, obs, ended, cand_size):
|
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
|
||||||
"""
|
"""
|
||||||
Extract teacher actions into variable.
|
Extract teacher actions into variable.
|
||||||
:param obs: The observation.
|
:param obs: The observation.
|
||||||
@ -221,6 +229,12 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
a[i] = cand_size - 1
|
a[i] = cand_size - 1
|
||||||
|
'''
|
||||||
|
if ob['found']:
|
||||||
|
a[i] = cand_size - 1
|
||||||
|
else:
|
||||||
|
a[i] = candidate_leng[i] - 1
|
||||||
|
'''
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def _teacher_REF(self, obs, just_ended):
|
def _teacher_REF(self, obs, just_ended):
|
||||||
@ -232,8 +246,12 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
candidate_objs = ob['candidate_obj'][2]
|
candidate_objs = ob['candidate_obj'][2]
|
||||||
for k, kid in enumerate(candidate_objs):
|
for k, kid in enumerate(candidate_objs):
|
||||||
if kid == ob['objId']:
|
if kid == ob['objId']:
|
||||||
|
if ob['found']:
|
||||||
a[i] = k
|
a[i] = k
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
a[i] = len(candidate_objs)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
a[i] = args.ignoreid
|
a[i] = args.ignoreid
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
@ -256,7 +274,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
for i, idx in enumerate(perm_idx):
|
for i, idx in enumerate(perm_idx):
|
||||||
action = a_t[i]
|
action = a_t[i]
|
||||||
if action != -1: # -1 is the <stop> action
|
if action != -1 and action != -2: # -1 is the <stop> action
|
||||||
select_candidate = perm_obs[i]['candidate'][action]
|
select_candidate = perm_obs[i]['candidate'][action]
|
||||||
src_point = perm_obs[i]['viewIndex']
|
src_point = perm_obs[i]['viewIndex']
|
||||||
trg_point = select_candidate['pointId']
|
trg_point = select_candidate['pointId']
|
||||||
@ -296,6 +314,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
|
|
||||||
|
|
||||||
batch_size = len(obs)
|
batch_size = len(obs)
|
||||||
|
|
||||||
# Reorder the language input for the encoder (do not ruin the original code)
|
# Reorder the language input for the encoder (do not ruin the original code)
|
||||||
@ -334,6 +353,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Initialization the tracking state
|
# Initialization the tracking state
|
||||||
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
|
||||||
just_ended = np.array([False] * batch_size)
|
just_ended = np.array([False] * batch_size)
|
||||||
|
found = np.array([None] * batch_size)
|
||||||
|
|
||||||
# Init the logs
|
# Init the logs
|
||||||
rewards = []
|
rewards = []
|
||||||
@ -398,7 +418,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
if train_ml is not None:
|
if train_ml is not None:
|
||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
|
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
# Determine next model inputs
|
# Determine next model inputs
|
||||||
@ -424,12 +444,15 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# NOTE: Env action is in the perm_obs space
|
# NOTE: Env action is in the perm_obs space
|
||||||
cpu_a_t = a_t.cpu().numpy()
|
cpu_a_t = a_t.cpu().numpy()
|
||||||
for i, next_id in enumerate(cpu_a_t):
|
for i, next_id in enumerate(cpu_a_t):
|
||||||
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
|
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
|
||||||
|
and (not ended[i]): # just stoppped and forced stopped
|
||||||
just_ended[i] = True
|
just_ended[i] = True
|
||||||
if self.feedback == 'argmax':
|
if self.feedback == 'argmax':
|
||||||
_, ref_t = logit_REF[i].max(0)
|
_, ref_t = logit_REF[i].max(0)
|
||||||
if ref_t != obj_leng[i]-1: # decide not to do REF
|
if ref_t != obj_leng[i]-1: # decide not to do REF
|
||||||
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
|
||||||
|
else:
|
||||||
|
traj[i]['ref'] = 'NOT_FOUND'
|
||||||
|
|
||||||
if args.submit:
|
if args.submit:
|
||||||
if obj_leng[i] == 1:
|
if obj_leng[i] == 1:
|
||||||
@ -443,8 +466,18 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
just_ended[i] = False
|
just_ended[i] = False
|
||||||
|
|
||||||
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
|
|
||||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
if (next_id == args.ignoreid) or (ended[i]):
|
||||||
|
cpu_a_t[i] = found[i]
|
||||||
|
elif (next_id == visual_temp_mask.size(1)):
|
||||||
|
cpu_a_t[i] = -1
|
||||||
|
found[i] = -1
|
||||||
|
if self.feedback == 'argmax':
|
||||||
|
_, ref_t = logit_REF[1].max(0)
|
||||||
|
if ref_t == obj_leng[i]-1:
|
||||||
|
found[i] = -2
|
||||||
|
else:
|
||||||
|
found[i] = -1
|
||||||
|
|
||||||
''' Supervised training for REF '''
|
''' Supervised training for REF '''
|
||||||
if train_ml is not None:
|
if train_ml is not None:
|
||||||
@ -600,7 +633,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
|
|
||||||
return traj
|
return traj, found
|
||||||
|
|
||||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||||
''' Evaluate once on each instruction in the current environment '''
|
''' Evaluate once on each instruction in the current environment '''
|
||||||
|
|||||||
@ -127,6 +127,7 @@ class R2RBatch():
|
|||||||
new_item = dict(item)
|
new_item = dict(item)
|
||||||
new_item['instr_id'] = '%s_%d' % (item['id'], j)
|
new_item['instr_id'] = '%s_%d' % (item['id'], j)
|
||||||
new_item['instructions'] = instr
|
new_item['instructions'] = instr
|
||||||
|
new_item['found'] = item['found'][j]
|
||||||
|
|
||||||
''' BERT tokenizer '''
|
''' BERT tokenizer '''
|
||||||
instr_tokens = tokenizer.tokenize(instr)
|
instr_tokens = tokenizer.tokenize(instr)
|
||||||
@ -339,7 +340,8 @@ class R2RBatch():
|
|||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['id'],
|
'path_id' : item['id'],
|
||||||
'objId': str(item['objId']) if 'objId' in item else str(None), # target objId
|
'objId': str(item['objId']) if 'objId' in item else str(None), # target objId
|
||||||
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject])
|
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject]),
|
||||||
|
'found': item['found']
|
||||||
})
|
})
|
||||||
if 'instr_encoding' in item:
|
if 'instr_encoding' in item:
|
||||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||||
|
|||||||
@ -50,11 +50,12 @@ class Evaluation(object):
|
|||||||
near_d = d
|
near_d = d
|
||||||
return near_id
|
return near_id
|
||||||
|
|
||||||
def _score_item(self, instr_id, path, ref_objId):
|
def _score_item(self, instr_id, path, ref_objId, predict_found):
|
||||||
''' Calculate error based on the final position in trajectory, and also
|
''' Calculate error based on the final position in trajectory, and also
|
||||||
the closest position (oracle stopping rule).
|
the closest position (oracle stopping rule).
|
||||||
The path contains [view_id, angle, vofv] '''
|
The path contains [view_id, angle, vofv] '''
|
||||||
gt = self.gt[instr_id[:-2]] # pathId_objId
|
gt = self.gt[instr_id[:-2]] # pathId_objId
|
||||||
|
index = int(instr_id.split('_')[-1])
|
||||||
start = gt['path'][0]
|
start = gt['path'][0]
|
||||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||||
goal = gt['path'][-1]
|
goal = gt['path'][-1]
|
||||||
@ -74,6 +75,19 @@ class Evaluation(object):
|
|||||||
self.distances[gt['scan']][start][goal]
|
self.distances[gt['scan']][start][goal]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gt['found'][index] == True:
|
||||||
|
if predict_found == -1:
|
||||||
|
self.scores['found_count'] += 1
|
||||||
|
self.scores['foundable'].append(1)
|
||||||
|
else:
|
||||||
|
self.scores['foundable'].append(0)
|
||||||
|
else:
|
||||||
|
if predict_found == -2:
|
||||||
|
self.scores['found_count'] += 1
|
||||||
|
self.scores['foundable'].append(1)
|
||||||
|
else:
|
||||||
|
self.scores['foundable'].append(0)
|
||||||
|
|
||||||
# REF success or not
|
# REF success or not
|
||||||
if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)):
|
if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)):
|
||||||
self.scores['rgs'].append(1)
|
self.scores['rgs'].append(1)
|
||||||
@ -104,6 +118,8 @@ class Evaluation(object):
|
|||||||
def score(self, output_file):
|
def score(self, output_file):
|
||||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||||
self.scores = defaultdict(list)
|
self.scores = defaultdict(list)
|
||||||
|
self.scores['found_count'] = 0
|
||||||
|
self.scores['foundable'] = []
|
||||||
instr_ids = set(self.instr_ids)
|
instr_ids = set(self.instr_ids)
|
||||||
if type(output_file) is str:
|
if type(output_file) is str:
|
||||||
with open(output_file) as f:
|
with open(output_file) as f:
|
||||||
@ -112,11 +128,13 @@ class Evaluation(object):
|
|||||||
results = output_file
|
results = output_file
|
||||||
|
|
||||||
print('result length', len(results))
|
print('result length', len(results))
|
||||||
|
path_counter = 0
|
||||||
for item in results:
|
for item in results:
|
||||||
# Check against expected ids
|
# Check against expected ids
|
||||||
if item['instr_id'] in instr_ids:
|
if item['instr_id'] in instr_ids:
|
||||||
instr_ids.remove(item['instr_id'])
|
instr_ids.remove(item['instr_id'])
|
||||||
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'])
|
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'], item['found'])
|
||||||
|
path_counter += 1
|
||||||
|
|
||||||
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
|
||||||
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
|
||||||
@ -125,7 +143,8 @@ class Evaluation(object):
|
|||||||
|
|
||||||
score_summary = {
|
score_summary = {
|
||||||
'steps': np.average(self.scores['trajectory_steps']),
|
'steps': np.average(self.scores['trajectory_steps']),
|
||||||
'lengths': np.average(self.scores['trajectory_lengths'])
|
'lengths': np.average(self.scores['trajectory_lengths']),
|
||||||
|
'found_score': self.scores['found_count'] / path_counter
|
||||||
}
|
}
|
||||||
end_successes = sum(self.scores['visible'])
|
end_successes = sum(self.scores['visible'])
|
||||||
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
|
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
|
||||||
@ -137,8 +156,18 @@ class Evaluation(object):
|
|||||||
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||||
]
|
]
|
||||||
score_summary['spl'] = np.average(spl)
|
score_summary['spl'] = np.average(spl)
|
||||||
|
# sspl
|
||||||
|
sspl = [float( foundable == 1) * float( visible == 1 ) * l / max(l, p, 0.01)
|
||||||
|
for foundable, visible, p, l in
|
||||||
|
zip(self.scores['foundable'], self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
|
||||||
|
]
|
||||||
|
score_summary['sspl'] = np.average(sspl)
|
||||||
|
|
||||||
assert len(self.scores['rgs']) == len(self.instr_ids)
|
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||||
|
try:
|
||||||
|
assert len(self.scores['rgs']) == len(self.instr_ids)
|
||||||
|
except:
|
||||||
|
print(len(self.scores['rgs']), len(self.instr_ids))
|
||||||
num_rgs = sum(self.scores['rgs'])
|
num_rgs = sum(self.scores['rgs'])
|
||||||
score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs']))
|
score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs']))
|
||||||
|
|
||||||
|
|||||||
@ -110,14 +110,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
|||||||
score_summary, _ = evaluator.score(result)
|
score_summary, _ = evaluator.score(result)
|
||||||
loss_str += ", %s " % env_name
|
loss_str += ", %s " % env_name
|
||||||
for metric, val in score_summary.items():
|
for metric, val in score_summary.items():
|
||||||
if metric in ['spl']:
|
if metric in ['sspl']:
|
||||||
writer.add_scalar("spl/%s" % env_name, val, idx)
|
writer.add_scalar("sspl/%s" % env_name, val, idx)
|
||||||
if env_name in best_val:
|
if env_name in best_val:
|
||||||
if val > best_val[env_name]['spl']:
|
if val > best_val[env_name]['sspl']:
|
||||||
best_val[env_name]['spl'] = val
|
best_val[env_name]['sspl'] = val
|
||||||
best_val[env_name]['update'] = True
|
best_val[env_name]['update'] = True
|
||||||
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
|
elif (val == best_val[env_name]['sspl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
|
||||||
best_val[env_name]['spl'] = val
|
best_val[env_name]['sspl'] = val
|
||||||
best_val[env_name]['update'] = True
|
best_val[env_name]['update'] = True
|
||||||
loss_str += ', %s: %.4f' % (metric, val)
|
loss_str += ', %s: %.4f' % (metric, val)
|
||||||
|
|
||||||
@ -236,6 +236,7 @@ def train_val(test_only=False):
|
|||||||
|
|
||||||
if args.train == 'listener':
|
if args.train == 'listener':
|
||||||
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
|
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
|
||||||
|
# train(train_env, tok, args.iters, log_every=100, val_envs=val_envs)
|
||||||
elif args.train == 'validlistener':
|
elif args.train == 'validlistener':
|
||||||
valid(train_env, tok, val_envs=val_envs)
|
valid(train_env, tok, val_envs=val_envs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
export AIRBERT_ROOT=$(pwd)
|
export AIRBERT_ROOT=$(pwd)
|
||||||
export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build
|
export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build
|
||||||
|
|
||||||
name=REVERIE-RC-VLN-BERT-original/train-init.airbert
|
name=REVERIE-RC-VLN-BERT-original/train-init.airbert-ver2
|
||||||
|
|
||||||
flag="--vlnbert vilbert
|
flag="--vlnbert vilbert
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ flag="--vlnbert vilbert
|
|||||||
--features places365
|
--features places365
|
||||||
--maxAction 15
|
--maxAction 15
|
||||||
--maxInput 50
|
--maxInput 50
|
||||||
--batchSize 4
|
--batchSize 8
|
||||||
--feedback sample
|
--feedback sample
|
||||||
--lr 1e-5
|
--lr 1e-5
|
||||||
--iters 200000
|
--iters 200000
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user