From 09e9854707b49b4b9250f2e4099acad7acfb34c6 Mon Sep 17 00:00:00 2001 From: jixingyue <18019376251@163.com> Date: Mon, 29 Sep 2025 13:10:43 +0800 Subject: [PATCH] train.py --- train.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 train.py diff --git a/train.py b/train.py new file mode 100644 index 0000000..d281130 --- /dev/null +++ b/train.py @@ -0,0 +1,63 @@ +''' +@copyright ziqi-jin +''' +import argparse +from omegaconf import OmegaConf +from torch.utils.data import DataLoader +from datasets import get_dataset,get_lettuce_dataset +from losses import get_losses +from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner +import random +import numpy as np +import torch +def set_seed(seed: int): + """ + 固定训练过程中的随机种子,保证结果相对可复现。 + """ + random.seed(seed) # Python 内置的 random + np.random.seed(seed) # NumPy + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch当前 GPU + torch.cuda.manual_seed_all(seed) # PyTorch所有 GPU + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False +set_seed(42) + +supported_tasks = ['detection', 'semantic_seg', 'instance_seg'] +parser = argparse.ArgumentParser() +parser.add_argument('--task_name', default='semantic_seg', type=str) +parser.add_argument('--cfg', default=None, type=str) + +if __name__ == '__main__': + args = parser.parse_args() + task_name = args.task_name + if args.cfg is not None: + config = OmegaConf.load(args.cfg) + else: + assert task_name in supported_tasks, "Please input the supported task name." + config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name)) + + train_cfg = config.train + val_cfg = config.val + test_cfg = config.test + + train_dataset,val_dataset = get_lettuce_dataset() + train_loader = DataLoader(train_dataset, batch_size=train_cfg.bs, shuffle=True, num_workers=train_cfg.num_workers, + drop_last=train_cfg.drop_last) + val_loader = DataLoader(val_dataset, batch_size=val_cfg.bs, shuffle=False, num_workers=val_cfg.num_workers, + drop_last=val_cfg.drop_last) + losses = get_losses(losses=train_cfg.losses) + # according the model name to get the adapted model + model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params) + opt_params = get_opt_pamams(model, lr_list=train_cfg.opt_params.lr_list, group_keys=train_cfg.opt_params.group_keys, + wd_list=train_cfg.opt_params.wd_list) + optimizer = get_optimizer(opt_name=train_cfg.opt_name, params=opt_params, lr=train_cfg.opt_params.lr_default, + momentum=train_cfg.opt_params.momentum, weight_decay=train_cfg.opt_params.wd_default) + scheduler = get_scheduler(optimizer=optimizer, lr_scheduler=train_cfg.scheduler_name) + runner = get_runner(train_cfg.runner_name)(model, optimizer, losses, train_loader, val_loader, scheduler) + # train_step + runner.train(train_cfg) + if test_cfg.need_test: + runner.test(test_cfg) + \ No newline at end of file