diff --git a/config/semantic_seg.yaml b/config/semantic_seg.yaml new file mode 100644 index 0000000..13b6a80 --- /dev/null +++ b/config/semantic_seg.yaml @@ -0,0 +1,95 @@ +train: + experiment_name: 'semantic_sam' + + # Model + model: + sam_name: 'sem_sam' + params: + # Fix the a part of parameters in SAM + fix_img_en: True + fix_prompt_en: True + fix_mask_de: False + ckpt_path: '/home/sweet/trt-finetune-anything/sam_ckpts/sam_vit_b_16.pth' + # class_num: 2 + class_num: 3 # [background, lettuce, weed] [0, 1, 2] + model_type: 'vit_b' # type should be in [vit_h, vit_b, vit_l, default] + + # Dataset + dataset: + name: 'torch_voc_sem' + params: + root: '/data/jinziqi/DATASETS/' + year: '2012' + image_set: 'train' + transforms: + resize: + params: + size: [1024, 1024] + to_tensor: + params: ~ + target_transforms: + resize: + params: + size: [1024, 1024] + + # Losses + losses: + ce: + weight: 0.5 + params: # ~ means None type, the initial params of loss could be identified here + ignore_index: 255 + label_one_hot: False + + # Optimizer + opt_params: + lr_default: 1e-3 + wd_default: 1e-4 + momentum: 0.9 + lr_list: [ 1e-2, ] + group_keys: [ [ 'mask_adapter.decoder_head', ], ] + wd_list: [ 0.0, ] + opt_name: 'sgd' # 'sgd' + scheduler_name: 'cosine' + + # Runner + max_iter: 100000 + log_iter: 20 + eval_iter: 100 + runner_name: 'sem_runner' + # Dataloader + bs: 2 # 8 + num_workers: 2 + drop_last: True + # Logger + use_tensorboard: True + tensorboard_folder: './experiment/tensorboard' + log_folder: './experiment/log' + model_folder: './experiment/model' + +val: + # Dataset + dataset: + name: 'torch_voc_sem' + params: + root: '/data/jinziqi/DATASETS/' + year: '2012' + image_set: 'train' + transforms: + resize: + params: + size: [1024, 1024] + to_tensor: + params: ~ + target_transforms: + resize: + params: + size: [1024, 1024] + + bs: 2 + num_workers: 2 + drop_last: True + + +test: + need_test: False + diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..1624873 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,60 @@ +from .detection import BaseDetectionDataset +from .instance_seg import BaseInstanceDataset +from .semantic_seg import BaseSemanticDataset, VOCSemanticDataset, TorchVOCSegmentation, LettuceSegDataset, construct_LettuceSegDataset +from .transforms import get_transforms +from torchvision.datasets import VOCSegmentation +from sklearn.model_selection import train_test_split +import glob +import os +segment_datasets = {'base_ins': BaseInstanceDataset, 'base_sem': BaseSemanticDataset, + 'voc_sem': VOCSemanticDataset, 'torch_voc_sem': TorchVOCSegmentation, 'lettuce_sem':construct_LettuceSegDataset} +det_dataset = {'base_det': BaseDetectionDataset, } +def get_lettuce_dataset(): + # all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG"))) + all_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune_multi_class/weed_data_bak', "*.jpg"))) + # JPG_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.JPG"))) + # jpg_image_paths = sorted(glob.glob(os.path.join('/home/sweetai/large_model/sam_finetune/lettuce_data', "*.jpg"))) + # all_image_paths = JPG_image_paths + jpg_image_paths + train_image_paths, val_image_paths = train_test_split( + all_image_paths, test_size=0.2, random_state=42 + ) + print(f"训练集数量: {len(train_image_paths)}") + print(f"测试集数量: {len(val_image_paths)}") + # train_dataset = LettuceSegDataset(train_image_paths, width=1024, height=1024) + # val_dataset = LettuceSegDataset(val_image_paths, width=1024, height=1024) + train_dataset = construct_LettuceSegDataset(train_image_paths, width=1024, height=1024) + val_dataset = construct_LettuceSegDataset(val_image_paths, width=1024, height=1024) + return train_dataset,val_dataset + +def get_dataset(cfg): + name = cfg.name + assert name in segment_datasets or name in det_dataset, \ + print('{name} is not supported, please implement it first.'.format(name=name)) + # TODO customized dataset params: + # customized dataset params example: + # if xxx: + # param1 = cfg.xxx + # param2 = cfg.xxx + # return name_dict[name](path, model, param1, param2, ...) + transform = get_transforms(cfg.transforms) + if name in det_dataset: + return det_dataset[name](**cfg.params, transform=transform) + target_transform = get_transforms(cfg.target_transforms) + return segment_datasets[name](**cfg.params, transform=transform, target_transform=target_transform) + +class Iterator: + def __init__(self, loader): + self.loader = loader + self.init() + + def init(self): + self.iterator = iter(self.loader) + + def get(self): + try: + data = next(self.iterator) + except StopIteration: + self.init() + data = next(self.iterator) + + return data diff --git a/datasets/detection.py b/datasets/detection.py new file mode 100644 index 0000000..881e3b6 --- /dev/null +++ b/datasets/detection.py @@ -0,0 +1,9 @@ +from torch.utils.data import Dataset + + +class BaseDetectionDataset(Dataset): + def __init__(self): + assert False, print('BaseDetectionDataset is not Unimplemented.') + + def __getitem__(self, item): + pass diff --git a/datasets/instance_seg.py b/datasets/instance_seg.py new file mode 100644 index 0000000..bf86b42 --- /dev/null +++ b/datasets/instance_seg.py @@ -0,0 +1,9 @@ +from torch.utils.data import Dataset + + +class BaseInstanceDataset(Dataset): + def __init__(self): + assert False, print("Unimplement Dataset.") + + def __getitem__(self, item): + pass diff --git a/datasets/matting.py b/datasets/matting.py new file mode 100644 index 0000000..581c7ad --- /dev/null +++ b/datasets/matting.py @@ -0,0 +1,103 @@ +import os +from PIL import Image +from torch.utils.data import Dataset +from torchvision.datasets import VisionDataset +import numpy as np + +class BaseMattingDataset(VisionDataset): + """ + if you want to customize a new dataset to train the matting task, + the img and mask file need be arranged as this sturcture. + ├── data + │ ├── my_dataset + │ │ ├── img + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── trimap + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann + │ │ │ ├── train + │ │ │ │ ├── xxx{ann_suffix} + │ │ │ │ ├── yyy{ann_suffix} + │ │ │ │ ├── zzz{ann_suffix} + │ │ │ ├── val + """ + + def __init__(self, metainfo, dataset_dir, transform, target_transform, + trimap_transform=None, + image_set='train', + img_suffix='.jpg', + ann_suffix='.png', + trimap_suffix=None, + data_prefix: dict = dict(img_path='img', ann_path='ann', trimap_path='trimap_pth'), + return_dict=False): + ''' + + :param metainfo: meta data in original dataset, e.g. class_names + :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above + :param image_set: 'train' or 'val' + :param img_suffix: your image suffix + :param ann_suffix: your annotation suffix + :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann' + :param return_dict: return dict() or tuple(img, ann) + ''' + super(BaseMattingDataset, self).__init__(root=dataset_dir, transform=transform, + target_transform=target_transform) + + self.class_names = metainfo['class_names'] + self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set) + self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set) + + print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format( + img_folder_name=self.img_path, ann_folder_name=self.ann_path)) + self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if + img_name.endswith(img_suffix)] + + self.has_trimap = trimap_suffix is not None + if self.has_trimap: + self.trimap_path = os.path.join(dataset_dir, data_prefix['trimap_pth'], image_set) + print('trimap_folder_name: {trimap_folder_name}'.format(trimap_folder_name=self.trimap_path)) + self.img_suffix = img_suffix + self.ann_suffix = ann_suffix + self.return_dict = return_dict + self.trimap_transform = trimap_transform + + def __getitem__(self, index): + img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix)) + ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) + if self.transforms is not None: + img, ann = self.transforms(img, ann) + ann = np.array(ann) + if self.has_trimap: + ## return for self.has_trimpa==True + trimap = Image.open(os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix)) + if self.trimap_transform: + trimap = self.trimap_transform(trimap) + else: + print("Warnning: you may need set transform function for trimap input") + if self.return_dict: + data = dict(img_name=self.img_names[index], img=img, ann=ann, trimap=trimap, + img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), + ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix), + trimap_path=os.path.join(self.trimap_path, self.img_names[index] + self.trimap_suffix)) + return data + return img, ann, trimap + else: + ## return for self.has_trimpa==False + if self.return_dict: + data = dict(img_name=self.img_names[index], img=img, ann=ann, + img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), + ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) + return data + return img, ann + + def __len__(self): + return len(self.img_names) + diff --git a/datasets/semantic_seg.py b/datasets/semantic_seg.py new file mode 100644 index 0000000..514b5ac --- /dev/null +++ b/datasets/semantic_seg.py @@ -0,0 +1,265 @@ +import os +from PIL import Image +from torch.utils.data import Dataset +from torchvision.datasets import VOCSegmentation, VisionDataset +import numpy as np +import cv2 +import json +import torch + +class BaseSemanticDataset(VisionDataset): + """ + if you want to customize a new dataset to train the segmentation task, + the img and mask file need be arranged as this sturcture. + ├── data + │ ├── my_dataset + │ │ ├── img + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann + │ │ │ ├── train + │ │ │ │ ├── xxx{ann_suffix} + │ │ │ │ ├── yyy{ann_suffix} + │ │ │ │ ├── zzz{ann_suffix} + │ │ │ ├── val + """ + + def __init__(self, metainfo, dataset_dir, transform, target_transform, + image_set='train', + img_suffix='.jpg', + ann_suffix='.png', + data_prefix: dict = dict(img_path='img', ann_path='ann'), + return_dict=False): + ''' + + :param metainfo: meta data in original dataset, e.g. class_names + :param dataset_dir: the path of your dataset, e.g. data/my_dataset/ by the stucture tree above + :param image_set: 'train' or 'val' + :param img_suffix: your image suffix + :param ann_suffix: your annotation suffix + :param data_prefix: data folder name, as the tree shows above, the data_prefix of my_dataset: img_path='img' , ann_path='ann' + :param return_dict: return dict() or tuple(img, ann) + ''' + super(BaseSemanticDataset, self).__init__(root=dataset_dir, transform=transform, + target_transform=target_transform) + + self.class_names = metainfo['class_names'] + self.img_path = os.path.join(dataset_dir, data_prefix['img_path'], image_set) + self.ann_path = os.path.join(dataset_dir, data_prefix['ann_path'], image_set) + print('img_folder_name: {img_folder_name}, ann_folder_name: {ann_folder_name}'.format( + img_folder_name=self.img_path, ann_folder_name=self.ann_path)) + self.img_names = [img_name.split(img_suffix)[0] for img_name in os.listdir(self.img_path) if + img_name.endswith(img_suffix)] + self.img_suffix = img_suffix + self.ann_suffix = ann_suffix + self.return_dict = return_dict + + def __getitem__(self, index): + img = Image.open(os.path.join(self.img_path, self.img_names[index] + self.img_suffix)) + ann = Image.open(os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) + if self.transforms is not None: + img, ann = self.transforms(img, ann) + ann = np.array(ann) + + if self.return_dict: + data = dict(img_name=self.img_names[index], img=img, ann=ann, + img_path=os.path.join(self.img_path, self.img_names[index] + self.img_suffix), + ann_path=os.path.join(self.ann_path, self.img_names[index] + self.ann_suffix)) + return data + return img, ann + + def __len__(self): + return len(self.img_names) + + +class VOCSemanticDataset(Dataset): + def __init__(self, root_dir, domain, transform, with_id=False, with_mask=False): + super(VOCSemanticDataset, self).__init__() + self.root_dir = root_dir + + self.image_dir = self.root_dir + 'JPEGImages/' + self.xml_dir = self.root_dir + 'Annotations/' + self.mask_dir = self.root_dir + 'SegmentationClass/' + + self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt' % domain).readlines()] + self.transform = transform + self.with_id = with_id + self.with_mask = with_mask + self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', + 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] + + def __len__(self): + return len(self.image_id_list) + + def get_image(self, image_id): + image = Image.open(self.image_dir + image_id + '.jpg').convert('RGB') + if self.transform is not None: + image = self.transform(image) + return image + + def get_mask(self, image_id): + mask_path = self.mask_dir + image_id + '.png' + if os.path.isfile(mask_path): + mask = Image.open(mask_path) + else: + mask = None + return mask + + def __getitem__(self, index): + image_id = self.image_id_list[index] + + data_list = [self.get_image(image_id)] + + if self.with_id: + data_list.append(image_id) + + if self.with_mask: + data_list.append(self.get_mask(image_id)) + + return data_list + + +class TorchVOCSegmentation(VOCSegmentation): + def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None): + super(TorchVOCSegmentation, self).__init__(root=root, year=year, image_set=image_set, download=download, + transform=transform, target_transform=target_transform) + self.class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', + 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] + + def __getitem__(self, index: int): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the image segmentation. + """ + img = Image.open(self.images[index]).convert('RGB') + target = Image.open(self.masks[index]) + + if self.transforms is not None: + img, target = self.transforms(img, target) + + target = np.array(target) + return img, target + +class LettuceSegDataset(Dataset): + def __init__(self, + file_list, + transform=None, + # image_suffix=".JPG", + image_suffix=".jpg", + label_suffix=".json", + width=None, + height=None): + super().__init__() + self.file_list = file_list + self.transform = transform + self.image_suffix = image_suffix + self.label_suffix = label_suffix + self.width = width + self.height = height + # self.class_names = ['background', 'lettuce'] + self.class_names = ['background', 'lettuce', 'weed'] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + image_path = self.file_list[idx] + json_path = image_path.replace(self.image_suffix, self.label_suffix) + + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + h, w, _ = image.shape + mask = np.zeros((h, w), dtype=np.uint8) + + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + for shape in data.get("shapes", []): + label_name = shape["label"] + polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2)) + # if label_name == "lettuce": + if label_name == "weed": + cv2.fillPoly(mask, [polygon], 1) + + if self.width is not None and self.height is not None: + image = cv2.resize(image, (self.width, self.height)) + mask = cv2.resize(mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST) + + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + mask = torch.from_numpy(mask[np.newaxis, ...]).long() + return image, mask, image_path + +class construct_LettuceSegDataset(Dataset): + def __init__(self, + file_list, + transform=None, + # image_suffix=".JPG", + image_suffix=".jpg", + label_suffix=".json", + width=None, + height=None): + super().__init__() + self.file_list = file_list + self.transform = transform + self.image_suffix = image_suffix + self.label_suffix = label_suffix + self.width = width + self.height = height + # 添加 'weed' 类别 + self.class_names = ['background', 'lettuce', 'weed'] + # 定义类别到索引的映射 + self.class_to_idx = { + 'background': 0, + 'lettuce': 1, + 'weed': 2 + } + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + image_path = self.file_list[idx] + json_path = image_path.replace(self.image_suffix, self.label_suffix) + + # 读取并转换图像 + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + h, w, _ = image.shape + + # 创建掩码,初始化为背景类别(0) + mask = np.zeros((h, w), dtype=np.uint8) + + # 读取标注文件 + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 处理所有标注的形状 + for shape in data.get("shapes", []): + label_name = shape["label"] + polygon = np.array(shape["points"], dtype=np.int32).reshape((-1,1,2)) + + # 根据类别填充不同的值 + if label_name in self.class_to_idx: + cv2.fillPoly(mask, [polygon], self.class_to_idx[label_name]) + + # 如果需要调整大小 + if self.width is not None and self.height is not None: + image = cv2.resize(image, (self.width, self.height)) + mask = cv2.resize(mask, (self.width, self.height), + interpolation=cv2.INTER_NEAREST) + + # 转换为张量 + image = torch.from_numpy(image.transpose(2, 0, 1)).float() + mask = torch.from_numpy(mask[np.newaxis, ...]).long() + + return image, mask, image_path \ No newline at end of file diff --git a/datasets/transforms.py b/datasets/transforms.py new file mode 100644 index 0000000..b3ad489 --- /dev/null +++ b/datasets/transforms.py @@ -0,0 +1,25 @@ +import torchvision.transforms as T +from omegaconf.dictconfig import DictConfig +import torch.nn as nn + +AVIAL_TRANSFORM = {'resize': T.Resize, 'to_tensor': T.ToTensor} + + +def get_transforms(transforms: DictConfig): + T_list = [] + for t_name in transforms.keys(): + assert t_name in AVIAL_TRANSFORM, "{T_name} is not supported transform, please implement it and add it to " \ + "AVIAL_TRANSFORM first.".format(T_name=t_name) + if transforms[t_name].params is not None: + T_list.append(AVIAL_TRANSFORM[t_name](**transforms[t_name].params)) + else: + T_list.append(AVIAL_TRANSFORM[t_name]()) + return T.Compose(T_list) + + +class CustomTransform(nn.Module): + def __init__(self): + pass + + def forward(self): + pass diff --git a/experiment/log/semantic_sam/log_file.txt b/experiment/log/semantic_sam/log_file.txt new file mode 100644 index 0000000..3a93933 --- /dev/null +++ b/experiment/log/semantic_sam/log_file.txt @@ -0,0 +1,312 @@ +iteration : 19, ce : 1.752346932888031, total_loss : 0.8761734664440155, time : 4 +iteration : 39, ce : 0.43678617626428606, total_loss : 0.21839308813214303, time : 4 +iteration : 59, ce : 0.22363422363996505, total_loss : 0.11181711181998252, time : 3 +iteration : 79, ce : 0.1457903351634741, total_loss : 0.07289516758173704, time : 3 +iteration : 99, ce : 0.13958008363842964, total_loss : 0.06979004181921482, time : 3 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 99, mIoU : 86.78554050143299, best_valid_mIoU : 86.78554050143299, time : 60 +iteration : 119, ce : 0.11194387599825859, total_loss : 0.05597193799912929, time : 64 +iteration : 139, ce : 0.13336510248482228, total_loss : 0.06668255124241114, time : 4 +iteration : 159, ce : 0.1312786541879177, total_loss : 0.06563932709395885, time : 3 +iteration : 179, ce : 0.12188598942011594, total_loss : 0.06094299471005797, time : 3 +iteration : 199, ce : 0.11663060411810874, total_loss : 0.05831530205905437, time : 4 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 199, mIoU : 89.69729538718849, best_valid_mIoU : 89.69729538718849, time : 59 +iteration : 219, ce : 0.1029241617769003, total_loss : 0.05146208088845015, time : 63 +iteration : 239, ce : 0.12284142505377531, total_loss : 0.061420712526887654, time : 4 +iteration : 259, ce : 0.11378218494355678, total_loss : 0.05689109247177839, time : 3 +iteration : 279, ce : 0.11434118337929249, total_loss : 0.05717059168964624, time : 4 +iteration : 299, ce : 0.11516949944198132, total_loss : 0.05758474972099066, time : 3 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 299, mIoU : 90.45565191317206, best_valid_mIoU : 90.45565191317206, time : 60 +iteration : 319, ce : 0.11855659522116184, total_loss : 0.05927829761058092, time : 64 +iteration : 339, ce : 0.10724063962697983, total_loss : 0.053620319813489914, time : 4 +iteration : 359, ce : 0.10985201448202134, total_loss : 0.05492600724101067, time : 3 +iteration : 379, ce : 0.09583570621907711, total_loss : 0.047917853109538555, time : 4 +iteration : 399, ce : 0.10013786368072033, total_loss : 0.05006893184036017, time : 4 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 399, mIoU : 90.56240643946504, best_valid_mIoU : 90.56240643946504, time : 62 +iteration : 419, ce : 0.09867587257176638, total_loss : 0.04933793628588319, time : 67 +iteration : 439, ce : 0.10385298412293195, total_loss : 0.051926492061465976, time : 4 +iteration : 459, ce : 0.10420440025627613, total_loss : 0.052102200128138064, time : 4 +iteration : 479, ce : 0.09501391369849443, total_loss : 0.04750695684924722, time : 3 +iteration : 499, ce : 0.08710500337183476, total_loss : 0.04355250168591738, time : 3 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 499, mIoU : 91.15427718021175, best_valid_mIoU : 91.15427718021175, time : 57 +iteration : 519, ce : 0.10585700683295726, total_loss : 0.05292850341647863, time : 61 +iteration : 539, ce : 0.10250990018248558, total_loss : 0.05125495009124279, time : 4 +iteration : 559, ce : 0.10396505333483219, total_loss : 0.051982526667416096, time : 4 +iteration : 579, ce : 0.1050586698576808, total_loss : 0.0525293349288404, time : 3 +iteration : 599, ce : 0.10245818942785263, total_loss : 0.05122909471392632, time : 4 +iteration : 599, mIoU : 89.07530754179987, best_valid_mIoU : 91.15427718021175, time : 56 +iteration : 619, ce : 0.10242737587541342, total_loss : 0.05121368793770671, time : 60 +iteration : 639, ce : 0.11541221737861633, total_loss : 0.057706108689308165, time : 3 +iteration : 659, ce : 0.08750226031988859, total_loss : 0.04375113015994429, time : 4 +iteration : 679, ce : 0.10053982809185982, total_loss : 0.05026991404592991, time : 4 +iteration : 699, ce : 0.10735129471868277, total_loss : 0.05367564735934138, time : 4 +iteration : 699, mIoU : 90.4541198653171, best_valid_mIoU : 91.15427718021175, time : 65 +iteration : 719, ce : 0.08845936376601457, total_loss : 0.044229681883007285, time : 69 +iteration : 739, ce : 0.09841484446078538, total_loss : 0.04920742223039269, time : 4 +iteration : 19, ce : 1.6118955180048942, total_loss : 0.8059477590024471, time : 10 +iteration : 39, ce : 0.6238165870308876, total_loss : 0.3119082935154438, time : 10 +iteration : 59, ce : 0.5466279983520508, total_loss : 0.2733139991760254, time : 10 +iteration : 79, ce : 0.3118415541946888, total_loss : 0.1559207770973444, time : 10 +iteration : 99, ce : 0.22329067587852477, total_loss : 0.11164533793926239, time : 10 +iteration : 19, ce : 1.6118955209851265, total_loss : 0.8059477604925632, time : 10 +iteration : 39, ce : 0.6238171197474003, total_loss : 0.31190855987370014, time : 10 +iteration : 59, ce : 0.546618615090847, total_loss : 0.2733093075454235, time : 10 +iteration : 79, ce : 0.31183895096182823, total_loss : 0.15591947548091412, time : 10 +iteration : 99, ce : 0.22327864803373815, total_loss : 0.11163932401686907, time : 10 +iteration : 19, ce : 1.6118965715169906, total_loss : 0.8059482857584953, time : 10 +iteration : 39, ce : 0.6238209880888462, total_loss : 0.3119104940444231, time : 10 +iteration : 59, ce : 0.5466369971632957, total_loss : 0.27331849858164786, time : 10 +iteration : 79, ce : 0.31184642761945724, total_loss : 0.15592321380972862, time : 10 +iteration : 99, ce : 0.22328564003109933, total_loss : 0.11164282001554966, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 99, mIoU : 76.15534017940087, best_valid_mIoU : 76.15534017940087, time : 43 +iteration : 119, ce : 0.1746383562684059, total_loss : 0.08731917813420295, time : 53 +iteration : 139, ce : 0.1467339999973774, total_loss : 0.0733669999986887, time : 10 +iteration : 159, ce : 0.12853854857385158, total_loss : 0.06426927428692579, time : 10 +iteration : 179, ce : 0.12929687201976775, total_loss : 0.06464843600988388, time : 10 +iteration : 199, ce : 0.12117353715002536, total_loss : 0.06058676857501268, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 199, mIoU : 89.77212887720785, best_valid_mIoU : 89.77212887720785, time : 44 +iteration : 219, ce : 0.12305688261985778, total_loss : 0.06152844130992889, time : 54 +iteration : 239, ce : 0.11886226013302803, total_loss : 0.059431130066514015, time : 10 +iteration : 259, ce : 0.13031740970909594, total_loss : 0.06515870485454797, time : 10 +iteration : 279, ce : 0.1261220879852772, total_loss : 0.0630610439926386, time : 10 +iteration : 299, ce : 0.11300399377942086, total_loss : 0.05650199688971043, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 299, mIoU : 90.50859210757906, best_valid_mIoU : 90.50859210757906, time : 44 +iteration : 319, ce : 0.13202827107161283, total_loss : 0.06601413553580641, time : 54 +iteration : 339, ce : 0.10633355155587196, total_loss : 0.05316677577793598, time : 10 +iteration : 359, ce : 0.11914260871708393, total_loss : 0.059571304358541965, time : 10 +iteration : 379, ce : 0.10447845719754696, total_loss : 0.05223922859877348, time : 10 +iteration : 399, ce : 0.10292214751243592, total_loss : 0.05146107375621796, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 399, mIoU : 91.10752328387134, best_valid_mIoU : 91.10752328387134, time : 44 +iteration : 419, ce : 0.11132022961974145, total_loss : 0.05566011480987072, time : 54 +iteration : 439, ce : 0.11224379669874907, total_loss : 0.05612189834937453, time : 10 +iteration : 459, ce : 0.09896511361002922, total_loss : 0.04948255680501461, time : 10 +iteration : 479, ce : 0.09913789071142673, total_loss : 0.049568945355713365, time : 10 +iteration : 499, ce : 0.1061447437852621, total_loss : 0.05307237189263105, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 499, mIoU : 91.4555279922262, best_valid_mIoU : 91.4555279922262, time : 44 +iteration : 519, ce : 0.12007921487092972, total_loss : 0.06003960743546486, time : 54 +iteration : 539, ce : 0.10178584884852171, total_loss : 0.050892924424260855, time : 10 +iteration : 559, ce : 0.11588475815951824, total_loss : 0.05794237907975912, time : 10 +iteration : 579, ce : 0.09687731992453337, total_loss : 0.048438659962266685, time : 10 +iteration : 599, ce : 0.10488986857235431, total_loss : 0.05244493428617716, time : 10 +iteration : 599, mIoU : 91.27345932141915, best_valid_mIoU : 91.4555279922262, time : 43 +iteration : 619, ce : 0.10749562252312898, total_loss : 0.05374781126156449, time : 53 +iteration : 639, ce : 0.12049341723322868, total_loss : 0.06024670861661434, time : 10 +iteration : 659, ce : 0.1019530326128006, total_loss : 0.0509765163064003, time : 10 +iteration : 679, ce : 0.09267976079136134, total_loss : 0.04633988039568067, time : 10 +iteration : 699, ce : 0.10727790277451277, total_loss : 0.053638951387256384, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 699, mIoU : 91.86473461538597, best_valid_mIoU : 91.86473461538597, time : 44 +iteration : 719, ce : 0.08669246602803468, total_loss : 0.04334623301401734, time : 54 +iteration : 739, ce : 0.11301723029464483, total_loss : 0.056508615147322416, time : 10 +iteration : 759, ce : 0.09895374476909638, total_loss : 0.04947687238454819, time : 10 +iteration : 779, ce : 0.10607226043939591, total_loss : 0.053036130219697955, time : 10 +iteration : 799, ce : 0.09932096153497696, total_loss : 0.04966048076748848, time : 10 +iteration : 799, mIoU : 90.76968357127359, best_valid_mIoU : 91.86473461538597, time : 44 +iteration : 819, ce : 0.09316363576799631, total_loss : 0.04658181788399816, time : 54 +iteration : 839, ce : 0.10023958738893271, total_loss : 0.050119793694466355, time : 10 +iteration : 859, ce : 0.09565037228167057, total_loss : 0.04782518614083529, time : 10 +iteration : 879, ce : 0.10756326355040073, total_loss : 0.053781631775200366, time : 10 +iteration : 19, ce : 1.4597653925418854, total_loss : 0.7298826962709427, time : 10 +iteration : 39, ce : 0.7552479837089777, total_loss : 0.37762399185448886, time : 10 +iteration : 59, ce : 0.5629064556211233, total_loss : 0.28145322781056165, time : 10 +iteration : 79, ce : 0.3730143416672945, total_loss : 0.18650717083364726, time : 10 +iteration : 99, ce : 0.2633991166949272, total_loss : 0.1316995583474636, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 99, mIoU : 45.60033837143256, best_valid_mIoU : 45.60033837143256, time : 41 +iteration : 119, ce : 0.21824998827651143, total_loss : 0.10912499413825572, time : 52 +iteration : 139, ce : 0.24961501657962798, total_loss : 0.12480750828981399, time : 10 +iteration : 159, ce : 0.24858376793563366, total_loss : 0.12429188396781683, time : 10 +iteration : 179, ce : 0.19755737725645303, total_loss : 0.09877868862822652, time : 10 +iteration : 199, ce : 0.1627231553196907, total_loss : 0.08136157765984535, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 199, mIoU : 68.10482110317365, best_valid_mIoU : 68.10482110317365, time : 42 +iteration : 219, ce : 0.11887449622154236, total_loss : 0.05943724811077118, time : 52 +iteration : 239, ce : 0.08812281191349029, total_loss : 0.044061405956745146, time : 10 +iteration : 259, ce : 0.08320211656391621, total_loss : 0.041601058281958106, time : 10 +iteration : 279, ce : 0.0784364627674222, total_loss : 0.0392182313837111, time : 10 +iteration : 299, ce : 0.08906380720436573, total_loss : 0.044531903602182864, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 299, mIoU : 85.81882754923308, best_valid_mIoU : 85.81882754923308, time : 42 +iteration : 319, ce : 0.06990350810810923, total_loss : 0.034951754054054616, time : 52 +iteration : 339, ce : 0.06432983251288533, total_loss : 0.03216491625644267, time : 10 +iteration : 359, ce : 0.07002460584044456, total_loss : 0.03501230292022228, time : 10 +iteration : 379, ce : 0.0986309826374054, total_loss : 0.0493154913187027, time : 10 +iteration : 399, ce : 0.10583090535365045, total_loss : 0.05291545267682522, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 399, mIoU : 88.4803214947598, best_valid_mIoU : 88.4803214947598, time : 42 +iteration : 419, ce : 0.06240550028160215, total_loss : 0.031202750140801074, time : 52 +iteration : 439, ce : 0.08079780722036958, total_loss : 0.04039890361018479, time : 10 +iteration : 459, ce : 0.05519880400970578, total_loss : 0.02759940200485289, time : 10 +iteration : 479, ce : 0.05341993579640984, total_loss : 0.02670996789820492, time : 10 +iteration : 499, ce : 0.06429818458855152, total_loss : 0.03214909229427576, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 499, mIoU : 89.647283352165, best_valid_mIoU : 89.647283352165, time : 42 +iteration : 519, ce : 0.05206382665783167, total_loss : 0.026031913328915836, time : 53 +iteration : 539, ce : 0.07656390639021993, total_loss : 0.03828195319510996, time : 10 +iteration : 559, ce : 0.06323499651625752, total_loss : 0.03161749825812876, time : 10 +iteration : 579, ce : 0.05692016114480793, total_loss : 0.028460080572403967, time : 10 +iteration : 599, ce : 0.06588180274702608, total_loss : 0.03294090137351304, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 599, mIoU : 89.88607999691865, best_valid_mIoU : 89.88607999691865, time : 42 +iteration : 619, ce : 0.0627711565233767, total_loss : 0.03138557826168835, time : 52 +iteration : 639, ce : 0.04458166812546551, total_loss : 0.022290834062732755, time : 10 +iteration : 659, ce : 0.05658222446218133, total_loss : 0.028291112231090664, time : 10 +iteration : 679, ce : 0.04089747462421656, total_loss : 0.02044873731210828, time : 10 +iteration : 699, ce : 0.07494942974299193, total_loss : 0.037474714871495965, time : 10 +iteration : 699, mIoU : 87.9944159611403, best_valid_mIoU : 89.88607999691865, time : 42 +iteration : 719, ce : 0.06946341348811984, total_loss : 0.03473170674405992, time : 52 +iteration : 739, ce : 0.09376809406094253, total_loss : 0.04688404703047126, time : 10 +iteration : 759, ce : 0.06281587863340973, total_loss : 0.031407939316704866, time : 10 +iteration : 779, ce : 0.049504976719617844, total_loss : 0.024752488359808922, time : 10 +iteration : 799, ce : 0.06230988763272762, total_loss : 0.03115494381636381, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 799, mIoU : 90.37318020387082, best_valid_mIoU : 90.37318020387082, time : 42 +iteration : 819, ce : 0.06486173206940293, total_loss : 0.03243086603470147, time : 53 +iteration : 839, ce : 0.05320575626101345, total_loss : 0.026602878130506723, time : 10 +iteration : 859, ce : 0.05594585915096104, total_loss : 0.02797292957548052, time : 10 +iteration : 879, ce : 0.04406192186288536, total_loss : 0.02203096093144268, time : 10 +iteration : 899, ce : 0.05902999769896269, total_loss : 0.029514998849481344, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 899, mIoU : 90.77366247407701, best_valid_mIoU : 90.77366247407701, time : 42 +iteration : 919, ce : 0.05046119377948344, total_loss : 0.02523059688974172, time : 53 +iteration : 939, ce : 0.055241983570158484, total_loss : 0.027620991785079242, time : 10 +iteration : 959, ce : 0.06541968509554863, total_loss : 0.032709842547774315, time : 10 +iteration : 979, ce : 0.056352639896795155, total_loss : 0.028176319948397578, time : 10 +iteration : 999, ce : 0.04117121635936201, total_loss : 0.020585608179681004, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 999, mIoU : 90.95911907670035, best_valid_mIoU : 90.95911907670035, time : 43 +iteration : 1019, ce : 0.04268322917632759, total_loss : 0.021341614588163794, time : 53 +iteration : 1039, ce : 0.07534589348360896, total_loss : 0.03767294674180448, time : 10 +iteration : 1059, ce : 0.06266294419765472, total_loss : 0.03133147209882736, time : 10 +iteration : 1079, ce : 0.040896324906498194, total_loss : 0.020448162453249097, time : 10 +iteration : 1099, ce : 0.05627818256616592, total_loss : 0.02813909128308296, time : 10 +iteration : 1099, mIoU : 90.65854459999014, best_valid_mIoU : 90.95911907670035, time : 42 +iteration : 1119, ce : 0.05832021026872099, total_loss : 0.029160105134360494, time : 52 +iteration : 1139, ce : 0.05653570280410349, total_loss : 0.028267851402051746, time : 10 +iteration : 1159, ce : 0.0540118848439306, total_loss : 0.0270059424219653, time : 10 +iteration : 1179, ce : 0.059156589978374544, total_loss : 0.029578294989187272, time : 10 +iteration : 1199, ce : 0.0586970953270793, total_loss : 0.02934854766353965, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 1199, mIoU : 91.30907206844861, best_valid_mIoU : 91.30907206844861, time : 42 +iteration : 1219, ce : 0.046731388312764466, total_loss : 0.023365694156382233, time : 52 +iteration : 1239, ce : 0.04605461307801306, total_loss : 0.02302730653900653, time : 10 +iteration : 1259, ce : 0.05333511717617512, total_loss : 0.02666755858808756, time : 10 +iteration : 1279, ce : 0.058591234497725964, total_loss : 0.029295617248862982, time : 10 +iteration : 1299, ce : 0.044012406887486574, total_loss : 0.022006203443743287, time : 10 +iteration : 1299, mIoU : 90.44798195565556, best_valid_mIoU : 91.30907206844861, time : 42 +iteration : 1319, ce : 0.03853592430241406, total_loss : 0.01926796215120703, time : 52 +iteration : 1339, ce : 0.04643560294061899, total_loss : 0.023217801470309496, time : 10 +iteration : 1359, ce : 0.05803217897191644, total_loss : 0.02901608948595822, time : 10 +iteration : 1379, ce : 0.06334102130495012, total_loss : 0.03167051065247506, time : 10 +iteration : 1399, ce : 0.08214310212060809, total_loss : 0.041071551060304044, time : 10 +iteration : 1399, mIoU : 90.18570528496528, best_valid_mIoU : 91.30907206844861, time : 42 +iteration : 1419, ce : 0.043989807507023214, total_loss : 0.021994903753511607, time : 52 +iteration : 1439, ce : 0.05312715098261833, total_loss : 0.026563575491309166, time : 10 +iteration : 1459, ce : 0.05344270861241966, total_loss : 0.02672135430620983, time : 10 +iteration : 1479, ce : 0.04879952352494001, total_loss : 0.024399761762470006, time : 10 +iteration : 1499, ce : 0.05729071167297661, total_loss : 0.028645355836488307, time : 10 +iteration : 1499, mIoU : 91.04520387661324, best_valid_mIoU : 91.30907206844861, time : 41 +iteration : 1519, ce : 0.03750903834588826, total_loss : 0.01875451917294413, time : 52 +iteration : 1539, ce : 0.04227787498384714, total_loss : 0.02113893749192357, time : 10 +iteration : 1559, ce : 0.043323819525539875, total_loss : 0.021661909762769938, time : 10 +iteration : 1579, ce : 0.039240577118471266, total_loss : 0.019620288559235633, time : 10 +iteration : 1599, ce : 0.05065902634523809, total_loss : 0.025329513172619045, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 1599, mIoU : 91.3858830174658, best_valid_mIoU : 91.3858830174658, time : 42 +iteration : 1619, ce : 0.042945701791904864, total_loss : 0.021472850895952432, time : 53 +iteration : 19, ce : 0.8456825375556946, total_loss : 0.4228412687778473, time : 10 +iteration : 39, ce : 0.4564362831413746, total_loss : 0.2282181415706873, time : 10 +iteration : 59, ce : 0.5016216538846493, total_loss : 0.25081082694232465, time : 10 +iteration : 79, ce : 0.1747898418456316, total_loss : 0.0873949209228158, time : 10 +iteration : 99, ce : 0.17478084242902697, total_loss : 0.08739042121451349, time : 10 +iteration : 19, ce : 0.8456787191331386, total_loss : 0.4228393595665693, time : 10 +iteration : 39, ce : 0.45643005296587946, total_loss : 0.22821502648293973, time : 10 +iteration : 59, ce : 0.5015822313725948, total_loss : 0.2507911156862974, time : 10 +iteration : 79, ce : 0.1747533490881324, total_loss : 0.0873766745440662, time : 10 +iteration : 99, ce : 0.17461899896152316, total_loss : 0.08730949948076158, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 99, mIoU : 70.67741422260869, best_valid_mIoU : 70.67741422260869, time : 48 +iteration : 119, ce : 0.17073935624212028, total_loss : 0.08536967812106014, time : 58 +iteration : 139, ce : 0.13068198719993235, total_loss : 0.06534099359996617, time : 10 +iteration : 159, ce : 0.08914234582334757, total_loss : 0.044571172911673784, time : 10 +iteration : 179, ce : 0.1461833517998457, total_loss : 0.07309167589992285, time : 10 +iteration : 199, ce : 0.11983967162668704, total_loss : 0.05991983581334352, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 199, mIoU : 71.61286606894889, best_valid_mIoU : 71.61286606894889, time : 48 +iteration : 219, ce : 0.11703812740743161, total_loss : 0.058519063703715804, time : 58 +iteration : 239, ce : 0.126152902841568, total_loss : 0.063076451420784, time : 10 +iteration : 259, ce : 0.11562294252216816, total_loss : 0.05781147126108408, time : 10 +iteration : 279, ce : 0.09897459410130978, total_loss : 0.04948729705065489, time : 10 +iteration : 299, ce : 0.12509905751794576, total_loss : 0.06254952875897288, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 299, mIoU : 80.60803678405016, best_valid_mIoU : 80.60803678405016, time : 49 +iteration : 319, ce : 0.09828957775607705, total_loss : 0.049144788878038526, time : 59 +iteration : 339, ce : 0.08943888759240508, total_loss : 0.04471944379620254, time : 10 +iteration : 359, ce : 0.06585264699533581, total_loss : 0.03292632349766791, time : 10 +iteration : 379, ce : 0.09908102322369813, total_loss : 0.04954051161184907, time : 10 +iteration : 399, ce : 0.06755148817319423, total_loss : 0.033775744086597115, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 399, mIoU : 83.12914567271814, best_valid_mIoU : 83.12914567271814, time : 48 +iteration : 419, ce : 0.07650328939780593, total_loss : 0.038251644698902965, time : 59 +iteration : 439, ce : 0.12737306039780377, total_loss : 0.06368653019890189, time : 10 +iteration : 459, ce : 0.09591937027871608, total_loss : 0.04795968513935804, time : 10 +iteration : 479, ce : 0.10537517564371228, total_loss : 0.05268758782185614, time : 10 +iteration : 499, ce : 0.08678049889858812, total_loss : 0.04339024944929406, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 499, mIoU : 86.71075383127464, best_valid_mIoU : 86.71075383127464, time : 48 +iteration : 519, ce : 0.065172965452075, total_loss : 0.0325864827260375, time : 58 +iteration : 539, ce : 0.07847554692998529, total_loss : 0.039237773464992645, time : 10 +iteration : 559, ce : 0.11086338181048631, total_loss : 0.05543169090524316, time : 10 +iteration : 579, ce : 0.11131466701626777, total_loss : 0.055657333508133885, time : 10 +iteration : 599, ce : 0.09892227221280336, total_loss : 0.04946113610640168, time : 10 +iteration : 599, mIoU : 83.43759111648548, best_valid_mIoU : 86.71075383127464, time : 49 +iteration : 619, ce : 0.060881080804392695, total_loss : 0.030440540402196348, time : 59 +iteration : 639, ce : 0.06826045289635659, total_loss : 0.03413022644817829, time : 10 +iteration : 659, ce : 0.08870259951800108, total_loss : 0.04435129975900054, time : 10 +iteration : 679, ce : 0.11652187979780138, total_loss : 0.05826093989890069, time : 10 +iteration : 699, ce : 0.07042193752713502, total_loss : 0.03521096876356751, time : 10 +iteration : 699, mIoU : 83.23095958793452, best_valid_mIoU : 86.71075383127464, time : 48 +iteration : 719, ce : 0.0663936838041991, total_loss : 0.03319684190209955, time : 58 +iteration : 739, ce : 0.06597791106905788, total_loss : 0.03298895553452894, time : 10 +iteration : 759, ce : 0.06343856947496533, total_loss : 0.031719284737482666, time : 10 +iteration : 779, ce : 0.09711240408942104, total_loss : 0.04855620204471052, time : 10 +iteration : 799, ce : 0.07680428037419915, total_loss : 0.03840214018709957, time : 10 +iteration : 799, mIoU : 81.43817219158842, best_valid_mIoU : 86.71075383127464, time : 48 +iteration : 819, ce : 0.07191853327676653, total_loss : 0.03595926663838327, time : 58 +iteration : 839, ce : 0.08352819001302123, total_loss : 0.04176409500651061, time : 10 +iteration : 859, ce : 0.07599039357155561, total_loss : 0.037995196785777806, time : 10 +iteration : 879, ce : 0.10239242473617197, total_loss : 0.05119621236808598, time : 10 +iteration : 899, ce : 0.07294631809927524, total_loss : 0.03647315904963762, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 899, mIoU : 88.7458099910793, best_valid_mIoU : 88.7458099910793, time : 48 +iteration : 919, ce : 0.0750368535052985, total_loss : 0.03751842675264925, time : 59 +iteration : 939, ce : 0.07469065655022859, total_loss : 0.037345328275114296, time : 10 +iteration : 959, ce : 0.09910964691080153, total_loss : 0.049554823455400764, time : 10 +iteration : 979, ce : 0.07515906286425889, total_loss : 0.03757953143212944, time : 10 +iteration : 999, ce : 0.04880384765565395, total_loss : 0.024401923827826976, time : 10 +iteration : 999, mIoU : 87.94771838428038, best_valid_mIoU : 88.7458099910793, time : 48 +iteration : 1019, ce : 0.0596143594942987, total_loss : 0.02980717974714935, time : 59 +iteration : 1039, ce : 0.07017137254588306, total_loss : 0.03508568627294153, time : 10 +iteration : 1059, ce : 0.04953117328695953, total_loss : 0.024765586643479765, time : 10 +iteration : 1079, ce : 0.06962326178327202, total_loss : 0.03481163089163601, time : 10 +iteration : 1099, ce : 0.1142594498116523, total_loss : 0.05712972490582615, time : 10 +saved model in ./experiment/model/semantic_sam/model.pth +iteration : 1099, mIoU : 90.2042178204663, best_valid_mIoU : 90.2042178204663, time : 49 +iteration : 1119, ce : 0.06606756038963794, total_loss : 0.03303378019481897, time : 59 +iteration : 1139, ce : 0.08854892421513796, total_loss : 0.04427446210756898, time : 10 +iteration : 1159, ce : 0.07104225680232049, total_loss : 0.03552112840116024, time : 10 +iteration : 1179, ce : 0.08063898030668497, total_loss : 0.040319490153342484, time : 10 +iteration : 1199, ce : 0.06300937542691827, total_loss : 0.031504687713459135, time : 10 +iteration : 1199, mIoU : 87.81258921438592, best_valid_mIoU : 90.2042178204663, time : 48 +iteration : 1219, ce : 0.06756734945811331, total_loss : 0.033783674729056655, time : 58 +iteration : 1239, ce : 0.08060045670717955, total_loss : 0.040300228353589776, time : 10 +iteration : 1259, ce : 0.06920733847655355, total_loss : 0.03460366923827678, time : 10 diff --git a/extend_sam/__init__.py b/extend_sam/__init__.py new file mode 100644 index 0000000..feaf8dd --- /dev/null +++ b/extend_sam/__init__.py @@ -0,0 +1,127 @@ +# copyright ziqi-jin +import torch +from .extend_sam import BaseExtendSam, SemanticSam +from .runner import BaseRunner, SemRunner +# from .optimizer import BaseOptimizer +from .scheduler import WarmupMultiStepLR +from .utils import get_opt_pamams + +AVAI_SCH = ["single_step", "multi_step", "warmup_multi_step", "cosine", "linear"] +AVAI_MODEL = {'base_sam': BaseExtendSam, 'sem_sam': SemanticSam} +# AVAI_OPT = {'base_opt': BaseOptimizer, 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam} +AVAI_OPT = {'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'adamw': torch.optim.AdamW} +AVAI_RUNNER = {'base_runner': BaseRunner, 'sem_runner': SemRunner} + + +def get_model(model_name, **kwargs): + if model_name not in AVAI_MODEL: + print('not supported model name, please implement it first.') + return AVAI_MODEL[model_name](**kwargs).cuda() + + +def get_optimizer(opt_name, **kwargs): + if opt_name not in AVAI_OPT: + print('not supported optimizer name, please implement it first.') + return AVAI_OPT[opt_name](**{k: v for k, v in kwargs.items() if v is not None}) + + +def get_runner(runner_name): + if runner_name not in AVAI_RUNNER: + print('not supported runner name, please implement it first.') + return AVAI_RUNNER[runner_name] + + +def get_scheduler( + optimizer, + lr_scheduler="single_step", + stepsize=1, + gamma=0.1, + warmup_factor=0.01, + warmup_steps=10, + max_epoch=1, + n_epochs_init=50, + n_epochs_decay=50, + +): + """A function wrapper for building a learning rate scheduler. + Args: + optimizer (Optimizer): an Optimizer. + lr_scheduler (str, optional): learning rate scheduler method. Default is + single_step. + stepsize (int or list, optional): step size to decay learning rate. + When ``lr_scheduler`` is "single_step", ``stepsize`` should be an integer. + When ``lr_scheduler`` is "multi_step", ``stepsize`` is a list. Default is 1. + gamma (float, optional): decay rate. Default is 0.1. + max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1. + Examples:: + >>> # Decay learning rate by every 20 epochs. + >>> scheduler = get_scheduler( + >>> optimizer, lr_scheduler='single_step', stepsize=20 + >>> ) + >>> # Decay learning rate at 30, 50 and 55 epochs. + >>> scheduler = get_scheduler( + >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55] + >>> ) + """ + if lr_scheduler not in AVAI_SCH: + raise ValueError( + "Unsupported scheduler: {}. Must be one of {}".format( + lr_scheduler, AVAI_SCH + ) + ) + + if lr_scheduler == "single_step": + if isinstance(stepsize, list): + stepsize = stepsize[-1] + + if not isinstance(stepsize, int): + raise TypeError( + "For single_step lr_scheduler, stepsize must " + "be an integer, but got {}".format(type(stepsize)) + ) + + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=stepsize, gamma=gamma + ) + + elif lr_scheduler == "multi_step": + if not isinstance(stepsize, list): + raise TypeError( + "For multi_step lr_scheduler, stepsize must " + "be a list, but got {}".format(type(stepsize)) + ) + + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=stepsize, gamma=gamma + ) + + elif lr_scheduler == "warmup_multi_step": + if not isinstance(stepsize, list): + raise TypeError( + "For warmup multi_step lr_scheduler, stepsize must " + "be a list, but got {}".format(type(stepsize)) + ) + + scheduler = WarmupMultiStepLR( + optimizer, + milestones=stepsize, + gamma=gamma, + warmup_factor=warmup_factor, + warmup_iters=warmup_steps, + ) + + elif lr_scheduler == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, int(max_epoch) + ) + + elif lr_scheduler == "linear": + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - n_epochs_init) / float(n_epochs_decay + 1) + return lr_l + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=lambda_rule + ) + + return scheduler diff --git a/extend_sam/extend_sam.py b/extend_sam/extend_sam.py new file mode 100644 index 0000000..9a4cfde --- /dev/null +++ b/extend_sam/extend_sam.py @@ -0,0 +1,48 @@ +# copyright ziqi-jin +import torch +import torch.nn as nn +from .segment_anything_ori import sam_model_registry +from .image_encoder_adapter import BaseImgEncodeAdapter +from .mask_decoder_adapter import BaseMaskDecoderAdapter, SemMaskDecoderAdapter +from .prompt_encoder_adapter import BasePromptEncodeAdapter + + +class BaseExtendSam(nn.Module): + + def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type='vit_b'): + super(BaseExtendSam, self).__init__() + assert model_type in ['default', 'vit_b', 'vit_l', 'vit_h'], print( + "Wrong model_type, SAM only can be built as vit_b, vot_l, vit_h and default ") + self.ori_sam = sam_model_registry[model_type](ckpt_path) + self.img_adapter = BaseImgEncodeAdapter(self.ori_sam, fix=fix_img_en) + self.prompt_adapter = BasePromptEncodeAdapter(self.ori_sam, fix=fix_prompt_en) + self.mask_adapter = BaseMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de) + + def forward(self, img): + x = self.img_adapter(img) + points = None + boxes = None + masks = None + + sparse_embeddings, dense_embeddings = self.prompt_adapter( + points=points, + boxes=boxes, + masks=masks, + ) + multimask_output = True + low_res_masks, iou_predictions = self.mask_adapter( + image_embeddings=x, + prompt_adapter=self.prompt_adapter, + sparse_embeddings=sparse_embeddings, + dense_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + return low_res_masks, iou_predictions + + +class SemanticSam(BaseExtendSam): + + def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, class_num=20, model_type='vit_b'): + super().__init__(ckpt_path=ckpt_path, fix_img_en=fix_img_en, fix_prompt_en=fix_prompt_en, + fix_mask_de=fix_mask_de, model_type=model_type) + self.mask_adapter = SemMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de, class_num=class_num) diff --git a/extend_sam/image_encoder_adapter.py b/extend_sam/image_encoder_adapter.py new file mode 100644 index 0000000..5f4e90b --- /dev/null +++ b/extend_sam/image_encoder_adapter.py @@ -0,0 +1,16 @@ +import torch.nn as nn +from .segment_anything_ori.modeling.sam import Sam +from .utils import fix_params + + +class BaseImgEncodeAdapter(nn.Module): + + def __init__(self, ori_sam: Sam, fix=False): + super(BaseImgEncodeAdapter, self).__init__() + self.sam_img_encoder = ori_sam.image_encoder + if fix: + fix_params(self.sam_img_encoder) + + def forward(self, x): + x = self.sam_img_encoder(x) + return x diff --git a/extend_sam/mask_decoder_adapter.py b/extend_sam/mask_decoder_adapter.py new file mode 100644 index 0000000..add328c --- /dev/null +++ b/extend_sam/mask_decoder_adapter.py @@ -0,0 +1,97 @@ +# @copyright ziqi-jin + +import torch.nn as nn +import torch +from .segment_anything_ori.modeling.sam import Sam +from .utils import fix_params +from .segment_anything_ori.modeling.mask_decoder import MaskDecoder +from typing import List, Tuple +from torch.nn import functional as F +from .mask_decoder_heads import SemSegHead +from .mask_decoder_neck import MaskDecoderNeck + + +class BaseMaskDecoderAdapter(MaskDecoder): + ''' + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + ''' + + # is fix and load params + def __init__(self, ori_sam: Sam, fix=False): + super(BaseMaskDecoderAdapter, self).__init__(transformer_dim=ori_sam.mask_decoder.transformer_dim, + transformer=ori_sam.mask_decoder.transformer) + self.sam_mask_decoder = ori_sam.mask_decoder + if fix: + fix_params(self.sam_mask_decoder) # move to runner to implement + + def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True): + low_res_masks, iou_predictions = self.sam_mask_decoder(image_embeddings=image_embeddings, + image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, ) + return low_res_masks, iou_predictions + + +class SemMaskDecoderAdapter(BaseMaskDecoderAdapter): + def __init__(self, ori_sam: Sam, fix=False, class_num=20): + super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix) + self.decoder_neck = MaskDecoderNeck(transformer_dim=self.sam_mask_decoder.transformer_dim, + transformer=self.sam_mask_decoder.transformer, + num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs) + self.decoder_head = SemSegHead(transformer_dim=self.sam_mask_decoder.transformer_dim, + num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs, + iou_head_depth=self.sam_mask_decoder.iou_head_depth, + iou_head_hidden_dim=self.sam_mask_decoder.iou_head_hidden_dim, + class_num=class_num) + # pair the params between ori mask_decoder and new mask_decoder_adapter + self.pair_params(self.decoder_neck) + self.pair_params(self.decoder_head) + + def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True, + scale=1): + src, iou_token_out, mask_tokens_out, src_shape = self.decoder_neck(image_embeddings=image_embeddings, + image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, ) + masks, iou_pred = self.decoder_head(src, iou_token_out, mask_tokens_out, src_shape, mask_scale=scale) + return masks, iou_pred + + def pair_params(self, target_model: nn.Module): + src_dict = self.sam_mask_decoder.state_dict() + for name, value in target_model.named_parameters(): + if name in src_dict.keys(): + value.data.copy_(src_dict[name].data) + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/extend_sam/mask_decoder_heads.py b/extend_sam/mask_decoder_heads.py new file mode 100644 index 0000000..15b6e94 --- /dev/null +++ b/extend_sam/mask_decoder_heads.py @@ -0,0 +1,228 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .segment_anything_ori.modeling.common import LayerNorm2d + + +class OriHead(nn.Module): + + def __init__( + self, + *, + transformer_dim: int, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + + self.num_multimask_outputs = num_multimask_outputs + + self.num_mask_tokens = num_multimask_outputs + 1 + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + src: torch.Tensor, + iou_token_out: torch.Tensor, + mask_tokens_out: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + b, c, h, w = src.shape + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + +class SemSegHead(nn.Module): + + def __init__( + self, + *, + transformer_dim: int, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + class_num: int = 20, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.num_multimask_outputs = num_multimask_outputs + self.num_mask_tokens = num_multimask_outputs + 1 + self.class_num = class_num + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for _ in range(self.class_num) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + src: torch.Tensor, + iou_token_out: torch.Tensor, + mask_tokens_out: torch.Tensor, + src_shape, + mask_scale=1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + src (torch.Tensor): The tensor contains image embedding and sparse prompt embedding + iou_token_out (torch.Tensor): Tokens of iou prediction from neck module + mask_tokens_out (torch.Tensor): Tokens of mask prediction form neck module + mask_scale (int): Original SAM output 3 masks which is from local to global as default + This Class use one of three mask tokens to transform it into class-ware + semantic segmentation prediction + + Returns: + torch.Tensor: batched predicted semantic masks + torch.Tensor: batched predictions of mask quality + """ + b, c, h, w = src_shape + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.class_num): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # B N H W, N is num of category + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) # B N H W, N is num of category + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/extend_sam/mask_decoder_neck.py b/extend_sam/mask_decoder_neck.py new file mode 100644 index 0000000..d0ff9ba --- /dev/null +++ b/extend_sam/mask_decoder_neck.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type +from .segment_anything_ori.modeling.common import LayerNorm2d + +''' +This file save the mask_decoder's neck class, +which is the former part of original mask decoder of SAM. +Then the mask_decoder_heads can be used with the neck. +''' + + +class MaskDecoderNeck(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: The tensor contains image embedding and sparse prompt embedding + torch.Tensor: Tokens of iou prediction + torch.Tensor: Tokens of mask prediction + """ + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + src_shape = src.shape + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :] + + return src, iou_token_out, mask_tokens_out, src_shape diff --git a/extend_sam/prompt_encoder_adapter.py b/extend_sam/prompt_encoder_adapter.py new file mode 100644 index 0000000..e92f7b3 --- /dev/null +++ b/extend_sam/prompt_encoder_adapter.py @@ -0,0 +1,19 @@ +# copyright ziqi-jin + +import torch.nn as nn +from .segment_anything_ori.modeling.sam import Sam +from .utils import fix_params + + +class BasePromptEncodeAdapter(nn.Module): + + def __init__(self, ori_sam: Sam, fix=False): + super(BasePromptEncodeAdapter, self).__init__() + + self.sam_prompt_encoder = ori_sam.prompt_encoder + if fix: + fix_params(self.sam_prompt_encoder) + + def forward(self, points=None, boxes=None, masks=None): + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(points, boxes, masks) + return sparse_embeddings, dense_embeddings diff --git a/extend_sam/runner.py b/extend_sam/runner.py new file mode 100644 index 0000000..e41c725 --- /dev/null +++ b/extend_sam/runner.py @@ -0,0 +1,227 @@ +from datasets import Iterator +from .utils import Average_Meter, Timer, print_and_save_log, mIoUOnline, get_numpy_from_tensor, save_model, write_log, \ + check_folder, one_hot_embedding_3d +import torch +import cv2 +import torch.nn.functional as F +import os +import torch.nn as nn +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from matplotlib.colors import ListedColormap, BoundaryNorm +import numpy as np + +class BaseRunner(): + def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): + self.optimizer = optimizer + self.losses = losses + self.train_loader = train_loader + self.val_loader = val_loader + self.model = model + self.scheduler = scheduler + self.train_timer = Timer() + self.eval_timer = Timer() + try: + use_gpu = os.environ['CUDA_VISIBLE_DEVICES'] + except KeyError: + use_gpu = '0' + self.the_number_of_gpu = len(use_gpu.split(',')) + self.original_size = self.model.img_adapter.sam_img_encoder.img_size + if self.the_number_of_gpu > 1: + self.model = nn.DataParallel(self.model) + + +class SemRunner(BaseRunner): + # def __init__(self, **kwargs): + # super().__init__(kwargs) + + def __init__(self, model, optimizer, losses, train_loader, val_loader, scheduler): + super().__init__(model, optimizer, losses, train_loader, val_loader, scheduler) + self.exist_status = ['train', 'eval', 'test'] + + def train(self, cfg): + # initial identify + train_meter = Average_Meter(list(self.losses.keys()) + ['total_loss']) + train_iterator = Iterator(self.train_loader) + best_valid_mIoU = -1 + model_path = "{cfg.model_folder}/{cfg.experiment_name}/model.pth".format(cfg=cfg) + log_path = "{cfg.log_folder}/{cfg.experiment_name}/log_file.txt".format(cfg=cfg) + check_folder(model_path) + check_folder(log_path) + writer = None + if cfg.use_tensorboard is True: + tensorboard_dir = "{cfg.tensorboard_folder}/{cfg.experiment_name}/tensorboard/".format(cfg=cfg) + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter(tensorboard_dir) + # train + train_losses = [] + for iteration in range(cfg.max_iter): + images, labels,_ = train_iterator.get() + images, labels = images.cuda(), labels.cuda().long() + labels = labels.squeeze(1) + masks_pred, iou_pred = self.model(images) + masks_pred = F.interpolate(masks_pred, self.original_size, mode="bilinear", align_corners=False) + + total_loss = torch.zeros(1).cuda() + loss_dict = {} + self._compute_loss(total_loss, loss_dict, masks_pred, labels, cfg) + self.optimizer.zero_grad() + total_loss.backward() + self.optimizer.step() + self.scheduler.step() + loss_dict['total_loss'] = total_loss.item() + train_losses.append(total_loss.item()) + train_meter.add(loss_dict) + + # log + if (iteration + 1) % cfg.log_iter == 0: + write_log(iteration=iteration, log_path=log_path, log_data=train_meter.get(clear=True), + status=self.exist_status[0], + writer=writer, timer=self.train_timer) + # eval + if (iteration + 1) % cfg.eval_iter == 0: + mIoU, _ = self._eval() + if best_valid_mIoU == -1 or best_valid_mIoU < mIoU: + best_valid_mIoU = mIoU + save_model(self.model, model_path, parallel=self.the_number_of_gpu > 1) + print_and_save_log("saved model in {model_path}".format(model_path=model_path), path=log_path) + log_data = {'mIoU': mIoU, 'best_valid_mIoU': best_valid_mIoU} + write_log(iteration=iteration, log_path=log_path, log_data=log_data, status=self.exist_status[1], + writer=writer, timer=self.eval_timer) + + plt.figure(figsize=(8, 5)) + plt.plot(train_losses, label="Train Loss") + plt.xlabel("Iteration") + plt.ylabel("Loss") + plt.title(f"Loss Curve up to Iter {iteration}") + plt.legend() + # 保存 + save_path = os.path.join('./', f"loss_iter_{iteration}.png") + plt.savefig(save_path, dpi=150) + plt.close() # 关闭当前 figure,释放内存 + # final process + save_model(self.model, model_path, is_final=True, parallel=self.the_number_of_gpu > 1) + if writer is not None: + writer.close() + + def test(self): + pass + + def _eval(self): + self.model.eval() + self.eval_timer.start() + class_names = self.val_loader.dataset.class_names + eval_metric = mIoUOnline(class_names=class_names) + with torch.no_grad(): + for index, (images, labels, img_paths) in enumerate(self.val_loader): + images = images.cuda() + labels = labels.cuda() + masks_pred, iou_pred = self.model(images) + predictions = torch.argmax(masks_pred, dim=1) + for batch_index in range(images.size()[0]): + pred_mask = get_numpy_from_tensor(predictions[batch_index]) + gt_mask = get_numpy_from_tensor(labels[batch_index].squeeze(0)) + h, w = pred_mask.shape + gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST) + + src_img = cv2.imread(img_paths[batch_index]) + # self.visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index])) + self.construct_visualize_segmentation(src_img,pred_mask,gt_mask,os.path.basename(img_paths[batch_index])) + eval_metric.add(pred_mask, gt_mask) + self.model.train() + return eval_metric.get(clear=True) + + def _compute_loss(self, total_loss, loss_dict, mask_pred, labels, cfg): + """ + Due to the inputs of losses are different, so if you want to add new losses, + you may need to modify the process in this function + """ + loss_cfg = cfg.losses + for index, item in enumerate(self.losses.items()): + # item -> (key: loss_name, val: loss) + real_labels = labels + if loss_cfg[item[0]].label_one_hot: + class_num = cfg.model.params.class_num + real_labels = one_hot_embedding_3d(real_labels, class_num=class_num) + tmp_loss = item[1](mask_pred, real_labels) + loss_dict[item[0]] = tmp_loss.item() + total_loss += loss_cfg[item[0]].weight * tmp_loss + + + def visualize_segmentation(self,image, pred_mask, gt_mask, save_path=None): + # 如果图像是 (C, H, W),需要先变成 (H, W, C) + if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3: + image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C) + + # 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色 + cmap = ListedColormap(["black", "green"]) + # 对应 0 和 1 两种类别,分界点给 [0,1,2] + norm = BoundaryNorm([0, 1, 2], cmap.N) + + # 创建图像 + fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + + axes[0].imshow(image.astype(np.uint8)) + axes[0].set_title("Original Image") + axes[0].axis("off") + + im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm) + axes[1].set_title("Predicted Mask") + axes[1].axis("off") + + im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm) + axes[2].set_title("Ground Truth Mask") + axes[2].axis("off") + + legend_patches = [ + mpatches.Patch(color="black", label="Background (0)"), + mpatches.Patch(color="green", label="lettuce (1)"), + ] + axes[3].legend(handles=legend_patches, loc='center', fontsize=10) + axes[3].set_title("Classes Legend") + axes[3].axis("off") + + # 调整布局 + plt.tight_layout() + plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight') + plt.close() + + + def construct_visualize_segmentation(self, image, pred_mask, gt_mask, save_path=None): + # 如果图像是 (C, H, W),需要先变成 (H, W, C) + if image.ndim == 3 and image.shape[0] == 3 and image.shape[1] != 3: + image = np.transpose(image, (1, 2, 0)) # (C, H, W) -> (H, W, C) + + # 定义用于显示 segmentation 的离散颜色映射:0=黑色, 1=绿色, 2=红色 + cmap = ListedColormap(["black", "green", "red"]) + # 对应 0,1,2 三种类别,分界点给 [0,1,2,3] + norm = BoundaryNorm([0, 1, 2, 3], cmap.N) + + # 创建图像 + fig, axes = plt.subplots(1, 4, figsize=(16, 4)) + + axes[0].imshow(image.astype(np.uint8)) + axes[0].set_title("Original Image") + axes[0].axis("off") + + im_pred = axes[1].imshow(pred_mask, cmap=cmap, norm=norm) + axes[1].set_title("Predicted Mask") + axes[1].axis("off") + + im_gt = axes[2].imshow(gt_mask, cmap=cmap, norm=norm) + axes[2].set_title("Ground Truth Mask") + axes[2].axis("off") + + legend_patches = [ + mpatches.Patch(color="black", label="Background (0)"), + mpatches.Patch(color="green", label="lettuce (1)"), + mpatches.Patch(color="red", label="weed (2)"), + ] + axes[3].legend(handles=legend_patches, loc='center', fontsize=10) + axes[3].set_title("Classes Legend") + axes[3].axis("off") + + # 调整布局 + plt.tight_layout() + plt.savefig(os.path.join('./outputs',save_path), dpi=200, bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/extend_sam/scheduler.py b/extend_sam/scheduler.py new file mode 100644 index 0000000..1b98308 --- /dev/null +++ b/extend_sam/scheduler.py @@ -0,0 +1,75 @@ +# Modified from https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/optim/lr_scheduler.py # noqa +# and https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/solver/lr_scheduler.py + +from bisect import bisect_right +from typing import List + +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupMultiStepLR(_LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + **kwargs, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +def _get_warmup_factor_at_iter( + method: str, iter: int, warmup_iters: int, warmup_factor: float +) -> float: + """ + Return the learning rate warmup factor at a specific iteration. + See https://arxiv.org/abs/1706.02677 for more details. + Args: + method (str): warmup method; either "constant" or "linear". + iter (int): iteration at which to calculate the warmup factor. + warmup_iters (int): the number of warmup iterations. + warmup_factor (float): the base warmup factor (the meaning changes according + to the method used). + Returns: + float: the effective warmup factor at the given iteration. + """ + if iter >= warmup_iters: + return 1.0 + + if method == "constant": + return warmup_factor + elif method == "linear": + alpha = iter / warmup_iters + return warmup_factor * (1 - alpha) + alpha + else: + raise ValueError("Unknown warmup method: {}".format(method)) diff --git a/extend_sam/segment_anything_ori/__init__.py b/extend_sam/segment_anything_ori/__init__.py new file mode 100644 index 0000000..9f34225 --- /dev/null +++ b/extend_sam/segment_anything_ori/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# modified by ziqi-jin + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .modeling.sam import Sam +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/extend_sam/segment_anything_ori/automatic_mask_generator.py b/extend_sam/segment_anything_ori/automatic_mask_generator.py new file mode 100644 index 0000000..2326497 --- /dev/null +++ b/extend_sam/segment_anything_ori/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crops_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crops_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros(len(data["boxes"])), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros(len(data["boxes"])), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros(len(boxes)), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/extend_sam/segment_anything_ori/build_sam.py b/extend_sam/segment_anything_ori/build_sam.py new file mode 100644 index 0000000..fead7d8 --- /dev/null +++ b/extend_sam/segment_anything_ori/build_sam.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# modified by ziqi-jin + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/extend_sam/segment_anything_ori/modeling/__init__.py b/extend_sam/segment_anything_ori/modeling/__init__.py new file mode 100644 index 0000000..38e9062 --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/extend_sam/segment_anything_ori/modeling/common.py b/extend_sam/segment_anything_ori/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/extend_sam/segment_anything_ori/modeling/image_encoder.py b/extend_sam/segment_anything_ori/modeling/image_encoder.py new file mode 100644 index 0000000..a6ad9ad --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/extend_sam/segment_anything_ori/modeling/mask_decoder.py b/extend_sam/segment_anything_ori/modeling/mask_decoder.py new file mode 100644 index 0000000..674c7cf --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/mask_decoder.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + tranformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/extend_sam/segment_anything_ori/modeling/prompt_encoder.py b/extend_sam/segment_anything_ori/modeling/prompt_encoder.py new file mode 100644 index 0000000..c3143f4 --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/extend_sam/segment_anything_ori/modeling/sam.py b/extend_sam/segment_anything_ori/modeling/sam.py new file mode 100644 index 0000000..eba0ec8 --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/sam.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# modified by ziqi-jin + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input promts, + C is determiend by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/extend_sam/segment_anything_ori/modeling/transformer.py b/extend_sam/segment_anything_ori/modeling/transformer.py new file mode 100644 index 0000000..f1a2812 --- /dev/null +++ b/extend_sam/segment_anything_ori/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attenion layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/extend_sam/segment_anything_ori/predictor.py b/extend_sam/segment_anything_ori/predictor.py new file mode 100644 index 0000000..217d060 --- /dev/null +++ b/extend_sam/segment_anything_ori/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from extend_sam.segment_anything_ori.modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks = masks[0].detach().cpu().numpy() + iou_predictions = iou_predictions[0].detach().cpu().numpy() + low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/extend_sam/segment_anything_ori/utils/__init__.py b/extend_sam/segment_anything_ori/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/extend_sam/segment_anything_ori/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/extend_sam/segment_anything_ori/utils/amg.py b/extend_sam/segment_anything_ori/utils/amg.py new file mode 100644 index 0000000..3a13777 --- /dev/null +++ b/extend_sam/segment_anything_ori/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecesary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/extend_sam/segment_anything_ori/utils/onnx.py b/extend_sam/segment_anything_ori/utils/onnx.py new file mode 100644 index 0000000..493950a --- /dev/null +++ b/extend_sam/segment_anything_ori/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/extend_sam/segment_anything_ori/utils/transforms.py b/extend_sam/segment_anything_ori/utils/transforms.py new file mode 100644 index 0000000..3ad3466 --- /dev/null +++ b/extend_sam/segment_anything_ori/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/extend_sam/utils.py b/extend_sam/utils.py new file mode 100644 index 0000000..a3f416d --- /dev/null +++ b/extend_sam/utils.py @@ -0,0 +1,234 @@ +''' +@copyright ziqi-jin +''' +import time +import numpy as np +import torch +import torch.nn.functional as F +import os.path as osp +import os + + +def fix_params(model): + for name, param in model.named_parameters(): + param.requires_grad = False + + +def load_params(model, params): + pass + + +def get_opt_pamams(model, lr_list, group_keys, wd_list): + ''' + + :param model: model + :param lr_list: list, contain the lr for each params group + :param wd_list: list, contain the weight decay for each params group + :param group_keys: list of list, according to the sub list to divide params to different groups + :return: list of dict + ''' + assert len(lr_list) == len(group_keys), "lr_list should has the same length as group_keys" + assert len(lr_list) == len(wd_list), "lr_list should has the same length as wd_list" + params_group = [[] for _ in range(len(lr_list))] + for name, value in model.named_parameters(): + for index, g_keys in enumerate(group_keys): + for g_key in g_keys: + if g_key in name: + params_group[index].append(value) + return [{'params': params_group[i], 'lr': lr_list[i], 'weight_decay': wd_list[i]} for i in range(len(lr_list))] + + +class Timer: + + def __init__(self): + self.start_time = 0.0 + self.end_time = 0.0 + + self.start() + + def start(self): + self.start_time = time.time() + + def end(self, ms=False, clear=False): + self.end_time = time.time() + + if ms: + duration = int((self.end_time - self.start_time) * 1000) + else: + duration = int(self.end_time - self.start_time) + + if clear: + self.start() + + return duration + + +class Average_Meter: + def __init__(self, keys): + self.keys = keys + self.clear() + + def add(self, dic): + for key, value in dic.items(): + self.data_dic[key].append(value) + + def get(self, keys=None, clear=False): + if keys is None: + keys = self.keys + + dataset = {} + for key in keys: + dataset[key] = float(np.mean(self.data_dic[key])) + + if clear: + self.clear() + + return dataset + + def clear(self): + self.data_dic = {key: [] for key in self.keys} + + +def print_and_save_log(message, path): + print(message) + + with open(path, 'a+') as f: + f.write(message + '\n') + + +class mIoUOnline: + def __init__(self, class_names): + self.class_names = class_names + self.class_num = len(self.class_names) + + self.clear() + + def get_data(self, pred_mask, gt_mask): + obj_mask = gt_mask < 255 + correct_mask = (pred_mask == gt_mask) * obj_mask + + P_list, T_list, TP_list = [], [], [] + for i in range(self.class_num): + P_list.append(np.sum((pred_mask == i) * obj_mask)) + T_list.append(np.sum((gt_mask == i) * obj_mask)) + TP_list.append(np.sum((gt_mask == i) * correct_mask)) + + return (P_list, T_list, TP_list) + + def add_using_data(self, data): + P_list, T_list, TP_list = data + for i in range(self.class_num): + self.P[i] += P_list[i] + self.T[i] += T_list[i] + self.TP[i] += TP_list[i] + + def add(self, pred_mask, gt_mask): + obj_mask = gt_mask < 255 + correct_mask = (pred_mask == gt_mask) * obj_mask + + for i in range(self.class_num): + self.P[i] += np.sum((pred_mask == i) * obj_mask) + self.T[i] += np.sum((gt_mask == i) * obj_mask) + self.TP[i] += np.sum((gt_mask == i) * correct_mask) + + def get(self, detail=False, clear=True): + IoU_dic = {} + IoU_list = [] + + FP_list = [] # over activation + FN_list = [] # under activation + + for i in range(self.class_num): + IoU = self.TP[i] / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) * 100 + FP = (self.P[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) + FN = (self.T[i] - self.TP[i]) / (self.T[i] + self.P[i] - self.TP[i] + 1e-10) + + IoU_dic[self.class_names[i]] = IoU + + IoU_list.append(IoU) + FP_list.append(FP) + FN_list.append(FN) + + mIoU = np.mean(np.asarray(IoU_list)) + mIoU_foreground = np.mean(np.asarray(IoU_list)[1:]) + + FP = np.mean(np.asarray(FP_list)) + FN = np.mean(np.asarray(FN_list)) + + if clear: + self.clear() + + if detail: + return mIoU, mIoU_foreground, IoU_dic, FP, FN + else: + return mIoU, mIoU_foreground + + def clear(self): + self.TP = [] + self.P = [] + self.T = [] + + for _ in range(self.class_num): + self.TP.append(0) + self.P.append(0) + self.T.append(0) + + +def get_numpy_from_tensor(tensor): + return tensor.cpu().detach().numpy() + + +def save_model(model, model_path, parallel=False, is_final=False): + if is_final: + model_path_split = model_path.split('.') + model_path = model_path_split[0] + "_final.pth" + if parallel: + torch.save(model.module.state_dict(), model_path) + else: + torch.save(model.state_dict(), model_path) + + +def write_log(iteration, log_path, log_data, status, writer, timer): + log_data['iteration'] = iteration + log_data['time'] = timer.end(clear=True) + message = "iteration : {val}, ".format(val=log_data['iteration']) + for key, value in log_data.items(): + if key == 'iteration': + continue + message += "{key} : {val}, ".format(key=key, val=value) + message = message[:-2] # + '\n' + print_and_save_log(message, log_path) + # visualize + if writer is not None: + for key, value in log_data.items(): + writer.add_scalar("{status}/{key}".format(status=status, key=key), value, iteration) + + +def check_folder(file_path, is_folder=False): + ''' + + :param file_path: the path of file, default input is a complete file name with dir path. + :param is_folder: if the input is a dir, not a file_name, is_folder should be True + :return: no return, this function will check and judge whether need to make dirs. + ''' + if is_folder: + if not osp.exists(is_folder): + os.makedirs(file_path) + + else: + splits = file_path.split("/") + folder_name = "/".join(splits[:-1]) + if not osp.exists(folder_name): + os.makedirs(folder_name) + + +def one_hot_embedding_3d(labels, class_num=21): + ''' + + :param real_labels: B H W + :param class_num: N + :return: B N H W + ''' + one_hot_labels = labels.clone() + one_hot_labels[one_hot_labels == 255] = 0 # 0 is background + return F.one_hot(one_hot_labels, num_classes=class_num).permute(0, 3, 1, 2).contiguous().float() diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100644 index 0000000..cbf9ccf --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1,16 @@ +import torch.nn as nn +from .losses import CustormLoss + +AVAI_LOSS = {'ce': nn.CrossEntropyLoss, 'multi_label_soft_margin': nn.MultiLabelSoftMarginLoss, + 'test_custom': CustormLoss, 'mse': nn.MSELoss} + + +def get_losses(losses): + loss_dict = {} + for name in losses: + assert name in AVAI_LOSS, print('{name} is not supported, please implement it first.'.format(name=name)) + if losses[name].params is not None: + loss_dict[name] = AVAI_LOSS[name](**losses[name].params) + else: + loss_dict[name] = AVAI_LOSS[name]() + return loss_dict diff --git a/losses/losses.py b/losses/losses.py new file mode 100644 index 0000000..38e35cf --- /dev/null +++ b/losses/losses.py @@ -0,0 +1,14 @@ +''' +@copyright ziqi-jin +You can create custom loss function in this file, then import the created loss in ./__init__.py and add the loss into AVAI_LOSS +''' +import torch.nn as nn + + +# example +class CustormLoss(nn.Module): + def __init__(self): + pass + + def forward(self, x, y): + pass \ No newline at end of file