Files
autoanno/autoanno.py
wangjialiang e0626adfb6 First
2025-11-12 17:04:47 +08:00

130 lines
3.5 KiB
Python

import os
import cv2
import xml.etree.ElementTree as ET
from xml.dom import minidom
from ultralytics import YOLO
# ========== 配置部分 ==========
model_path = "epoch220.pt"
input_dir = "test_images/"
output_xml_dir = "annotations/"
temp_dir = "temp_rgb_images"
# ===== 自定义类别名称映射 =====
class_mapping = {
0: "EM14",
1: "EM18",
2: "EM17",
3: "EM170",
4: "EM19",
5: "EM190",
6: "EM20",
7: "EM200",
8: "EM201",
9: "EM202",
10: "EM203",
11: "EM180",
12: "EM181"
}
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(output_xml_dir, exist_ok=True)
# ========== 加载模型 ==========
model = YOLO(model_path)
# ========== 灰度转RGB ==========
for file in os.listdir(input_dir):
if not file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
continue
input_path = os.path.join(input_dir, file)
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 2 or img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
temp_path = os.path.join(temp_dir, file)
cv2.imwrite(temp_path, img)
# ========== 执行检测 ==========
results = model.predict(source=temp_dir, conf=0.2, save=False,classes=[11])
# ========== 生成标准化VOC XML ==========
for result in results:
file_name = os.path.basename(result.path)
image_path = os.path.join(input_dir, file_name)
img = cv2.imread(image_path)
height, width, depth = img.shape
annotation = ET.Element("annotation")
folder = ET.SubElement(annotation, "folder")
folder.text = os.path.basename(os.path.dirname(image_path))
filename = ET.SubElement(annotation, "filename")
filename.text = file_name
path = ET.SubElement(annotation, "path")
path.text = os.path.abspath(image_path)
source = ET.SubElement(annotation, "source")
database = ET.SubElement(source, "database")
database.text = "Unknown"
size = ET.SubElement(annotation, "size")
ET.SubElement(size, "width").text = str(width)
ET.SubElement(size, "height").text = str(height)
ET.SubElement(size, "depth").text = str(depth)
segmented = ET.SubElement(annotation, "segmented")
segmented.text = "0"
# 遍历检测框
for box in result.boxes:
cls = int(box.cls[0])
xyxy = box.xyxy[0].tolist()
xmin, ymin, xmax, ymax = map(int, xyxy)
obj = ET.SubElement(annotation, "object")
name = ET.SubElement(obj, "name")
name.text = class_mapping.get(cls, str(cls))
pose = ET.SubElement(obj, "pose")
pose.text = "Unspecified"
truncated = ET.SubElement(obj, "truncated")
truncated.text = "0"
difficult = ET.SubElement(obj, "difficult")
difficult.text = "0"
bndbox = ET.SubElement(obj, "bndbox")
ET.SubElement(bndbox, "xmin").text = str(xmin)
ET.SubElement(bndbox, "ymin").text = str(ymin)
ET.SubElement(bndbox, "xmax").text = str(xmax)
ET.SubElement(bndbox, "ymax").text = str(ymax)
# ===== 用 minidom 格式化输出 =====
rough_string = ET.tostring(annotation, 'utf-8')
reparsed = minidom.parseString(rough_string)
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
# 去掉多余的第一行空行
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
# 保存文件
xml_path = os.path.join(output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
with open(xml_path, "w", encoding="utf-8") as f:
f.write(xml_str)
print(f"[生成完成] {xml_path}")