laser_weeding/extend_sam/extend_sam.py

49 lines
2.0 KiB
Python

# copyright ziqi-jin
import torch
import torch.nn as nn
from .segment_anything_ori import sam_model_registry
from .image_encoder_adapter import BaseImgEncodeAdapter
from .mask_decoder_adapter import BaseMaskDecoderAdapter, SemMaskDecoderAdapter
from .prompt_encoder_adapter import BasePromptEncodeAdapter
class BaseExtendSam(nn.Module):
def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type='vit_b'):
super(BaseExtendSam, self).__init__()
assert model_type in ['default', 'vit_b', 'vit_l', 'vit_h'], print(
"Wrong model_type, SAM only can be built as vit_b, vot_l, vit_h and default ")
self.ori_sam = sam_model_registry[model_type](ckpt_path)
self.img_adapter = BaseImgEncodeAdapter(self.ori_sam, fix=fix_img_en)
self.prompt_adapter = BasePromptEncodeAdapter(self.ori_sam, fix=fix_prompt_en)
self.mask_adapter = BaseMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de)
def forward(self, img):
x = self.img_adapter(img)
points = None
boxes = None
masks = None
sparse_embeddings, dense_embeddings = self.prompt_adapter(
points=points,
boxes=boxes,
masks=masks,
)
multimask_output = True
low_res_masks, iou_predictions = self.mask_adapter(
image_embeddings=x,
prompt_adapter=self.prompt_adapter,
sparse_embeddings=sparse_embeddings,
dense_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
return low_res_masks, iou_predictions
class SemanticSam(BaseExtendSam):
def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, class_num=20, model_type='vit_b'):
super().__init__(ckpt_path=ckpt_path, fix_img_en=fix_img_en, fix_prompt_en=fix_prompt_en,
fix_mask_de=fix_mask_de, model_type=model_type)
self.mask_adapter = SemMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de, class_num=class_num)