generated from yuanbiao/python_templates
265 lines
10 KiB
Python
265 lines
10 KiB
Python
import os
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset
|
|
from torchvision.datasets import VOCSegmentation, VisionDataset
|
|
import numpy as np
|
|
import cv2
|
|
import json
|
|
import torch
|
|
|
|
class BaseSemanticDataset(VisionDataset):
|
|
"""
|
|
if you want to customize a new dataset to train the segmentation task,
|
|
the img and mask file need be arranged as this sturcture.
|
|
├── data
|
|
│ ├── my_dataset
|
|
│ │ ├── img
|
|
│ │ │ ├── train
|
|
│ │ │ │ ├── xxx{img_suffix}
|
|
│ │ │ │ ├── yyy{img_suffix}
|
|
│ │ │ │ ├── zzz{img_suffix}
|
|
│ │ │ ├── val
|
|
│ │ ├── ann
|
|
│ │ │ ├── train
|
|
│ │ │ │ ├── xxx{ann_suffix}
|
|
│ │ │ │ ├── yyy{ann_suffix}
|
|
│ │ │ │ ├── zzz{ann_suffix}
|
|
│ │ │ ├── val
|
|
"""
|
|
|
|
def __init__(self, metainfo, dataset_dir, transform, target_transform,
|
|
image_set='train',
|
|
img_suffix='.jpg',
|
|
ann_suffix='.png',
|
|
data_prefix: dict = dict(img_path='img', ann_path='ann'),
|
|
return_dict=False):
|
|
'''
|
|
|
|
:param metainfo: meta data in original dataset, e.g. class_names
|
|
:param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above
|
|
:param image_set: 'train' or 'val'
|
|
:param img_suffix: your image suffix
|
|
:param ann_suffix: your annotation suffix
|
|
:param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann'
|
|
:param return_dict: return dict() or tuple(img, ann)
|
|
'''
|
|
super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform,
|
|
target_transform=target_transform)
|
|
|
|
self.class_names = metainfo['class_names']
|
|
self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set)
|
|
self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set)
|
|
print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format(
|
|
img_folder_name=self.img_path, ann_folder_name=self.ann_path))
|
|
self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if
|
|
img_name.endswith(img_suffix)]
|
|
self.img_suffix = img_suffix
|
|
self.ann_suffix = ann_suffix
|
|
self.return_dict = return_dict
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix))
|
|
ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
|
|
if self.transforms is not None:
|
|
img, ann = self.transforms(img, ann)
|
|
ann = np.array(ann)
|
|
|
|
if self.return_dict:
|
|
data = dict(img_name=self.img_names[index], img=img, ann=ann,
|
|
img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
|
|
ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
|
|
return data
|
|
return img, ann
|
|
|
|
def __len__(self):
|
|
return len(self.img_names)
|
|
|
|
|
|
class VOCSemanticDataset(Dataset):
|
|
def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False):
|
|
super(VOCSemanticDataset, self).__init__()
|
|
self.root_dir = root_dir
|
|
|
|
self.image_dir = self.root_dir + 'JPEGImages/'
|
|
self.xml_dir = self.root_dir + 'Annotations/'
|
|
self.mask_dir = self.root_dir + 'SegmentationClass/'
|
|
|
|
self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()]
|
|
self.transform = transform
|
|
self.with_id = with_id
|
|
self.with_mask = with_mask
|
|
self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
|
'bus', 'car', 'cat', 'chair', 'cow',
|
|
'diningtable', 'dog', 'horse', 'motorbike', 'person',
|
|
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
|
|
|
|
def __len__(self):
|
|
return len(self.image_id_list)
|
|
|
|
def get_image(self, image_id):
|
|
image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB')
|
|
if self.transform is not None:
|
|
image = self.transform(image)
|
|
return image
|
|
|
|
def get_mask(self, image_id):
|
|
mask_path = self.mask_dir + image_id + '.png'
|
|
if os.path.isfile(mask_path):
|
|
mask = Image.open(mask_path)
|
|
else:
|
|
mask = None
|
|
return mask
|
|
|
|
def __getitem__(self, index):
|
|
image_id = self.image_id_list[index]
|
|
|
|
data_list = [self.get_image(image_id)]
|
|
|
|
if self.with_id:
|
|
data_list.append(image_id)
|
|
|
|
if self.with_mask:
|
|
data_list.append(self.get_mask(image_id))
|
|
|
|
return data_list
|
|
|
|
|
|
class TorchVOCSegmentation(VOCSegmentation):
|
|
def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None):
|
|
super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download,
|
|
transform=transform, target_transform=target_transform)
|
|
self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
|
'bus', 'car', 'cat', 'chair', 'cow',
|
|
'diningtable', 'dog', 'horse', 'motorbike', 'person',
|
|
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
|
|
|
|
def __getitem__(self, index: int):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
|
|
Returns:
|
|
tuple: (image, target) where target is the image segmentation.
|
|
"""
|
|
img = Image.open(self.images[index]).convert('RGB')
|
|
target = Image.open(self.masks[index])
|
|
|
|
if self.transforms is not None:
|
|
img, target = self.transforms(img, target)
|
|
|
|
target = np.array(target)
|
|
return img, target
|
|
|
|
class LettuceSegDataset(Dataset):
|
|
def __init__(self,
|
|
file_list,
|
|
transform=None,
|
|
# image_suffix=".JPG",
|
|
image_suffix=".jpg",
|
|
label_suffix=".json",
|
|
width=None,
|
|
height=None):
|
|
super().__init__()
|
|
self.file_list = file_list
|
|
self.transform = transform
|
|
self.image_suffix = image_suffix
|
|
self.label_suffix = label_suffix
|
|
self.width = width
|
|
self.height = height
|
|
# self.class_names = ['background', 'lettuce']
|
|
self.class_names = ['background', 'lettuce', 'weed']
|
|
|
|
def __len__(self):
|
|
return len(self.file_list)
|
|
|
|
def __getitem__(self, idx):
|
|
image_path = self.file_list[idx]
|
|
json_path = image_path.replace(self.image_suffix, self.label_suffix)
|
|
|
|
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
h, w, _ = image.shape
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
for shape in data.get("shapes", []):
|
|
label_name = shape["label"]
|
|
polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2))
|
|
# if label_name == "lettuce":
|
|
if label_name == "weed":
|
|
cv2.fillPoly(mask, [polygon], 1)
|
|
|
|
if self.width is not None and self.height is not None:
|
|
image = cv2.resize(image, (self.width, self.height))
|
|
mask = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
|
|
|
|
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
|
|
mask = torch.from_numpy(mask[np.newaxis, ...]).long()
|
|
return image, mask, image_path
|
|
|
|
class construct_LettuceSegDataset(Dataset):
|
|
def __init__(self,
|
|
file_list,
|
|
transform=None,
|
|
# image_suffix=".JPG",
|
|
image_suffix=".jpg",
|
|
label_suffix=".json",
|
|
width=None,
|
|
height=None):
|
|
super().__init__()
|
|
self.file_list = file_list
|
|
self.transform = transform
|
|
self.image_suffix = image_suffix
|
|
self.label_suffix = label_suffix
|
|
self.width = width
|
|
self.height = height
|
|
# 添加 'weed' 类别
|
|
self.class_names = ['background', 'lettuce', 'weed']
|
|
# 定义类别到索引的映射
|
|
self.class_to_idx = {
|
|
'background': 0,
|
|
'lettuce': 1,
|
|
'weed': 2
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.file_list)
|
|
|
|
def __getitem__(self, idx):
|
|
image_path = self.file_list[idx]
|
|
json_path = image_path.replace(self.image_suffix, self.label_suffix)
|
|
|
|
# 读取并转换图像
|
|
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
h, w, _ = image.shape
|
|
|
|
# 创建掩码,初始化为背景类别(0)
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
|
# 读取标注文件
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# 处理所有标注的形状
|
|
for shape in data.get("shapes", []):
|
|
label_name = shape["label"]
|
|
polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2))
|
|
|
|
# 根据类别填充不同的值
|
|
if label_name in self.class_to_idx:
|
|
cv2.fillPoly(mask, [polygon], self.class_to_idx[label_name])
|
|
|
|
# 如果需要调整大小
|
|
if self.width is not None and self.height is not None:
|
|
image = cv2.resize(image, (self.width, self.height))
|
|
mask = cv2.resize(mask, (self.width, self.height),
|
|
interpolation=cv2.INTER_NEAREST)
|
|
|
|
# 转换为张量
|
|
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
|
|
mask = torch.from_numpy(mask[np.newaxis, ...]).long()
|
|
|
|
return image, mask, image_path |