generated from yuanbiao/python_templates
227 lines
9.8 KiB
Python
227 lines
9.8 KiB
Python
from datasets import Iterator
|
||
from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \
|
||
check_folder, one_hot_embedding_3d
|
||
import torch
|
||
import cv2
|
||
import torch.nn.functional as F
|
||
import os
|
||
import torch.nn as nn
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
from matplotlib.colors import ListedColormap, BoundaryNorm
|
||
import numpy as np
|
||
|
||
class BaseRunner():
|
||
def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
|
||
self.optimizer = optimizer
|
||
self.losses = losses
|
||
self.train_loader = train_loader
|
||
self.val_loader = val_loader
|
||
self.model = model
|
||
self.scheduler = scheduler
|
||
self.train_timer = Timer()
|
||
self.eval_timer = Timer()
|
||
try:
|
||
use_gpu = os.environ['CUDA_VISIBLE_DEVICES']
|
||
except KeyError:
|
||
use_gpu = '0'
|
||
self.the_number_of_gpu = len(use_gpu.split(','))
|
||
self.original_size = self.model.img_adapter.sam_img_encoder.img_size
|
||
if self.the_number_of_gpu > 1:
|
||
self.model = nn.DataParallel(self.model)
|
||
|
||
|
||
class SemRunner(BaseRunner):
|
||
# def __init__(self, **kwargs):
|
||
# super().__init__(kwargs)
|
||
|
||
def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
|
||
super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler)
|
||
self.exist_status = ['train', 'eval', 'test']
|
||
|
||
def train(self, cfg):
|
||
# initial identify
|
||
train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss'])
|
||
train_iterator = Iterator(self.train_loader)
|
||
best_valid_mIoU = -1
|
||
model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg)
|
||
log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg)
|
||
check_folder(model_path)
|
||
check_folder(log_path)
|
||
writer = None
|
||
if cfg.use_tensorboard is True:
|
||
tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg)
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
writer = SummaryWriter(tensorboard_dir)
|
||
# train
|
||
train_losses = []
|
||
for iteration in range(cfg.max_iter):
|
||
images, labels,_ = train_iterator.get()
|
||
images, labels = images.cuda(), labels.cuda().long()
|
||
labels = labels.squeeze(1)
|
||
masks_pred, iou_pred = self.model(images)
|
||
masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False)
|
||
|
||
total_loss = torch.zeros(1).cuda()
|
||
loss_dict = {}
|
||
self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg)
|
||
self.optimizer.zero_grad()
|
||
total_loss.backward()
|
||
self.optimizer.step()
|
||
self.scheduler.step()
|
||
loss_dict['total_loss'] = total_loss.item()
|
||
train_losses.append(total_loss.item())
|
||
train_meter.add(loss_dict)
|
||
|
||
# log
|
||
if (iteration + 1) % cfg.log_iter == 0:
|
||
write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True),
|
||
status=self.exist_status[0],
|
||
writer=writer, timer=self.train_timer)
|
||
# eval
|
||
if (iteration + 1) % cfg.eval_iter == 0:
|
||
mIoU, _ = self._eval()
|
||
if best_valid_mIoU == -1 or best_valid_mIoU < mIoU:
|
||
best_valid_mIoU = mIoU
|
||
save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1)
|
||
print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path)
|
||
log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU}
|
||
write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1],
|
||
writer=writer, timer=self.eval_timer)
|
||
|
||
plt.figure(figsize=(8, 5))
|
||
plt.plot(train_losses, label="Train Loss")
|
||
plt.xlabel("Iteration")
|
||
plt.ylabel("Loss")
|
||
plt.title(f"Loss Curve up to Iter {iteration}")
|
||
plt.legend()
|
||
# 保存
|
||
save_path = os.path.join('./', f"loss_iter_{iteration}.png")
|
||
plt.savefig(save_path, dpi=150)
|
||
plt.close() # 关闭当前 figure,释放内存
|
||
# final process
|
||
save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1)
|
||
if writer is not None:
|
||
writer.close()
|
||
|
||
def test(self):
|
||
pass
|
||
|
||
def _eval(self):
|
||
self.model.eval()
|
||
self.eval_timer.start()
|
||
class_names = self.val_loader.dataset.class_names
|
||
eval_metric = mIoUOnline(class_names=class_names)
|
||
with torch.no_grad():
|
||
for index, (images, labels, img_paths) in enumerate(self.val_loader):
|
||
images = images.cuda()
|
||
labels = labels.cuda()
|
||
masks_pred, iou_pred = self.model(images)
|
||
predictions = torch.argmax(masks_pred, dim=1)
|
||
for batch_index in range(images.size()[0]):
|
||
pred_mask = get_numpy_from_tensor(predictions[batch_index])
|
||
gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0))
|
||
h, w = pred_mask.shape
|
||
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||
|
||
src_img = cv2.imread(img_paths[batch_index])
|
||
# self.visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index]))
|
||
self.construct_visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index]))
|
||
eval_metric.add(pred_mask, gt_mask)
|
||
self.model.train()
|
||
return eval_metric.get(clear=True)
|
||
|
||
def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg):
|
||
"""
|
||
Due to the inputs of losses are different, so if you want to add new losses,
|
||
you may need to modify the process in this function
|
||
"""
|
||
loss_cfg = cfg.losses
|
||
for index, item in enumerate(self.losses.items()):
|
||
# item -> (key: loss_name, val: loss)
|
||
real_labels = labels
|
||
if loss_cfg[item[0]].label_one_hot:
|
||
class_num = cfg.model.params.class_num
|
||
real_labels = one_hot_embedding_3d(real_labels, class_num=class_num)
|
||
tmp_loss = item[1](mask_pred, real_labels)
|
||
loss_dict[item[0]] = tmp_loss.item()
|
||
total_loss += loss_cfg[item[0]].weight * tmp_loss
|
||
|
||
|
||
def visualize_segmentation(self,image, pred_mask, gt_mask, save_path=None):
|
||
# 如果图像是 (C, H, W),需要先变成 (H, W, C)
|
||
if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3:
|
||
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)
|
||
|
||
# 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色
|
||
cmap = ListedColormap(["black", "green"])
|
||
# 对应 0 和 1 两种类别,分界点给 [0,1,2]
|
||
norm = BoundaryNorm([0, 1, 2], cmap.N)
|
||
|
||
# 创建图像
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
||
|
||
axes[0].imshow(image.astype(np.uint8))
|
||
axes[0].set_title("Original Image")
|
||
axes[0].axis("off")
|
||
|
||
im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm)
|
||
axes[1].set_title("Predicted Mask")
|
||
axes[1].axis("off")
|
||
|
||
im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm)
|
||
axes[2].set_title("Ground Truth Mask")
|
||
axes[2].axis("off")
|
||
|
||
legend_patches = [
|
||
mpatches.Patch(color="black", label="Background (0)"),
|
||
mpatches.Patch(color="green", label="lettuce (1)"),
|
||
]
|
||
axes[3].legend(handles=legend_patches, loc='center', fontsize=10)
|
||
axes[3].set_title("Classes Legend")
|
||
axes[3].axis("off")
|
||
|
||
# 调整布局
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
|
||
def construct_visualize_segmentation(self, image, pred_mask, gt_mask, save_path=None):
|
||
# 如果图像是 (C, H, W),需要先变成 (H, W, C)
|
||
if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3:
|
||
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)
|
||
|
||
# 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色, 2=红色
|
||
cmap = ListedColormap(["black", "green", "red"])
|
||
# 对应 0,1,2 三种类别,分界点给 [0,1,2,3]
|
||
norm = BoundaryNorm([0, 1, 2, 3], cmap.N)
|
||
|
||
# 创建图像
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
||
|
||
axes[0].imshow(image.astype(np.uint8))
|
||
axes[0].set_title("Original Image")
|
||
axes[0].axis("off")
|
||
|
||
im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm)
|
||
axes[1].set_title("Predicted Mask")
|
||
axes[1].axis("off")
|
||
|
||
im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm)
|
||
axes[2].set_title("Ground Truth Mask")
|
||
axes[2].axis("off")
|
||
|
||
legend_patches = [
|
||
mpatches.Patch(color="black", label="Background (0)"),
|
||
mpatches.Patch(color="green", label="lettuce (1)"),
|
||
mpatches.Patch(color="red", label="weed (2)"),
|
||
]
|
||
axes[3].legend(handles=legend_patches, loc='center', fontsize=10)
|
||
axes[3].set_title("Classes Legend")
|
||
axes[3].axis("off")
|
||
|
||
# 调整布局
|
||
plt.tight_layout()
|
||
plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight')
|
||
plt.close() |