130 lines
3.5 KiB
Python
130 lines
3.5 KiB
Python
import os
|
|
import cv2
|
|
import xml.etree.ElementTree as ET
|
|
from xml.dom import minidom
|
|
from ultralytics import YOLO
|
|
|
|
# ========== 配置部分 ==========
|
|
model_path = "epoch220.pt"
|
|
input_dir = "test_images/"
|
|
output_xml_dir = "annotations/"
|
|
temp_dir = "temp_rgb_images"
|
|
|
|
|
|
# ===== 自定义类别名称映射 =====
|
|
class_mapping = {
|
|
0: "EM14",
|
|
1: "EM18",
|
|
2: "EM17",
|
|
3: "EM170",
|
|
4: "EM19",
|
|
5: "EM190",
|
|
6: "EM20",
|
|
7: "EM200",
|
|
8: "EM201",
|
|
9: "EM202",
|
|
10: "EM203",
|
|
11: "EM180",
|
|
12: "EM181"
|
|
}
|
|
|
|
|
|
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
os.makedirs(output_xml_dir, exist_ok=True)
|
|
|
|
# ========== 加载模型 ==========
|
|
model = YOLO(model_path)
|
|
|
|
|
|
|
|
|
|
# ========== 灰度转RGB ==========
|
|
for file in os.listdir(input_dir):
|
|
if not file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
|
|
continue
|
|
|
|
input_path = os.path.join(input_dir, file)
|
|
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
|
|
|
if len(img.shape) == 2 or img.shape[2] == 1:
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
temp_path = os.path.join(temp_dir, file)
|
|
cv2.imwrite(temp_path, img)
|
|
|
|
# ========== 执行检测 ==========
|
|
results = model.predict(source=temp_dir, conf=0.2, save=False,classes=[11])
|
|
|
|
# ========== 生成标准化VOC XML ==========
|
|
for result in results:
|
|
file_name = os.path.basename(result.path)
|
|
image_path = os.path.join(input_dir, file_name)
|
|
img = cv2.imread(image_path)
|
|
height, width, depth = img.shape
|
|
|
|
annotation = ET.Element("annotation")
|
|
|
|
folder = ET.SubElement(annotation, "folder")
|
|
folder.text = os.path.basename(os.path.dirname(image_path))
|
|
|
|
filename = ET.SubElement(annotation, "filename")
|
|
filename.text = file_name
|
|
|
|
path = ET.SubElement(annotation, "path")
|
|
path.text = os.path.abspath(image_path)
|
|
|
|
source = ET.SubElement(annotation, "source")
|
|
database = ET.SubElement(source, "database")
|
|
database.text = "Unknown"
|
|
|
|
size = ET.SubElement(annotation, "size")
|
|
ET.SubElement(size, "width").text = str(width)
|
|
ET.SubElement(size, "height").text = str(height)
|
|
ET.SubElement(size, "depth").text = str(depth)
|
|
|
|
segmented = ET.SubElement(annotation, "segmented")
|
|
segmented.text = "0"
|
|
|
|
# 遍历检测框
|
|
for box in result.boxes:
|
|
cls = int(box.cls[0])
|
|
xyxy = box.xyxy[0].tolist()
|
|
xmin, ymin, xmax, ymax = map(int, xyxy)
|
|
|
|
obj = ET.SubElement(annotation, "object")
|
|
|
|
name = ET.SubElement(obj, "name")
|
|
name.text = class_mapping.get(cls, str(cls))
|
|
|
|
|
|
pose = ET.SubElement(obj, "pose")
|
|
pose.text = "Unspecified"
|
|
|
|
truncated = ET.SubElement(obj, "truncated")
|
|
truncated.text = "0"
|
|
|
|
difficult = ET.SubElement(obj, "difficult")
|
|
difficult.text = "0"
|
|
|
|
bndbox = ET.SubElement(obj, "bndbox")
|
|
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
|
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
|
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
|
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
|
|
|
# ===== 用 minidom 格式化输出 =====
|
|
rough_string = ET.tostring(annotation, 'utf-8')
|
|
reparsed = minidom.parseString(rough_string)
|
|
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
|
|
|
# 去掉多余的第一行空行
|
|
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
|
|
|
# 保存文件
|
|
xml_path = os.path.join(output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
|
with open(xml_path, "w", encoding="utf-8") as f:
|
|
f.write(xml_str)
|
|
|
|
print(f"[生成完成] {xml_path}")
|