241 lines
11 KiB
Python
241 lines
11 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
import random
|
|
import math
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
import torch.nn.functional as F
|
|
|
|
from utils.distributed import is_default_gpu
|
|
from utils.ops import pad_tensors, gen_seq_masks
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from reverie.agent_obj import GMapObjectNavAgent
|
|
from models.graph_utils import GraphMap
|
|
from models.model import VLNBert, Critic
|
|
|
|
|
|
class SoonGMapObjectNavAgent(GMapObjectNavAgent):
|
|
|
|
def get_results(self):
|
|
output = [{'instr_id': k,
|
|
'trajectory': {
|
|
'path': v['path'],
|
|
'obj_heading': [v['pred_obj_direction'][0]],
|
|
'obj_elevation': [v['pred_obj_direction'][1]],
|
|
}} for k, v in self.results.items()]
|
|
return output
|
|
|
|
def rollout(self, train_ml=None, train_rl=False, reset=True):
|
|
if reset: # Reset env
|
|
obs = self.env.reset()
|
|
else:
|
|
obs = self.env._get_obs()
|
|
self._update_scanvp_cands(obs)
|
|
|
|
batch_size = len(obs)
|
|
# build graph: keep the start viewpoint
|
|
gmaps = [GraphMap(ob['viewpoint']) for ob in obs]
|
|
for i, ob in enumerate(obs):
|
|
gmaps[i].update_graph(ob)
|
|
|
|
# Record the navigation path
|
|
traj = [{
|
|
'instr_id': ob['instr_id'],
|
|
'path': [[ob['viewpoint']]],
|
|
'pred_obj_direction': None,
|
|
'details': {},
|
|
} for ob in obs]
|
|
|
|
# Language input: txt_ids, txt_masks
|
|
language_inputs = self._language_variable(obs)
|
|
txt_embeds = self.vln_bert('language', language_inputs)
|
|
|
|
# Initialization the tracking state
|
|
ended = np.array([False] * batch_size)
|
|
just_ended = np.array([False] * batch_size)
|
|
|
|
# Init the logs
|
|
masks = []
|
|
entropys = []
|
|
ml_loss = 0.
|
|
og_loss = 0.
|
|
|
|
for t in range(self.args.max_action_len):
|
|
for i, gmap in enumerate(gmaps):
|
|
if not ended[i]:
|
|
gmap.node_step_ids[obs[i]['viewpoint']] = t + 1
|
|
|
|
# graph representation
|
|
pano_inputs = self._panorama_feature_variable(obs)
|
|
pano_embeds, pano_masks = self.vln_bert('panorama', pano_inputs)
|
|
avg_pano_embeds = torch.sum(pano_embeds * pano_masks.unsqueeze(2), 1) / \
|
|
torch.sum(pano_masks, 1, keepdim=True)
|
|
|
|
for i, gmap in enumerate(gmaps):
|
|
if not ended[i]:
|
|
# update visited node
|
|
i_vp = obs[i]['viewpoint']
|
|
gmap.update_node_embed(i_vp, avg_pano_embeds[i], rewrite=True)
|
|
# update unvisited nodes
|
|
for j, i_cand_vp in enumerate(pano_inputs['cand_vpids'][i]):
|
|
if not gmap.graph.visited(i_cand_vp):
|
|
gmap.update_node_embed(i_cand_vp, pano_embeds[i, j])
|
|
|
|
# navigation policy
|
|
nav_inputs = self._nav_gmap_variable(obs, gmaps)
|
|
nav_inputs.update(
|
|
self._nav_vp_variable(
|
|
obs, gmaps, pano_embeds, pano_inputs['cand_vpids'],
|
|
pano_inputs['view_lens'], pano_inputs['obj_lens'],
|
|
pano_inputs['nav_types'],
|
|
)
|
|
)
|
|
nav_inputs.update({
|
|
'txt_embeds': txt_embeds,
|
|
'txt_masks': language_inputs['txt_masks'],
|
|
})
|
|
nav_outs = self.vln_bert('navigation', nav_inputs)
|
|
|
|
if self.args.fusion == 'local':
|
|
nav_logits = nav_outs['local_logits']
|
|
nav_vpids = nav_inputs['vp_cand_vpids']
|
|
elif self.args.fusion == 'global':
|
|
nav_logits = nav_outs['global_logits']
|
|
nav_vpids = nav_inputs['gmap_vpids']
|
|
else:
|
|
nav_logits = nav_outs['fused_logits']
|
|
nav_vpids = nav_inputs['gmap_vpids']
|
|
|
|
nav_probs = torch.softmax(nav_logits, 1)
|
|
obj_logits = nav_outs['obj_logits']
|
|
|
|
# update graph
|
|
for i, gmap in enumerate(gmaps):
|
|
if not ended[i]:
|
|
i_vp = obs[i]['viewpoint']
|
|
# update i_vp: stop and object grounding scores
|
|
i_objids = obs[i]['obj_ids']
|
|
i_obj_logits = obj_logits[i, pano_inputs['view_lens'][i]+1:]
|
|
gmap.node_stop_scores[i_vp] = {
|
|
'stop': nav_probs[i, 0].data.item(),
|
|
'og': i_objids[torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
|
|
'og_direction': obs[i]['obj_directions'][torch.argmax(i_obj_logits)] if len(i_objids) > 0 else None,
|
|
'og_details': {'objids': i_objids, 'logits': i_obj_logits[:len(i_objids)]},
|
|
}
|
|
|
|
if train_ml is not None:
|
|
# Supervised training
|
|
nav_targets = self._teacher_action(
|
|
obs, nav_vpids, ended,
|
|
visited_masks=nav_inputs['gmap_visited_masks'] if self.args.fusion != 'local' else None
|
|
)
|
|
# print(t, nav_logits, nav_targets)
|
|
ml_loss += self.criterion(nav_logits, nav_targets)
|
|
# print(t, 'ml_loss', ml_loss.item(), self.criterion(nav_logits, nav_targets).item())
|
|
if self.args.fusion in ['avg', 'dynamic'] and self.args.loss_nav_3:
|
|
# add global and local losses
|
|
ml_loss += self.criterion(nav_outs['global_logits'], nav_targets) # global
|
|
local_nav_targets = self._teacher_action(
|
|
obs, nav_inputs['vp_cand_vpids'], ended, visited_masks=None
|
|
)
|
|
ml_loss += self.criterion(nav_outs['local_logits'], local_nav_targets) # local
|
|
# objec grounding
|
|
obj_targets = self._teacher_object(obs, ended, pano_inputs['view_lens'])
|
|
# print(t, obj_targets[6], obj_logits[6], obs[6]['obj_ids'], pano_inputs['view_lens'][i], obs[6]['gt_obj_id'])
|
|
og_loss += self.criterion(obj_logits, obj_targets)
|
|
# print(F.cross_entropy(obj_logits, obj_targets, reduction='none'))
|
|
# print(t, 'og_loss', og_loss.item(), self.criterion(obj_logits, obj_targets).item())
|
|
|
|
# Determinate the next navigation viewpoint
|
|
if self.feedback == 'teacher':
|
|
a_t = nav_targets # teacher forcing
|
|
elif self.feedback == 'argmax':
|
|
_, a_t = nav_logits.max(1) # student forcing - argmax
|
|
a_t = a_t.detach()
|
|
elif self.feedback == 'sample':
|
|
c = torch.distributions.Categorical(nav_probs)
|
|
self.logs['entropy'].append(c.entropy().sum().item()) # For log
|
|
entropys.append(c.entropy()) # For optimization
|
|
a_t = c.sample().detach()
|
|
elif self.feedback == 'expl_sample':
|
|
_, a_t = nav_probs.max(1)
|
|
rand_explores = np.random.rand(batch_size, ) > self.args.expl_max_ratio # hyper-param
|
|
if self.args.fusion == 'local':
|
|
cpu_nav_masks = nav_inputs['vp_nav_masks'].data.cpu().numpy()
|
|
else:
|
|
cpu_nav_masks = (nav_inputs['gmap_masks'] * nav_inputs['gmap_visited_masks'].logical_not()).data.cpu().numpy()
|
|
for i in range(batch_size):
|
|
if rand_explores[i]:
|
|
cand_a_t = np.arange(len(cpu_nav_masks[i]))[cpu_nav_masks[i]]
|
|
a_t[i] = np.random.choice(cand_a_t)
|
|
else:
|
|
print(self.feedback)
|
|
sys.exit('Invalid feedback option')
|
|
|
|
# Determine stop actions
|
|
if self.feedback == 'teacher' or self.feedback == 'sample': # in training
|
|
# a_t_stop = [ob['viewpoint'] in ob['gt_end_vps'] for ob in obs]
|
|
a_t_stop = [ob['viewpoint'] == ob['gt_path'][-1] for ob in obs]
|
|
else:
|
|
a_t_stop = a_t == 0
|
|
|
|
# Prepare environment action
|
|
cpu_a_t = []
|
|
for i in range(batch_size):
|
|
if a_t_stop[i] or ended[i] or nav_inputs['no_vp_left'][i] or (t == self.args.max_action_len - 1):
|
|
cpu_a_t.append(None)
|
|
just_ended[i] = True
|
|
else:
|
|
cpu_a_t.append(nav_vpids[i][a_t[i]])
|
|
|
|
# Make action and get the new state
|
|
self.make_equiv_action(cpu_a_t, gmaps, obs, traj)
|
|
for i in range(batch_size):
|
|
if (not ended[i]) and just_ended[i]:
|
|
stop_node, stop_score = None, {'stop': -float('inf'), 'og': None}
|
|
for k, v in gmaps[i].node_stop_scores.items():
|
|
if v['stop'] > stop_score['stop']:
|
|
stop_score = v
|
|
stop_node = k
|
|
if stop_node is not None and obs[i]['viewpoint'] != stop_node:
|
|
traj[i]['path'].append(gmaps[i].graph.path(obs[i]['viewpoint'], stop_node))
|
|
traj[i]['pred_obj_direction'] = stop_score['og_direction']
|
|
if self.args.detailed_output:
|
|
for k, v in gmaps[i].node_stop_scores.items():
|
|
traj[i]['details'][k] = {
|
|
'stop_prob': float(v['stop']),
|
|
'obj_ids': [str(x) for x in v['og_details']['objids']],
|
|
'obj_logits': v['og_details']['logits'].tolist(),
|
|
}
|
|
|
|
# new observation and update graph
|
|
obs = self.env._get_obs()
|
|
self._update_scanvp_cands(obs)
|
|
for i, ob in enumerate(obs):
|
|
if not ended[i]:
|
|
gmaps[i].update_graph(ob)
|
|
|
|
ended[:] = np.logical_or(ended, np.array([x is None for x in cpu_a_t]))
|
|
|
|
# Early exit if all ended
|
|
if ended.all():
|
|
break
|
|
|
|
if train_ml is not None:
|
|
ml_loss = ml_loss * train_ml / batch_size
|
|
og_loss = og_loss * train_ml / batch_size
|
|
self.loss += ml_loss
|
|
self.loss += og_loss
|
|
self.logs['IL_loss'].append(ml_loss.item())
|
|
self.logs['OG_loss'].append(og_loss.item())
|
|
|
|
return traj
|
|
|
|
|