import os import cv2 import sys import xml.etree.ElementTree as ET from xml.dom import minidom from ultralytics import YOLO from PyQt5.QtWidgets import ( QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider ) from PyQt5.QtCore import Qt, QThread, pyqtSignal # ========== 后台任务线程 ========== class DetectionThread(QThread): update_model_progress = pyqtSignal(int) update_xml_progress = pyqtSignal(int) log_message = pyqtSignal(str) finished_signal = pyqtSignal() def __init__(self, model_path, input_dir, output_xml_dir, conf, classes): super().__init__() self.model_path = model_path self.input_dir = input_dir self.output_xml_dir = output_xml_dir self.conf = conf self.classes = classes # 类别映射(可修改) self.class_mapping = { 0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19", 5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202", 10: "EM203", 11: "EM180", 12: "EM181" } def run(self): try: os.makedirs(self.output_xml_dir, exist_ok=True) temp_dir = os.path.join(self.output_xml_dir, "temp_rgb_images") os.makedirs(temp_dir, exist_ok=True) # 加载模型 self.log_message.emit(f"加载模型:{self.model_path}") model = YOLO(self.model_path) self.log_message.emit("模型加载完成 ✅") # 读取图片列表 imgs = [f for f in os.listdir(self.input_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))] total_imgs = len(imgs) if total_imgs == 0: self.log_message.emit("输入目录中未找到图片 ❌") self.finished_signal.emit() return # 灰度转RGB并保存临时文件 self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待") for img_name in imgs: img_path = os.path.join(self.input_dir, img_name) img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) if img is None: continue if len(img.shape) == 2 or img.shape[2] == 1: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) cv2.imwrite(os.path.join(temp_dir, img_name), img) # ========= 模型检测阶段 ========= self.log_message.emit("开始执行YOLO检测...") results = [] for i, file in enumerate(imgs, 1): temp_path = os.path.join(temp_dir, file) result = model.predict( source=temp_path, conf=self.conf, save=False, classes=self.classes ) results.extend(result) progress = int(i / total_imgs * 100) self.update_model_progress.emit(progress) self.log_message.emit(f"[检测完成] {file}") # ========= XML生成阶段 ========= self.log_message.emit("开始生成VOC格式XML...") total_results = len(results) for j, result in enumerate(results, 1): file_name = os.path.basename(result.path) image_path = os.path.join(self.input_dir, file_name) img = cv2.imread(image_path) if img is None: continue height, width, depth = img.shape annotation = ET.Element("annotation") folder = ET.SubElement(annotation, "folder") folder.text = os.path.basename(os.path.dirname(image_path)) filename = ET.SubElement(annotation, "filename") filename.text = file_name path = ET.SubElement(annotation, "path") path.text = os.path.abspath(image_path) source = ET.SubElement(annotation, "source") database = ET.SubElement(source, "database") database.text = "Unknown" size = ET.SubElement(annotation, "size") ET.SubElement(size, "width").text = str(width) ET.SubElement(size, "height").text = str(height) ET.SubElement(size, "depth").text = str(depth) segmented = ET.SubElement(annotation, "segmented") segmented.text = "0" for box in result.boxes: cls = int(box.cls[0]) xyxy = box.xyxy[0].tolist() xmin, ymin, xmax, ymax = map(int, xyxy) obj = ET.SubElement(annotation, "object") ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls)) ET.SubElement(obj, "pose").text = "Unspecified" ET.SubElement(obj, "truncated").text = "0" ET.SubElement(obj, "difficult").text = "0" bndbox = ET.SubElement(obj, "bndbox") ET.SubElement(bndbox, "xmin").text = str(xmin) ET.SubElement(bndbox, "ymin").text = str(ymin) ET.SubElement(bndbox, "xmax").text = str(xmax) ET.SubElement(bndbox, "ymax").text = str(ymax) rough_string = ET.tostring(annotation, 'utf-8') reparsed = minidom.parseString(rough_string) xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8") xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()]) xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml") with open(xml_path, "w", encoding="utf-8") as f: f.write(xml_str) progress = int(j / total_results * 100) self.update_xml_progress.emit(progress) self.log_message.emit(f"[生成完成] {xml_path}") self.log_message.emit("✅ 所有任务完成!") except Exception as e: self.log_message.emit(f"❌ 错误: {e}") finally: self.finished_signal.emit() # ========== 主界面 ========== class MainWindow(QWidget): def __init__(self): super().__init__() self.setWindowTitle("YOLO 自动检测 + VOC生成工具") self.resize(700, 600) layout = QVBoxLayout() # 模型路径 self.model_edit = self._add_path_selector(layout, "模型路径:") # 输入文件夹 self.input_edit = self._add_path_selector(layout, "输入图片文件夹:") # 输出文件夹 self.output_edit = self._add_path_selector(layout, "输出XML文件夹:") # 置信度滑块 conf_layout = QHBoxLayout() conf_layout.addWidget(QLabel("置信度阈值:")) self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(0, 100) self.conf_slider.setValue(20) self.conf_value = QLabel("0.2") self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100))) conf_layout.addWidget(self.conf_slider) conf_layout.addWidget(self.conf_value) layout.addLayout(conf_layout) # 控制按钮 btn_layout = QHBoxLayout() self.btn_load = QPushButton("加载模型") self.btn_start = QPushButton("开始检测") btn_layout.addWidget(self.btn_load) btn_layout.addWidget(self.btn_start) layout.addLayout(btn_layout) # 两个进度条 self.progress_model = QProgressBar() self.progress_model.setFormat("模型检测进度:%p%") self.progress_xml = QProgressBar() self.progress_xml.setFormat("XML生成进度:%p%") layout.addWidget(self.progress_model) layout.addWidget(self.progress_xml) # 日志输出 layout.addWidget(QLabel("日志输出:")) self.log_output = QTextEdit() self.log_output.setReadOnly(True) layout.addWidget(self.log_output) self.setLayout(layout) # 信号绑定 self.btn_load.clicked.connect(self.load_model) self.btn_start.clicked.connect(self.start_detection) self.model_loaded = False def _add_path_selector(self, layout, label_text): hlayout = QHBoxLayout() label = QLabel(label_text) edit = QLineEdit() btn = QPushButton("浏览") hlayout.addWidget(label) hlayout.addWidget(edit) hlayout.addWidget(btn) layout.addLayout(hlayout) btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件")) return edit def _select_path(self, edit, title): if "模型" in title: path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)") else: path = QFileDialog.getExistingDirectory(self, title) if path: edit.setText(path) def load_model(self): model_path = self.model_edit.text().strip() if not os.path.exists(model_path): self._log("❌ 模型文件不存在") return self.model_loaded = True self._log("✅ 模型路径加载成功,可开始检测") def start_detection(self): if not self.model_loaded: self._log("⚠️ 请先加载模型路径") return model_path = self.model_edit.text().strip() input_dir = self.input_edit.text().strip() output_xml_dir = self.output_edit.text().strip() conf = float(self.conf_slider.value() / 100) classes = [11] # 可修改为自定义选择 if not os.path.exists(input_dir): self._log("❌ 输入文件夹不存在") return self.progress_model.setValue(0) self.progress_xml.setValue(0) self.thread = DetectionThread(model_path, input_dir, output_xml_dir, conf, classes) self.thread.update_model_progress.connect(self.progress_model.setValue) self.thread.update_xml_progress.connect(self.progress_xml.setValue) self.thread.log_message.connect(self._log) self.thread.finished_signal.connect(self._finish) self.thread.start() self.btn_start.setEnabled(False) self._log("开始执行检测任务...") def _log(self, msg): self.log_output.append(msg) self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum()) def _finish(self): self.btn_start.setEnabled(True) self._log("任务完成 ✅") # ========== 主程序入口 ========== if __name__ == "__main__": app = QApplication(sys.argv) window = MainWindow() window.show() sys.exit(app.exec_())