|

分享源码
界面截图: |
- |
是否带模块: |
- |
备注说明: |
- |
from ultralytics import YOLO
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
# 加载模型(如报错请先安装ultralytics: pip install ultralytics)
model = YOLO('best.pt') # 用Ultralytics的YOLO类加载
# 输入图片文件夹
img_dir = Path('images') # 你的图片目录,注意路径前加r或用双斜杠
img_files = list(img_dir.glob('*.jpg')) + list(img_dir.glob('*.png')) + list(img_dir.glob('*.bmp'))
# 输出结果文件夹(图片)
output_img_dir = Path('output_img')
output_img_dir.mkdir(exist_ok=True)
for img_path in img_files:
results = model(img_path)
for r in results:
if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
# 提取所有目标的中心点坐标、框、分类名
all_targets = [] # [(cx, cy, idx, [x1, y1, x2, y2], class_id)]
for idx, box in enumerate(r.boxes):
x1, y1, x2, y2 = box.xyxy[0].tolist()
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
class_id = int(box.cls[0].item()) if hasattr(box, 'cls') else -1
all_targets.append((cx, cy, idx, [x1, y1, x2, y2], class_id))
# 取y最大的3个(最下面的3个),按x从小到大排序
bottom3 = sorted(all_targets, key=lambda x: -x[1])[:3
bottom3 = sorted(bottom3, key=lambda x: x[0])
# 对每个bottom3,找上方且分类id相同且x最接近的目标
matched_top = []
used_top_idx = set()
for b in bottom3:
bx, by, bidx, bbox, bcls = b
# 在所有上方且分类id相同且未被匹配过的目标中找x最接近的
candidates = [t for t in all_targets if t[1 < by and t[4 == bcls and t[2 not in used_top_idx
if not candidates:
continue
best_top = min(candidates, key=lambda t: abs(t[0 - bx))
matched_top.append(best_top)
used_top_idx.add(best_top[2])
# 画框和数字
img_with_boxes = r.plot()
img_pil = Image.fromarray(img_with_boxes)
draw = ImageDraw.Draw(img_pil)
try:
font = ImageFont.truetype("arial.ttf", 32)
except:
font = None
for i, (cx, cy, idx, (x1, y1, x2, y2), class_id) in enumerate(matched_top):
label = f"{i+1}"
draw.text((x1, y1-30), label, fill=(0,0,255), font=font)
output_img_path = output_img_dir / img_path.name
img_pil.save(output_img_path)
print(f"{img_path.name} 检测结果图片已保存到 {output_img_path}")
接口代码
from ultralytics import YOLO
from flask import Flask, request, jsonify
import base64
import tempfile
import os
# 加载模型(只加载一次,避免多次加载)
model = YOLO('best.pt')
def get_matched_points(img_path):
"""
输入图片路径,返回匹配到的3个点的中心坐标 [(cx1, cy1), (cx2, cy2), (cx3, cy3)]
若未检测到3个点,返回空列表
"""
results = model(img_path)
for r in results:
if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
all_targets = [] # [(cx, cy, idx, [x1, y1, x2, y2], class_id)]
for idx, box in enumerate(r.boxes):
x1, y1, x2, y2 = box.xyxy[0].tolist()
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
class_id = int(box.cls[0].item()) if hasattr(box, 'cls') else -1
all_targets.append((cx, cy, idx, [x1, y1, x2, y2], class_id))
bottom3 = sorted(all_targets, key=lambda x: -x[1])[:3
bottom3 = sorted(bottom3, key=lambda x: x[0])
matched_top = []
used_top_idx = set()
for b in bottom3:
bx, by, bidx, bbox, bcls = b
candidates = [t for t in all_targets if t[1 < by and t[4 == bcls and t[2 not in used_top_idx
if not candidates:
continue
best_top = min(candidates, key=lambda t: abs(t[0 - bx))
matched_top.append(best_top)
used_top_idx.add(best_top[2])
if len(matched_top) == 3:
return [(int(cx), int(cy)) for (cx, cy, *_ ) in matched_top
return [] # 未检测到3个点
app = Flask(__name__)
@app.route('/get_points', methods=['POST'])
def get_points():
data = request.get_json()
img_base64 = data.get('img_base64', '')
if not img_base64:
return jsonify({'error': 'No img_base64 provided'}), 400
# 解码base64并保存为临时文件
try:
img_bytes = base64.b64decode(img_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix='.bmp') as tmp_file:
tmp_file.write(img_bytes)
tmp_img_path = tmp_file.name
points = get_matched_points(tmp_img_path)
os.remove(tmp_img_path)
return jsonify({'points': points})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=1214, debug=True)
测试代码
import requests
import base64
# 读取本地图片并编码为base64
with open("images/00ad74af33c2700b5db5b7bf6adbddf9.bmp", "rb") as f:
img_base64 = base64.b64encode(f.read()).decode()
# 构造请求数据
data = {
"img_base64": img_base64
}
# 发送POST请求
response = requests.post("http://127.0.0.1:1214/get_points", json=data)
# 输出返回结果
print(response.json())
模型下载地址:[color=rgba(0, 0, 0, 0.85)]https://www.123865.com/s/oY0iVv-WXqyh
|
-
|