feat: adversarial test & load swap in observations
This commit is contained in:
parent
832c6368dd
commit
7329f7fa0a
21
data/adversarial.py
Normal file
21
data/adversarial.py
Normal file
@ -0,0 +1,21 @@
|
||||
import json
|
||||
import sys
|
||||
import random
|
||||
|
||||
with open(sys.argv[1]) as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
for _, d in enumerate(data):
|
||||
swaps = []
|
||||
for index, ins in enumerate(d['instructions']):
|
||||
p = random.random()
|
||||
if p > 0.5:
|
||||
swaps.append(True)
|
||||
d['instructions'][index] += 'This is swap.'
|
||||
else:
|
||||
swaps.append(False)
|
||||
d['swap'] = swaps
|
||||
print(data)
|
||||
|
||||
with open(sys.argv[1], 'w') as fp:
|
||||
json.dump(data, fp)
|
||||
@ -252,7 +252,10 @@ class Seq2SeqAgent(BaseAgent):
|
||||
# Language input
|
||||
sentence, language_attention_mask, token_type_ids, \
|
||||
seq_lengths, perm_idx = self._sort_batch(obs)
|
||||
|
||||
print("perm_index:", perm_idx)
|
||||
perm_obs = obs[perm_idx]
|
||||
|
||||
|
||||
''' Language BERT '''
|
||||
language_inputs = {'mode': 'language',
|
||||
@ -294,9 +297,6 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
|
||||
|
||||
print("input_a_t: ", input_a_t.shape)
|
||||
print("candidate_feat: ", candidate_feat.shape)
|
||||
print("candidate_leng: ", candidate_leng)
|
||||
|
||||
# the first [CLS] token, initialized by the language BERT, serves
|
||||
# as the agent's state passing through time steps
|
||||
@ -322,11 +322,13 @@ class Seq2SeqAgent(BaseAgent):
|
||||
|
||||
# Mask outputs where agent can't move forward
|
||||
# Here the logit is [b, max_candidate]
|
||||
# (8, max(candidate))
|
||||
candidate_mask = utils.length2mask(candidate_leng)
|
||||
logit.masked_fill_(candidate_mask, -float('inf'))
|
||||
|
||||
# Supervised training
|
||||
target = self._teacher_action(perm_obs, ended)
|
||||
print("target: ", target.shape)
|
||||
ml_loss += self.criterion(logit, target)
|
||||
|
||||
# Determine next model inputs
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
''' Batched Room-to-Room navigation environment '''
|
||||
|
||||
import sys
|
||||
|
||||
from networkx.algorithms import swap
|
||||
sys.path.append('buildpy36')
|
||||
sys.path.append('Matterport_Simulator/build/')
|
||||
import MatterSim
|
||||
@ -14,6 +16,7 @@ import os
|
||||
import random
|
||||
import networkx as nx
|
||||
from param import args
|
||||
import time
|
||||
|
||||
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
|
||||
from IPython import embed
|
||||
@ -127,6 +130,7 @@ class R2RBatch():
|
||||
new_item = dict(item)
|
||||
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
|
||||
new_item['instructions'] = instr
|
||||
new_item['swap'] = item['swap'][j]
|
||||
|
||||
''' BERT tokenizer '''
|
||||
instr_tokens = tokenizer.tokenize(instr)
|
||||
@ -136,10 +140,12 @@ class R2RBatch():
|
||||
if new_item['instr_encoding'] is not None: # Filter the wrong data
|
||||
self.data.append(new_item)
|
||||
scans.append(item['scan'])
|
||||
|
||||
except:
|
||||
continue
|
||||
print("split {} has {} datas in the file.".format(split, max_len))
|
||||
|
||||
|
||||
if name is None:
|
||||
self.name = splits[0] if len(splits) > 0 else "FAKE"
|
||||
else:
|
||||
@ -341,7 +347,8 @@ class R2RBatch():
|
||||
'instructions' : item['instructions'],
|
||||
'teacher' : self._shortest_path_action(state, item['path'][-1]),
|
||||
'gt_path' : item['path'],
|
||||
'path_id' : item['path_id']
|
||||
'path_id' : item['path_id'],
|
||||
'swap': item['swap']
|
||||
})
|
||||
if 'instr_encoding' in item:
|
||||
obs[-1]['instr_encoding'] = item['instr_encoding']
|
||||
|
||||
Loading…
Reference in New Issue
Block a user