feat: adversarial test & load swap in observations

This commit is contained in:
Ting-Jun Wang 2023-11-04 21:29:17 +08:00
parent 832c6368dd
commit 7329f7fa0a
Signed by: snsd0805
GPG Key ID: 48D331A3D6160354
3 changed files with 34 additions and 4 deletions

21
data/adversarial.py Normal file
View File

@ -0,0 +1,21 @@
import json
import sys
import random
with open(sys.argv[1]) as fp:
data = json.load(fp)
for _, d in enumerate(data):
swaps = []
for index, ins in enumerate(d['instructions']):
p = random.random()
if p > 0.5:
swaps.append(True)
d['instructions'][index] += 'This is swap.'
else:
swaps.append(False)
d['swap'] = swaps
print(data)
with open(sys.argv[1], 'w') as fp:
json.dump(data, fp)

View File

@ -252,7 +252,10 @@ class Seq2SeqAgent(BaseAgent):
# Language input
sentence, language_attention_mask, token_type_ids, \
seq_lengths, perm_idx = self._sort_batch(obs)
print("perm_index:", perm_idx)
perm_obs = obs[perm_idx]
''' Language BERT '''
language_inputs = {'mode': 'language',
@ -294,9 +297,6 @@ class Seq2SeqAgent(BaseAgent):
input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs)
print("input_a_t: ", input_a_t.shape)
print("candidate_feat: ", candidate_feat.shape)
print("candidate_leng: ", candidate_leng)
# the first [CLS] token, initialized by the language BERT, serves
# as the agent's state passing through time steps
@ -322,11 +322,13 @@ class Seq2SeqAgent(BaseAgent):
# Mask outputs where agent can't move forward
# Here the logit is [b, max_candidate]
# (8, max(candidate))
candidate_mask = utils.length2mask(candidate_leng)
logit.masked_fill_(candidate_mask, -float('inf'))
# Supervised training
target = self._teacher_action(perm_obs, ended)
print("target: ", target.shape)
ml_loss += self.criterion(logit, target)
# Determine next model inputs

View File

@ -1,6 +1,8 @@
''' Batched Room-to-Room navigation environment '''
import sys
from networkx.algorithms import swap
sys.path.append('buildpy36')
sys.path.append('Matterport_Simulator/build/')
import MatterSim
@ -14,6 +16,7 @@ import os
import random
import networkx as nx
from param import args
import time
from utils import load_datasets, load_nav_graphs, pad_instr_tokens
from IPython import embed
@ -127,6 +130,7 @@ class R2RBatch():
new_item = dict(item)
new_item['instr_id'] = '%s_%d' % (item['path_id'], j)
new_item['instructions'] = instr
new_item['swap'] = item['swap'][j]
''' BERT tokenizer '''
instr_tokens = tokenizer.tokenize(instr)
@ -136,10 +140,12 @@ class R2RBatch():
if new_item['instr_encoding'] is not None: # Filter the wrong data
self.data.append(new_item)
scans.append(item['scan'])
except:
continue
print("split {} has {} datas in the file.".format(split, max_len))
if name is None:
self.name = splits[0] if len(splits) > 0 else "FAKE"
else:
@ -341,7 +347,8 @@ class R2RBatch():
'instructions' : item['instructions'],
'teacher' : self._shortest_path_action(state, item['path'][-1]),
'gt_path' : item['path'],
'path_id' : item['path_id']
'path_id' : item['path_id'],
'swap': item['swap']
})
if 'instr_encoding' in item:
obs[-1]['instr_encoding'] = item['instr_encoding']