adversarial_VLNDUET/pretrain_src/parser.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

145 lines
3.9 KiB
Python

import argparse
import sys
import json
def load_parser():
parser = argparse.ArgumentParser()
# Required parameters
# NOTE: train tasks and val tasks cannot take command line arguments
parser.add_argument('--vlnbert', choices=['cmt'])
parser.add_argument(
"--model_config", type=str, help="path to model structure config json"
)
parser.add_argument(
"--checkpoint", default=None, type=str, help="path to model checkpoint (*.pt)"
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="The output directory where the model checkpoints will be written.",
)
# training parameters
parser.add_argument(
"--train_batch_size",
default=4096,
type=int,
help="Total batch size for training. ",
)
parser.add_argument(
"--val_batch_size",
default=4096,
type=int,
help="Total batch size for validation. ",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.",
)
parser.add_argument(
"--valid_steps", default=1000, type=int, help="Run validation every X steps"
)
parser.add_argument("--log_steps", default=1000, type=int)
parser.add_argument(
"--num_train_steps",
default=100000,
type=int,
help="Total number of training updates to perform.",
)
parser.add_argument(
"--optim",
default="adamw",
choices=["adam", "adamax", "adamw"],
help="optimizer",
)
parser.add_argument(
"--betas", default=[0.9, 0.98], nargs="+", help="beta for adam optimizer"
)
parser.add_argument(
"--dropout", default=0.1, type=float, help="tune dropout regularization"
)
parser.add_argument(
"--weight_decay",
default=0.01,
type=float,
help="weight decay (L2) regularization",
)
parser.add_argument(
"--grad_norm",
default=2.0,
type=float,
help="gradient clipping (-1 for no clipping)",
)
parser.add_argument(
"--warmup_steps",
default=10000,
type=int,
help="Number of training steps to perform linear " "learning rate warmup for.",
)
# device parameters
parser.add_argument(
"--seed", type=int, default=0, help="random seed for initialization"
)
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit float precision instead of 32-bit",
)
parser.add_argument(
"--n_workers", type=int, default=4, help="number of data workers"
)
parser.add_argument("--pin_mem", action="store_true", help="pin memory")
# distributed computing
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank for distributed training on gpus",
)
parser.add_argument(
"--node_rank",
type=int,
default=0,
help="Id of the node",
)
parser.add_argument(
"--world_size",
type=int,
default=1,
help="Number of GPUs across all nodes",
)
# can use config files
parser.add_argument("--config", required=True, help="JSON config files")
return parser
def parse_with_config(parser):
args = parser.parse_args()
if args.config is not None:
config_args = json.load(open(args.config))
override_keys = {
arg[2:].split("=")[0] for arg in sys.argv[1:] if arg.startswith("--")
}
for k, v in config_args.items():
if k not in override_keys:
setattr(args, k, v)
del args.config
return args