feat: haven't fix not found
This commit is contained in:
parent
857c7e8e10
commit
d2f18c1c61
@ -35,12 +35,12 @@ class BaseAgent(object):
|
|||||||
self.losses = [] # For learning agents
|
self.losses = [] # For learning agents
|
||||||
|
|
||||||
def write_results(self):
|
def write_results(self):
|
||||||
output = [{'instr_id':k, 'trajectory': v} for k,v in self.results.items()]
|
output = [{'instr_id':k, 'trajectory': v[0], 'found': v[1]} for k,v in self.results.items()]
|
||||||
with open(self.results_path, 'w') as f:
|
with open(self.results_path, 'w') as f:
|
||||||
json.dump(output, f)
|
json.dump(output, f)
|
||||||
|
|
||||||
def get_results(self):
|
def get_results(self):
|
||||||
output = [{'instr_id': k, 'trajectory': v} for k, v in self.results.items()]
|
output = [{'instr_id': k, 'trajectory': v[0], 'found': v[1]} for k, v in self.results.items()]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def rollout(self, **args):
|
def rollout(self, **args):
|
||||||
@ -61,17 +61,21 @@ class BaseAgent(object):
|
|||||||
if iters is not None:
|
if iters is not None:
|
||||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
for traj in self.rollout(**kwargs):
|
trajs, found = self.rollout(**kwargs)
|
||||||
|
print(found)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = traj['path']
|
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
for traj in self.rollout(**kwargs):
|
trajs, found = self.rollout(**kwargs)
|
||||||
|
print("FOUND: ", found)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
if traj['instr_id'] in self.results:
|
if traj['instr_id'] in self.results:
|
||||||
looped = True
|
looped = True
|
||||||
else:
|
else:
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = traj['path']
|
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||||
if looped:
|
if looped:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -155,7 +159,9 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
candidate_feat[i, -1, :] = np.ones((self.feature_size + args.angle_feat_size))
|
|
||||||
|
# 補上 not fount token
|
||||||
|
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones((self.feature_size + args.angle_feat_size))
|
||||||
|
|
||||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||||
|
|
||||||
@ -188,12 +194,13 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
if ob['swap']: # instruction 有被換過,所以要 not found
|
if ob['swap']: # instruction 有被換過,所以要 not found
|
||||||
a[i] = len(ob['candidate'])
|
|
||||||
else: # STOP
|
|
||||||
a[i] = len(ob['candidate'])-1
|
a[i] = len(ob['candidate'])-1
|
||||||
|
else: # STOP
|
||||||
|
a[i] = len(ob['candidate'])-2
|
||||||
|
print(" ", a)
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
|
||||||
"""
|
"""
|
||||||
Interface between Panoramic view and Egocentric view
|
Interface between Panoramic view and Egocentric view
|
||||||
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
|
||||||
@ -209,7 +216,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
for i, idx in enumerate(perm_idx):
|
for i, idx in enumerate(perm_idx):
|
||||||
action = a_t[i]
|
action = a_t[i]
|
||||||
print('action: ', action)
|
# print('action: ', action)
|
||||||
if action != -1 and action != -2: # -1 is the <stop> action
|
if action != -1 and action != -2: # -1 is the <stop> action
|
||||||
select_candidate = perm_obs[i]['candidate'][action]
|
select_candidate = perm_obs[i]['candidate'][action]
|
||||||
src_point = perm_obs[i]['viewIndex']
|
src_point = perm_obs[i]['viewIndex']
|
||||||
@ -233,11 +240,17 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||||
if traj is not None:
|
if traj is not None:
|
||||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||||
|
else:
|
||||||
|
found[i] = action
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
elif action == -1:
|
elif action == -1:
|
||||||
print('<STOP>')
|
print('<STOP>')
|
||||||
elif action == -2:
|
elif action == -2:
|
||||||
print('<NOT_FOUND>')
|
print('<NOT_FOUND>')
|
||||||
|
'''
|
||||||
|
|
||||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||||
"""
|
"""
|
||||||
@ -247,6 +260,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
print("ROLLOUT!!!")
|
||||||
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
||||||
train_rl = False
|
train_rl = False
|
||||||
|
|
||||||
@ -283,6 +297,9 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||||
} for ob in perm_obs]
|
} for ob in perm_obs]
|
||||||
|
|
||||||
|
found = [None for _ in range(len(perm_obs))]
|
||||||
|
|
||||||
|
|
||||||
# Init the reward shaping
|
# Init the reward shaping
|
||||||
last_dist = np.zeros(batch_size, np.float32)
|
last_dist = np.zeros(batch_size, np.float32)
|
||||||
last_ndtw = np.zeros(batch_size, np.float32)
|
last_ndtw = np.zeros(batch_size, np.float32)
|
||||||
@ -306,6 +323,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
|
|
||||||
|
|
||||||
# the first [CLS] token, initialized by the language BERT, serves
|
# the first [CLS] token, initialized by the language BERT, serves
|
||||||
# as the agent's state passing through time steps
|
# as the agent's state passing through time steps
|
||||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||||
@ -337,9 +355,10 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended)
|
target = self._teacher_action(perm_obs, ended)
|
||||||
for i, d in enumerate(target):
|
for i, d in enumerate(target):
|
||||||
print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
|
# print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
|
||||||
print(d)
|
# print(d)
|
||||||
_, at_t = logit.max(1)
|
_, at_t = logit.max(1)
|
||||||
|
'''
|
||||||
if at_t[i].item() == candidate_leng[i]-1:
|
if at_t[i].item() == candidate_leng[i]-1:
|
||||||
print("-2")
|
print("-2")
|
||||||
elif at_t[i].item() == candidate_leng[i]-2:
|
elif at_t[i].item() == candidate_leng[i]-2:
|
||||||
@ -347,14 +366,19 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
print(at_t[i].item())
|
print(at_t[i].item())
|
||||||
print()
|
print()
|
||||||
|
'''
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
|
a_predict = None
|
||||||
# Determine next model inputs
|
# Determine next model inputs
|
||||||
if self.feedback == 'teacher':
|
if self.feedback == 'teacher':
|
||||||
a_t = target # teacher forcing
|
a_t = target # teacher forcing
|
||||||
|
_, a_predict = logit.max(1)
|
||||||
|
a_predict = a_predict.detach()
|
||||||
elif self.feedback == 'argmax':
|
elif self.feedback == 'argmax':
|
||||||
_, a_t = logit.max(1) # student forcing - argmax
|
_, a_t = logit.max(1) # student forcing - argmax
|
||||||
a_t = a_t.detach()
|
a_t = a_t.detach()
|
||||||
|
a_predict = a_t.detach()
|
||||||
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
|
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
|
||||||
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
|
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
|
||||||
elif self.feedback == 'sample':
|
elif self.feedback == 'sample':
|
||||||
@ -362,23 +386,39 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
c = torch.distributions.Categorical(probs)
|
c = torch.distributions.Categorical(probs)
|
||||||
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
||||||
entropys.append(c.entropy()) # For optimization
|
entropys.append(c.entropy()) # For optimization
|
||||||
a_t = c.sample().detach()
|
new_c = c.sample()
|
||||||
|
a_t = new_c.detach()
|
||||||
|
a_predict = new_c.detach()
|
||||||
policy_log_probs.append(c.log_prob(a_t))
|
policy_log_probs.append(c.log_prob(a_t))
|
||||||
else:
|
else:
|
||||||
print(self.feedback)
|
# print(self.feedback)
|
||||||
sys.exit('Invalid feedback option')
|
sys.exit('Invalid feedback option')
|
||||||
|
|
||||||
# Prepare environment action
|
# Prepare environment action
|
||||||
# NOTE: Env action is in the perm_obs space
|
# NOTE: Env action is in the perm_obs space
|
||||||
cpu_a_t = a_t.cpu().numpy()
|
cpu_a_t = a_t.cpu().numpy()
|
||||||
for i, next_id in enumerate(cpu_a_t):
|
for i, next_id in enumerate(cpu_a_t):
|
||||||
if next_id == (candidate_leng[i]-2) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
if next_id == args.ignoreid or ended[i]:
|
||||||
|
if found[i] == True:
|
||||||
|
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||||
|
else:
|
||||||
|
cpu_a_t[i] = -2
|
||||||
|
elif next_id == (candidate_leng[i]-2):
|
||||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||||
elif next_id == (candidate_leng[i]-1):
|
elif next_id == (candidate_leng[i]-1):
|
||||||
cpu_a_t[i] = -2
|
cpu_a_t[i] = -2
|
||||||
|
|
||||||
|
|
||||||
|
cpu_a_predict = a_predict.cpu().numpy()
|
||||||
|
for i, next_id in enumerate(cpu_a_predict):
|
||||||
|
if next_id == (candidate_leng[i]-2):
|
||||||
|
cpu_a_predict[i] = -1 # Change the <end> and ignore action to -1
|
||||||
|
elif next_id == (candidate_leng[i]-1):
|
||||||
|
cpu_a_predict[i] = -2
|
||||||
|
|
||||||
# Make action and get the new state
|
# Make action and get the new state
|
||||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
print(cpu_a_t)
|
||||||
|
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found=found)
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||||
|
|
||||||
@ -445,7 +485,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if ended.all():
|
if ended.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
print()
|
# print()
|
||||||
|
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Last action in A2C
|
# Last action in A2C
|
||||||
@ -517,8 +557,9 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
self.losses.append(0.)
|
self.losses.append(0.)
|
||||||
else:
|
else:
|
||||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||||
|
print("\n\n")
|
||||||
|
|
||||||
return traj
|
return traj, found
|
||||||
|
|
||||||
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
def test(self, use_dropout=False, feedback='argmax', allow_cheat=False, iters=None):
|
||||||
''' Evaluate once on each instruction in the current environment '''
|
''' Evaluate once on each instruction in the current environment '''
|
||||||
|
|||||||
@ -199,7 +199,8 @@ def train_val(test_only=False):
|
|||||||
else:
|
else:
|
||||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||||
val_env_names = ['val_train_seen']
|
# val_env_names = ['val_train_seen']
|
||||||
|
val_env_names = ['val_unseen']
|
||||||
|
|
||||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user