generated from yuanbiao/python_templates
128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
# copyright ziqi-jin
|
|
import torch
|
|
from .extend_sam import BaseExtendSam, SemanticSam
|
|
from .runner import BaseRunner, SemRunner
|
|
# from .optimizer import BaseOptimizer
|
|
from .scheduler import WarmupMultiStepLR
|
|
from .utils import get_opt_pamams
|
|
|
|
AVAI_SCH = ["single_step", "multi_step", "warmup_multi_step", "cosine", "linear"]
|
|
AVAI_MODEL = {'base_sam': BaseExtendSam, 'sem_sam': SemanticSam}
|
|
# AVAI_OPT = {'base_opt': BaseOptimizer, 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam}
|
|
AVAI_OPT = {'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adamw': torch.optim.AdamW}
|
|
AVAI_RUNNER = {'base_runner': BaseRunner, 'sem_runner': SemRunner}
|
|
|
|
|
|
def get_model(model_name, **kwargs):
|
|
if model_name not in AVAI_MODEL:
|
|
print('not supported model name, please implement it first.')
|
|
return AVAI_MODEL[model_name](**kwargs).cuda()
|
|
|
|
|
|
def get_optimizer(opt_name, **kwargs):
|
|
if opt_name not in AVAI_OPT:
|
|
print('not supported optimizer name, please implement it first.')
|
|
return AVAI_OPT[opt_name](**{k: v for k, v in kwargs.items() if v is not None})
|
|
|
|
|
|
def get_runner(runner_name):
|
|
if runner_name not in AVAI_RUNNER:
|
|
print('not supported runner name, please implement it first.')
|
|
return AVAI_RUNNER[runner_name]
|
|
|
|
|
|
def get_scheduler(
|
|
optimizer,
|
|
lr_scheduler="single_step",
|
|
stepsize=1,
|
|
gamma=0.1,
|
|
warmup_factor=0.01,
|
|
warmup_steps=10,
|
|
max_epoch=1,
|
|
n_epochs_init=50,
|
|
n_epochs_decay=50,
|
|
|
|
):
|
|
"""A function wrapper for building a learning rate scheduler.
|
|
Args:
|
|
optimizer (Optimizer): an Optimizer.
|
|
lr_scheduler (str, optional): learning rate scheduler method. Default is
|
|
single_step.
|
|
stepsize (int or list, optional): step size to decay learning rate.
|
|
When ``lr_scheduler`` is "single_step", ``stepsize`` should be an integer.
|
|
When ``lr_scheduler`` is "multi_step", ``stepsize`` is a list. Default is 1.
|
|
gamma (float, optional): decay rate. Default is 0.1.
|
|
max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1.
|
|
Examples::
|
|
>>> # Decay learning rate by every 20 epochs.
|
|
>>> scheduler = get_scheduler(
|
|
>>> optimizer, lr_scheduler='single_step', stepsize=20
|
|
>>> )
|
|
>>> # Decay learning rate at 30, 50 and 55 epochs.
|
|
>>> scheduler = get_scheduler(
|
|
>>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55]
|
|
>>> )
|
|
"""
|
|
if lr_scheduler not in AVAI_SCH:
|
|
raise ValueError(
|
|
"Unsupported scheduler: {}. Must be one of {}".format(
|
|
lr_scheduler, AVAI_SCH
|
|
)
|
|
)
|
|
|
|
if lr_scheduler == "single_step":
|
|
if isinstance(stepsize, list):
|
|
stepsize = stepsize[-1]
|
|
|
|
if not isinstance(stepsize, int):
|
|
raise TypeError(
|
|
"For single_step lr_scheduler, stepsize must "
|
|
"be an integer, but got {}".format(type(stepsize))
|
|
)
|
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer, step_size=stepsize, gamma=gamma
|
|
)
|
|
|
|
elif lr_scheduler == "multi_step":
|
|
if not isinstance(stepsize, list):
|
|
raise TypeError(
|
|
"For multi_step lr_scheduler, stepsize must "
|
|
"be a list, but got {}".format(type(stepsize))
|
|
)
|
|
|
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer, milestones=stepsize, gamma=gamma
|
|
)
|
|
|
|
elif lr_scheduler == "warmup_multi_step":
|
|
if not isinstance(stepsize, list):
|
|
raise TypeError(
|
|
"For warmup multi_step lr_scheduler, stepsize must "
|
|
"be a list, but got {}".format(type(stepsize))
|
|
)
|
|
|
|
scheduler = WarmupMultiStepLR(
|
|
optimizer,
|
|
milestones=stepsize,
|
|
gamma=gamma,
|
|
warmup_factor=warmup_factor,
|
|
warmup_iters=warmup_steps,
|
|
)
|
|
|
|
elif lr_scheduler == "cosine":
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, int(max_epoch)
|
|
)
|
|
|
|
elif lr_scheduler == "linear":
|
|
def lambda_rule(epoch):
|
|
lr_l = 1.0 - max(0, epoch - n_epochs_init) / float(n_epochs_decay + 1)
|
|
return lr_l
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
optimizer, lr_lambda=lambda_rule
|
|
)
|
|
|
|
return scheduler
|