laser_weeding/extend_sam/mask_decoder_adapter.py

98 lines
5.0 KiB
Python

# @copyright ziqi-jin
import torch.nn as nn
import torch
from .segment_anything_ori.modeling.sam import Sam
from .utils import fix_params
from .segment_anything_ori.modeling.mask_decoder import MaskDecoder
from typing import List, Tuple
from torch.nn import functional as F
from .mask_decoder_heads import SemSegHead
from .mask_decoder_neck import MaskDecoderNeck
class BaseMaskDecoderAdapter(MaskDecoder):
'''
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
'''
# is fix and load params
def __init__(self, ori_sam: Sam, fix=False):
super(BaseMaskDecoderAdapter, self).__init__(transformer_dim=ori_sam.mask_decoder.transformer_dim,
transformer=ori_sam.mask_decoder.transformer)
self.sam_mask_decoder = ori_sam.mask_decoder
if fix:
fix_params(self.sam_mask_decoder) # move to runner to implement
def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True):
low_res_masks, iou_predictions = self.sam_mask_decoder(image_embeddings=image_embeddings,
image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output, )
return low_res_masks, iou_predictions
class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
def __init__(self, ori_sam: Sam, fix=False, class_num=20):
super(SemMaskDecoderAdapter, self).__init__(ori_sam, fix)
self.decoder_neck = MaskDecoderNeck(transformer_dim=self.sam_mask_decoder.transformer_dim,
transformer=self.sam_mask_decoder.transformer,
num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs)
self.decoder_head = SemSegHead(transformer_dim=self.sam_mask_decoder.transformer_dim,
num_multimask_outputs=self.sam_mask_decoder.num_multimask_outputs,
iou_head_depth=self.sam_mask_decoder.iou_head_depth,
iou_head_hidden_dim=self.sam_mask_decoder.iou_head_hidden_dim,
class_num=class_num)
# pair the params between ori mask_decoder and new mask_decoder_adapter
self.pair_params(self.decoder_neck)
self.pair_params(self.decoder_head)
def forward(self, image_embeddings, prompt_adapter, sparse_embeddings, dense_embeddings, multimask_output=True,
scale=1):
src, iou_token_out, mask_tokens_out, src_shape = self.decoder_neck(image_embeddings=image_embeddings,
image_pe=prompt_adapter.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output, )
masks, iou_pred = self.decoder_head(src, iou_token_out, mask_tokens_out, src_shape, mask_scale=scale)
return masks, iou_pred
def pair_params(self, target_model: nn.Module):
src_dict = self.sam_mask_decoder.state_dict()
for name, value in target_model.named_parameters():
if name in src_dict.keys():
value.data.copy_(src_dict[name].data)
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x