generated from yuanbiao/python_templates
export onnx model
This commit is contained in:
parent
6062510fb4
commit
44adc2fbd3
|
|
@ -0,0 +1,106 @@
|
|||
import argparse
|
||||
import time
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from extend_sam import get_model
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
import cv2
|
||||
|
||||
supported_tasks = ['detection', 'semantic_seg', 'instance_seg']
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', default='semantic_seg', type=str)
|
||||
parser.add_argument('--cfg', default=None, type=str)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
task_name = args.task_name
|
||||
if args.cfg is not None:
|
||||
config = OmegaConf.load(args.cfg)
|
||||
else:
|
||||
assert task_name in supported_tasks, "Please input the supported task name."
|
||||
config = OmegaConf.load("./config/{task_name}.yaml".format(task_name=args.task_name))
|
||||
|
||||
train_cfg = config.train
|
||||
model = get_model(model_name=train_cfg.model.sam_name, **train_cfg.model.params)
|
||||
|
||||
# 加载模型权重
|
||||
model.load_state_dict(torch.load('experiment/model/semantic_sam/model.pth', map_location='cpu'))
|
||||
model.eval()
|
||||
|
||||
# 准备示例输入
|
||||
dummy_input = torch.randn(1, 3, 1024, 1024)
|
||||
|
||||
# 确保模型和输入在同一设备上
|
||||
device = torch.device('cpu')
|
||||
model = model.to(device)
|
||||
dummy_input = dummy_input.to(device)
|
||||
|
||||
# 导出ONNX模型
|
||||
torch.onnx.export(
|
||||
model, # 要转换的模型
|
||||
dummy_input, # 模型的输入
|
||||
# "semantic_sam.onnx", # 导出的ONNX文件名
|
||||
"best_multi.onnx", # 导出的ONNX文件名
|
||||
export_params=True, # 存储训练好的参数权重
|
||||
opset_version=13, # ONNX算子集版本
|
||||
do_constant_folding=True, # 是否执行常量折叠优化
|
||||
input_names=['input'], # 输入节点的名称
|
||||
output_names=['output'], # 输出节点的名称
|
||||
dynamic_axes={ # 动态尺寸
|
||||
'input': {0: 'batch_size'},
|
||||
'output': {0: 'batch_size'}
|
||||
}
|
||||
)
|
||||
|
||||
# print("ONNX model exported successfully to semantic_sam.onnx")
|
||||
print("ONNX model exported successfully to best_multi.onnx")
|
||||
|
||||
# 加载ONNX模型进行推理
|
||||
# ort_session = onnxruntime.InferenceSession("semantic_sam.onnx")
|
||||
ort_session = onnxruntime.InferenceSession("./best_multi.onnx")
|
||||
|
||||
# 设置输入输出路径
|
||||
test_data_path = '../test_data'
|
||||
output_dir = './result2'
|
||||
|
||||
# 遍历测试图片进行推理
|
||||
for img_path in glob.glob(os.path.join(test_data_path,'*.jpg')):
|
||||
print(img_path)
|
||||
# 读取图片并转换为RGB
|
||||
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 预处理图片
|
||||
img = cv2.resize(img, (1024, 1024))
|
||||
input_tensor = img.transpose(2, 0, 1).astype(np.float32)
|
||||
input_tensor = np.expand_dims(input_tensor, axis=0)
|
||||
|
||||
# ONNX推理
|
||||
start_time = time.time() # Start time measurement
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: input_tensor}
|
||||
output = ort_session.run(None, ort_inputs)[0]
|
||||
end_time = time.time() # End time measurement
|
||||
|
||||
# 计算推理时间
|
||||
inference_time = end_time - start_time
|
||||
print(f"Inference time for {img_path}: {inference_time:.4f} seconds")
|
||||
|
||||
# 后处理
|
||||
pred = np.argmax(output, axis=1)
|
||||
pred = pred.squeeze(0)
|
||||
|
||||
# 保存结果
|
||||
img_name = os.path.splitext(os.path.basename(img_path))[0]
|
||||
sub_output_dir = os.path.join(output_dir, img_name)
|
||||
os.makedirs(sub_output_dir, exist_ok=True)
|
||||
|
||||
unique_labels = np.unique(pred)
|
||||
for label in unique_labels:
|
||||
binary_mask = (pred == label).astype(np.uint8) * 255
|
||||
mask_filename = os.path.join(sub_output_dir, f'class_{label}.png')
|
||||
cv2.imwrite(mask_filename, binary_mask)
|
||||
|
||||
print(f"Processed {img_path}, saved masks to {sub_output_dir}")
|
||||
Loading…
Reference in New Issue