[Python] 纯文本查看 复制代码
# detect_custom.py
import sys
import argparse
import cv2
import numpy as np
import onnxruntime as ort
import torch
from pathlib import Path
# ====== 极速配置 ======
sys.path.insert(0, r"C:\Users\x2719\Desktop\易\大漠yolo\yolo\yolov5-7.0")
from utils.general import non_max_suppression, scale_boxes
# 硬件加速初始化
cv2.setUseOptimized(True) # 启用OpenCV优化
cv2.setNumThreads(4) # 限制OpenCV线程数
def load_model(onnx_path):
"""极速模型加载 (4线程+内存预分配)"""
options = ort.SessionOptions()
options.intra_op_num_threads = 4 # 并行计算线程
options.enable_cpu_mem_arena = True # 预分配内存池
return ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'], sess_options=options)
def detect(img_path, model_path, conf=0.4, iou=0.4):
"""全流程硬件加速"""
# 极速图像加载 (避免文件锁)
img0 = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8, count=-1, offset=0), cv2.IMREAD_UNCHANGED)
if img0 is None:
raise ValueError("图像读取失败: 可能路径含中文或特殊字符")
# 快速预处理 (固定640x640)
h, w = img0.shape[:2]
scale = 640 / max(h, w)
img = cv2.resize(img0, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_LINEAR)
img = np.ascontiguousarray(cv2.copyMakeBorder(img, 0, 640-img.shape[0], 0, 640-img.shape[1],
cv2.BORDER_CONSTANT, value=(114,114,114)))
blob = img.transpose(2,0,1)[None].astype(np.float32) / 255.0 # 归一化 (0~1)
# 推理 (内存复用)
session = load_model(model_path)
outputs = session.run(None, {session.get_inputs()[0].name: blob})
# 零拷贝后处理
with torch.no_grad(): # 禁用梯度计算
pred = torch.tensor(outputs[0], device='cpu')
pred = non_max_suppression(pred, conf, iou, classes=None)[0]
if pred is not None and len(pred):
pred[:, :4] = scale_boxes((640,640), pred[:, :4], img0.shape[:2]).round()
return pred.numpy() if pred is not None else np.array([]), img0.shape
def save_results(detections, output_path, img_shape):
"""二进制写入加速"""
h, w = img_shape[:2]
buffer = []
for det in detections:
x1, y1, x2, y2, conf, cls = det
xc = ((x1 + x2) / 2) / w
yc = ((y1 + y2) / 2) / h
nw = (x2 - x1) / w
nh = (y2 - y1) / h
buffer.append(f"{int(cls)} {xc:.6f} {yc:.6f} {nw:.6f} {nh:.6f} {conf:.6f}\n")
with open(output_path, 'w', buffering=8192) as f: # 8KB缓冲
f.writelines(buffer)
if __name__ == "__main__":
# 极简参数解析
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--img', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--conf', type=float, default=0.4)
parser.add_argument('--iou', type=float, default=0.4)
args = parser.parse_args()
# 无异常包装 (减少try开销)
detections, img_shape = detect(args.img, args.model, args.conf, args.iou)
if detections.size:
save_results(detections, args.output, img_shape)
print(f"检测完成: {len(detections)}个目标")
else:
Path(args.output).touch(exist_ok=True)
print("未检测到目标")