laser_weeding/datasets/matting.py

104 lines
5.0 KiB
Python

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import VisionDataset
import numpy as np
class BaseMattingDataset(VisionDataset):
"""
if you want to customize a new dataset to train the matting task,
the img and mask file need be arranged as this sturcture.
├── data
│ ├── my_dataset
│ │ ├── img
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── trimap
│ │ │ ├── train
│ │ │ │ ├── xxx{img_suffix}
│ │ │ │ ├── yyy{img_suffix}
│ │ │ │ ├── zzz{img_suffix}
│ │ │ ├── val
│ │ ├── ann
│ │ │ ├── train
│ │ │ │ ├── xxx{ann_suffix}
│ │ │ │ ├── yyy{ann_suffix}
│ │ │ │ ├── zzz{ann_suffix}
│ │ │ ├── val
"""
def __init__(self, metainfo, dataset_dir, transform, target_transform,
trimap_transform=None,
image_set='train',
img_suffix='.jpg',
ann_suffix='.png',
trimap_suffix=None,
data_prefix: dict = dict(img_path='img', ann_path='ann', trimap_path='trimap_pth'),
return_dict=False):
'''
:param metainfo: meta data in original dataset, e.g. class_names
:param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above
:param image_set: 'train' or 'val'
:param img_suffix: your image suffix
:param ann_suffix: your annotation suffix
:param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann'
:param return_dict: return dict() or tuple(img, ann)
'''
super(BaseMattingDataset, self).__init__(root=dataset_dir, transform=transform,
target_transform=target_transform)
self.class_names = metainfo['class_names']
self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set)
self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set)
print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format(
img_folder_name=self.img_path, ann_folder_name=self.ann_path))
self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if
img_name.endswith(img_suffix)]
self.has_trimap = trimap_suffix is not None
if self.has_trimap:
self.trimap_path = os.path.join(dataset_dir, data_prefix['trimap_pth'], image_set)
print('trimap_folder_name: {trimap_folder_name}'.format(trimap_folder_name=self.trimap_path))
self.img_suffix = img_suffix
self.ann_suffix = ann_suffix
self.return_dict = return_dict
self.trimap_transform = trimap_transform
def __getitem__(self, index):
img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix))
ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
if self.transforms is not None:
img, ann = self.transforms(img, ann)
ann = np.array(ann)
if self.has_trimap:
## return for self.has_trimpa==True
trimap = Image.open(os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix))
if self.trimap_transform:
trimap = self.trimap_transform(trimap)
else:
print("Warnning: you may need set transform function for trimap input")
if self.return_dict:
data = dict(img_name=self.img_names[index], img=img, ann=ann, trimap=trimap,
img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix),
trimap_path=os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix))
return data
return img, ann, trimap
else:
## return for self.has_trimpa==False
if self.return_dict:
data = dict(img_name=self.img_names[index], img=img, ann=ann,
img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix),
ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix))
return data
return img, ann
def __len__(self):
return len(self.img_names)