feat: add NOT_FOUND action in rollout
This commit is contained in:
parent
4936098b5e
commit
03a3e5b489
@ -147,7 +147,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||||
|
|
||||||
def _candidate_variable(self, obs):
|
def _candidate_variable(self, obs):
|
||||||
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end
|
||||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||||
|
|
||||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||||
@ -155,6 +155,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
|
candidate_feat[i, len(ob['candidate']), :] = np.zeros(self.feature_size+args.angle_feat_size, dtype=np.float32) # <STOP>
|
||||||
|
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones(self.feature_size+args.angle_feat_size, dtype=np.float32) # <NOT_FOUND>
|
||||||
|
|
||||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||||
|
|
||||||
@ -186,10 +188,13 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
break
|
break
|
||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
a[i] = len(ob['candidate'])
|
if ob['found']:
|
||||||
|
a[i] = len(ob['candidate'])
|
||||||
|
else:
|
||||||
|
a[i] = len(ob['candidate'])+1
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
|
||||||
"""
|
"""
|
||||||
Interface between Panoramic view and Egocentric view
|
Interface between Panoramic view and Egocentric view
|
||||||
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
||||||
@ -205,7 +210,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
for i, idx in enumerate(perm_idx):
|
for i, idx in enumerate(perm_idx):
|
||||||
action = a_t[i]
|
action = a_t[i]
|
||||||
if action != -1: # -1 is the <stop> action
|
if action != -1 and action != -2: # -1 is the <stop> action
|
||||||
select_candidate = perm_obs[i]['candidate'][action]
|
select_candidate = perm_obs[i]['candidate'][action]
|
||||||
src_point = perm_obs[i]['viewIndex']
|
src_point = perm_obs[i]['viewIndex']
|
||||||
trg_point = select_candidate['pointId']
|
trg_point = select_candidate['pointId']
|
||||||
@ -228,6 +233,10 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||||
if traj is not None:
|
if traj is not None:
|
||||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||||
|
elif action == -1 or action == -2:
|
||||||
|
if found is not None:
|
||||||
|
found[i] = action
|
||||||
|
|
||||||
|
|
||||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||||
"""
|
"""
|
||||||
@ -246,7 +255,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
obs = np.array(self.env.reset())
|
obs = np.array(self.env.reset())
|
||||||
else:
|
else:
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
|
|
||||||
batch_size = len(obs)
|
batch_size = len(obs)
|
||||||
|
|
||||||
# Language input
|
# Language input
|
||||||
@ -270,6 +279,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
'instr_id': ob['instr_id'],
|
'instr_id': ob['instr_id'],
|
||||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||||
} for ob in perm_obs]
|
} for ob in perm_obs]
|
||||||
|
|
||||||
|
found = [ None for _ in range(len(perm_obs)) ]
|
||||||
|
|
||||||
# Init the reward shaping
|
# Init the reward shaping
|
||||||
last_dist = np.zeros(batch_size, np.float32)
|
last_dist = np.zeros(batch_size, np.float32)
|
||||||
@ -293,6 +304,15 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for t in range(self.episode_len):
|
for t in range(self.episode_len):
|
||||||
|
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
|
|
||||||
|
'''
|
||||||
|
# show feature
|
||||||
|
for index, feat in enumerate(candidate_feat):
|
||||||
|
for ff in feat:
|
||||||
|
print(ff)
|
||||||
|
print(candidate_leng[index])
|
||||||
|
print()
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
# the first [CLS] token, initialized by the language BERT, serves
|
# the first [CLS] token, initialized by the language BERT, serves
|
||||||
@ -324,9 +344,22 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended)
|
target = self._teacher_action(perm_obs, ended)
|
||||||
|
for i in perm_obs:
|
||||||
|
print(i['found'], end=' ')
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
for index, mask in enumerate(candidate_mask):
|
||||||
|
print(mask)
|
||||||
|
print(candidate_leng[index])
|
||||||
|
print(logit[index])
|
||||||
|
print(target[index])
|
||||||
|
print("\n\n")
|
||||||
|
'''
|
||||||
|
|
||||||
# Determine next model inputs
|
# Determine next model inputs
|
||||||
|
|
||||||
if self.feedback == 'teacher':
|
if self.feedback == 'teacher':
|
||||||
a_t = target # teacher forcing
|
a_t = target # teacher forcing
|
||||||
elif self.feedback == 'argmax':
|
elif self.feedback == 'argmax':
|
||||||
@ -344,15 +377,24 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
print(self.feedback)
|
print(self.feedback)
|
||||||
sys.exit('Invalid feedback option')
|
sys.exit('Invalid feedback option')
|
||||||
|
|
||||||
# Prepare environment action
|
# Prepare environment action
|
||||||
# NOTE: Env action is in the perm_obs space
|
# NOTE: Env action is in the perm_obs space
|
||||||
cpu_a_t = a_t.cpu().numpy()
|
cpu_a_t = a_t.cpu().numpy()
|
||||||
for i, next_id in enumerate(cpu_a_t):
|
for i, next_id in enumerate(cpu_a_t):
|
||||||
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
if next_id == (args.ignoreid) or ended[i]:
|
||||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
cpu_a_t[i] = found[i]
|
||||||
|
elif next_id == (candidate_leng[i]-2):
|
||||||
|
cpu_a_t[i] = -1
|
||||||
|
elif next_id == (candidate_leng[i]-1):
|
||||||
|
cpu_a_t[i] = -2
|
||||||
|
|
||||||
|
|
||||||
|
print(cpu_a_t)
|
||||||
|
|
||||||
# Make action and get the new state
|
# Make action and get the new state
|
||||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
|
||||||
|
print(self.feedback, found)
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||||
|
|
||||||
@ -376,6 +418,20 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if action_idx == -1: # If the action now is end
|
if action_idx == -1: # If the action now is end
|
||||||
if dist[i] < 3.0: # Correct
|
if dist[i] < 3.0: # Correct
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
|
if ob['found']:
|
||||||
|
reward[i] += 1
|
||||||
|
else:
|
||||||
|
reward[i] -= 2
|
||||||
|
else: # Incorrect
|
||||||
|
reward[i] = -2.0
|
||||||
|
|
||||||
|
elif action_idx == -2:
|
||||||
|
if dist[i] < 3.0:
|
||||||
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
|
if ob['found']:
|
||||||
|
reward[i] -= 2
|
||||||
|
else:
|
||||||
|
reward[i] += 1
|
||||||
else: # Incorrect
|
else: # Incorrect
|
||||||
reward[i] = -2.0
|
reward[i] = -2.0
|
||||||
else: # The action is not end
|
else: # The action is not end
|
||||||
@ -399,6 +455,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Update the finished actions
|
# Update the finished actions
|
||||||
# -1 means ended or ignored (already ended)
|
# -1 means ended or ignored (already ended)
|
||||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||||
|
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
||||||
|
|
||||||
# Early exit if all ended
|
# Early exit if all ended
|
||||||
if ended.all():
|
if ended.all():
|
||||||
@ -476,6 +533,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||||
|
|
||||||
|
print('\n')
|
||||||
return traj
|
return traj
|
||||||
|
|
||||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||||
|
|||||||
@ -127,6 +127,7 @@ class R2RBatch():
|
|||||||
new_item = dict(item)
|
new_item = dict(item)
|
||||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||||
new_item['instructions'] = instr
|
new_item['instructions'] = instr
|
||||||
|
new_item['found'] = item['found'][j]
|
||||||
|
|
||||||
''' BERT tokenizer '''
|
''' BERT tokenizer '''
|
||||||
instr_tokens = tokenizer.tokenize(instr)
|
instr_tokens = tokenizer.tokenize(instr)
|
||||||
@ -328,6 +329,7 @@ class R2RBatch():
|
|||||||
# [visual_feature, angle_feature] for views
|
# [visual_feature, angle_feature] for views
|
||||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||||
|
|
||||||
|
|
||||||
obs.append({
|
obs.append({
|
||||||
'instr_id' : item['instr_id'],
|
'instr_id' : item['instr_id'],
|
||||||
'scan' : state.scanId,
|
'scan' : state.scanId,
|
||||||
@ -341,7 +343,8 @@ class R2RBatch():
|
|||||||
'instructions' : item['instructions'],
|
'instructions' : item['instructions'],
|
||||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id']
|
'path_id' : item['path_id'],
|
||||||
|
'found': item['found']
|
||||||
})
|
})
|
||||||
if 'instr_encoding' in item:
|
if 'instr_encoding' in item:
|
||||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user