Files
autoanno/autoannogui.py
wangjialiang e0626adfb6 First
2025-11-12 17:04:47 +08:00

280 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import sys
import xml.etree.ElementTree as ET
from xml.dom import minidom
from ultralytics import YOLO
from PyQt5.QtWidgets import (
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
# ========== 后台任务线程 ==========
class DetectionThread(QThread):
update_model_progress = pyqtSignal(int)
update_xml_progress = pyqtSignal(int)
log_message = pyqtSignal(str)
finished_signal = pyqtSignal()
def __init__(self, model_path, input_dir, output_xml_dir, conf, classes):
super().__init__()
self.model_path = model_path
self.input_dir = input_dir
self.output_xml_dir = output_xml_dir
self.conf = conf
self.classes = classes
# 类别映射(可修改)
self.class_mapping = {
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
10: "EM203", 11: "EM180", 12: "EM181"
}
def run(self):
try:
os.makedirs(self.output_xml_dir, exist_ok=True)
temp_dir = os.path.join(self.output_xml_dir, "temp_rgb_images")
os.makedirs(temp_dir, exist_ok=True)
# 加载模型
self.log_message.emit(f"加载模型:{self.model_path}")
model = YOLO(self.model_path)
self.log_message.emit("模型加载完成 ✅")
# 读取图片列表
imgs = [f for f in os.listdir(self.input_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
total_imgs = len(imgs)
if total_imgs == 0:
self.log_message.emit("输入目录中未找到图片 ❌")
self.finished_signal.emit()
return
# 灰度转RGB并保存临时文件
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
for img_name in imgs:
img_path = os.path.join(self.input_dir, img_name)
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None:
continue
if len(img.shape) == 2 or img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
cv2.imwrite(os.path.join(temp_dir, img_name), img)
# ========= 模型检测阶段 =========
self.log_message.emit("开始执行YOLO检测...")
results = []
for i, file in enumerate(imgs, 1):
temp_path = os.path.join(temp_dir, file)
result = model.predict(
source=temp_path, conf=self.conf, save=False, classes=self.classes
)
results.extend(result)
progress = int(i / total_imgs * 100)
self.update_model_progress.emit(progress)
self.log_message.emit(f"[检测完成] {file}")
# ========= XML生成阶段 =========
self.log_message.emit("开始生成VOC格式XML...")
total_results = len(results)
for j, result in enumerate(results, 1):
file_name = os.path.basename(result.path)
image_path = os.path.join(self.input_dir, file_name)
img = cv2.imread(image_path)
if img is None:
continue
height, width, depth = img.shape
annotation = ET.Element("annotation")
folder = ET.SubElement(annotation, "folder")
folder.text = os.path.basename(os.path.dirname(image_path))
filename = ET.SubElement(annotation, "filename")
filename.text = file_name
path = ET.SubElement(annotation, "path")
path.text = os.path.abspath(image_path)
source = ET.SubElement(annotation, "source")
database = ET.SubElement(source, "database")
database.text = "Unknown"
size = ET.SubElement(annotation, "size")
ET.SubElement(size, "width").text = str(width)
ET.SubElement(size, "height").text = str(height)
ET.SubElement(size, "depth").text = str(depth)
segmented = ET.SubElement(annotation, "segmented")
segmented.text = "0"
for box in result.boxes:
cls = int(box.cls[0])
xyxy = box.xyxy[0].tolist()
xmin, ymin, xmax, ymax = map(int, xyxy)
obj = ET.SubElement(annotation, "object")
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
ET.SubElement(obj, "pose").text = "Unspecified"
ET.SubElement(obj, "truncated").text = "0"
ET.SubElement(obj, "difficult").text = "0"
bndbox = ET.SubElement(obj, "bndbox")
ET.SubElement(bndbox, "xmin").text = str(xmin)
ET.SubElement(bndbox, "ymin").text = str(ymin)
ET.SubElement(bndbox, "xmax").text = str(xmax)
ET.SubElement(bndbox, "ymax").text = str(ymax)
rough_string = ET.tostring(annotation, 'utf-8')
reparsed = minidom.parseString(rough_string)
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
with open(xml_path, "w", encoding="utf-8") as f:
f.write(xml_str)
progress = int(j / total_results * 100)
self.update_xml_progress.emit(progress)
self.log_message.emit(f"[生成完成] {xml_path}")
self.log_message.emit("✅ 所有任务完成!")
except Exception as e:
self.log_message.emit(f"❌ 错误: {e}")
finally:
self.finished_signal.emit()
# ========== 主界面 ==========
class MainWindow(QWidget):
def __init__(self):
super().__init__()
self.setWindowTitle("YOLO 自动检测 + VOC生成工具")
self.resize(700, 600)
layout = QVBoxLayout()
# 模型路径
self.model_edit = self._add_path_selector(layout, "模型路径:")
# 输入文件夹
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
# 输出文件夹
self.output_edit = self._add_path_selector(layout, "输出XML文件夹")
# 置信度滑块
conf_layout = QHBoxLayout()
conf_layout.addWidget(QLabel("置信度阈值:"))
self.conf_slider = QSlider(Qt.Horizontal)
self.conf_slider.setRange(0, 100)
self.conf_slider.setValue(20)
self.conf_value = QLabel("0.2")
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
conf_layout.addWidget(self.conf_slider)
conf_layout.addWidget(self.conf_value)
layout.addLayout(conf_layout)
# 控制按钮
btn_layout = QHBoxLayout()
self.btn_load = QPushButton("加载模型")
self.btn_start = QPushButton("开始检测")
btn_layout.addWidget(self.btn_load)
btn_layout.addWidget(self.btn_start)
layout.addLayout(btn_layout)
# 两个进度条
self.progress_model = QProgressBar()
self.progress_model.setFormat("模型检测进度:%p%")
self.progress_xml = QProgressBar()
self.progress_xml.setFormat("XML生成进度%p%")
layout.addWidget(self.progress_model)
layout.addWidget(self.progress_xml)
# 日志输出
layout.addWidget(QLabel("日志输出:"))
self.log_output = QTextEdit()
self.log_output.setReadOnly(True)
layout.addWidget(self.log_output)
self.setLayout(layout)
# 信号绑定
self.btn_load.clicked.connect(self.load_model)
self.btn_start.clicked.connect(self.start_detection)
self.model_loaded = False
def _add_path_selector(self, layout, label_text):
hlayout = QHBoxLayout()
label = QLabel(label_text)
edit = QLineEdit()
btn = QPushButton("浏览")
hlayout.addWidget(label)
hlayout.addWidget(edit)
hlayout.addWidget(btn)
layout.addLayout(hlayout)
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
return edit
def _select_path(self, edit, title):
if "模型" in title:
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
else:
path = QFileDialog.getExistingDirectory(self, title)
if path:
edit.setText(path)
def load_model(self):
model_path = self.model_edit.text().strip()
if not os.path.exists(model_path):
self._log("❌ 模型文件不存在")
return
self.model_loaded = True
self._log("✅ 模型路径加载成功,可开始检测")
def start_detection(self):
if not self.model_loaded:
self._log("⚠️ 请先加载模型路径")
return
model_path = self.model_edit.text().strip()
input_dir = self.input_edit.text().strip()
output_xml_dir = self.output_edit.text().strip()
conf = float(self.conf_slider.value() / 100)
classes = [11] # 可修改为自定义选择
if not os.path.exists(input_dir):
self._log("❌ 输入文件夹不存在")
return
self.progress_model.setValue(0)
self.progress_xml.setValue(0)
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, conf, classes)
self.thread.update_model_progress.connect(self.progress_model.setValue)
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
self.thread.log_message.connect(self._log)
self.thread.finished_signal.connect(self._finish)
self.thread.start()
self.btn_start.setEnabled(False)
self._log("开始执行检测任务...")
def _log(self, msg):
self.log_output.append(msg)
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
def _finish(self):
self.btn_start.setEnabled(True)
self._log("任务完成 ✅")
# ========== 主程序入口 ==========
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())