generated from yuanbiao/python_templates
20 lines
601 B
Python
20 lines
601 B
Python
# copyright ziqi-jin
|
|
|
|
import torch.nn as nn
|
|
from .segment_anything_ori.modeling.sam import Sam
|
|
from .utils import fix_params
|
|
|
|
|
|
class BasePromptEncodeAdapter(nn.Module):
|
|
|
|
def __init__(self, ori_sam: Sam, fix=False):
|
|
super(BasePromptEncodeAdapter, self).__init__()
|
|
|
|
self.sam_prompt_encoder = ori_sam.prompt_encoder
|
|
if fix:
|
|
fix_params(self.sam_prompt_encoder)
|
|
|
|
def forward(self, points=None, boxes=None, masks=None):
|
|
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(points, boxes, masks)
|
|
return sparse_embeddings, dense_embeddings
|