feat: complete adversarial version

This commit is contained in:
Ting-Jun Wang 2024-01-15 17:09:40 +08:00
parent 7fab347934
commit ad72df7970
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 92 additions and 27 deletions

View File

@ -36,7 +36,7 @@ class BaseAgent(object):
json.dump(output, f) json.dump(output, f)
def get_results(self): def get_results(self):
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()] output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()]
return output return output
def rollout(self, **args): def rollout(self, **args):
@ -57,17 +57,19 @@ class BaseAgent(object):
if iters is not None: if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before) # For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters): for i in range(iters):
for traj in self.rollout(**kwargs): trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId']) self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
else: # Do a full round else: # Do a full round
while True: while True:
for traj in self.rollout(**kwargs): trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results: if traj['instr_id'] in self.results:
looped = True looped = True
else: else:
self.loss = 0 self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId']) self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
if looped: if looped:
break break
@ -169,8 +171,14 @@ class Seq2SeqAgent(BaseAgent):
for i, ob in enumerate(obs): for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']): for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature'] candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
'''
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
'''
result = result.cuda()
return torch.from_numpy(candidate_feat).cuda(), candidate_leng return result, candidate_leng
def _object_variable(self, obs): def _object_variable(self, obs):
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
@ -202,7 +210,7 @@ class Seq2SeqAgent(BaseAgent):
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
def _teacher_action(self, obs, ended, cand_size): def _teacher_action(self, obs, ended, cand_size, candidate_leng):
""" """
Extract teacher actions into variable. Extract teacher actions into variable.
:param obs: The observation. :param obs: The observation.
@ -221,6 +229,12 @@ class Seq2SeqAgent(BaseAgent):
else: # Stop here else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE" assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = cand_size - 1 a[i] = cand_size - 1
'''
if ob['found']:
a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
'''
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
def _teacher_REF(self, obs, just_ended): def _teacher_REF(self, obs, just_ended):
@ -232,8 +246,12 @@ class Seq2SeqAgent(BaseAgent):
candidate_objs = ob['candidate_obj'][2] candidate_objs = ob['candidate_obj'][2]
for k, kid in enumerate(candidate_objs): for k, kid in enumerate(candidate_objs):
if kid == ob['objId']: if kid == ob['objId']:
a[i] = k if ob['found']:
break a[i] = k
break
else:
a[i] = len(candidate_objs)
break
else: else:
a[i] = args.ignoreid a[i] = args.ignoreid
return torch.from_numpy(a).cuda() return torch.from_numpy(a).cuda()
@ -256,7 +274,7 @@ class Seq2SeqAgent(BaseAgent):
for i, idx in enumerate(perm_idx): for i, idx in enumerate(perm_idx):
action = a_t[i] action = a_t[i]
if action != -1: # -1 is the <stop> action if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action] select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex'] src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId'] trg_point = select_candidate['pointId']
@ -296,6 +314,7 @@ class Seq2SeqAgent(BaseAgent):
else: else:
obs = np.array(self.env._get_obs()) obs = np.array(self.env._get_obs())
batch_size = len(obs) batch_size = len(obs)
# Reorder the language input for the encoder (do not ruin the original code) # Reorder the language input for the encoder (do not ruin the original code)
@ -334,6 +353,7 @@ class Seq2SeqAgent(BaseAgent):
# Initialization the tracking state # Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
just_ended = np.array([False] * batch_size) just_ended = np.array([False] * batch_size)
found = np.array([None] * batch_size)
# Init the logs # Init the logs
rewards = [] rewards = []
@ -398,7 +418,7 @@ class Seq2SeqAgent(BaseAgent):
if train_ml is not None: if train_ml is not None:
# Supervised training # Supervised training
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1)) target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
ml_loss += self.criterion(logit, target) ml_loss += self.criterion(logit, target)
# Determine next model inputs # Determine next model inputs
@ -424,12 +444,15 @@ class Seq2SeqAgent(BaseAgent):
# NOTE: Env action is in the perm_obs space # NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy() cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t): for i, next_id in enumerate(cpu_a_t):
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
and (not ended[i]): # just stoppped and forced stopped
just_ended[i] = True just_ended[i] = True
if self.feedback == 'argmax': if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0) _, ref_t = logit_REF[i].max(0)
if ref_t != obj_leng[i]-1: # decide not to do REF if ref_t != obj_leng[i]-1: # decide not to do REF
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t] traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
else:
traj[i]['ref'] = 'NOT_FOUND'
if args.submit: if args.submit:
if obj_leng[i] == 1: if obj_leng[i] == 1:
@ -443,8 +466,18 @@ class Seq2SeqAgent(BaseAgent):
else: else:
just_ended[i] = False just_ended[i] = False
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1 if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
cpu_a_t[i] = -1
found[i] = -1
if self.feedback == 'argmax':
_, ref_t = logit_REF[1].max(0)
if ref_t == obj_leng[i]-1:
found[i] = -2
else:
found[i] = -1
''' Supervised training for REF ''' ''' Supervised training for REF '''
if train_ml is not None: if train_ml is not None:
@ -600,7 +633,7 @@ class Seq2SeqAgent(BaseAgent):
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
return traj return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None): def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment ''' ''' Evaluate once on each instruction in the current environment '''

View File

@ -127,6 +127,7 @@ class R2RBatch():
new_item = dict(item) new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['id'], j) new_item['instr_id'] = '%s_%d' % (item['id'], j)
new_item['instructions'] = instr new_item['instructions'] = instr
new_item['found'] = item['found'][j]
''' BERT tokenizer ''' ''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr) instr_tokens = tokenizer.tokenize(instr)
@ -339,7 +340,8 @@ class R2RBatch():
'gt_path' : item['path'], 'gt_path' : item['path'],
'path_id' : item['id'], 'path_id' : item['id'],
'objId': str(item['objId']) if 'objId' in item else str(None), # target objId 'objId': str(item['objId']) if 'objId' in item else str(None), # target objId
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject]) 'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject]),
'found': item['found']
}) })
if 'instr_encoding' in item: if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding'] obs[-1]['instr_encoding'] = item['instr_encoding']

View File

@ -50,11 +50,12 @@ class Evaluation(object):
near_d = d near_d = d
return near_id return near_id
def _score_item(self, instr_id, path, ref_objId): def _score_item(self, instr_id, path, ref_objId, predict_found):
''' Calculate error based on the final position in trajectory, and also ''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule). the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] ''' The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id[:-2]] # pathId_objId gt = self.gt[instr_id[:-2]] # pathId_objId
index = int(instr_id.split('_')[-1])
start = gt['path'][0] start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position' assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1] goal = gt['path'][-1]
@ -74,6 +75,19 @@ class Evaluation(object):
self.distances[gt['scan']][start][goal] self.distances[gt['scan']][start][goal]
) )
if gt['found'][index] == True:
if predict_found == -1:
self.scores['found_count'] += 1
self.scores['foundable'].append(1)
else:
self.scores['foundable'].append(0)
else:
if predict_found == -2:
self.scores['found_count'] += 1
self.scores['foundable'].append(1)
else:
self.scores['foundable'].append(0)
# REF success or not # REF success or not
if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)): if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)):
self.scores['rgs'].append(1) self.scores['rgs'].append(1)
@ -104,6 +118,8 @@ class Evaluation(object):
def score(self, output_file): def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location ''' ''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list) self.scores = defaultdict(list)
self.scores['found_count'] = 0
self.scores['foundable'] = []
instr_ids = set(self.instr_ids) instr_ids = set(self.instr_ids)
if type(output_file) is str: if type(output_file) is str:
with open(output_file) as f: with open(output_file) as f:
@ -112,11 +128,13 @@ class Evaluation(object):
results = output_file results = output_file
print('result length', len(results)) print('result length', len(results))
path_counter = 0
for item in results: for item in results:
# Check against expected ids # Check against expected ids
if item['instr_id'] in instr_ids: if item['instr_id'] in instr_ids:
instr_ids.remove(item['instr_id']) instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory'], item['predObjId']) self._score_item(item['instr_id'], item['trajectory'], item['predObjId'], item['found'])
path_counter += 1
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial) if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\ assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
@ -125,7 +143,8 @@ class Evaluation(object):
score_summary = { score_summary = {
'steps': np.average(self.scores['trajectory_steps']), 'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths']) 'lengths': np.average(self.scores['trajectory_lengths']),
'found_score': self.scores['found_count'] / path_counter
} }
end_successes = sum(self.scores['visible']) end_successes = sum(self.scores['visible'])
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible'])) score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
@ -137,8 +156,18 @@ class Evaluation(object):
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths']) zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
] ]
score_summary['spl'] = np.average(spl) score_summary['spl'] = np.average(spl)
# sspl
sspl = [float( foundable == 1) * float( visible == 1 ) * l / max(l, p, 0.01)
for foundable, visible, p, l in
zip(self.scores['foundable'], self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
]
score_summary['sspl'] = np.average(sspl)
assert len(self.scores['rgs']) == len(self.instr_ids) assert len(self.scores['rgs']) == len(self.instr_ids)
try:
assert len(self.scores['rgs']) == len(self.instr_ids)
except:
print(len(self.scores['rgs']), len(self.instr_ids))
num_rgs = sum(self.scores['rgs']) num_rgs = sum(self.scores['rgs'])
score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs'])) score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs']))

View File

@ -110,14 +110,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
score_summary, _ = evaluator.score(result) score_summary, _ = evaluator.score(result)
loss_str += ", %s " % env_name loss_str += ", %s " % env_name
for metric, val in score_summary.items(): for metric, val in score_summary.items():
if metric in ['spl']: if metric in ['sspl']:
writer.add_scalar("spl/%s" % env_name, val, idx) writer.add_scalar("sspl/%s" % env_name, val, idx)
if env_name in best_val: if env_name in best_val:
if val > best_val[env_name]['spl']: if val > best_val[env_name]['sspl']:
best_val[env_name]['spl'] = val best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True best_val[env_name]['update'] = True
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']): elif (val == best_val[env_name]['sspl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
best_val[env_name]['spl'] = val best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True best_val[env_name]['update'] = True
loss_str += ', %s: %.4f' % (metric, val) loss_str += ', %s: %.4f' % (metric, val)
@ -236,6 +236,7 @@ def train_val(test_only=False):
if args.train == 'listener': if args.train == 'listener':
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs) train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
# train(train_env, tok, args.iters, log_every=100, val_envs=val_envs)
elif args.train == 'validlistener': elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs) valid(train_env, tok, val_envs=val_envs)
else: else:

View File

@ -1,7 +1,7 @@
export AIRBERT_ROOT=$(pwd) export AIRBERT_ROOT=$(pwd)
export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build
name=REVERIE-RC-VLN-BERT-original/train-init.airbert name=REVERIE-RC-VLN-BERT-original/train-init.airbert-ver2
flag="--vlnbert vilbert flag="--vlnbert vilbert
@ -13,7 +13,7 @@ flag="--vlnbert vilbert
--features places365 --features places365
--maxAction 15 --maxAction 15
--maxInput 50 --maxInput 50
--batchSize 4 --batchSize 8
--feedback sample --feedback sample
--lr 1e-5 --lr 1e-5
--iters 200000 --iters 200000