diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..653f067 --- /dev/null +++ b/infer.py @@ -0,0 +1,64 @@ +import argparse +from omegaconf import OmegaConf +from torch.utils.data import DataLoader +from datasets import get_dataset,get_lettuce_dataset +from losses import get_losses +from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner +import random +import numpy as np +import torch +import os +import glob +from torchvision import transforms +from PIL import Image +import cv2 +supported_tasks = ['detection', 'semantic_seg', 'instance_seg'] +parser = argparse.ArgumentParser() +parser.add_argument('--task_name', default='semantic_seg', type=str) +parser.add_argument('--cfg', default=None, type=str) + +if __name__ == '__main__': + args = parser.parse_args() + task_name = args.task_name + if args.cfg is not None: + config = OmegaConf.load(args.cfg) + else: + assert task_name in supported_tasks, "Please input the supported task name." + config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name)) + test_cfg = config.test + train_cfg = config.train + model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params) + #我以./weed_data下的图片为例 + test_data_path = '../test_data' + output_dir = './result' + model.load_state_dict(torch.load('experiment/model/semantic_sam/lam_vit_b_01ec64.pth')) + model.eval() + model.cuda() + preprocess = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + ]) + for img_path in glob.glob(os.path.join(test_data_path,'*.jpg')): + print(img_path) + # 读取图片并转换为 RGB + #img = Image.open(img_path).convert("RGB") + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # 对图片进行预处理,并增加 batch 维度 + #input_tensor = preprocess(img).unsqueeze(0).cuda() + img = cv2.resize(img, (1024, 1024)) + input_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().cuda().unsqueeze(0) + with torch.no_grad(): + output,_ = model(input_tensor) + pred = torch.argmax(output, dim=1) + pred = pred.squeeze(0).cpu().numpy() + img_name = os.path.splitext(os.path.basename(img_path))[0] + sub_output_dir = os.path.join(output_dir, img_name) + os.makedirs(sub_output_dir, exist_ok=True) + unique_labels = np.unique(pred) + for label in unique_labels: + # 生成二值 mask:像素属于该类别则为 255,否则为 0 + binary_mask = (pred == label).astype(np.uint8) * 255 + mask_filename = os.path.join(sub_output_dir, f'class_{label}.png') + cv2.imwrite(mask_filename, binary_mask) + print(f"Processed {img_path}, saved masks to {sub_output_dir}") \ No newline at end of file