adversarial_VLNDUET/map_nav_src/reverie/agent_obj.py

547 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import sys
import numpy as np
import random
import math
import time
from collections import defaultdict
import line_profiler
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from utils.distributed import is_default_gpu
from utils.ops import pad_tensors, gen_seq_masks
from torch.nn.utils.rnn import pad_sequence
from .agent_base import Seq2SeqAgent
from models.graph_utils import GraphMap
from models.model import VLNBert, Critic
from models.ops import pad_tensors_wgrad
class GMapObjectNavAgent(Seq2SeqAgent):
def _build_model(self):
self.vln_bert = VLNBert(self.args).cuda()
self.critic = Critic(self.args).cuda()
# buffer
self.scanvp_cands = {}
def _language_variable(self, obs):
seq_lengths = [len(ob['instr_encoding']) for ob in obs]
seq_tensor = np.zeros((len(obs), max(seq_lengths)), dtype=np.int64)
mask = np.zeros((len(obs), max(seq_lengths)), dtype=np.bool)
for i, ob in enumerate(obs):
seq_tensor[i, :seq_lengths[i]] = ob['instr_encoding']
mask[i, :seq_lengths[i]] = True
seq_tensor = torch.from_numpy(seq_tensor).long().cuda()
mask = torch.from_numpy(mask).cuda()
return {
'txt_ids': seq_tensor, 'txt_masks': mask
}
def _panorama_feature_variable(self, obs):
''' Extract precomputed features into variable. '''
batch_view_img_fts, batch_obj_img_fts, batch_loc_fts, batch_nav_types = [], [], [], []
batch_view_lens, batch_obj_lens = [], []
batch_cand_vpids, batch_objids = [], []
for i, ob in enumerate(obs):
view_img_fts, view_ang_fts, nav_types, cand_vpids = [], [], [], []
# cand views
used_viewidxs = set()
for j, cc in enumerate(ob['candidate']):
view_img_fts.append(cc['feature'][:self.args.image_feat_size])
view_ang_fts.append(cc['feature'][self.args.image_feat_size:])
nav_types.append(1)
cand_vpids.append(cc['viewpointId'])
used_viewidxs.add(cc['pointId'])
# non cand views
view_img_fts.extend([x[:self.args.image_feat_size] for k, x \
in enumerate(ob['feature']) if k not in used_viewidxs])
view_ang_fts.extend([x[self.args.image_feat_size:] for k, x \
in enumerate(ob['feature']) if k not in used_viewidxs])
nav_types.extend([0] * (36 - len(used_viewidxs)))
# combine cand views and noncand views
view_img_fts = np.stack(view_img_fts, 0) # (n_views, dim_ft)
view_ang_fts = np.stack(view_ang_fts, 0)
view_box_fts = np.array([[1, 1, 1]] * len(view_img_fts)).astype(np.float32)
view_loc_fts = np.concatenate([view_ang_fts, view_box_fts], 1)
# object
obj_loc_fts = np.concatenate([ob['obj_ang_fts'], ob['obj_box_fts']], 1)
nav_types.extend([2] * len(obj_loc_fts))
batch_view_img_fts.append(torch.from_numpy(view_img_fts))
batch_obj_img_fts.append(torch.from_numpy(ob['obj_img_fts']))
batch_loc_fts.append(torch.from_numpy(np.concatenate([view_loc_fts, obj_loc_fts], 0)))
batch_nav_types.append(torch.LongTensor(nav_types))
batch_cand_vpids.append(cand_vpids)
batch_objids.append(ob['obj_ids'])
batch_view_lens.append(len(view_img_fts))
batch_obj_lens.append(len(ob['obj_img_fts']))
# pad features to max_len
batch_view_img_fts = pad_tensors(batch_view_img_fts).cuda()
batch_obj_img_fts = pad_tensors(batch_obj_img_fts).cuda()
batch_loc_fts = pad_tensors(batch_loc_fts).cuda()
batch_nav_types = pad_sequence(batch_nav_types, batch_first=True, padding_value=0).cuda()
batch_view_lens = torch.LongTensor(batch_view_lens).cuda()
batch_obj_lens = torch.LongTensor(batch_obj_lens).cuda()
return {
'view_img_fts': batch_view_img_fts, 'obj_img_fts': batch_obj_img_fts,
'loc_fts': batch_loc_fts, 'nav_types': batch_nav_types,
'view_lens': batch_view_lens, 'obj_lens': batch_obj_lens,
'cand_vpids': batch_cand_vpids, 'obj_ids': batch_objids,
}
def _nav_gmap_variable(self, obs, gmaps):
# [stop] + gmap_vpids
batch_size = len(obs)
batch_gmap_vpids, batch_gmap_lens = [], []
batch_gmap_img_embeds, batch_gmap_step_ids, batch_gmap_pos_fts = [], [], []
batch_gmap_pair_dists, batch_gmap_visited_masks = [], []
batch_no_vp_left = []
for i, gmap in enumerate(gmaps):
visited_vpids, unvisited_vpids = [], []
for k in gmap.node_positions.keys():
if gmap.graph.visited(k):
visited_vpids.append(k)
else:
unvisited_vpids.append(k)
batch_no_vp_left.append(len(unvisited_vpids) == 0)
if self.args.enc_full_graph:
gmap_vpids = [None] + visited_vpids + unvisited_vpids
gmap_visited_masks = [0] + [1] * len(visited_vpids) + [0] * len(unvisited_vpids)
else:
gmap_vpids = [None] + unvisited_vpids
gmap_visited_masks = [0] * len(gmap_vpids)
gmap_step_ids = [gmap.node_step_ids.get(vp, 0) for vp in gmap_vpids]
gmap_img_embeds = [gmap.get_node_embed(vp) for vp in gmap_vpids[1:]]
gmap_img_embeds = torch.stack(
[torch.zeros_like(gmap_img_embeds[0])] + gmap_img_embeds, 0
) # cuda
gmap_pos_fts = gmap.get_pos_fts(
obs[i]['viewpoint'], gmap_vpids, obs[i]['heading'], obs[i]['elevation'],
)
gmap_pair_dists = np.zeros((len(gmap_vpids), len(gmap_vpids)), dtype=np.float32)
for i in range(1, len(gmap_vpids)):
for j in range(i+1, len(gmap_vpids)):
gmap_pair_dists[i, j] = gmap_pair_dists[j, i] = \
gmap.graph.distance(gmap_vpids[i], gmap_vpids[j])
batch_gmap_img_embeds.append(gmap_img_embeds)
batch_gmap_step_ids.append(torch.LongTensor(gmap_step_ids))
batch_gmap_pos_fts.append(torch.from_numpy(gmap_pos_fts))
batch_gmap_pair_dists.append(torch.from_numpy(gmap_pair_dists))
batch_gmap_visited_masks.append(torch.BoolTensor(gmap_visited_masks))
batch_gmap_vpids.append(gmap_vpids)
batch_gmap_lens.append(len(gmap_vpids))
# collate
batch_gmap_lens = torch.LongTensor(batch_gmap_lens)
batch_gmap_masks = gen_seq_masks(batch_gmap_lens).cuda()
batch_gmap_img_embeds = pad_tensors_wgrad(batch_gmap_img_embeds)
batch_gmap_step_ids = pad_sequence(batch_gmap_step_ids, batch_first=True).cuda()
batch_gmap_pos_fts = pad_tensors(batch_gmap_pos_fts).cuda()
batch_gmap_visited_masks = pad_sequence(batch_gmap_visited_masks, batch_first=True).cuda()
max_gmap_len = max(batch_gmap_lens)
gmap_pair_dists = torch.zeros(batch_size, max_gmap_len, max_gmap_len).float()
for i in range(batch_size):
gmap_pair_dists[i, :batch_gmap_lens[i], :batch_gmap_lens[i]] = batch_gmap_pair_dists[i]
gmap_pair_dists = gmap_pair_dists.cuda()
return {
'gmap_vpids': batch_gmap_vpids, 'gmap_img_embeds': batch_gmap_img_embeds,
'gmap_step_ids': batch_gmap_step_ids, 'gmap_pos_fts': batch_gmap_pos_fts,
'gmap_visited_masks': batch_gmap_visited_masks,
'gmap_pair_dists': gmap_pair_dists, 'gmap_masks': batch_gmap_masks,
'no_vp_left': batch_no_vp_left,
}
def _nav_vp_variable(self, obs, gmaps, pano_embeds, cand_vpids, view_lens, obj_lens, nav_types):
batch_size = len(obs)
# print("PANO shape", pano_embeds.shape)
# add [stop] token & [NOT FOUND] token
# [STOP] 在最前面, [NOT FOUND] 在最後面
vp_img_embeds = torch.cat(
[torch.zeros_like(pano_embeds[:, :1]), pano_embeds, torch.ones_like(pano_embeds[:, :1])], 1
)
# print("SHAPE:", vp_img_embeds.shape)
batch_vp_pos_fts = []
for i, gmap in enumerate(gmaps):
cur_cand_pos_fts = gmap.get_pos_fts(
obs[i]['viewpoint'], cand_vpids[i],
obs[i]['heading'], obs[i]['elevation']
)
cur_start_pos_fts = gmap.get_pos_fts(
obs[i]['viewpoint'], [gmap.start_vp],
obs[i]['heading'], obs[i]['elevation']
)
# add [stop] token at beginning
vp_pos_fts = np.zeros((vp_img_embeds.size(1), 14), dtype=np.float32)
# print("vp_pos_fts:", vp_pos_fts.shape)
vp_pos_fts[:, :7] = cur_start_pos_fts
# print("vp_pos_fts[:, :7]:", vp_pos_fts[:, :7].shape)
# print("cur_start_pos_fts:", cur_start_pos_fts.shape)
vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:] = cur_cand_pos_fts
# print("vp_pos_fts[1:len(), 7:]:", vp_pos_fts[1:len(cur_cand_pos_fts)+1, 7:].shape)
# print("cur_cand_pos_fts:", cur_cand_pos_fts.shape)
batch_vp_pos_fts.append(torch.from_numpy(vp_pos_fts))
batch_vp_pos_fts = pad_tensors(batch_vp_pos_fts).cuda()
# 要把 stop 和 not found 的 mask 補上去
# 這邊把 stop 跟 candidate 放一起、把 not found 跟 object 放一起
vp_nav_masks = torch.cat([torch.ones(batch_size, 1).bool().cuda(), nav_types == 1, torch.zeros(batch_size, 1).bool().cuda()], 1)
vp_obj_masks = torch.cat([torch.zeros(batch_size, 1).bool().cuda(), nav_types == 2, torch.ones(batch_size, 1).bool().cuda()], 1)
# print('vp_nav_masks:', vp_nav_masks.shape)
# print('vp_obj_masks:', vp_obj_masks.shape)
vp_masks = gen_seq_masks(view_lens+obj_lens+2)
# print()
return {
'vp_img_embeds': vp_img_embeds,
'vp_pos_fts': batch_vp_pos_fts,
'vp_masks': vp_masks,
'vp_nav_masks': vp_nav_masks,
'vp_obj_masks': vp_obj_masks,
'vp_cand_vpids': [[None]+x for x in cand_vpids],
}
def _teacher_action(self, obs, vpids, ended, visited_masks=None):
"""
Extract teacher actions into variable.
:param obs: The observation.
:param ended: Whether the action seq is ended
:return:
"""
a = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]: # Just ignore this index
a[i] = self.args.ignoreid
else:
if ob['viewpoint'] == ob['gt_path'][-1]:
a[i] = 0 # Stop if arrived
else:
scan = ob['scan']
cur_vp = ob['viewpoint']
min_idx, min_dist = self.args.ignoreid, float('inf')
for j, vpid in enumerate(vpids[i]):
if j > 0 and ((visited_masks is None) or (not visited_masks[i][j])):
# dist = min([self.env.shortest_distances[scan][vpid][end_vp] for end_vp in ob['gt_end_vps']])
dist = self.env.shortest_distances[scan][vpid][ob['gt_path'][-1]] \
+ self.env.shortest_distances[scan][cur_vp][vpid]
if dist < min_dist:
min_dist = dist
min_idx = j
a[i] = min_idx
if min_idx == self.args.ignoreid:
print('scan %s: all vps are searched' % (scan))
return torch.from_numpy(a).cuda()
def _teacher_object(self, obs, ended, view_lens, obj_logits):
targets = np.zeros(len(obs), dtype=np.int64)
for i, ob in enumerate(obs):
if ended[i]:
targets[i] = self.args.ignoreid
else:
i_vp = ob['viewpoint']
if i_vp not in ob['gt_end_vps']:
targets[i] = self.args.ignoreid
else:
i_objids = ob['obj_ids']
targets[i] = self.args.ignoreid
for j, obj_id in enumerate(i_objids):
if str(obj_id) == str(ob['gt_obj_id']):
if ob['gt_found'] == True: # 可以找得到
targets[i] = j + view_lens[i] + 1
else:
targets[i] = len(obj_logits[i])-1 # 不能找到,
break
return torch.from_numpy(targets).cuda()
def make_equiv_action(self, a_t, gmaps, obs, traj=None):
"""
Interface between Panoramic view and Egocentric view
It will convert the action panoramic view action a_t to equivalent egocentric view actions for the simulator
"""
for i, ob in enumerate(obs):
action = a_t[i]
if action is not None: # None is the <stop> action
traj[i]['path'].append(gmaps[i].graph.path(ob['viewpoint'], action))
if len(traj[i]['path'][-1]) == 1:
prev_vp = traj[i]['path'][-2][-1]
else:
prev_vp = traj[i]['path'][-1][-2]
viewidx = self.scanvp_cands['%s_%s'%(ob['scan'], prev_vp)][action]
heading = (viewidx % 12) * math.radians(30)
elevation = (viewidx // 12 - 1) * math.radians(30)
self.env.env.sims[i].newEpisode([ob['scan']], [action], [heading], [elevation])
def _update_scanvp_cands(self, obs):
for ob in obs:
scan = ob['scan']
vp = ob['viewpoint']
scanvp = '%s_%s' % (scan, vp)
self.scanvp_cands.setdefault(scanvp, {})
for cand in ob['candidate']:
self.scanvp_cands[scanvp].setdefault(cand['viewpointId'], {})
self.scanvp_cands[scanvp][cand['viewpointId']] = cand['pointId']
# @profile
def rollout(self, train_ml=None, train_rl=False, reset=True):
if reset: # Reset env
obs = self.env.reset()
else:
obs = self.env._get_obs()
self._update_scanvp_cands(obs)
batch_size = len(obs)
# build graph: keep the start viewpoint
gmaps = [GraphMap(ob['viewpoint']) for ob in obs] # input the start point
for i, ob in enumerate(obs):
gmaps[i].update_graph(ob)
# Record the navigation path
traj = [{
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'pred_objid': None,
'gt_objid': None,
'found': None,
'gt_found': None,
'details': {},
} for ob in obs]
# Language input: txt_ids, txt_masks
language_inputs = self._language_variable(obs)
txt_embeds = self.vln_bert('language', language_inputs)
# Initialization the tracking state
ended = np.array([False] * batch_size)
just_ended = np.array([False] * batch_size)
# Init the logs
masks = []
entropys = []
ml_loss = 0.
og_loss = 0.
for t in range(self.args.max_action_len):
for i, gmap in enumerate(gmaps):
if not ended[i]:
gmap.node_step_ids[obs[i]['viewpoint']] = t + 1
# graph representation
pano_inputs = self._panorama_feature_variable(obs)
pano_embeds, pano_masks = self.vln_bert('panorama', pano_inputs)
avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
torch.sum(pano_masks, 1, keepdim=True)
for i, gmap in enumerate(gmaps):
if not ended[i]:
# update visited node
i_vp = obs[i]['viewpoint']
gmap.update_node_embed(i_vp, avg_pano_embeds[i], rewrite=True)
# update unvisited nodes
for j, i_cand_vp in enumerate(pano_inputs['cand_vpids'][i]):
if not gmap.graph.visited(i_cand_vp):
gmap.update_node_embed(i_cand_vp, pano_embeds[i, j])
# navigation policy
nav_inputs = self._nav_gmap_variable(obs, gmaps)
nav_inputs.update(
self._nav_vp_variable(
obs, gmaps, pano_embeds, pano_inputs['cand_vpids'],
pano_inputs['view_lens'], pano_inputs['obj_lens'],
pano_inputs['nav_types'],
)
)
nav_inputs.update({
'txt_embeds': txt_embeds,
'txt_masks': language_inputs['txt_masks'],
})
nav_outs = self.vln_bert('navigation', nav_inputs)
if self.args.fusion == 'local':
nav_logits = nav_outs['local_logits']
nav_vpids = nav_inputs['vp_cand_vpids']
elif self.args.fusion == 'global':
nav_logits = nav_outs['global_logits']
nav_vpids = nav_inputs['gmap_vpids']
else:
nav_logits = nav_outs['fused_logits']
nav_vpids = nav_inputs['gmap_vpids']
nav_probs = torch.softmax(nav_logits, 1)
obj_logits = nav_outs['obj_logits']
# update graph
for i, gmap in enumerate(gmaps):
if not ended[i]:
i_vp = obs[i]['viewpoint']
# update i_vp: stop and object grounding scores
i_objids = obs[i]['obj_ids']
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:] # 最後一個是 not found
if len(i_objids) > 0:
if torch.argmax(i_obj_logits) >= len(i_objids): # not found 那格 logit 最大(會在最後一格)
og = -1
else:
og = i_objids[torch.argmax(i_obj_logits)]
else:
og = None
# 如果有找到og 會是 object id
# 如果是 not foundog 會是 -1
# 如果這個 viewpoint 看不到物件og 會是 None
gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(),
'og': og,
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
}
if train_ml is not None:
# Supervised training
nav_targets = self._teacher_action(
obs, nav_vpids, ended,
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
)
# print(t, nav_logits, nav_targets)
ml_loss += self.criterion(nav_logits, nav_targets)
# print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())
if self.args.fusion in ['avg', 'dynamic'] and self.args.loss_nav_3:
# add global and local losses
ml_loss += self.criterion(nav_outs['global_logits'], nav_targets) # global
local_nav_targets = self._teacher_action(
obs, nav_inputs['vp_cand_vpids'], ended, visited_masks=None
)
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'], obj_logits)
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
# Determinate the next navigation viewpoint
if self.feedback == 'teacher':
a_t = nav_targets # teacher forcing
elif self.feedback == 'argmax':
_, a_t = nav_logits.max(1) # student forcing - argmax
a_t = a_t.detach()
elif self.feedback == 'sample':
c = torch.distributions.Categorical(nav_probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization
a_t = c.sample().detach()
elif self.feedback == 'expl_sample':
_, a_t = nav_probs.max(1)
rand_explores = np.random.rand(batch_size, ) > self.args.expl_max_ratio # hyper-param
if self.args.fusion == 'local':
cpu_nav_masks = nav_inputs['vp_nav_masks'].data.cpu().numpy()
else:
cpu_nav_masks = (nav_inputs['gmap_masks'] * nav_inputs['gmap_visited_masks'].logical_not()).data.cpu().numpy()
for i in range(batch_size):
if rand_explores[i]:
cand_a_t = np.arange(len(cpu_nav_masks[i]))[cpu_nav_masks[i]]
a_t[i] = np.random.choice(cand_a_t)
else:
print(self.feedback)
sys.exit('Invalid feedback option')
# Determine stop actions
if self.feedback == 'teacher' or self.feedback == 'sample': # in training
# a_t_stop = [ob['viewpoint'] in ob['gt_end_vps'] for ob in obs]
a_t_stop = [ob['viewpoint'] == ob['gt_path'][-1] for ob in obs]
else:
a_t_stop = a_t == 0
# Prepare environment action
cpu_a_t = []
for i in range(batch_size):
if a_t_stop[i] or ended[i] or nav_inputs['no_vp_left'][i] or (t == self.args.max_action_len - 1):
cpu_a_t.append(None)
just_ended[i] = True
else:
cpu_a_t.append(nav_vpids[i][a_t[i]])
original_gt_founds = [ ob['gt_found'] for ob in obs ]
# Make action and get the new state
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
for i in range(batch_size):
traj[i]['gt_found'] = original_gt_founds[i]
if (not ended[i]) and just_ended[i]:
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
for k, v in gmaps[i].node_stop_scores.items():
if v['stop'] > stop_score['stop']:
stop_score = v
stop_node = k
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_objid'] = stop_score['og']
if stop_score['og'] == -1:
traj[i]['found'] = False
else:
traj[i]['found'] = True
if self.args.detailed_output:
for k, v in gmaps[i].node_stop_scores.items():
traj[i]['details'][k] = {
'stop_prob': float(v['stop']),
'obj_ids': [str(x) for x in v['og_details']['objids']],
'obj_logits': v['og_details']['logits'].tolist(),
}
# new observation and update graph
obs = self.env._get_obs()
self._update_scanvp_cands(obs)
for i, ob in enumerate(obs):
if not ended[i]:
gmaps[i].update_graph(ob)
ended[:] = np.logical_or(ended, np.array([x is None for x in cpu_a_t]))
# Early exit if all ended
if ended.all():
break
if train_ml is not None:
ml_loss = ml_loss * train_ml / batch_size
og_loss = og_loss * train_ml / batch_size
self.loss += ml_loss
self.loss += og_loss
self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item())
'''
print("TRAJ:")
for i in traj:
print(" GT: {}, PREDICT: {}, SCORE: {}".format(i['gt_found'], i['found'], 1 if i['gt_found']==i['found'] else 0))
'''
return traj