diff --git a/data/adversarial.py b/data/adversarial.py new file mode 100644 index 0000000..cf55985 --- /dev/null +++ b/data/adversarial.py @@ -0,0 +1,21 @@ +import json +import sys +import random + +with open(sys.argv[1]) as fp: + data = json.load(fp) + +for _, d in enumerate(data): + swaps = [] + for index, ins in enumerate(d['instructions']): + p = random.random() + if p > 0.5: + swaps.append(True) + d['instructions'][index] += 'This is swap.' + else: + swaps.append(False) + d['swap'] = swaps +print(data) + +with open(sys.argv[1], 'w') as fp: + json.dump(data, fp) diff --git a/r2r_src/agent.py b/r2r_src/agent.py index 7afbaa4..53fffb6 100644 --- a/r2r_src/agent.py +++ b/r2r_src/agent.py @@ -252,7 +252,10 @@ class Seq2SeqAgent(BaseAgent): # Language input sentence, language_attention_mask, token_type_ids, \ seq_lengths, perm_idx = self._sort_batch(obs) + + print("perm_index:", perm_idx) perm_obs = obs[perm_idx] + ''' Language BERT ''' language_inputs = {'mode': 'language', @@ -294,9 +297,6 @@ class Seq2SeqAgent(BaseAgent): input_a_t, candidate_feat, candidate_leng = self.get_input_feat(perm_obs) - print("input_a_t: ", input_a_t.shape) - print("candidate_feat: ", candidate_feat.shape) - print("candidate_leng: ", candidate_leng) # the first [CLS] token, initialized by the language BERT, serves # as the agent's state passing through time steps @@ -322,11 +322,13 @@ class Seq2SeqAgent(BaseAgent): # Mask outputs where agent can't move forward # Here the logit is [b, max_candidate] + # (8, max(candidate)) candidate_mask = utils.length2mask(candidate_leng) logit.masked_fill_(candidate_mask, -float('inf')) # Supervised training target = self._teacher_action(perm_obs, ended) + print("target: ", target.shape) ml_loss += self.criterion(logit, target) # Determine next model inputs diff --git a/r2r_src/env.py b/r2r_src/env.py index 5b36e62..d35174b 100644 --- a/r2r_src/env.py +++ b/r2r_src/env.py @@ -1,6 +1,8 @@ ''' Batched Room-to-Room navigation environment ''' import sys + +from networkx.algorithms import swap sys.path.append('buildpy36') sys.path.append('Matterport_Simulator/build/') import MatterSim @@ -14,6 +16,7 @@ import os import random import networkx as nx from param import args +import time from utils import load_datasets, load_nav_graphs, pad_instr_tokens from IPython import embed @@ -127,6 +130,7 @@ class R2RBatch(): new_item = dict(item) new_item['instr_id'] = '%s_%d' % (item['path_id'], j) new_item['instructions'] = instr + new_item['swap'] = item['swap'][j] ''' BERT tokenizer ''' instr_tokens = tokenizer.tokenize(instr) @@ -136,10 +140,12 @@ class R2RBatch(): if new_item['instr_encoding'] is not None: # Filter the wrong data self.data.append(new_item) scans.append(item['scan']) + except: continue print("split {} has {} datas in the file.".format(split, max_len)) + if name is None: self.name = splits[0] if len(splits) > 0 else "FAKE" else: @@ -341,7 +347,8 @@ class R2RBatch(): 'instructions' : item['instructions'], 'teacher' : self._shortest_path_action(state, item['path'][-1]), 'gt_path' : item['path'], - 'path_id' : item['path_id'] + 'path_id' : item['path_id'], + 'swap': item['swap'] }) if 'instr_encoding' in item: obs[-1]['instr_encoding'] = item['instr_encoding']