Compare commits
3 Commits
adversaria
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| d2f18c1c61 | |||
| 857c7e8e10 | |||
| 7329f7fa0a |
@ -1,42 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
def remove_non_ascii(text):
|
|
||||||
return re.sub(r'[^\x00-\x7F]', ' ', text)
|
|
||||||
|
|
||||||
|
|
||||||
for file in ['train', 'val_unseen', 'val_seen', 'train_seen', 'test', 'val_train_seen']:
|
|
||||||
print(file)
|
|
||||||
if os.path.isfile('data/adversarial/reverie_{}_fnf.json'.format(file)):
|
|
||||||
with open('data/adversarial/reverie_{}_fnf.json'.format(file)) as fp:
|
|
||||||
data = json.load(fp)
|
|
||||||
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for i in data:
|
|
||||||
instruction_id = i['path_id']
|
|
||||||
if instruction_id not in result:
|
|
||||||
result[instruction_id] = {
|
|
||||||
'distance': float(i['distance']),
|
|
||||||
'scan': i['scan'],
|
|
||||||
'path_id': int(i['path_id']),
|
|
||||||
'path': i['path'],
|
|
||||||
'heading': float(i['heading']),
|
|
||||||
'instructions': [ remove_non_ascii(i['instruction'])],
|
|
||||||
'found': [ i['found'] ],
|
|
||||||
'id': i['id'],
|
|
||||||
'objId': i['objId']
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
result[instruction_id]['instructions'].append(remove_non_ascii(i['instruction']))
|
|
||||||
result[instruction_id]['found'].append( i['found'] )
|
|
||||||
|
|
||||||
output = []
|
|
||||||
for k, item in result.items():
|
|
||||||
output.append(item)
|
|
||||||
else:
|
|
||||||
output = []
|
|
||||||
|
|
||||||
with open('data/adversarial/R2R_{}.json'.format(file), 'w') as fp:
|
|
||||||
json.dump(output, fp)
|
|
||||||
21
data/adversarial.py
Normal file
21
data/adversarial.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
|
||||||
|
with open(sys.argv[1]) as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
|
||||||
|
for _, d in enumerate(data):
|
||||||
|
swaps = []
|
||||||
|
for index, ins in enumerate(d['instructions']):
|
||||||
|
p = random.random()
|
||||||
|
if p > 0.5:
|
||||||
|
swaps.append(True)
|
||||||
|
d['instructions'][index] += 'This is swap.'
|
||||||
|
else:
|
||||||
|
swaps.append(False)
|
||||||
|
d['swap'] = swaps
|
||||||
|
print(data)
|
||||||
|
|
||||||
|
with open(sys.argv[1], 'w') as fp:
|
||||||
|
json.dump(data, fp)
|
||||||
140
r2r_src/agent.py
140
r2r_src/agent.py
@ -61,14 +61,16 @@ class BaseAgent(object):
|
|||||||
if iters is not None:
|
if iters is not None:
|
||||||
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
# For each time, it will run the first 'iters' iterations. (It was shuffled before)
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
traj, found = self.rollout(**kwargs)
|
trajs, found = self.rollout(**kwargs)
|
||||||
for index, traj in enumerate(traj):
|
print(found)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
self.loss = 0
|
self.loss = 0
|
||||||
self.results[traj['instr_id']] = (traj['path'], found[index])
|
self.results[traj['instr_id']] = (traj['path'], found[index])
|
||||||
else: # Do a full round
|
else: # Do a full round
|
||||||
while True:
|
while True:
|
||||||
traj, found = self.rollout(**kwargs)
|
trajs, found = self.rollout(**kwargs)
|
||||||
for index, traj in enumerate(traj):
|
print("FOUND: ", found)
|
||||||
|
for index, traj in enumerate(trajs):
|
||||||
if traj['instr_id'] in self.results:
|
if traj['instr_id'] in self.results:
|
||||||
looped = True
|
looped = True
|
||||||
else:
|
else:
|
||||||
@ -157,8 +159,9 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
candidate_feat[i, len(ob['candidate']), :] = np.zeros(self.feature_size+args.angle_feat_size, dtype=np.float32) # <STOP>
|
|
||||||
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones(self.feature_size+args.angle_feat_size, dtype=np.float32) # <NOT_FOUND>
|
# 補上 not fount token
|
||||||
|
candidate_feat[i, len(ob['candidate'])+1, :] = np.ones((self.feature_size + args.angle_feat_size))
|
||||||
|
|
||||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||||
|
|
||||||
@ -190,10 +193,11 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
break
|
break
|
||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
if ob['found']:
|
if ob['swap']: # instruction 有被換過,所以要 not found
|
||||||
a[i] = len(ob['candidate'])
|
a[i] = len(ob['candidate'])-1
|
||||||
else:
|
else: # STOP
|
||||||
a[i] = len(ob['candidate'])+1
|
a[i] = len(ob['candidate'])-2
|
||||||
|
print(" ", a)
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
|
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None, found=None):
|
||||||
@ -212,6 +216,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
for i, idx in enumerate(perm_idx):
|
for i, idx in enumerate(perm_idx):
|
||||||
action = a_t[i]
|
action = a_t[i]
|
||||||
|
# print('action: ', action)
|
||||||
if action != -1 and action != -2: # -1 is the <stop> action
|
if action != -1 and action != -2: # -1 is the <stop> action
|
||||||
select_candidate = perm_obs[i]['candidate'][action]
|
select_candidate = perm_obs[i]['candidate'][action]
|
||||||
src_point = perm_obs[i]['viewIndex']
|
src_point = perm_obs[i]['viewIndex']
|
||||||
@ -235,11 +240,18 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||||
if traj is not None:
|
if traj is not None:
|
||||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||||
elif action == -1 or action == -2:
|
else:
|
||||||
if found is not None:
|
found[i] = action
|
||||||
found[i] = action
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
elif action == -1:
|
||||||
|
print('<STOP>')
|
||||||
|
elif action == -2:
|
||||||
|
print('<NOT_FOUND>')
|
||||||
|
'''
|
||||||
|
|
||||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||||
"""
|
"""
|
||||||
:param train_ml: The weight to train with maximum likelihood
|
:param train_ml: The weight to train with maximum likelihood
|
||||||
@ -248,6 +260,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
print("ROLLOUT!!!")
|
||||||
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
if self.feedback == 'teacher' or self.feedback == 'argmax':
|
||||||
train_rl = False
|
train_rl = False
|
||||||
|
|
||||||
@ -263,8 +276,10 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Language input
|
# Language input
|
||||||
sentence, language_attention_mask, token_type_ids, \
|
sentence, language_attention_mask, token_type_ids, \
|
||||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||||
|
|
||||||
perm_obs = obs[perm_idx]
|
perm_obs = obs[perm_idx]
|
||||||
|
|
||||||
|
|
||||||
''' Language BERT '''
|
''' Language BERT '''
|
||||||
language_inputs = {'mode': 'language',
|
language_inputs = {'mode': 'language',
|
||||||
'sentence': sentence,
|
'sentence': sentence,
|
||||||
@ -282,7 +297,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
'path': [(ob['viewpoint'], ob['heading'], ob['elevation'])],
|
||||||
} for ob in perm_obs]
|
} for ob in perm_obs]
|
||||||
|
|
||||||
found = [ None for _ in range(len(perm_obs)) ]
|
found = [None for _ in range(len(perm_obs))]
|
||||||
|
|
||||||
|
|
||||||
# Init the reward shaping
|
# Init the reward shaping
|
||||||
last_dist = np.zeros(batch_size, np.float32)
|
last_dist = np.zeros(batch_size, np.float32)
|
||||||
@ -307,15 +323,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
|
|
||||||
'''
|
|
||||||
# show feature
|
|
||||||
for index, feat in enumerate(candidate_feat):
|
|
||||||
for ff in feat:
|
|
||||||
print(ff)
|
|
||||||
print(candidate_leng[index])
|
|
||||||
print()
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
# the first [CLS] token, initialized by the language BERT, serves
|
# the first [CLS] token, initialized by the language BERT, serves
|
||||||
# as the agent's state passing through time steps
|
# as the agent's state passing through time steps
|
||||||
@ -341,30 +348,37 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
# Mask outputs where agent can't move forward
|
# Mask outputs where agent can't move forward
|
||||||
# Here the logit is [b, max_candidate]
|
# Here the logit is [b, max_candidate]
|
||||||
|
# (8, max(candidate))
|
||||||
candidate_mask = utils.length2mask(candidate_leng)
|
candidate_mask = utils.length2mask(candidate_leng)
|
||||||
logit.masked_fill_(candidate_mask, -float('inf'))
|
logit.masked_fill_(candidate_mask, -float('inf'))
|
||||||
|
|
||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended)
|
target = self._teacher_action(perm_obs, ended)
|
||||||
|
for i, d in enumerate(target):
|
||||||
|
# print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
|
||||||
|
# print(d)
|
||||||
|
_, at_t = logit.max(1)
|
||||||
|
'''
|
||||||
|
if at_t[i].item() == candidate_leng[i]-1:
|
||||||
|
print("-2")
|
||||||
|
elif at_t[i].item() == candidate_leng[i]-2:
|
||||||
|
print("-1")
|
||||||
|
else:
|
||||||
|
print(at_t[i].item())
|
||||||
|
print()
|
||||||
|
'''
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
|
a_predict = None
|
||||||
'''
|
|
||||||
for index, mask in enumerate(candidate_mask):
|
|
||||||
print(mask)
|
|
||||||
print(candidate_leng[index])
|
|
||||||
print(logit[index])
|
|
||||||
print(target[index])
|
|
||||||
print("\n\n")
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Determine next model inputs
|
# Determine next model inputs
|
||||||
|
|
||||||
if self.feedback == 'teacher':
|
if self.feedback == 'teacher':
|
||||||
a_t = target # teacher forcing
|
a_t = target # teacher forcing
|
||||||
|
_, a_predict = logit.max(1)
|
||||||
|
a_predict = a_predict.detach()
|
||||||
elif self.feedback == 'argmax':
|
elif self.feedback == 'argmax':
|
||||||
_, a_t = logit.max(1) # student forcing - argmax
|
_, a_t = logit.max(1) # student forcing - argmax
|
||||||
a_t = a_t.detach()
|
a_t = a_t.detach()
|
||||||
|
a_predict = a_t.detach()
|
||||||
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
|
log_probs = F.log_softmax(logit, 1) # Calculate the log_prob here
|
||||||
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
|
policy_log_probs.append(log_probs.gather(1, a_t.unsqueeze(1))) # Gather the log_prob for each batch
|
||||||
elif self.feedback == 'sample':
|
elif self.feedback == 'sample':
|
||||||
@ -372,39 +386,42 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
c = torch.distributions.Categorical(probs)
|
c = torch.distributions.Categorical(probs)
|
||||||
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
||||||
entropys.append(c.entropy()) # For optimization
|
entropys.append(c.entropy()) # For optimization
|
||||||
a_t = c.sample().detach()
|
new_c = c.sample()
|
||||||
|
a_t = new_c.detach()
|
||||||
|
a_predict = new_c.detach()
|
||||||
policy_log_probs.append(c.log_prob(a_t))
|
policy_log_probs.append(c.log_prob(a_t))
|
||||||
else:
|
else:
|
||||||
print(self.feedback)
|
# print(self.feedback)
|
||||||
sys.exit('Invalid feedback option')
|
sys.exit('Invalid feedback option')
|
||||||
|
|
||||||
# Prepare environment action
|
# Prepare environment action
|
||||||
# NOTE: Env action is in the perm_obs space
|
# NOTE: Env action is in the perm_obs space
|
||||||
cpu_a_t = a_t.cpu().numpy()
|
cpu_a_t = a_t.cpu().numpy()
|
||||||
for i, next_id in enumerate(cpu_a_t):
|
for i, next_id in enumerate(cpu_a_t):
|
||||||
if next_id == (args.ignoreid) or ended[i]:
|
if next_id == args.ignoreid or ended[i]:
|
||||||
cpu_a_t[i] = found[i]
|
if found[i] == True:
|
||||||
|
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||||
|
else:
|
||||||
|
cpu_a_t[i] = -2
|
||||||
elif next_id == (candidate_leng[i]-2):
|
elif next_id == (candidate_leng[i]-2):
|
||||||
cpu_a_t[i] = -1
|
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||||
elif next_id == (candidate_leng[i]-1):
|
elif next_id == (candidate_leng[i]-1):
|
||||||
cpu_a_t[i] = -2
|
cpu_a_t[i] = -2
|
||||||
|
|
||||||
|
|
||||||
# Make action and get the new state
|
cpu_a_predict = a_predict.cpu().numpy()
|
||||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found)
|
for i, next_id in enumerate(cpu_a_predict):
|
||||||
|
if next_id == (candidate_leng[i]-2):
|
||||||
|
cpu_a_predict[i] = -1 # Change the <end> and ignore action to -1
|
||||||
|
elif next_id == (candidate_leng[i]-1):
|
||||||
|
cpu_a_predict[i] = -2
|
||||||
|
|
||||||
'''
|
# Make action and get the new state
|
||||||
print(self.feedback, end=' ')
|
print(cpu_a_t)
|
||||||
print(cpu_a_t, end=' ')
|
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj, found=found)
|
||||||
for i in perm_obs:
|
|
||||||
print(i['found'], end=' ')
|
|
||||||
print(found)
|
|
||||||
print()
|
|
||||||
'''
|
|
||||||
obs = np.array(self.env._get_obs())
|
obs = np.array(self.env._get_obs())
|
||||||
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
perm_obs = obs[perm_idx] # Perm the obs for the resu
|
||||||
|
|
||||||
'''
|
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Calculate the mask and reward
|
# Calculate the mask and reward
|
||||||
dist = np.zeros(batch_size, np.float32)
|
dist = np.zeros(batch_size, np.float32)
|
||||||
@ -425,22 +442,22 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if action_idx == -1: # If the action now is end
|
if action_idx == -1: # If the action now is end
|
||||||
if dist[i] < 3.0: # Correct
|
if dist[i] < 3.0: # Correct
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
if ob['found']:
|
if ob['swap']:
|
||||||
reward[i] += 1
|
|
||||||
else:
|
|
||||||
reward[i] -= 2
|
reward[i] -= 2
|
||||||
|
else:
|
||||||
|
reward[i] += 1
|
||||||
else: # Incorrect
|
else: # Incorrect
|
||||||
reward[i] = -2.0
|
reward[i] = -2.0
|
||||||
|
elif action_idx == -2: # NOT_FOUND reward 設定在這裏
|
||||||
elif action_idx == -2:
|
|
||||||
if dist[i] < 3.0:
|
if dist[i] < 3.0:
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
if ob['found']:
|
if ob['swap']:
|
||||||
reward[i] -= 2
|
reward[i] += 3 # 偵測到錯誤 instruction,多加一分
|
||||||
else:
|
else:
|
||||||
reward[i] += 1
|
reward[i] -= 2
|
||||||
else: # Incorrect
|
else: # Incorrect
|
||||||
reward[i] = -2.0
|
reward[i] = -2.0
|
||||||
|
reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一
|
||||||
else: # The action is not end
|
else: # The action is not end
|
||||||
# Path fidelity rewards (distance & nDTW)
|
# Path fidelity rewards (distance & nDTW)
|
||||||
reward[i] = - (dist[i] - last_dist[i])
|
reward[i] = - (dist[i] - last_dist[i])
|
||||||
@ -458,7 +475,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
last_dist[:] = dist
|
last_dist[:] = dist
|
||||||
last_ndtw[:] = ndtw_score
|
last_ndtw[:] = ndtw_score
|
||||||
'''
|
|
||||||
|
|
||||||
# Update the finished actions
|
# Update the finished actions
|
||||||
# -1 means ended or ignored (already ended)
|
# -1 means ended or ignored (already ended)
|
||||||
@ -469,7 +485,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if ended.all():
|
if ended.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
'''
|
# print()
|
||||||
|
|
||||||
if train_rl:
|
if train_rl:
|
||||||
# Last action in A2C
|
# Last action in A2C
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
@ -480,6 +497,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
visual_attention_mask = torch.cat((language_attention_mask, visual_temp_mask), dim=-1)
|
||||||
|
|
||||||
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
self.vln_bert.vln_bert.config.directions = max(candidate_leng)
|
||||||
|
''' Visual BERT '''
|
||||||
visual_inputs = {'mode': 'visual',
|
visual_inputs = {'mode': 'visual',
|
||||||
'sentence': language_features,
|
'sentence': language_features,
|
||||||
'attention_mask': visual_attention_mask,
|
'attention_mask': visual_attention_mask,
|
||||||
@ -530,7 +548,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
self.loss += rl_loss
|
self.loss += rl_loss
|
||||||
self.logs['RL_loss'].append(rl_loss.item())
|
self.logs['RL_loss'].append(rl_loss.item())
|
||||||
'''
|
|
||||||
|
|
||||||
if train_ml is not None:
|
if train_ml is not None:
|
||||||
self.loss += ml_loss * train_ml / batch_size
|
self.loss += ml_loss * train_ml / batch_size
|
||||||
@ -540,6 +557,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
self.losses.append(0.)
|
self.losses.append(0.)
|
||||||
else:
|
else:
|
||||||
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
self.losses.append(self.loss.item() / self.episode_len) # This argument is useless.
|
||||||
|
print("\n\n")
|
||||||
|
|
||||||
return traj, found
|
return traj, found
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
''' Batched Room-to-Room navigation environment '''
|
''' Batched Room-to-Room navigation environment '''
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from networkx.algorithms import swap
|
||||||
sys.path.append('buildpy36')
|
sys.path.append('buildpy36')
|
||||||
sys.path.append('Matterport_Simulator/build/')
|
sys.path.append('Matterport_Simulator/build/')
|
||||||
import MatterSim
|
import MatterSim
|
||||||
@ -14,6 +16,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from param import args
|
from param import args
|
||||||
|
import time
|
||||||
|
|
||||||
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
@ -127,7 +130,7 @@ class R2RBatch():
|
|||||||
new_item = dict(item)
|
new_item = dict(item)
|
||||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||||
new_item['instructions'] = instr
|
new_item['instructions'] = instr
|
||||||
new_item['found'] = item['found'][j]
|
new_item['swap'] = item['swap'][j]
|
||||||
|
|
||||||
''' BERT tokenizer '''
|
''' BERT tokenizer '''
|
||||||
instr_tokens = tokenizer.tokenize(instr)
|
instr_tokens = tokenizer.tokenize(instr)
|
||||||
@ -137,10 +140,12 @@ class R2RBatch():
|
|||||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||||
self.data.append(new_item)
|
self.data.append(new_item)
|
||||||
scans.append(item['scan'])
|
scans.append(item['scan'])
|
||||||
|
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
print("split {} has {} datas in the file.".format(split, max_len))
|
print("split {} has {} datas in the file.".format(split, max_len))
|
||||||
|
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||||
else:
|
else:
|
||||||
@ -329,7 +334,6 @@ class R2RBatch():
|
|||||||
# [visual_feature, angle_feature] for views
|
# [visual_feature, angle_feature] for views
|
||||||
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
feature = np.concatenate((feature, self.angle_feature[base_view_id]), -1)
|
||||||
|
|
||||||
|
|
||||||
obs.append({
|
obs.append({
|
||||||
'instr_id' : item['instr_id'],
|
'instr_id' : item['instr_id'],
|
||||||
'scan' : state.scanId,
|
'scan' : state.scanId,
|
||||||
@ -344,7 +348,7 @@ class R2RBatch():
|
|||||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id'],
|
'path_id' : item['path_id'],
|
||||||
'found': item['found']
|
'swap': item['swap']
|
||||||
})
|
})
|
||||||
if 'instr_encoding' in item:
|
if 'instr_encoding' in item:
|
||||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||||
|
|||||||
@ -55,16 +55,11 @@ class Evaluation(object):
|
|||||||
near_d = d
|
near_d = d
|
||||||
return near_id
|
return near_id
|
||||||
|
|
||||||
def _score_item(self, instr_id, path, predict_found):
|
def _score_item(self, instr_id, path):
|
||||||
''' Calculate error based on the final position in trajectory, and also
|
''' Calculate error based on the final position in trajectory, and also
|
||||||
the closest position (oracle stopping rule).
|
the closest position (oracle stopping rule).
|
||||||
The path contains [view_id, angle, vofv] '''
|
The path contains [view_id, angle, vofv] '''
|
||||||
gt = self.gt[instr_id.split('_')[-2]]
|
gt = self.gt[instr_id.split('_')[-2]]
|
||||||
index = int(instr_id.split('_')[-1])
|
|
||||||
|
|
||||||
gt_instruction = gt['instructions'][index]
|
|
||||||
gt_found = gt['found'][index]
|
|
||||||
|
|
||||||
start = gt['path'][0]
|
start = gt['path'][0]
|
||||||
assert start == path[0][0], 'Result trajectories should include the start position'
|
assert start == path[0][0], 'Result trajectories should include the start position'
|
||||||
goal = gt['path'][-1]
|
goal = gt['path'][-1]
|
||||||
@ -73,19 +68,6 @@ class Evaluation(object):
|
|||||||
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
self.scores['nav_errors'].append(self.distances[gt['scan']][final_position][goal])
|
||||||
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
self.scores['oracle_errors'].append(self.distances[gt['scan']][nearest_position][goal])
|
||||||
self.scores['trajectory_steps'].append(len(path)-1)
|
self.scores['trajectory_steps'].append(len(path)-1)
|
||||||
|
|
||||||
# <STOP> <NOT_FOUND> score
|
|
||||||
score = 0
|
|
||||||
if gt_found == True:
|
|
||||||
if predict_found == -1:
|
|
||||||
score = 1
|
|
||||||
else:
|
|
||||||
if predict_found == -2:
|
|
||||||
score = 1
|
|
||||||
self.scores['found_count'] += score
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
distance = 0 # length of the path in meters
|
distance = 0 # length of the path in meters
|
||||||
prev = path[0]
|
prev = path[0]
|
||||||
for curr in path[1:]:
|
for curr in path[1:]:
|
||||||
@ -99,7 +81,6 @@ class Evaluation(object):
|
|||||||
def score(self, output_file):
|
def score(self, output_file):
|
||||||
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
''' Evaluate each agent trajectory based on how close it got to the goal location '''
|
||||||
self.scores = defaultdict(list)
|
self.scores = defaultdict(list)
|
||||||
self.scores['found_count'] = 0
|
|
||||||
instr_ids = set(self.instr_ids)
|
instr_ids = set(self.instr_ids)
|
||||||
if type(output_file) is str:
|
if type(output_file) is str:
|
||||||
with open(output_file) as f:
|
with open(output_file) as f:
|
||||||
@ -109,14 +90,12 @@ class Evaluation(object):
|
|||||||
|
|
||||||
# print('result length', len(results))
|
# print('result length', len(results))
|
||||||
# print("RESULT:", results)
|
# print("RESULT:", results)
|
||||||
path_counter = 0
|
|
||||||
for item in results:
|
for item in results:
|
||||||
# Check against expected ids
|
# Check against expected ids
|
||||||
if item['instr_id'] in instr_ids:
|
if item['instr_id'] in instr_ids:
|
||||||
# print("{} exist".format(item['instr_id']))
|
# print("{} exist".format(item['instr_id']))
|
||||||
instr_ids.remove(item['instr_id'])
|
instr_ids.remove(item['instr_id'])
|
||||||
self._score_item(item['instr_id'], item['trajectory'], item['found'])
|
self._score_item(item['instr_id'], item['trajectory'])
|
||||||
path_counter += 1
|
|
||||||
else:
|
else:
|
||||||
print("{} not exist".format(item['instr_id']))
|
print("{} not exist".format(item['instr_id']))
|
||||||
print(item)
|
print(item)
|
||||||
@ -129,8 +108,7 @@ class Evaluation(object):
|
|||||||
'nav_error': np.average(self.scores['nav_errors']),
|
'nav_error': np.average(self.scores['nav_errors']),
|
||||||
'oracle_error': np.average(self.scores['oracle_errors']),
|
'oracle_error': np.average(self.scores['oracle_errors']),
|
||||||
'steps': np.average(self.scores['trajectory_steps']),
|
'steps': np.average(self.scores['trajectory_steps']),
|
||||||
'lengths': np.average(self.scores['trajectory_lengths']),
|
'lengths': np.average(self.scores['trajectory_lengths'])
|
||||||
'found_score': self.scores['found_count'] / path_counter
|
|
||||||
}
|
}
|
||||||
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
num_successes = len([i for i in self.scores['nav_errors'] if i < self.error_margin])
|
||||||
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
score_summary['success_rate'] = float(num_successes)/float(len(self.scores['nav_errors']))
|
||||||
|
|||||||
@ -105,9 +105,6 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
|||||||
|
|
||||||
# Run validation
|
# Run validation
|
||||||
loss_str = "iter {}".format(iter)
|
loss_str = "iter {}".format(iter)
|
||||||
|
|
||||||
|
|
||||||
save_results = []
|
|
||||||
for env_name, (env, evaluator) in val_envs.items():
|
for env_name, (env, evaluator) in val_envs.items():
|
||||||
listner.env = env
|
listner.env = env
|
||||||
|
|
||||||
@ -115,8 +112,6 @@ def train(train_env, tok, n_iters, log_every=2000, val_envs={}, aug_env=None):
|
|||||||
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
listner.test(use_dropout=False, feedback='argmax', iters=None)
|
||||||
result = listner.get_results()
|
result = listner.get_results()
|
||||||
score_summary, _ = evaluator.score(result)
|
score_summary, _ = evaluator.score(result)
|
||||||
|
|
||||||
print(score_summary)
|
|
||||||
loss_str += ", %s " % env_name
|
loss_str += ", %s " % env_name
|
||||||
for metric, val in score_summary.items():
|
for metric, val in score_summary.items():
|
||||||
if metric in ['spl']:
|
if metric in ['spl']:
|
||||||
@ -200,11 +195,12 @@ def train_val(test_only=False):
|
|||||||
|
|
||||||
if test_only:
|
if test_only:
|
||||||
featurized_scans = None
|
featurized_scans = None
|
||||||
val_env_names = ['val_unseen']
|
val_env_names = ['val_train_seen']
|
||||||
else:
|
else:
|
||||||
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
|
||||||
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
# val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']
|
||||||
val_env_names = ['train','val_unseen']
|
# val_env_names = ['val_train_seen']
|
||||||
|
val_env_names = ['val_unseen']
|
||||||
|
|
||||||
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user