Compare commits
No commits in common. "857c7e8e10502cb331e190b7dc464a039a28e51e" and "832c6368ddcacd6db2d271df2abec9124c1eb519" have entirely different histories.
857c7e8e10
...
832c6368dd
@ -1,21 +0,0 @@
|
|||||||
import json
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
|
|
||||||
with open(sys.argv[1]) as fp:
|
|
||||||
data = json.load(fp)
|
|
||||||
|
|
||||||
for _, d in enumerate(data):
|
|
||||||
swaps = []
|
|
||||||
for index, ins in enumerate(d['instructions']):
|
|
||||||
p = random.random()
|
|
||||||
if p > 0.5:
|
|
||||||
swaps.append(True)
|
|
||||||
d['instructions'][index] += 'This is swap.'
|
|
||||||
else:
|
|
||||||
swaps.append(False)
|
|
||||||
d['swap'] = swaps
|
|
||||||
print(data)
|
|
||||||
|
|
||||||
with open(sys.argv[1], 'w') as fp:
|
|
||||||
json.dump(data, fp)
|
|
||||||
@ -147,7 +147,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
return Variable(torch.from_numpy(features), requires_grad=False).cuda()
|
||||||
|
|
||||||
def _candidate_variable(self, obs):
|
def _candidate_variable(self, obs):
|
||||||
candidate_leng = [len(ob['candidate']) + 2 for ob in obs] # +1 is for the end
|
candidate_leng = [len(ob['candidate']) + 1 for ob in obs] # +1 is for the end
|
||||||
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
candidate_feat = np.zeros((len(obs), max(candidate_leng), self.feature_size + args.angle_feat_size), dtype=np.float32)
|
||||||
|
|
||||||
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
# Note: The candidate_feat at len(ob['candidate']) is the feature for the END
|
||||||
@ -155,7 +155,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
for i, ob in enumerate(obs):
|
for i, ob in enumerate(obs):
|
||||||
for j, cc in enumerate(ob['candidate']):
|
for j, cc in enumerate(ob['candidate']):
|
||||||
candidate_feat[i, j, :] = cc['feature']
|
candidate_feat[i, j, :] = cc['feature']
|
||||||
candidate_feat[i, -1, :] = np.ones((self.feature_size + args.angle_feat_size))
|
|
||||||
|
|
||||||
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
return torch.from_numpy(candidate_feat).cuda(), candidate_leng
|
||||||
|
|
||||||
@ -187,10 +186,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
break
|
break
|
||||||
else: # Stop here
|
else: # Stop here
|
||||||
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
assert ob['teacher'] == ob['viewpoint'] # The teacher action should be "STAY HERE"
|
||||||
if ob['swap']: # instruction 有被換過,所以要 not found
|
a[i] = len(ob['candidate'])
|
||||||
a[i] = len(ob['candidate'])
|
|
||||||
else: # STOP
|
|
||||||
a[i] = len(ob['candidate'])-1
|
|
||||||
return torch.from_numpy(a).cuda()
|
return torch.from_numpy(a).cuda()
|
||||||
|
|
||||||
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
def make_equiv_action(self, a_t, perm_obs, perm_idx=None, traj=None):
|
||||||
@ -209,8 +205,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
for i, idx in enumerate(perm_idx):
|
for i, idx in enumerate(perm_idx):
|
||||||
action = a_t[i]
|
action = a_t[i]
|
||||||
print('action: ', action)
|
if action != -1: # -1 is the <stop> action
|
||||||
if action != -1 and action != -2: # -1 is the <stop> action
|
|
||||||
select_candidate = perm_obs[i]['candidate'][action]
|
select_candidate = perm_obs[i]['candidate'][action]
|
||||||
src_point = perm_obs[i]['viewIndex']
|
src_point = perm_obs[i]['viewIndex']
|
||||||
trg_point = select_candidate['pointId']
|
trg_point = select_candidate['pointId']
|
||||||
@ -233,11 +228,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
# print("action: {} view_index: {}".format(action, state.viewIndex))
|
||||||
if traj is not None:
|
if traj is not None:
|
||||||
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
traj[i]['path'].append((state.location.viewpointId, state.heading, state.elevation))
|
||||||
elif action == -1:
|
|
||||||
print('<STOP>')
|
|
||||||
elif action == -2:
|
|
||||||
print('<NOT_FOUND>')
|
|
||||||
|
|
||||||
|
|
||||||
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
def rollout(self, train_ml=None, train_rl=True, reset=True):
|
||||||
"""
|
"""
|
||||||
@ -262,9 +252,7 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Language input
|
# Language input
|
||||||
sentence, language_attention_mask, token_type_ids, \
|
sentence, language_attention_mask, token_type_ids, \
|
||||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||||
|
|
||||||
perm_obs = obs[perm_idx]
|
perm_obs = obs[perm_idx]
|
||||||
|
|
||||||
|
|
||||||
''' Language BERT '''
|
''' Language BERT '''
|
||||||
language_inputs = {'mode': 'language',
|
language_inputs = {'mode': 'language',
|
||||||
@ -306,6 +294,10 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||||
|
|
||||||
|
print("input_a_t: ", input_a_t.shape)
|
||||||
|
print("candidate_feat: ", candidate_feat.shape)
|
||||||
|
print("candidate_leng: ", candidate_leng)
|
||||||
|
|
||||||
# the first [CLS] token, initialized by the language BERT, serves
|
# the first [CLS] token, initialized by the language BERT, serves
|
||||||
# as the agent's state passing through time steps
|
# as the agent's state passing through time steps
|
||||||
if (t >= 1) or (args.vlnbert=='prevalent'):
|
if (t >= 1) or (args.vlnbert=='prevalent'):
|
||||||
@ -330,23 +322,11 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
|
|
||||||
# Mask outputs where agent can't move forward
|
# Mask outputs where agent can't move forward
|
||||||
# Here the logit is [b, max_candidate]
|
# Here the logit is [b, max_candidate]
|
||||||
# (8, max(candidate))
|
|
||||||
candidate_mask = utils.length2mask(candidate_leng)
|
candidate_mask = utils.length2mask(candidate_leng)
|
||||||
logit.masked_fill_(candidate_mask, -float('inf'))
|
logit.masked_fill_(candidate_mask, -float('inf'))
|
||||||
|
|
||||||
# Supervised training
|
# Supervised training
|
||||||
target = self._teacher_action(perm_obs, ended)
|
target = self._teacher_action(perm_obs, ended)
|
||||||
for i, d in enumerate(target):
|
|
||||||
print(perm_obs[i]['swap'], perm_obs[i]['instructions'])
|
|
||||||
print(d)
|
|
||||||
_, at_t = logit.max(1)
|
|
||||||
if at_t[i].item() == candidate_leng[i]-1:
|
|
||||||
print("-2")
|
|
||||||
elif at_t[i].item() == candidate_leng[i]-2:
|
|
||||||
print("-1")
|
|
||||||
else:
|
|
||||||
print(at_t[i].item())
|
|
||||||
print()
|
|
||||||
ml_loss += self.criterion(logit, target)
|
ml_loss += self.criterion(logit, target)
|
||||||
|
|
||||||
# Determine next model inputs
|
# Determine next model inputs
|
||||||
@ -367,15 +347,12 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
print(self.feedback)
|
print(self.feedback)
|
||||||
sys.exit('Invalid feedback option')
|
sys.exit('Invalid feedback option')
|
||||||
|
|
||||||
# Prepare environment action
|
# Prepare environment action
|
||||||
# NOTE: Env action is in the perm_obs space
|
# NOTE: Env action is in the perm_obs space
|
||||||
cpu_a_t = a_t.cpu().numpy()
|
cpu_a_t = a_t.cpu().numpy()
|
||||||
for i, next_id in enumerate(cpu_a_t):
|
for i, next_id in enumerate(cpu_a_t):
|
||||||
if next_id == (candidate_leng[i]-2) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
if next_id == (candidate_leng[i]-1) or next_id == args.ignoreid or ended[i]: # The last action is <end>
|
||||||
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
cpu_a_t[i] = -1 # Change the <end> and ignore action to -1
|
||||||
elif next_id == (candidate_leng[i]-1):
|
|
||||||
cpu_a_t[i] = -2
|
|
||||||
|
|
||||||
# Make action and get the new state
|
# Make action and get the new state
|
||||||
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
self.make_equiv_action(cpu_a_t, perm_obs, perm_idx, traj)
|
||||||
@ -402,22 +379,8 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
if action_idx == -1: # If the action now is end
|
if action_idx == -1: # If the action now is end
|
||||||
if dist[i] < 3.0: # Correct
|
if dist[i] < 3.0: # Correct
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
||||||
if ob['swap']:
|
|
||||||
reward[i] -= 2
|
|
||||||
else:
|
|
||||||
reward[i] += 1
|
|
||||||
else: # Incorrect
|
else: # Incorrect
|
||||||
reward[i] = -2.0
|
reward[i] = -2.0
|
||||||
elif action_idx == -2: # NOT_FOUND reward 設定在這裏
|
|
||||||
if dist[i] < 3.0:
|
|
||||||
reward[i] = 2.0 + ndtw_score[i] * 2.0
|
|
||||||
if ob['swap']:
|
|
||||||
reward[i] += 3 # 偵測到錯誤 instruction,多加一分
|
|
||||||
else:
|
|
||||||
reward[i] -= 2
|
|
||||||
else: # Incorrect
|
|
||||||
reward[i] = -2.0
|
|
||||||
reward[i] += 1 # distance > 3, 確實沒找到東西,從扣二變成扣一
|
|
||||||
else: # The action is not end
|
else: # The action is not end
|
||||||
# Path fidelity rewards (distance & nDTW)
|
# Path fidelity rewards (distance & nDTW)
|
||||||
reward[i] = - (dist[i] - last_dist[i])
|
reward[i] = - (dist[i] - last_dist[i])
|
||||||
@ -439,7 +402,6 @@ class Seq2SeqAgent(BaseAgent):
|
|||||||
# Update the finished actions
|
# Update the finished actions
|
||||||
# -1 means ended or ignored (already ended)
|
# -1 means ended or ignored (already ended)
|
||||||
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
ended[:] = np.logical_or(ended, (cpu_a_t == -1))
|
||||||
ended[:] = np.logical_or(ended, (cpu_a_t == -2))
|
|
||||||
|
|
||||||
# Early exit if all ended
|
# Early exit if all ended
|
||||||
if ended.all():
|
if ended.all():
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
''' Batched Room-to-Room navigation environment '''
|
''' Batched Room-to-Room navigation environment '''
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from networkx.algorithms import swap
|
|
||||||
sys.path.append('buildpy36')
|
sys.path.append('buildpy36')
|
||||||
sys.path.append('Matterport_Simulator/build/')
|
sys.path.append('Matterport_Simulator/build/')
|
||||||
import MatterSim
|
import MatterSim
|
||||||
@ -16,7 +14,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from param import args
|
from param import args
|
||||||
import time
|
|
||||||
|
|
||||||
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
@ -130,7 +127,6 @@ class R2RBatch():
|
|||||||
new_item = dict(item)
|
new_item = dict(item)
|
||||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||||
new_item['instructions'] = instr
|
new_item['instructions'] = instr
|
||||||
new_item['swap'] = item['swap'][j]
|
|
||||||
|
|
||||||
''' BERT tokenizer '''
|
''' BERT tokenizer '''
|
||||||
instr_tokens = tokenizer.tokenize(instr)
|
instr_tokens = tokenizer.tokenize(instr)
|
||||||
@ -140,12 +136,10 @@ class R2RBatch():
|
|||||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||||
self.data.append(new_item)
|
self.data.append(new_item)
|
||||||
scans.append(item['scan'])
|
scans.append(item['scan'])
|
||||||
|
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
print("split {} has {} datas in the file.".format(split, max_len))
|
print("split {} has {} datas in the file.".format(split, max_len))
|
||||||
|
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||||
else:
|
else:
|
||||||
@ -347,8 +341,7 @@ class R2RBatch():
|
|||||||
'instructions' : item['instructions'],
|
'instructions' : item['instructions'],
|
||||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||||
'gt_path' : item['path'],
|
'gt_path' : item['path'],
|
||||||
'path_id' : item['path_id'],
|
'path_id' : item['path_id']
|
||||||
'swap': item['swap']
|
|
||||||
})
|
})
|
||||||
if 'instr_encoding' in item:
|
if 'instr_encoding' in item:
|
||||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user