109 lines
3.3 KiB
Python
109 lines
3.3 KiB
Python
import os
|
|
import json
|
|
import time
|
|
|
|
from data_utils import construct_reverie_instrs
|
|
from utils.logger import write_to_record_file
|
|
|
|
from utils.data import ImageObservationsDB
|
|
from parser import parse_args
|
|
from env import REVERIENavBatch
|
|
from agent import NavGPTAgent
|
|
|
|
def build_dataset(args, data_limit=100):
|
|
|
|
feat_db = ImageObservationsDB(args.obs_dir, args.obs_summary_dir, args.obj_dir)
|
|
print(feat_db)
|
|
|
|
dataset_class = REVERIENavBatch
|
|
|
|
val_env_names = [args.val_env_name]
|
|
|
|
val_envs = {}
|
|
for split in val_env_names:
|
|
val_instr_data = construct_reverie_instrs(
|
|
args.anno_dir, args.dataset, [split]
|
|
)
|
|
val_env = dataset_class(
|
|
feat_db, val_instr_data, args.connectivity_dir, args.navigable_dir,
|
|
batch_size=args.batch_size, seed=args.seed, name=split, data_limit=data_limit
|
|
) # evaluation using all objects
|
|
val_envs[split] = val_env
|
|
|
|
return val_envs
|
|
|
|
|
|
def valid(args, val_envs):
|
|
|
|
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
|
|
|
with open(os.path.join(args.log_dir, 'validation_args.json'), 'w') as outf:
|
|
json.dump(vars(args), outf, indent=4)
|
|
record_file = os.path.join(args.log_dir, 'valid.txt')
|
|
write_to_record_file(str(args) + '\n\n', record_file)
|
|
|
|
for env_name, env in val_envs.items():
|
|
prefix = 'submit'
|
|
if os.path.exists(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name))):
|
|
continue
|
|
agent.env = env
|
|
|
|
start_time = time.time()
|
|
agent.test(iters=args.iters)
|
|
print(env_name, 'cost time: %.2fs' % (time.time() - start_time))
|
|
# Get the results
|
|
preds = agent.get_results(detailed_output=False)
|
|
# Record llm output details
|
|
if args.detailed_output:
|
|
preds_detail = agent.get_results(detailed_output=True)
|
|
|
|
json.dump(
|
|
preds_detail,
|
|
open(os.path.join(args.log_dir, "detail_%s.json" % (env_name)), 'w'),
|
|
sort_keys=True, indent=4, separators=(',', ': ')
|
|
)
|
|
|
|
if 'test' not in env_name:
|
|
score_summary, _ = env.eval_metrics(preds)
|
|
loss_str = "Env name: %s" % env_name
|
|
for metric, val in score_summary.items():
|
|
loss_str += ', %s: %.2f' % (metric, val)
|
|
write_to_record_file(loss_str+'\n', record_file)
|
|
|
|
json.dump(
|
|
preds,
|
|
open(os.path.join(args.pred_dir, "%s_%s.json" % (prefix, env_name)), 'w'),
|
|
sort_keys=True, indent=4, separators=(',', ': ')
|
|
)
|
|
|
|
|
|
def valid_from_file(args, val_envs):
|
|
|
|
agent = NavGPTAgent(next(iter(val_envs.values())), args)
|
|
with open(args.valid_file, 'r') as f:
|
|
preds = json.load(f)
|
|
|
|
for env_name, env in val_envs.items():
|
|
agent.env = env
|
|
valid_list = [preds]
|
|
for valid_pred in valid_list:
|
|
score_summary, _ = env.eval_metrics(valid_pred)
|
|
loss_str = "Env name: %s, length %d" % (env_name, len(valid_pred))
|
|
for metric, val in score_summary.items():
|
|
loss_str += ', %s: %.2f' % (metric, val)
|
|
print(loss_str)
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
val_envs = build_dataset(args, data_limit=100)
|
|
|
|
if args.valid_file is not None:
|
|
valid_from_file(args, val_envs)
|
|
else:
|
|
valid(args, val_envs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|