generated from yuanbiao/python_templates
98 lines
5.0 KiB
Python
98 lines
5.0 KiB
Python
# @copyright ziqi-jin
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
from .segment_anything_ori.modeling.sam import Sam
|
|
from .utils import fix_params
|
|
from .segment_anything_ori.modeling.mask_decoder import MaskDecoder
|
|
from typing import List, Tuple
|
|
from torch.nn import functional as F
|
|
from .mask_decoder_heads import SemSegHead
|
|
from .mask_decoder_neck import MaskDecoderNeck
|
|
|
|
|
|
class BaseMaskDecoderAdapter(MaskDecoder):
|
|
'''
|
|
multimask_output (bool): If true, the model will return three masks.
|
|
For ambiguous input prompts (such as a single click), this will often
|
|
produce better masks than a single prediction. If only a single
|
|
mask is needed, the model's predicted quality score can be used
|
|
to select the best mask. For non-ambiguous prompts, such as multiple
|
|
input prompts, multimask_output=False can give better results.
|
|
'''
|
|
|
|
# is fix and load params
|
|
def __init__(self, ori_sam: Sam, fix=False):
|
|
super(BaseMaskDecoderAdapter, self).__init__(transformer_dim=ori_sam.mask_decoder.transformer_dim,
|
|
transformer=ori_sam.mask_decoder.transformer)
|
|
self.sam_mask_decoder = ori_sam.mask_decoder
|
|
if fix:
|
|
fix_params(self.sam_mask_decoder) # move to runner to implement
|
|
|
|
def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True):
|
|
low_res_masks, iou_predictions = self.sam_mask_decoder(image_embeddings=image_embeddings,
|
|
image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output, )
|
|
return low_res_masks, iou_predictions
|
|
|
|
|
|
class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
|
|
def __init__(self, ori_sam: Sam, fix=False, class_num=20):
|
|
super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix)
|
|
self.decoder_neck = MaskDecoderNeck(transformer_dim=self.sam_mask_decoder.transformer_dim,
|
|
transformer=self.sam_mask_decoder.transformer,
|
|
num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs)
|
|
self.decoder_head = SemSegHead(transformer_dim=self.sam_mask_decoder.transformer_dim,
|
|
num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs,
|
|
iou_head_depth=self.sam_mask_decoder.iou_head_depth,
|
|
iou_head_hidden_dim=self.sam_mask_decoder.iou_head_hidden_dim,
|
|
class_num=class_num)
|
|
# pair the params between ori mask_decoder and new mask_decoder_adapter
|
|
self.pair_params(self.decoder_neck)
|
|
self.pair_params(self.decoder_head)
|
|
|
|
def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True,
|
|
scale=1):
|
|
src, iou_token_out, mask_tokens_out, src_shape = self.decoder_neck(image_embeddings=image_embeddings,
|
|
image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
multimask_output=multimask_output, )
|
|
masks, iou_pred = self.decoder_head(src, iou_token_out, mask_tokens_out, src_shape, mask_scale=scale)
|
|
return masks, iou_pred
|
|
|
|
def pair_params(self, target_model: nn.Module):
|
|
src_dict = self.sam_mask_decoder.state_dict()
|
|
for name, value in target_model.named_parameters():
|
|
if name in src_dict.keys():
|
|
value.data.copy_(src_dict[name].data)
|
|
|
|
|
|
# Lightly adapted from
|
|
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
output_dim: int,
|
|
num_layers: int,
|
|
sigmoid_output: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
self.layers = nn.ModuleList(
|
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
|
)
|
|
self.sigmoid_output = sigmoid_output
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.layers):
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
if self.sigmoid_output:
|
|
x = F.sigmoid(x)
|
|
return x
|