import os import cv2 import sys import xml.etree.ElementTree as ET from xml.dom import minidom from ultralytics import YOLO from PyQt5.QtWidgets import ( QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider, QCheckBox, QScrollArea, QGroupBox, QGridLayout ) from PyQt5.QtCore import Qt, QThread, pyqtSignal # ========== 后台任务线程 ========== class DetectionThread(QThread): update_rgb_progress = pyqtSignal(int) update_model_progress = pyqtSignal(int) update_xml_progress = pyqtSignal(int) log_message = pyqtSignal(str) finished_signal = pyqtSignal() def __init__(self, model_path, input_dir, output_xml_dir, temp_dir, conf, classes): super().__init__() self.model_path = model_path self.input_dir = input_dir self.output_xml_dir = output_xml_dir self.temp_dir = temp_dir self.conf = conf self.classes = classes # 类别映射 self.class_mapping = { 0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19", 5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202", 10: "EM203", 11: "EM180", 12: "EM181" } def run(self): try: os.makedirs(self.output_xml_dir, exist_ok=True) os.makedirs(self.temp_dir, exist_ok=True) # 加载模型 self.log_message.emit(f"加载模型:{self.model_path}") model = YOLO(self.model_path) self.log_message.emit("模型加载完成 ✅") # 读取图片 imgs = [f for f in os.listdir(self.input_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))] total_imgs = len(imgs) if total_imgs == 0: self.log_message.emit("输入目录中未找到图片 ❌") self.finished_signal.emit() return # ========= 灰度转RGB阶段 ========= self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待") for i, img_name in enumerate(imgs, 1): img_path = os.path.join(self.input_dir, img_name) img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) if img is None: continue if len(img.shape) == 2 or img.shape[2] == 1: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) cv2.imwrite(os.path.join(self.temp_dir, img_name), img) self.update_rgb_progress.emit(int(i / total_imgs * 100)) self.log_message.emit("RGB 转换完成 ✅") # ========= 模型检测阶段 ========= self.log_message.emit("开始执行YOLO检测...") results = [] for i, file in enumerate(imgs, 1): temp_path = os.path.join(self.temp_dir, file) result = model.predict( source=temp_path, conf=self.conf, save=False, classes=self.classes if self.classes else None ) results.extend(result) self.update_model_progress.emit(int(i / total_imgs * 100)) self.log_message.emit(f"[检测完成] {file}") # ========= XML生成阶段 ========= self.log_message.emit("开始生成VOC格式XML...") total_results = len(results) for j, result in enumerate(results, 1): file_name = os.path.basename(result.path) image_path = os.path.join(self.input_dir, file_name) img = cv2.imread(image_path) if img is None: continue height, width, depth = img.shape annotation = ET.Element("annotation") folder = ET.SubElement(annotation, "folder") folder.text = os.path.basename(os.path.dirname(image_path)) filename = ET.SubElement(annotation, "filename") filename.text = file_name path = ET.SubElement(annotation, "path") path.text = os.path.abspath(image_path) source = ET.SubElement(annotation, "source") database = ET.SubElement(source, "database") database.text = "Unknown" size = ET.SubElement(annotation, "size") ET.SubElement(size, "width").text = str(width) ET.SubElement(size, "height").text = str(height) ET.SubElement(size, "depth").text = str(depth) segmented = ET.SubElement(annotation, "segmented") segmented.text = "0" for box in result.boxes: cls = int(box.cls[0]) xyxy = box.xyxy[0].tolist() xmin, ymin, xmax, ymax = map(int, xyxy) obj = ET.SubElement(annotation, "object") ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls)) ET.SubElement(obj, "pose").text = "Unspecified" ET.SubElement(obj, "truncated").text = "0" ET.SubElement(obj, "difficult").text = "0" bndbox = ET.SubElement(obj, "bndbox") ET.SubElement(bndbox, "xmin").text = str(xmin) ET.SubElement(bndbox, "ymin").text = str(ymin) ET.SubElement(bndbox, "xmax").text = str(xmax) ET.SubElement(bndbox, "ymax").text = str(ymax) xml_str = minidom.parseString(ET.tostring(annotation, 'utf-8')).toprettyxml(indent="\t", encoding="utf-8").decode("utf-8") xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()]) xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml") with open(xml_path, "w", encoding="utf-8") as f: f.write(xml_str) self.update_xml_progress.emit(int(j / total_results * 100)) self.log_message.emit(f"[生成完成] {xml_path}") self.log_message.emit("✅ 所有任务完成!") except Exception as e: self.log_message.emit(f"❌ 错误: {e}") finally: self.finished_signal.emit() # ========== 主界面 ========== class MainWindow(QWidget): def __init__(self): super().__init__() self.setWindowTitle("YOLO 半自动标注工具(华东专用)") self.resize(750, 750) layout = QVBoxLayout() # 路径部分 self.model_edit = self._add_path_selector(layout, "模型路径:") self.input_edit = self._add_path_selector(layout, "输入图片文件夹:") self.output_edit = self._add_path_selector(layout, "输出XML文件夹:") self.temp_edit = self._add_path_selector(layout, "临时RGB图片文件夹:") # 置信度 conf_layout = QHBoxLayout() conf_layout.addWidget(QLabel("置信度阈值:")) self.conf_slider = QSlider(Qt.Horizontal) self.conf_slider.setRange(0, 100) self.conf_slider.setValue(20) self.conf_value = QLabel("0.2") self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100))) conf_layout.addWidget(self.conf_slider) conf_layout.addWidget(self.conf_value) layout.addLayout(conf_layout) # 类别选择区域 self.class_mapping = { 0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19", 5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202", 10: "EM203", 11: "EM180", 12: "EM181" } layout.addWidget(QLabel("选择要检测的类别(可多选):")) self.class_checkboxes = {} class_group = QGroupBox() grid = QGridLayout() for idx, (key, name) in enumerate(self.class_mapping.items()): cb = QCheckBox(f"{key}: {name}") self.class_checkboxes[key] = cb grid.addWidget(cb, idx // 4, idx % 4) class_group.setLayout(grid) scroll = QScrollArea() scroll.setWidget(class_group) scroll.setWidgetResizable(True) scroll.setFixedHeight(120) layout.addWidget(scroll) # 控制按钮 btn_layout = QHBoxLayout() self.btn_load = QPushButton("加载模型") self.btn_start = QPushButton("开始检测") btn_layout.addWidget(self.btn_load) btn_layout.addWidget(self.btn_start) layout.addLayout(btn_layout) # 三个进度条 self.progress_rgb = QProgressBar() self.progress_rgb.setFormat("RGB转换进度:%p%") self.progress_model = QProgressBar() self.progress_model.setFormat("模型检测进度:%p%") self.progress_xml = QProgressBar() self.progress_xml.setFormat("XML生成进度:%p%") layout.addWidget(self.progress_rgb) layout.addWidget(self.progress_model) layout.addWidget(self.progress_xml) # 日志 layout.addWidget(QLabel("日志输出:")) self.log_output = QTextEdit() self.log_output.setReadOnly(True) layout.addWidget(self.log_output) self.setLayout(layout) # 信号绑定 self.btn_load.clicked.connect(self.load_model) self.btn_start.clicked.connect(self.start_detection) self.model_loaded = False def _add_path_selector(self, layout, label_text): hlayout = QHBoxLayout() label = QLabel(label_text) edit = QLineEdit() btn = QPushButton("浏览") hlayout.addWidget(label) hlayout.addWidget(edit) hlayout.addWidget(btn) layout.addLayout(hlayout) btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件")) return edit def _select_path(self, edit, title): if "模型" in title: path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)") else: path = QFileDialog.getExistingDirectory(self, title) if path: edit.setText(path) def load_model(self): path = self.model_edit.text().strip() if not os.path.exists(path): self._log("❌ 模型文件不存在") return self.model_loaded = True self._log("✅ 模型路径加载成功,可开始检测") def start_detection(self): if not self.model_loaded: self._log("⚠️ 请先加载模型路径") return model_path = self.model_edit.text().strip() input_dir = self.input_edit.text().strip() output_xml_dir = self.output_edit.text().strip() temp_dir = self.temp_edit.text().strip() conf = float(self.conf_slider.value() / 100) selected_classes = [cid for cid, cb in self.class_checkboxes.items() if cb.isChecked()] self._log(f"选中类别:{selected_classes if selected_classes else '全部'}") if not os.path.exists(input_dir): self._log("❌ 输入文件夹不存在") return if not temp_dir: self._log("⚠️ 请选择临时RGB图片文件夹") return self.progress_rgb.setValue(0) self.progress_model.setValue(0) self.progress_xml.setValue(0) self.thread = DetectionThread(model_path, input_dir, output_xml_dir, temp_dir, conf, selected_classes) self.thread.update_rgb_progress.connect(self.progress_rgb.setValue) self.thread.update_model_progress.connect(self.progress_model.setValue) self.thread.update_xml_progress.connect(self.progress_xml.setValue) self.thread.log_message.connect(self._log) self.thread.finished_signal.connect(self._finish) self.thread.start() self.btn_start.setEnabled(False) self._log("开始执行检测任务...") def _log(self, msg): self.log_output.append(msg) self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum()) def _finish(self): self.btn_start.setEnabled(True) self._log("任务完成 ✅") # ========== 主程序入口 ========== if __name__ == "__main__": app = QApplication(sys.argv) window = MainWindow() window.show() sys.exit(app.exec_())