laser_weeding/datasets/__init__.py

61 lines
2.8 KiB
Python

from .detection import BaseDetectionDataset
from .instance_seg import BaseInstanceDataset
from .semantic_seg import BaseSemanticDataset, VOCSemanticDataset, TorchVOCSegmentation, LettuceSegDataset, construct_LettuceSegDataset
from .transforms import get_transforms
from torchvision.datasets import VOCSegmentation
from sklearn.model_selection import train_test_split
import glob
import os
segment_datasets = {'base_ins': BaseInstanceDataset, 'base_sem': BaseSemanticDataset,
'voc_sem': VOCSemanticDataset, 'torch_voc_sem': TorchVOCSegmentation, 'lettuce_sem':construct_LettuceSegDataset}
det_dataset = {'base_det': BaseDetectionDataset, }
def get_lettuce_dataset():
# all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG")))
all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune_multi_class/weed_data_bak', "*.jpg")))
# JPG_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG")))
# jpg_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.jpg")))
# all_image_paths = JPG_image_paths + jpg_image_paths
train_image_paths, val_image_paths = train_test_split(
all_image_paths, test_size=0.2, random_state=42
)
print(f"训练集数量: {len(train_image_paths)}")
print(f"测试集数量: {len(val_image_paths)}")
# train_dataset = LettuceSegDataset(train_image_paths, width=1024, height=1024)
# val_dataset = LettuceSegDataset(val_image_paths, width=1024, height=1024)
train_dataset = construct_LettuceSegDataset(train_image_paths, width=1024, height=1024)
val_dataset = construct_LettuceSegDataset(val_image_paths, width=1024, height=1024)
return train_dataset,val_dataset
def get_dataset(cfg):
name = cfg.name
assert name in segment_datasets or name in det_dataset, \
print('{name} is not supported, please implement it first.'.format(name=name))
# TODO customized dataset params:
# customized dataset params example:
# if xxx:
# param1 = cfg.xxx
# param2 = cfg.xxx
# return name_dict[name](path, model, param1, param2, ...)
transform = get_transforms(cfg.transforms)
if name in det_dataset:
return det_dataset[name](**cfg.params, transform=transform)
target_transform = get_transforms(cfg.target_transforms)
return segment_datasets[name](**cfg.params, transform=transform, target_transform=target_transform)
class Iterator:
def __init__(self, loader):
self.loader = loader
self.init()
def init(self):
self.iterator = iter(self.loader)
def get(self):
try:
data = next(self.iterator)
except StopIteration:
self.init()
data = next(self.iterator)
return data