generated from yuanbiao/python_templates
train.py
This commit is contained in:
parent
e4377404c5
commit
09e9854707
|
|
@ -0,0 +1,63 @@
|
||||||
|
'''
|
||||||
|
@copyright ziqi-jin
|
||||||
|
'''
|
||||||
|
import argparse
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import get_dataset,get_lettuce_dataset
|
||||||
|
from losses import get_losses
|
||||||
|
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
def set_seed(seed: int):
|
||||||
|
"""
|
||||||
|
固定训练过程中的随机种子,保证结果相对可复现。
|
||||||
|
"""
|
||||||
|
random.seed(seed) # Python 内置的 random
|
||||||
|
np.random.seed(seed) # NumPy
|
||||||
|
torch.manual_seed(seed) # PyTorch CPU
|
||||||
|
torch.cuda.manual_seed(seed) # PyTorch当前 GPU
|
||||||
|
torch.cuda.manual_seed_all(seed) # PyTorch所有 GPU
|
||||||
|
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task_name', default='semantic_seg', type=str)
|
||||||
|
parser.add_argument('--cfg', default=None, type=str)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
task_name = args.task_name
|
||||||
|
if args.cfg is not None:
|
||||||
|
config = OmegaConf.load(args.cfg)
|
||||||
|
else:
|
||||||
|
assert task_name in supported_tasks, "Please input the supported task name."
|
||||||
|
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
|
||||||
|
|
||||||
|
train_cfg = config.train
|
||||||
|
val_cfg = config.val
|
||||||
|
test_cfg = config.test
|
||||||
|
|
||||||
|
train_dataset,val_dataset = get_lettuce_dataset()
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers,
|
||||||
|
drop_last=train_cfg.drop_last)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=val_cfg.bs, shuffle=False, num_workers=val_cfg.num_workers,
|
||||||
|
drop_last=val_cfg.drop_last)
|
||||||
|
losses = get_losses(losses=train_cfg.losses)
|
||||||
|
# according the model name to get the adapted model
|
||||||
|
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
|
||||||
|
opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys,
|
||||||
|
wd_list=train_cfg.opt_params.wd_list)
|
||||||
|
optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default,
|
||||||
|
momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default)
|
||||||
|
scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name)
|
||||||
|
runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler)
|
||||||
|
# train_step
|
||||||
|
runner.train(train_cfg)
|
||||||
|
if test_cfg.need_test:
|
||||||
|
runner.test(test_cfg)
|
||||||
|
|
||||||
Loading…
Reference in New Issue