feat: complete adversarial version

This commit is contained in:
Ting-Jun Wang 2024-01-15 17:09:40 +08:00
parent 7fab347934
commit ad72df7970
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
5 changed files with 92 additions and 27 deletions

View File

@ -36,7 +36,7 @@ class BaseAgent(object):
json.dump(output, f)
def get_results(self):
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r} for k, (v,r) in self.results.items()]
output = [{'instr_id': k, 'trajectory': v, 'predObjId': r, 'found': found} for k, (v,r, found) in self.results.items()]
return output
def rollout(self, **args):
@ -57,17 +57,19 @@ class BaseAgent(object):
if iters is not None:
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
for i in range(iters):
for traj in self.rollout(**kwargs):
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
else: # Do a full round
while True:
for traj in self.rollout(**kwargs):
trajs, found = self.rollout(**kwargs)
for index, traj in enumerate(trajs):
if traj['instr_id'] in self.results:
looped = True
else:
self.loss = 0
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'])
self.results[traj['instr_id']] = (traj['path'], traj['predObjId'], found[index])
if looped:
break
@ -169,8 +171,14 @@ class Seq2SeqAgent(BaseAgent):
for i, ob in enumerate(obs):
for j, cc in enumerate(ob['candidate']):
candidate_feat[i, j, :] = cc['feature']
result = torch.from_numpy(candidate_feat)
'''
for i, ob in enumerate(obs):
result[i, len(ob['candidate']), :] = torch.ones((self.feature_size + args.angle_feat_size), dtype=torch.float32)
'''
result = result.cuda()
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
return result, candidate_leng
def _object_variable(self, obs):
cand_obj_leng = [len(ob['candidate_obj'][2]) + 1 for ob in obs] # +1 is for no REF
@ -202,7 +210,7 @@ class Seq2SeqAgent(BaseAgent):
return input_a_t, f_t, candidate_feat, candidate_leng, obj_feat, obj_pos, obj_leng
def _teacher_action(self, obs, ended, cand_size):
def _teacher_action(self, obs, ended, cand_size, candidate_leng):
"""
Extract teacher actions into variable.
:param obs: The observation.
@ -221,6 +229,12 @@ class Seq2SeqAgent(BaseAgent):
else: # Stop here
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
a[i] = cand_size - 1
'''
if ob['found']:
a[i] = cand_size - 1
else:
a[i] = candidate_leng[i] - 1
'''
return torch.from_numpy(a).cuda()
def _teacher_REF(self, obs, just_ended):
@ -232,8 +246,12 @@ class Seq2SeqAgent(BaseAgent):
candidate_objs = ob['candidate_obj'][2]
for k, kid in enumerate(candidate_objs):
if kid == ob['objId']:
a[i] = k
break
if ob['found']:
a[i] = k
break
else:
a[i] = len(candidate_objs)
break
else:
a[i] = args.ignoreid
return torch.from_numpy(a).cuda()
@ -256,7 +274,7 @@ class Seq2SeqAgent(BaseAgent):
for i, idx in enumerate(perm_idx):
action = a_t[i]
if action != -1: # -1 is the <stop> action
if action != -1 and action != -2: # -1 is the <stop> action
select_candidate = perm_obs[i]['candidate'][action]
src_point = perm_obs[i]['viewIndex']
trg_point = select_candidate['pointId']
@ -296,6 +314,7 @@ class Seq2SeqAgent(BaseAgent):
else:
obs = np.array(self.env._get_obs())
batch_size = len(obs)
# Reorder the language input for the encoder (do not ruin the original code)
@ -334,6 +353,7 @@ class Seq2SeqAgent(BaseAgent):
# Initialization the tracking state
ended = np.array([False] * batch_size) # Indices match permuation of the model, not env
just_ended = np.array([False] * batch_size)
found = np.array([None] * batch_size)
# Init the logs
rewards = []
@ -398,7 +418,7 @@ class Seq2SeqAgent(BaseAgent):
if train_ml is not None:
# Supervised training
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1))
target = self._teacher_action(perm_obs, ended, candidate_mask.size(1), candidate_leng)
ml_loss += self.criterion(logit, target)
# Determine next model inputs
@ -424,12 +444,15 @@ class Seq2SeqAgent(BaseAgent):
# NOTE: Env action is in the perm_obs space
cpu_a_t = a_t.cpu().numpy()
for i, next_id in enumerate(cpu_a_t):
if ((next_id == visual_temp_mask.size(1)) or (t == self.episode_len-1)) and (not ended[i]): # just stopped and forced stopped
if ((next_id == visual_temp_mask.size(1)) or (next_id == (candidate_leng[i]-1)) or (t == self.episode_len-1)) \
and (not ended[i]): # just stoppped and forced stopped
just_ended[i] = True
if self.feedback == 'argmax':
_, ref_t = logit_REF[i].max(0)
if ref_t != obj_leng[i]-1: # decide not to do REF
traj[i]['predObjId'] = perm_obs[i]['candidate_obj'][2][ref_t]
else:
traj[i]['ref'] = 'NOT_FOUND'
if args.submit:
if obj_leng[i] == 1:
@ -443,8 +466,18 @@ class Seq2SeqAgent(BaseAgent):
else:
just_ended[i] = False
if (next_id == visual_temp_mask.size(1)) or (next_id == args.ignoreid) or (ended[i]): # The last action is <end>
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
if (next_id == args.ignoreid) or (ended[i]):
cpu_a_t[i] = found[i]
elif (next_id == visual_temp_mask.size(1)):
cpu_a_t[i] = -1
found[i] = -1
if self.feedback == 'argmax':
_, ref_t = logit_REF[1].max(0)
if ref_t == obj_leng[i]-1:
found[i] = -2
else:
found[i] = -1
''' Supervised training for REF '''
if train_ml is not None:
@ -600,7 +633,7 @@ class Seq2SeqAgent(BaseAgent):
# import pdb; pdb.set_trace()
return traj
return traj, found
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
''' Evaluate once on each instruction in the current environment '''

View File

@ -127,6 +127,7 @@ class R2RBatch():
new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['id'], j)
new_item['instructions'] = instr
new_item['found'] = item['found'][j]
''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr)
@ -339,7 +340,8 @@ class R2RBatch():
'gt_path' : item['path'],
'path_id' : item['id'],
'objId': str(item['objId']) if 'objId' in item else str(None), # target objId
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject])
'candidate_obj': (obj_local_pos[:args.maxObject], obj_features[:args.maxObject], candidate_objId[:args.maxObject]),
'found': item['found']
})
if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding']

View File

@ -50,11 +50,12 @@ class Evaluation(object):
near_d = d
return near_id
def _score_item(self, instr_id, path, ref_objId):
def _score_item(self, instr_id, path, ref_objId, predict_found):
''' Calculate error based on the final position in trajectory, and also
the closest position (oracle stopping rule).
The path contains [view_id, angle, vofv] '''
gt = self.gt[instr_id[:-2]] # pathId_objId
index = int(instr_id.split('_')[-1])
start = gt['path'][0]
assert start == path[0][0], 'Result trajectories should include the start position'
goal = gt['path'][-1]
@ -74,6 +75,19 @@ class Evaluation(object):
self.distances[gt['scan']][start][goal]
)
if gt['found'][index] == True:
if predict_found == -1:
self.scores['found_count'] += 1
self.scores['foundable'].append(1)
else:
self.scores['foundable'].append(0)
else:
if predict_found == -2:
self.scores['found_count'] += 1
self.scores['foundable'].append(1)
else:
self.scores['foundable'].append(0)
# REF success or not
if (ref_objId == str(gt.get('objId', 0))) or (ref_objId == gt.get('objId', 0)):
self.scores['rgs'].append(1)
@ -104,6 +118,8 @@ class Evaluation(object):
def score(self, output_file):
''' Evaluate each agent trajectory based on how close it got to the goal location '''
self.scores = defaultdict(list)
self.scores['found_count'] = 0
self.scores['foundable'] = []
instr_ids = set(self.instr_ids)
if type(output_file) is str:
with open(output_file) as f:
@ -112,11 +128,13 @@ class Evaluation(object):
results = output_file
print('result length', len(results))
path_counter = 0
for item in results:
# Check against expected ids
if item['instr_id'] in instr_ids:
instr_ids.remove(item['instr_id'])
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'])
self._score_item(item['instr_id'], item['trajectory'], item['predObjId'], item['found'])
path_counter += 1
if 'train' not in self.splits: # Exclude the training from this. (Because training eval may be partial)
assert len(instr_ids) == 0, 'Missing %d of %d instruction ids from %s - not in %s'\
@ -125,7 +143,8 @@ class Evaluation(object):
score_summary = {
'steps': np.average(self.scores['trajectory_steps']),
'lengths': np.average(self.scores['trajectory_lengths'])
'lengths': np.average(self.scores['trajectory_lengths']),
'found_score': self.scores['found_count'] / path_counter
}
end_successes = sum(self.scores['visible'])
score_summary['success_rate'] = float(end_successes) / float(len(self.scores['visible']))
@ -137,8 +156,18 @@ class Evaluation(object):
zip(self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
]
score_summary['spl'] = np.average(spl)
# sspl
sspl = [float( foundable == 1) * float( visible == 1 ) * l / max(l, p, 0.01)
for foundable, visible, p, l in
zip(self.scores['foundable'], self.scores['visible'], self.scores['trajectory_lengths'], self.scores['shortest_lengths'])
]
score_summary['sspl'] = np.average(sspl)
assert len(self.scores['rgs']) == len(self.instr_ids)
try:
assert len(self.scores['rgs']) == len(self.instr_ids)
except:
print(len(self.scores['rgs']), len(self.instr_ids))
num_rgs = sum(self.scores['rgs'])
score_summary['rgs'] = float(num_rgs) / float(len(self.scores['rgs']))

View File

@ -110,14 +110,14 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
score_summary, _ = evaluator.score(result)
loss_str += ", %s " % env_name
for metric, val in score_summary.items():
if metric in ['spl']:
writer.add_scalar("spl/%s" % env_name, val, idx)
if metric in ['sspl']:
writer.add_scalar("sspl/%s" % env_name, val, idx)
if env_name in best_val:
if val > best_val[env_name]['spl']:
best_val[env_name]['spl'] = val
if val > best_val[env_name]['sspl']:
best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True
elif (val == best_val[env_name]['spl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
best_val[env_name]['spl'] = val
elif (val == best_val[env_name]['sspl']) and (score_summary['success_rate'] > best_val[env_name]['sr']):
best_val[env_name]['sspl'] = val
best_val[env_name]['update'] = True
loss_str += ', %s: %.4f' % (metric, val)
@ -236,6 +236,7 @@ def train_val(test_only=False):
if args.train == 'listener':
train(train_env, tok, args.iters, log_every=args.log_every, val_envs=val_envs)
# train(train_env, tok, args.iters, log_every=100, val_envs=val_envs)
elif args.train == 'validlistener':
valid(train_env, tok, val_envs=val_envs)
else:

View File

@ -1,7 +1,7 @@
export AIRBERT_ROOT=$(pwd)
export PYTHONPATH=${PYTHONPATH}:${AIRBERT_ROOT}/build
name=REVERIE-RC-VLN-BERT-original/train-init.airbert
name=REVERIE-RC-VLN-BERT-original/train-init.airbert-ver2
flag="--vlnbert vilbert
@ -13,7 +13,7 @@ flag="--vlnbert vilbert
--features places365
--maxAction 15
--maxInput 50
--batchSize 4
--batchSize 8
--feedback sample
--lr 1e-5
--iters 200000