adversarial_VLNDUET/map_nav_src/soon/agent_obj.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

241 lines
11 KiB
Python

import json
import os
import sys
import numpy as np
import random
import math
import time
from collections import defaultdict
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from utils.distributed import is_default_gpu
from utils.ops import pad_tensors, gen_seq_masks
from torch.nn.utils.rnn import pad_sequence
from reverie.agent_obj import GMapObjectNavAgent
from models.graph_utils import GraphMap
from models.model import VLNBert, Critic
class SoonGMapObjectNavAgent(GMapObjectNavAgent):
def get_results(self):
output = [{'instr_id': k,
'trajectory': {
'path': v['path'],
'obj_heading': [v['pred_obj_direction'][0]],
'obj_elevation': [v['pred_obj_direction'][1]],
}} for k, v in self.results.items()]
return output
def rollout(self, train_ml=None, train_rl=False, reset=True):
if reset: # Reset env
obs = self.env.reset()
else:
obs = self.env._get_obs()
self._update_scanvp_cands(obs)
batch_size = len(obs)
# build graph: keep the start viewpoint
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
for i, ob in enumerate(obs):
gmaps[i].update_graph(ob)
# Record the navigation path
traj = [{
'instr_id': ob['instr_id'],
'path': [[ob['viewpoint']]],
'pred_obj_direction': None,
'details': {},
} for ob in obs]
# Language input: txt_ids, txt_masks
language_inputs = self._language_variable(obs)
txt_embeds = self.vln_bert('language', language_inputs)
# Initialization the tracking state
ended = np.array([False] * batch_size)
just_ended = np.array([False] * batch_size)
# Init the logs
masks = []
entropys = []
ml_loss = 0.
og_loss = 0.
for t in range(self.args.max_action_len):
for i, gmap in enumerate(gmaps):
if not ended[i]:
gmap.node_step_ids[obs[i]['viewpoint']] = t + 1
# graph representation
pano_inputs = self._panorama_feature_variable(obs)
pano_embeds, pano_masks = self.vln_bert('panorama', pano_inputs)
avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
torch.sum(pano_masks, 1, keepdim=True)
for i, gmap in enumerate(gmaps):
if not ended[i]:
# update visited node
i_vp = obs[i]['viewpoint']
gmap.update_node_embed(i_vp, avg_pano_embeds[i], rewrite=True)
# update unvisited nodes
for j, i_cand_vp in enumerate(pano_inputs['cand_vpids'][i]):
if not gmap.graph.visited(i_cand_vp):
gmap.update_node_embed(i_cand_vp, pano_embeds[i, j])
# navigation policy
nav_inputs = self._nav_gmap_variable(obs, gmaps)
nav_inputs.update(
self._nav_vp_variable(
obs, gmaps, pano_embeds, pano_inputs['cand_vpids'],
pano_inputs['view_lens'], pano_inputs['obj_lens'],
pano_inputs['nav_types'],
)
)
nav_inputs.update({
'txt_embeds': txt_embeds,
'txt_masks': language_inputs['txt_masks'],
})
nav_outs = self.vln_bert('navigation', nav_inputs)
if self.args.fusion == 'local':
nav_logits = nav_outs['local_logits']
nav_vpids = nav_inputs['vp_cand_vpids']
elif self.args.fusion == 'global':
nav_logits = nav_outs['global_logits']
nav_vpids = nav_inputs['gmap_vpids']
else:
nav_logits = nav_outs['fused_logits']
nav_vpids = nav_inputs['gmap_vpids']
nav_probs = torch.softmax(nav_logits, 1)
obj_logits = nav_outs['obj_logits']
# update graph
for i, gmap in enumerate(gmaps):
if not ended[i]:
i_vp = obs[i]['viewpoint']
# update i_vp: stop and object grounding scores
i_objids = obs[i]['obj_ids']
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
gmap.node_stop_scores[i_vp] = {
'stop': nav_probs[i, 0].data.item(),
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
'og_direction': obs[i]['obj_directions'][torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
}
if train_ml is not None:
# Supervised training
nav_targets = self._teacher_action(
obs, nav_vpids, ended,
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
)
# print(t, nav_logits, nav_targets)
ml_loss += self.criterion(nav_logits, nav_targets)
# print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())
if self.args.fusion in ['avg', 'dynamic'] and self.args.loss_nav_3:
# add global and local losses
ml_loss += self.criterion(nav_outs['global_logits'], nav_targets) # global
local_nav_targets = self._teacher_action(
obs, nav_inputs['vp_cand_vpids'], ended, visited_masks=None
)
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
# objec grounding
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
og_loss += self.criterion(obj_logits, obj_targets)
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
# Determinate the next navigation viewpoint
if self.feedback == 'teacher':
a_t = nav_targets # teacher forcing
elif self.feedback == 'argmax':
_, a_t = nav_logits.max(1) # student forcing - argmax
a_t = a_t.detach()
elif self.feedback == 'sample':
c = torch.distributions.Categorical(nav_probs)
self.logs['entropy'].append(c.entropy().sum().item()) # For log
entropys.append(c.entropy()) # For optimization
a_t = c.sample().detach()
elif self.feedback == 'expl_sample':
_, a_t = nav_probs.max(1)
rand_explores = np.random.rand(batch_size, ) > self.args.expl_max_ratio # hyper-param
if self.args.fusion == 'local':
cpu_nav_masks = nav_inputs['vp_nav_masks'].data.cpu().numpy()
else:
cpu_nav_masks = (nav_inputs['gmap_masks'] * nav_inputs['gmap_visited_masks'].logical_not()).data.cpu().numpy()
for i in range(batch_size):
if rand_explores[i]:
cand_a_t = np.arange(len(cpu_nav_masks[i]))[cpu_nav_masks[i]]
a_t[i] = np.random.choice(cand_a_t)
else:
print(self.feedback)
sys.exit('Invalid feedback option')
# Determine stop actions
if self.feedback == 'teacher' or self.feedback == 'sample': # in training
# a_t_stop = [ob['viewpoint'] in ob['gt_end_vps'] for ob in obs]
a_t_stop = [ob['viewpoint'] == ob['gt_path'][-1] for ob in obs]
else:
a_t_stop = a_t == 0
# Prepare environment action
cpu_a_t = []
for i in range(batch_size):
if a_t_stop[i] or ended[i] or nav_inputs['no_vp_left'][i] or (t == self.args.max_action_len - 1):
cpu_a_t.append(None)
just_ended[i] = True
else:
cpu_a_t.append(nav_vpids[i][a_t[i]])
# Make action and get the new state
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
for i in range(batch_size):
if (not ended[i]) and just_ended[i]:
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
for k, v in gmaps[i].node_stop_scores.items():
if v['stop'] > stop_score['stop']:
stop_score = v
stop_node = k
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
traj[i]['pred_obj_direction'] = stop_score['og_direction']
if self.args.detailed_output:
for k, v in gmaps[i].node_stop_scores.items():
traj[i]['details'][k] = {
'stop_prob': float(v['stop']),
'obj_ids': [str(x) for x in v['og_details']['objids']],
'obj_logits': v['og_details']['logits'].tolist(),
}
# new observation and update graph
obs = self.env._get_obs()
self._update_scanvp_cands(obs)
for i, ob in enumerate(obs):
if not ended[i]:
gmaps[i].update_graph(ob)
ended[:] = np.logical_or(ended, np.array([x is None for x in cpu_a_t]))
# Early exit if all ended
if ended.all():
break
if train_ml is not None:
ml_loss = ml_loss * train_ml / batch_size
og_loss = og_loss * train_ml / batch_size
self.loss += ml_loss
self.loss += og_loss
self.logs['IL_loss'].append(ml_loss.item())
self.logs['OG_loss'].append(og_loss.item())
return traj