generated from yuanbiao/python_templates
49 lines
2.0 KiB
Python
49 lines
2.0 KiB
Python
# copyright ziqi-jin
|
|
import torch
|
|
import torch.nn as nn
|
|
from .segment_anything_ori import sam_model_registry
|
|
from .image_encoder_adapter import BaseImgEncodeAdapter
|
|
from .mask_decoder_adapter import BaseMaskDecoderAdapter, SemMaskDecoderAdapter
|
|
from .prompt_encoder_adapter import BasePromptEncodeAdapter
|
|
|
|
|
|
class BaseExtendSam(nn.Module):
|
|
|
|
def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type='vit_b'):
|
|
super(BaseExtendSam, self).__init__()
|
|
assert model_type in ['default', 'vit_b', 'vit_l', 'vit_h'], print(
|
|
"Wrong model_type, SAM only can be built as vit_b, vot_l, vit_h and default ")
|
|
self.ori_sam = sam_model_registry[model_type](ckpt_path)
|
|
self.img_adapter = BaseImgEncodeAdapter(self.ori_sam, fix=fix_img_en)
|
|
self.prompt_adapter = BasePromptEncodeAdapter(self.ori_sam, fix=fix_prompt_en)
|
|
self.mask_adapter = BaseMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de)
|
|
|
|
def forward(self, img):
|
|
x = self.img_adapter(img)
|
|
points = None
|
|
boxes = None
|
|
masks = None
|
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_adapter(
|
|
points=points,
|
|
boxes=boxes,
|
|
masks=masks,
|
|
)
|
|
multimask_output = True
|
|
low_res_masks, iou_predictions = self.mask_adapter(
|
|
image_embeddings=x,
|
|
prompt_adapter=self.prompt_adapter,
|
|
sparse_embeddings=sparse_embeddings,
|
|
dense_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output,
|
|
)
|
|
return low_res_masks, iou_predictions
|
|
|
|
|
|
class SemanticSam(BaseExtendSam):
|
|
|
|
def __init__(self, ckpt_path=None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, class_num=20, model_type='vit_b'):
|
|
super().__init__(ckpt_path=ckpt_path, fix_img_en=fix_img_en, fix_prompt_en=fix_prompt_en,
|
|
fix_mask_de=fix_mask_de, model_type=model_type)
|
|
self.mask_adapter = SemMaskDecoderAdapter(self.ori_sam, fix=fix_mask_de, class_num=class_num)
|