laser_weeding/datasets/semantic_seg.py

265 lines
10 KiB
Python

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import VOCSegmentation, VisionDataset
import numpy as np
import cv2
import json
import torch
class BaseSemanticDataset(VisionDataset):
"""
if you want to customize a new dataset to train the segmentation task,
the img and mask file need be arranged as this sturcture.
├── data
│ ├── my_dataset
│ │ ├── img
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann
│ │ │ ├── train
│ │ │ │ ├── xxx{ann_suffix}
│ │ │ │ ├── yyy{ann_suffix}
│ │ │ │ ├── zzz{ann_suffix}
│ │ │ ├── val
"""
def __init__(self, metainfo, dataset_dir, transform, target_transform,
image_set='train',
img_suffix='.jpg',
ann_suffix='.png',
data_prefix: dict = dict(img_path='img', ann_path='ann'),
return_dict=False):
'''
:param metainfo: meta data in original dataset, e.g. class_names
:param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above
:param image_set: 'train' or 'val'
:param img_suffix: your image suffix
:param ann_suffix: your annotation suffix
:param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann'
:param return_dict: return dict() or tuple(img, ann)
'''
super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform,
target_transform=target_transform)
self.class_names = metainfo['class_names']
self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set)
self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set)
print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format(
img_folder_name=self.img_path, ann_folder_name=self.ann_path))
self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if
img_name.endswith(img_suffix)]
self.img_suffix = img_suffix
self.ann_suffix = ann_suffix
self.return_dict = return_dict
def __getitem__(self, index):
img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix))
ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
if self.transforms is not None:
img, ann = self.transforms(img, ann)
ann = np.array(ann)
if self.return_dict:
data = dict(img_name=self.img_names[index], img=img, ann=ann,
img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
return data
return img, ann
def __len__(self):
return len(self.img_names)
class VOCSemanticDataset(Dataset):
def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False):
super(VOCSemanticDataset, self).__init__()
self.root_dir = root_dir
self.image_dir = self.root_dir + 'JPEGImages/'
self.xml_dir = self.root_dir + 'Annotations/'
self.mask_dir = self.root_dir + 'SegmentationClass/'
self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()]
self.transform = transform
self.with_id = with_id
self.with_mask = with_mask
self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
def __len__(self):
return len(self.image_id_list)
def get_image(self, image_id):
image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image
def get_mask(self, image_id):
mask_path = self.mask_dir + image_id + '.png'
if os.path.isfile(mask_path):
mask = Image.open(mask_path)
else:
mask = None
return mask
def __getitem__(self, index):
image_id = self.image_id_list[index]
data_list = [self.get_image(image_id)]
if self.with_id:
data_list.append(image_id)
if self.with_mask:
data_list.append(self.get_mask(image_id))
return data_list
class TorchVOCSegmentation(VOCSegmentation):
def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None):
super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download,
transform=transform, target_transform=target_transform)
self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
def __getitem__(self, index: int):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
target = np.array(target)
return img, target
class LettuceSegDataset(Dataset):
def __init__(self,
file_list,
transform=None,
# image_suffix=".JPG",
image_suffix=".jpg",
label_suffix=".json",
width=None,
height=None):
super().__init__()
self.file_list = file_list
self.transform = transform
self.image_suffix = image_suffix
self.label_suffix = label_suffix
self.width = width
self.height = height
# self.class_names = ['background', 'lettuce']
self.class_names = ['background', 'lettuce', 'weed']
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
image_path = self.file_list[idx]
json_path = image_path.replace(self.image_suffix, self.label_suffix)
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, _ = image.shape
mask = np.zeros((h, w), dtype=np.uint8)
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for shape in data.get("shapes", []):
label_name = shape["label"]
polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2))
# if label_name == "lettuce":
if label_name == "weed":
cv2.fillPoly(mask, [polygon], 1)
if self.width is not None and self.height is not None:
image = cv2.resize(image, (self.width, self.height))
mask = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
mask = torch.from_numpy(mask[np.newaxis, ...]).long()
return image, mask, image_path
class construct_LettuceSegDataset(Dataset):
def __init__(self,
file_list,
transform=None,
# image_suffix=".JPG",
image_suffix=".jpg",
label_suffix=".json",
width=None,
height=None):
super().__init__()
self.file_list = file_list
self.transform = transform
self.image_suffix = image_suffix
self.label_suffix = label_suffix
self.width = width
self.height = height
# 添加 'weed' 类别
self.class_names = ['background', 'lettuce', 'weed']
# 定义类别到索引的映射
self.class_to_idx = {
'background': 0,
'lettuce': 1,
'weed': 2
}
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
image_path = self.file_list[idx]
json_path = image_path.replace(self.image_suffix, self.label_suffix)
# 读取并转换图像
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, _ = image.shape
# 创建掩码,初始化为背景类别(0)
mask = np.zeros((h, w), dtype=np.uint8)
# 读取标注文件
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 处理所有标注的形状
for shape in data.get("shapes", []):
label_name = shape["label"]
polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2))
# 根据类别填充不同的值
if label_name in self.class_to_idx:
cv2.fillPoly(mask, [polygon], self.class_to_idx[label_name])
# 如果需要调整大小
if self.width is not None and self.height is not None:
image = cv2.resize(image, (self.width, self.height))
mask = cv2.resize(mask, (self.width, self.height),
interpolation=cv2.INTER_NEAREST)
# 转换为张量
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
mask = torch.from_numpy(mask[np.newaxis, ...]).long()
return image, mask, image_path