laser_weeding/train.py

63 lines
2.7 KiB
Python

'''
@copyright ziqi-jin
'''
import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from datasets import get_dataset,get_lettuce_dataset
from losses import get_losses
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
import random
import numpy as np
import torch
def set_seed(seed: int):
"""
固定训练过程中的随机种子,保证结果相对可复现。
"""
random.seed(seed) # Python 内置的 random
np.random.seed(seed) # NumPy
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch当前 GPU
torch.cuda.manual_seed_all(seed) # PyTorch所有 GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', default='semantic_seg', type=str)
parser.add_argument('--cfg', default=None, type=str)
if __name__ == '__main__':
args = parser.parse_args()
task_name = args.task_name
if args.cfg is not None:
config = OmegaConf.load(args.cfg)
else:
assert task_name in supported_tasks, "Please input the supported task name."
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
train_cfg = config.train
val_cfg = config.val
test_cfg = config.test
train_dataset,val_dataset = get_lettuce_dataset()
train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers,
drop_last=train_cfg.drop_last)
val_loader = DataLoader(val_dataset, batch_size=val_cfg.bs, shuffle=False, num_workers=val_cfg.num_workers,
drop_last=val_cfg.drop_last)
losses = get_losses(losses=train_cfg.losses)
# according the model name to get the adapted model
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys,
wd_list=train_cfg.opt_params.wd_list)
optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default,
momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default)
scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name)
runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler)
# train_step
runner.train(train_cfg)
if test_cfg.need_test:
runner.test(test_cfg)