generated from yuanbiao/python_templates
inference.py
This commit is contained in:
parent
44adc2fbd3
commit
3ce7a9704f
|
|
@ -0,0 +1,64 @@
|
||||||
|
import argparse
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import get_dataset,get_lettuce_dataset
|
||||||
|
from losses import get_losses
|
||||||
|
from extend_sam import get_model, get_optimizer, get_scheduler, get_opt_pamams, get_runner
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task_name', default='semantic_seg', type=str)
|
||||||
|
parser.add_argument('--cfg', default=None, type=str)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
task_name = args.task_name
|
||||||
|
if args.cfg is not None:
|
||||||
|
config = OmegaConf.load(args.cfg)
|
||||||
|
else:
|
||||||
|
assert task_name in supported_tasks, "Please input the supported task name."
|
||||||
|
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
|
||||||
|
test_cfg = config.test
|
||||||
|
train_cfg = config.train
|
||||||
|
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
|
||||||
|
#我以./weed_data下的图片为例
|
||||||
|
test_data_path = '../test_data'
|
||||||
|
output_dir = './result'
|
||||||
|
model.load_state_dict(torch.load('experiment/model/semantic_sam/lam_vit_b_01ec64.pth'))
|
||||||
|
model.eval()
|
||||||
|
model.cuda()
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize((1024, 1024)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
for img_path in glob.glob(os.path.join(test_data_path,'*.jpg')):
|
||||||
|
print(img_path)
|
||||||
|
# 读取图片并转换为 RGB
|
||||||
|
#img = Image.open(img_path).convert("RGB")
|
||||||
|
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
# 对图片进行预处理,并增加 batch 维度
|
||||||
|
#input_tensor = preprocess(img).unsqueeze(0).cuda()
|
||||||
|
img = cv2.resize(img, (1024, 1024))
|
||||||
|
input_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().cuda().unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
output,_ = model(input_tensor)
|
||||||
|
pred = torch.argmax(output, dim=1)
|
||||||
|
pred = pred.squeeze(0).cpu().numpy()
|
||||||
|
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||||
|
sub_output_dir = os.path.join(output_dir, img_name)
|
||||||
|
os.makedirs(sub_output_dir, exist_ok=True)
|
||||||
|
unique_labels = np.unique(pred)
|
||||||
|
for label in unique_labels:
|
||||||
|
# 生成二值 mask:像素属于该类别则为 255,否则为 0
|
||||||
|
binary_mask = (pred == label).astype(np.uint8) * 255
|
||||||
|
mask_filename = os.path.join(sub_output_dir, f'class_{label}.png')
|
||||||
|
cv2.imwrite(mask_filename, binary_mask)
|
||||||
|
print(f"Processed {img_path}, saved masks to {sub_output_dir}")
|
||||||
Loading…
Reference in New Issue