laser_weeding/infer.py

64 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from datasets import get_dataset,get_lettuce_dataset
from losses import get_losses
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
import random
import numpy as np
import torch
import os
import glob
from torchvision import transforms
from PIL import Image
import cv2
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', default='semantic_seg', type=str)
parser.add_argument('--cfg', default=None, type=str)
if __name__ == '__main__':
args = parser.parse_args()
task_name = args.task_name
if args.cfg is not None:
config = OmegaConf.load(args.cfg)
else:
assert task_name in supported_tasks, "Please input the supported task name."
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
test_cfg = config.test
train_cfg = config.train
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
#我以./weed_data下的图片为例
test_data_path = '../test_data'
output_dir = './result'
model.load_state_dict(torch.load('experiment/model/semantic_sam/lam_vit_b_01ec64.pth'))
model.eval()
model.cuda()
preprocess = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
])
for img_path in glob.glob(os.path.join(test_data_path,'*.jpg')):
print(img_path)
# 读取图片并转换为 RGB
#img = Image.open(img_path).convert("RGB")
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 对图片进行预处理,并增加 batch 维度
#input_tensor = preprocess(img).unsqueeze(0).cuda()
img = cv2.resize(img, (1024, 1024))
input_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().cuda().unsqueeze(0)
with torch.no_grad():
output,_ = model(input_tensor)
pred = torch.argmax(output, dim=1)
pred = pred.squeeze(0).cpu().numpy()
img_name = os.path.splitext(os.path.basename(img_path))[0]
sub_output_dir = os.path.join(output_dir, img_name)
os.makedirs(sub_output_dir, exist_ok=True)
unique_labels = np.unique(pred)
for label in unique_labels:
# 生成二值 mask像素属于该类别则为 255否则为 0
binary_mask = (pred == label).astype(np.uint8) * 255
mask_filename = os.path.join(sub_output_dir, f'class_{label}.png')
cv2.imwrite(mask_filename, binary_mask)
print(f"Processed {img_path}, saved masks to {sub_output_dir}")