from datasets import Iterator from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \ check_folder, one_hot_embedding_3d import torch import cv2 import torch.nn.functional as F import os import torch.nn as nn import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import ListedColormap, BoundaryNorm import numpy as np class BaseRunner(): def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): self.optimizer = optimizer self.losses = losses self.train_loader = train_loader self.val_loader = val_loader self.model = model self.scheduler = scheduler self.train_timer = Timer() self.eval_timer = Timer() try: use_gpu = os.environ['CUDA_VISIBLE_DEVICES'] except KeyError: use_gpu = '0' self.the_number_of_gpu = len(use_gpu.split(',')) self.original_size = self.model.img_adapter.sam_img_encoder.img_size if self.the_number_of_gpu > 1: self.model = nn.DataParallel(self.model) class SemRunner(BaseRunner): # def __init__(self, **kwargs): # super().__init__(kwargs) def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler) self.exist_status = ['train', 'eval', 'test'] def train(self, cfg): # initial identify train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss']) train_iterator = Iterator(self.train_loader) best_valid_mIoU = -1 model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg) log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg) check_folder(model_path) check_folder(log_path) writer = None if cfg.use_tensorboard is True: tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg) from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(tensorboard_dir) # train train_losses = [] for iteration in range(cfg.max_iter): images, labels,_ = train_iterator.get() images, labels = images.cuda(), labels.cuda().long() labels = labels.squeeze(1) masks_pred, iou_pred = self.model(images) masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False) total_loss = torch.zeros(1).cuda() loss_dict = {} self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg) self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() self.scheduler.step() loss_dict['total_loss'] = total_loss.item() train_losses.append(total_loss.item()) train_meter.add(loss_dict) # log if (iteration + 1) % cfg.log_iter == 0: write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True), status=self.exist_status[0], writer=writer, timer=self.train_timer) # eval if (iteration + 1) % cfg.eval_iter == 0: mIoU, _ = self._eval() if best_valid_mIoU == -1 or best_valid_mIoU < mIoU: best_valid_mIoU = mIoU save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1) print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path) log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU} write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1], writer=writer, timer=self.eval_timer) plt.figure(figsize=(8, 5)) plt.plot(train_losses, label="Train Loss") plt.xlabel("Iteration") plt.ylabel("Loss") plt.title(f"Loss Curve up to Iter {iteration}") plt.legend() # 保存 save_path = os.path.join('./', f"loss_iter_{iteration}.png") plt.savefig(save_path, dpi=150) plt.close() # 关闭当前 figure,释放内存 # final process save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1) if writer is not None: writer.close() def test(self): pass def _eval(self): self.model.eval() self.eval_timer.start() class_names = self.val_loader.dataset.class_names eval_metric = mIoUOnline(class_names=class_names) with torch.no_grad(): for index, (images, labels, img_paths) in enumerate(self.val_loader): images = images.cuda() labels = labels.cuda() masks_pred, iou_pred = self.model(images) predictions = torch.argmax(masks_pred, dim=1) for batch_index in range(images.size()[0]): pred_mask = get_numpy_from_tensor(predictions[batch_index]) gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0)) h, w = pred_mask.shape gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST) src_img = cv2.imread(img_paths[batch_index]) # self.visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index])) self.construct_visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index])) eval_metric.add(pred_mask, gt_mask) self.model.train() return eval_metric.get(clear=True) def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg): """ Due to the inputs of losses are different, so if you want to add new losses, you may need to modify the process in this function """ loss_cfg = cfg.losses for index, item in enumerate(self.losses.items()): # item -> (key: loss_name, val: loss) real_labels = labels if loss_cfg[item[0]].label_one_hot: class_num = cfg.model.params.class_num real_labels = one_hot_embedding_3d(real_labels, class_num=class_num) tmp_loss = item[1](mask_pred, real_labels) loss_dict[item[0]] = tmp_loss.item() total_loss += loss_cfg[item[0]].weight * tmp_loss def visualize_segmentation(self,image, pred_mask, gt_mask, save_path=None): # 如果图像是 (C, H, W),需要先变成 (H, W, C) if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3: image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C) # 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色 cmap = ListedColormap(["black", "green"]) # 对应 0 和 1 两种类别,分界点给 [0,1,2] norm = BoundaryNorm([0, 1, 2], cmap.N) # 创建图像 fig, axes = plt.subplots(1, 4, figsize=(16, 4)) axes[0].imshow(image.astype(np.uint8)) axes[0].set_title("Original Image") axes[0].axis("off") im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm) axes[1].set_title("Predicted Mask") axes[1].axis("off") im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm) axes[2].set_title("Ground Truth Mask") axes[2].axis("off") legend_patches = [ mpatches.Patch(color="black", label="Background (0)"), mpatches.Patch(color="green", label="lettuce (1)"), ] axes[3].legend(handles=legend_patches, loc='center', fontsize=10) axes[3].set_title("Classes Legend") axes[3].axis("off") # 调整布局 plt.tight_layout() plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight') plt.close() def construct_visualize_segmentation(self, image, pred_mask, gt_mask, save_path=None): # 如果图像是 (C, H, W),需要先变成 (H, W, C) if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3: image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C) # 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色, 2=红色 cmap = ListedColormap(["black", "green", "red"]) # 对应 0,1,2 三种类别,分界点给 [0,1,2,3] norm = BoundaryNorm([0, 1, 2, 3], cmap.N) # 创建图像 fig, axes = plt.subplots(1, 4, figsize=(16, 4)) axes[0].imshow(image.astype(np.uint8)) axes[0].set_title("Original Image") axes[0].axis("off") im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm) axes[1].set_title("Predicted Mask") axes[1].axis("off") im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm) axes[2].set_title("Ground Truth Mask") axes[2].axis("off") legend_patches = [ mpatches.Patch(color="black", label="Background (0)"), mpatches.Patch(color="green", label="lettuce (1)"), mpatches.Patch(color="red", label="weed (2)"), ] axes[3].legend(handles=legend_patches, loc='center', fontsize=10) axes[3].set_title("Classes Legend") axes[3].axis("off") # 调整布局 plt.tight_layout() plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight') plt.close()