laser_weeding/extend_sam/runner.py

227 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datasets import Iterator
from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \
check_folder, one_hot_embedding_3d
import torch
import cv2
import torch.nn.functional as F
import os
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, BoundaryNorm
import numpy as np
class BaseRunner():
def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
self.optimizer = optimizer
self.losses = losses
self.train_loader = train_loader
self.val_loader = val_loader
self.model = model
self.scheduler = scheduler
self.train_timer = Timer()
self.eval_timer = Timer()
try:
use_gpu = os.environ['CUDA_VISIBLE_DEVICES']
except KeyError:
use_gpu = '0'
self.the_number_of_gpu = len(use_gpu.split(','))
self.original_size = self.model.img_adapter.sam_img_encoder.img_size
if self.the_number_of_gpu > 1:
self.model = nn.DataParallel(self.model)
class SemRunner(BaseRunner):
# def __init__(self, **kwargs):
# super().__init__(kwargs)
def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler):
super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler)
self.exist_status = ['train', 'eval', 'test']
def train(self, cfg):
# initial identify
train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss'])
train_iterator = Iterator(self.train_loader)
best_valid_mIoU = -1
model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg)
log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg)
check_folder(model_path)
check_folder(log_path)
writer = None
if cfg.use_tensorboard is True:
tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(tensorboard_dir)
# train
train_losses = []
for iteration in range(cfg.max_iter):
images, labels,_ = train_iterator.get()
images, labels = images.cuda(), labels.cuda().long()
labels = labels.squeeze(1)
masks_pred, iou_pred = self.model(images)
masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False)
total_loss = torch.zeros(1).cuda()
loss_dict = {}
self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg)
self.optimizer.zero_grad()
total_loss.backward()
self.optimizer.step()
self.scheduler.step()
loss_dict['total_loss'] = total_loss.item()
train_losses.append(total_loss.item())
train_meter.add(loss_dict)
# log
if (iteration + 1) % cfg.log_iter == 0:
write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True),
status=self.exist_status[0],
writer=writer, timer=self.train_timer)
# eval
if (iteration + 1) % cfg.eval_iter == 0:
mIoU, _ = self._eval()
if best_valid_mIoU == -1 or best_valid_mIoU < mIoU:
best_valid_mIoU = mIoU
save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1)
print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path)
log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU}
write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1],
writer=writer, timer=self.eval_timer)
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Train Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title(f"Loss Curve up to Iter {iteration}")
plt.legend()
# 保存
save_path = os.path.join('./', f"loss_iter_{iteration}.png")
plt.savefig(save_path, dpi=150)
plt.close() # 关闭当前 figure释放内存
# final process
save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1)
if writer is not None:
writer.close()
def test(self):
pass
def _eval(self):
self.model.eval()
self.eval_timer.start()
class_names = self.val_loader.dataset.class_names
eval_metric = mIoUOnline(class_names=class_names)
with torch.no_grad():
for index, (images, labels, img_paths) in enumerate(self.val_loader):
images = images.cuda()
labels = labels.cuda()
masks_pred, iou_pred = self.model(images)
predictions = torch.argmax(masks_pred, dim=1)
for batch_index in range(images.size()[0]):
pred_mask = get_numpy_from_tensor(predictions[batch_index])
gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0))
h, w = pred_mask.shape
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
src_img = cv2.imread(img_paths[batch_index])
# self.visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index]))
self.construct_visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index]))
eval_metric.add(pred_mask, gt_mask)
self.model.train()
return eval_metric.get(clear=True)
def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg):
"""
Due to the inputs of losses are different, so if you want to add new losses,
you may need to modify the process in this function
"""
loss_cfg = cfg.losses
for index, item in enumerate(self.losses.items()):
# item -> (key: loss_name, val: loss)
real_labels = labels
if loss_cfg[item[0]].label_one_hot:
class_num = cfg.model.params.class_num
real_labels = one_hot_embedding_3d(real_labels, class_num=class_num)
tmp_loss = item[1](mask_pred, real_labels)
loss_dict[item[0]] = tmp_loss.item()
total_loss += loss_cfg[item[0]].weight * tmp_loss
def visualize_segmentation(self,image, pred_mask, gt_mask, save_path=None):
# 如果图像是 (C, H, W),需要先变成 (H, W, C)
if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3:
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)
# 定义用于显示 segmentation 的离散颜色映射0=黑色, 1=绿色
cmap = ListedColormap(["black", "green"])
# 对应 0 和 1 两种类别,分界点给 [0,1,2]
norm = BoundaryNorm([0, 1, 2], cmap.N)
# 创建图像
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(image.astype(np.uint8))
axes[0].set_title("Original Image")
axes[0].axis("off")
im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm)
axes[1].set_title("Predicted Mask")
axes[1].axis("off")
im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm)
axes[2].set_title("Ground Truth Mask")
axes[2].axis("off")
legend_patches = [
mpatches.Patch(color="black", label="Background (0)"),
mpatches.Patch(color="green", label="lettuce (1)"),
]
axes[3].legend(handles=legend_patches, loc='center', fontsize=10)
axes[3].set_title("Classes Legend")
axes[3].axis("off")
# 调整布局
plt.tight_layout()
plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight')
plt.close()
def construct_visualize_segmentation(self, image, pred_mask, gt_mask, save_path=None):
# 如果图像是 (C, H, W),需要先变成 (H, W, C)
if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3:
image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C)
# 定义用于显示 segmentation 的离散颜色映射0=黑色, 1=绿色, 2=红色
cmap = ListedColormap(["black", "green", "red"])
# 对应 0,1,2 三种类别,分界点给 [0,1,2,3]
norm = BoundaryNorm([0, 1, 2, 3], cmap.N)
# 创建图像
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(image.astype(np.uint8))
axes[0].set_title("Original Image")
axes[0].axis("off")
im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm)
axes[1].set_title("Predicted Mask")
axes[1].axis("off")
im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm)
axes[2].set_title("Ground Truth Mask")
axes[2].axis("off")
legend_patches = [
mpatches.Patch(color="black", label="Background (0)"),
mpatches.Patch(color="green", label="lettuce (1)"),
mpatches.Patch(color="red", label="weed (2)"),
]
axes[3].legend(handles=legend_patches, loc='center', fontsize=10)
axes[3].set_title("Classes Legend")
axes[3].axis("off")
# 调整布局
plt.tight_layout()
plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight')
plt.close()