import os from PIL import Image from torch.utils.data import Dataset from torchvision.datasets import VOCSegmentation, VisionDataset import numpy as np import cv2 import json import torch class BaseSemanticDataset(VisionDataset): """ if you want to customize a new dataset to train the segmentation task, the img and mask file need be arranged as this sturcture. ├── data │ ├── my_dataset │ │ ├── img │ │ │ ├── train │ │ │ │ ├── xxx{img_suffix} │ │ │ │ ├── yyy{img_suffix} │ │ │ │ ├── zzz{img_suffix} │ │ │ ├── val │ │ ├── ann │ │ │ ├── train │ │ │ │ ├── xxx{ann_suffix} │ │ │ │ ├── yyy{ann_suffix} │ │ │ │ ├── zzz{ann_suffix} │ │ │ ├── val """ def __init__(self, metainfo, dataset_dir, transform, target_transform, image_set='train', img_suffix='.jpg', ann_suffix='.png', data_prefix: dict = dict(img_path='img', ann_path='ann'), return_dict=False): ''' :param metainfo: meta data in original dataset, e.g. class_names :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above :param image_set: 'train' or 'val' :param img_suffix: your image suffix :param ann_suffix: your annotation suffix :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann' :param return_dict: return dict() or tuple(img, ann) ''' super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform, target_transform=target_transform) self.class_names = metainfo['class_names'] self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set) self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set) print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format( img_folder_name=self.img_path, ann_folder_name=self.ann_path)) self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if img_name.endswith(img_suffix)] self.img_suffix = img_suffix self.ann_suffix = ann_suffix self.return_dict = return_dict def __getitem__(self, index): img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix)) ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) if self.transforms is not None: img, ann = self.transforms(img, ann) ann = np.array(ann) if self.return_dict: data = dict(img_name=self.img_names[index], img=img, ann=ann, img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) return data return img, ann def __len__(self): return len(self.img_names) class VOCSemanticDataset(Dataset): def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False): super(VOCSemanticDataset, self).__init__() self.root_dir = root_dir self.image_dir = self.root_dir + 'JPEGImages/' self.xml_dir = self.root_dir + 'Annotations/' self.mask_dir = self.root_dir + 'SegmentationClass/' self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()] self.transform = transform self.with_id = with_id self.with_mask = with_mask self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] def __len__(self): return len(self.image_id_list) def get_image(self, image_id): image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB') if self.transform is not None: image = self.transform(image) return image def get_mask(self, image_id): mask_path = self.mask_dir + image_id + '.png' if os.path.isfile(mask_path): mask = Image.open(mask_path) else: mask = None return mask def __getitem__(self, index): image_id = self.image_id_list[index] data_list = [self.get_image(image_id)] if self.with_id: data_list.append(image_id) if self.with_mask: data_list.append(self.get_mask(image_id)) return data_list class TorchVOCSegmentation(VOCSegmentation): def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None): super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download, transform=transform, target_transform=target_transform) self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] def __getitem__(self, index: int): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) if self.transforms is not None: img, target = self.transforms(img, target) target = np.array(target) return img, target class LettuceSegDataset(Dataset): def __init__(self, file_list, transform=None, # image_suffix=".JPG", image_suffix=".jpg", label_suffix=".json", width=None, height=None): super().__init__() self.file_list = file_list self.transform = transform self.image_suffix = image_suffix self.label_suffix = label_suffix self.width = width self.height = height # self.class_names = ['background', 'lettuce'] self.class_names = ['background', 'lettuce', 'weed'] def __len__(self): return len(self.file_list) def __getitem__(self, idx): image_path = self.file_list[idx] json_path = image_path.replace(self.image_suffix, self.label_suffix) image = cv2.imread(image_path, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w, _ = image.shape mask = np.zeros((h, w), dtype=np.uint8) with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) for shape in data.get("shapes", []): label_name = shape["label"] polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2)) # if label_name == "lettuce": if label_name == "weed": cv2.fillPoly(mask, [polygon], 1) if self.width is not None and self.height is not None: image = cv2.resize(image, (self.width, self.height)) mask = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST) image = torch.from_numpy(image.transpose(2, 0, 1)).float() mask = torch.from_numpy(mask[np.newaxis, ...]).long() return image, mask, image_path class construct_LettuceSegDataset(Dataset): def __init__(self, file_list, transform=None, # image_suffix=".JPG", image_suffix=".jpg", label_suffix=".json", width=None, height=None): super().__init__() self.file_list = file_list self.transform = transform self.image_suffix = image_suffix self.label_suffix = label_suffix self.width = width self.height = height # 添加 'weed' 类别 self.class_names = ['background', 'lettuce', 'weed'] # 定义类别到索引的映射 self.class_to_idx = { 'background': 0, 'lettuce': 1, 'weed': 2 } def __len__(self): return len(self.file_list) def __getitem__(self, idx): image_path = self.file_list[idx] json_path = image_path.replace(self.image_suffix, self.label_suffix) # 读取并转换图像 image = cv2.imread(image_path, cv2.IMREAD_COLOR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w, _ = image.shape # 创建掩码,初始化为背景类别(0) mask = np.zeros((h, w), dtype=np.uint8) # 读取标注文件 with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) # 处理所有标注的形状 for shape in data.get("shapes", []): label_name = shape["label"] polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2)) # 根据类别填充不同的值 if label_name in self.class_to_idx: cv2.fillPoly(mask, [polygon], self.class_to_idx[label_name]) # 如果需要调整大小 if self.width is not None and self.height is not None: image = cv2.resize(image, (self.width, self.height)) mask = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST) # 转换为张量 image = torch.from_numpy(image.transpose(2, 0, 1)).float() mask = torch.from_numpy(mask[np.newaxis, ...]).long() return image, mask, image_path