generated from yuanbiao/python_templates
64 lines
2.7 KiB
Python
64 lines
2.7 KiB
Python
import argparse
|
||
from omegaconf import OmegaConf
|
||
from torch.utils.data import DataLoader
|
||
from datasets import get_dataset,get_lettuce_dataset
|
||
from losses import get_losses
|
||
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
|
||
import random
|
||
import numpy as np
|
||
import torch
|
||
import os
|
||
import glob
|
||
from torchvision import transforms
|
||
from PIL import Image
|
||
import cv2
|
||
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--task_name', default='semantic_seg', type=str)
|
||
parser.add_argument('--cfg', default=None, type=str)
|
||
|
||
if __name__ == '__main__':
|
||
args = parser.parse_args()
|
||
task_name = args.task_name
|
||
if args.cfg is not None:
|
||
config = OmegaConf.load(args.cfg)
|
||
else:
|
||
assert task_name in supported_tasks, "Please input the supported task name."
|
||
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
|
||
test_cfg = config.test
|
||
train_cfg = config.train
|
||
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
|
||
#我以./weed_data下的图片为例
|
||
test_data_path = '../test_data'
|
||
output_dir = './result'
|
||
model.load_state_dict(torch.load('experiment/model/semantic_sam/lam_vit_b_01ec64.pth'))
|
||
model.eval()
|
||
model.cuda()
|
||
preprocess = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
])
|
||
for img_path in glob.glob(os.path.join(test_data_path,'*.jpg')):
|
||
print(img_path)
|
||
# 读取图片并转换为 RGB
|
||
#img = Image.open(img_path).convert("RGB")
|
||
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
# 对图片进行预处理,并增加 batch 维度
|
||
#input_tensor = preprocess(img).unsqueeze(0).cuda()
|
||
img = cv2.resize(img, (1024, 1024))
|
||
input_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().cuda().unsqueeze(0)
|
||
with torch.no_grad():
|
||
output,_ = model(input_tensor)
|
||
pred = torch.argmax(output, dim=1)
|
||
pred = pred.squeeze(0).cpu().numpy()
|
||
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||
sub_output_dir = os.path.join(output_dir, img_name)
|
||
os.makedirs(sub_output_dir, exist_ok=True)
|
||
unique_labels = np.unique(pred)
|
||
for label in unique_labels:
|
||
# 生成二值 mask:像素属于该类别则为 255,否则为 0
|
||
binary_mask = (pred == label).astype(np.uint8) * 255
|
||
mask_filename = os.path.join(sub_output_dir, f'class_{label}.png')
|
||
cv2.imwrite(mask_filename, binary_mask)
|
||
print(f"Processed {img_path}, saved masks to {sub_output_dir}") |