generated from yuanbiao/python_templates
61 lines
2.8 KiB
Python
61 lines
2.8 KiB
Python
from .detection import BaseDetectionDataset
|
|
from .instance_seg import BaseInstanceDataset
|
|
from .semantic_seg import BaseSemanticDataset, VOCSemanticDataset, TorchVOCSegmentation, LettuceSegDataset, construct_LettuceSegDataset
|
|
from .transforms import get_transforms
|
|
from torchvision.datasets import VOCSegmentation
|
|
from sklearn.model_selection import train_test_split
|
|
import glob
|
|
import os
|
|
segment_datasets = {'base_ins': BaseInstanceDataset, 'base_sem': BaseSemanticDataset,
|
|
'voc_sem': VOCSemanticDataset, 'torch_voc_sem': TorchVOCSegmentation, 'lettuce_sem':construct_LettuceSegDataset}
|
|
det_dataset = {'base_det': BaseDetectionDataset, }
|
|
def get_lettuce_dataset():
|
|
# all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG")))
|
|
all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune_multi_class/weed_data_bak', "*.jpg")))
|
|
# JPG_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG")))
|
|
# jpg_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.jpg")))
|
|
# all_image_paths = JPG_image_paths + jpg_image_paths
|
|
train_image_paths, val_image_paths = train_test_split(
|
|
all_image_paths, test_size=0.2, random_state=42
|
|
)
|
|
print(f"训练集数量: {len(train_image_paths)}")
|
|
print(f"测试集数量: {len(val_image_paths)}")
|
|
# train_dataset = LettuceSegDataset(train_image_paths, width=1024, height=1024)
|
|
# val_dataset = LettuceSegDataset(val_image_paths, width=1024, height=1024)
|
|
train_dataset = construct_LettuceSegDataset(train_image_paths, width=1024, height=1024)
|
|
val_dataset = construct_LettuceSegDataset(val_image_paths, width=1024, height=1024)
|
|
return train_dataset,val_dataset
|
|
|
|
def get_dataset(cfg):
|
|
name = cfg.name
|
|
assert name in segment_datasets or name in det_dataset, \
|
|
print('{name} is not supported, please implement it first.'.format(name=name))
|
|
# TODO customized dataset params:
|
|
# customized dataset params example:
|
|
# if xxx:
|
|
# param1 = cfg.xxx
|
|
# param2 = cfg.xxx
|
|
# return name_dict[name](path, model, param1, param2, ...)
|
|
transform = get_transforms(cfg.transforms)
|
|
if name in det_dataset:
|
|
return det_dataset[name](**cfg.params, transform=transform)
|
|
target_transform = get_transforms(cfg.target_transforms)
|
|
return segment_datasets[name](**cfg.params, transform=transform, target_transform=target_transform)
|
|
|
|
class Iterator:
|
|
def __init__(self, loader):
|
|
self.loader = loader
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.iterator = iter(self.loader)
|
|
|
|
def get(self):
|
|
try:
|
|
data = next(self.iterator)
|
|
except StopIteration:
|
|
self.init()
|
|
data = next(self.iterator)
|
|
|
|
return data
|