import os import cv2 import xml.etree.ElementTree as ET from xml.dom import minidom from ultralytics import YOLO # ========== 配置部分 ========== model_path = "epoch220.pt" input_dir = "test_images/" output_xml_dir = "annotations/" temp_dir = "temp_rgb_images" # ===== 自定义类别名称映射 ===== class_mapping = { 0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19", 5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202", 10: "EM203", 11: "EM180", 12: "EM181" } os.makedirs(temp_dir, exist_ok=True) os.makedirs(output_xml_dir, exist_ok=True) # ========== 加载模型 ========== model = YOLO(model_path) # ========== 灰度转RGB ========== for file in os.listdir(input_dir): if not file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')): continue input_path = os.path.join(input_dir, file) img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) if len(img.shape) == 2 or img.shape[2] == 1: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) temp_path = os.path.join(temp_dir, file) cv2.imwrite(temp_path, img) # ========== 执行检测 ========== results = model.predict(source=temp_dir, conf=0.2, save=False,classes=[11]) # ========== 生成标准化VOC XML ========== for result in results: file_name = os.path.basename(result.path) image_path = os.path.join(input_dir, file_name) img = cv2.imread(image_path) height, width, depth = img.shape annotation = ET.Element("annotation") folder = ET.SubElement(annotation, "folder") folder.text = os.path.basename(os.path.dirname(image_path)) filename = ET.SubElement(annotation, "filename") filename.text = file_name path = ET.SubElement(annotation, "path") path.text = os.path.abspath(image_path) source = ET.SubElement(annotation, "source") database = ET.SubElement(source, "database") database.text = "Unknown" size = ET.SubElement(annotation, "size") ET.SubElement(size, "width").text = str(width) ET.SubElement(size, "height").text = str(height) ET.SubElement(size, "depth").text = str(depth) segmented = ET.SubElement(annotation, "segmented") segmented.text = "0" # 遍历检测框 for box in result.boxes: cls = int(box.cls[0]) xyxy = box.xyxy[0].tolist() xmin, ymin, xmax, ymax = map(int, xyxy) obj = ET.SubElement(annotation, "object") name = ET.SubElement(obj, "name") name.text = class_mapping.get(cls, str(cls)) pose = ET.SubElement(obj, "pose") pose.text = "Unspecified" truncated = ET.SubElement(obj, "truncated") truncated.text = "0" difficult = ET.SubElement(obj, "difficult") difficult.text = "0" bndbox = ET.SubElement(obj, "bndbox") ET.SubElement(bndbox, "xmin").text = str(xmin) ET.SubElement(bndbox, "ymin").text = str(ymin) ET.SubElement(bndbox, "xmax").text = str(xmax) ET.SubElement(bndbox, "ymax").text = str(ymax) # ===== 用 minidom 格式化输出 ===== rough_string = ET.tostring(annotation, 'utf-8') reparsed = minidom.parseString(rough_string) xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8") # 去掉多余的第一行空行 xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()]) # 保存文件 xml_path = os.path.join(output_xml_dir, os.path.splitext(file_name)[0] + ".xml") with open(xml_path, "w", encoding="utf-8") as f: f.write(xml_str) print(f"[生成完成] {xml_path}")