280 lines
11 KiB
Python
280 lines
11 KiB
Python
import os
|
||
import cv2
|
||
import sys
|
||
import xml.etree.ElementTree as ET
|
||
from xml.dom import minidom
|
||
from ultralytics import YOLO
|
||
from PyQt5.QtWidgets import (
|
||
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
||
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider
|
||
)
|
||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||
|
||
|
||
# ========== 后台任务线程 ==========
|
||
class DetectionThread(QThread):
|
||
update_model_progress = pyqtSignal(int)
|
||
update_xml_progress = pyqtSignal(int)
|
||
log_message = pyqtSignal(str)
|
||
finished_signal = pyqtSignal()
|
||
|
||
def __init__(self, model_path, input_dir, output_xml_dir, conf, classes):
|
||
super().__init__()
|
||
self.model_path = model_path
|
||
self.input_dir = input_dir
|
||
self.output_xml_dir = output_xml_dir
|
||
self.conf = conf
|
||
self.classes = classes
|
||
|
||
# 类别映射(可修改)
|
||
self.class_mapping = {
|
||
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||
10: "EM203", 11: "EM180", 12: "EM181"
|
||
}
|
||
|
||
def run(self):
|
||
try:
|
||
os.makedirs(self.output_xml_dir, exist_ok=True)
|
||
temp_dir = os.path.join(self.output_xml_dir, "temp_rgb_images")
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
|
||
# 加载模型
|
||
self.log_message.emit(f"加载模型:{self.model_path}")
|
||
model = YOLO(self.model_path)
|
||
self.log_message.emit("模型加载完成 ✅")
|
||
|
||
# 读取图片列表
|
||
imgs = [f for f in os.listdir(self.input_dir)
|
||
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
|
||
total_imgs = len(imgs)
|
||
if total_imgs == 0:
|
||
self.log_message.emit("输入目录中未找到图片 ❌")
|
||
self.finished_signal.emit()
|
||
return
|
||
|
||
# 灰度转RGB并保存临时文件
|
||
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
|
||
for img_name in imgs:
|
||
img_path = os.path.join(self.input_dir, img_name)
|
||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||
if img is None:
|
||
continue
|
||
if len(img.shape) == 2 or img.shape[2] == 1:
|
||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||
cv2.imwrite(os.path.join(temp_dir, img_name), img)
|
||
|
||
# ========= 模型检测阶段 =========
|
||
self.log_message.emit("开始执行YOLO检测...")
|
||
results = []
|
||
for i, file in enumerate(imgs, 1):
|
||
temp_path = os.path.join(temp_dir, file)
|
||
result = model.predict(
|
||
source=temp_path, conf=self.conf, save=False, classes=self.classes
|
||
)
|
||
results.extend(result)
|
||
progress = int(i / total_imgs * 100)
|
||
self.update_model_progress.emit(progress)
|
||
self.log_message.emit(f"[检测完成] {file}")
|
||
|
||
# ========= XML生成阶段 =========
|
||
self.log_message.emit("开始生成VOC格式XML...")
|
||
total_results = len(results)
|
||
for j, result in enumerate(results, 1):
|
||
file_name = os.path.basename(result.path)
|
||
image_path = os.path.join(self.input_dir, file_name)
|
||
img = cv2.imread(image_path)
|
||
if img is None:
|
||
continue
|
||
height, width, depth = img.shape
|
||
|
||
annotation = ET.Element("annotation")
|
||
|
||
folder = ET.SubElement(annotation, "folder")
|
||
folder.text = os.path.basename(os.path.dirname(image_path))
|
||
|
||
filename = ET.SubElement(annotation, "filename")
|
||
filename.text = file_name
|
||
|
||
path = ET.SubElement(annotation, "path")
|
||
path.text = os.path.abspath(image_path)
|
||
|
||
source = ET.SubElement(annotation, "source")
|
||
database = ET.SubElement(source, "database")
|
||
database.text = "Unknown"
|
||
|
||
size = ET.SubElement(annotation, "size")
|
||
ET.SubElement(size, "width").text = str(width)
|
||
ET.SubElement(size, "height").text = str(height)
|
||
ET.SubElement(size, "depth").text = str(depth)
|
||
|
||
segmented = ET.SubElement(annotation, "segmented")
|
||
segmented.text = "0"
|
||
|
||
for box in result.boxes:
|
||
cls = int(box.cls[0])
|
||
xyxy = box.xyxy[0].tolist()
|
||
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||
|
||
obj = ET.SubElement(annotation, "object")
|
||
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
|
||
ET.SubElement(obj, "pose").text = "Unspecified"
|
||
ET.SubElement(obj, "truncated").text = "0"
|
||
ET.SubElement(obj, "difficult").text = "0"
|
||
|
||
bndbox = ET.SubElement(obj, "bndbox")
|
||
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||
|
||
rough_string = ET.tostring(annotation, 'utf-8')
|
||
reparsed = minidom.parseString(rough_string)
|
||
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||
|
||
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||
with open(xml_path, "w", encoding="utf-8") as f:
|
||
f.write(xml_str)
|
||
|
||
progress = int(j / total_results * 100)
|
||
self.update_xml_progress.emit(progress)
|
||
self.log_message.emit(f"[生成完成] {xml_path}")
|
||
|
||
self.log_message.emit("✅ 所有任务完成!")
|
||
except Exception as e:
|
||
self.log_message.emit(f"❌ 错误: {e}")
|
||
finally:
|
||
self.finished_signal.emit()
|
||
|
||
|
||
# ========== 主界面 ==========
|
||
class MainWindow(QWidget):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("YOLO 自动检测 + VOC生成工具")
|
||
self.resize(700, 600)
|
||
|
||
layout = QVBoxLayout()
|
||
|
||
# 模型路径
|
||
self.model_edit = self._add_path_selector(layout, "模型路径:")
|
||
# 输入文件夹
|
||
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
|
||
# 输出文件夹
|
||
self.output_edit = self._add_path_selector(layout, "输出XML文件夹:")
|
||
|
||
# 置信度滑块
|
||
conf_layout = QHBoxLayout()
|
||
conf_layout.addWidget(QLabel("置信度阈值:"))
|
||
self.conf_slider = QSlider(Qt.Horizontal)
|
||
self.conf_slider.setRange(0, 100)
|
||
self.conf_slider.setValue(20)
|
||
self.conf_value = QLabel("0.2")
|
||
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
|
||
conf_layout.addWidget(self.conf_slider)
|
||
conf_layout.addWidget(self.conf_value)
|
||
layout.addLayout(conf_layout)
|
||
|
||
# 控制按钮
|
||
btn_layout = QHBoxLayout()
|
||
self.btn_load = QPushButton("加载模型")
|
||
self.btn_start = QPushButton("开始检测")
|
||
btn_layout.addWidget(self.btn_load)
|
||
btn_layout.addWidget(self.btn_start)
|
||
layout.addLayout(btn_layout)
|
||
|
||
# 两个进度条
|
||
self.progress_model = QProgressBar()
|
||
self.progress_model.setFormat("模型检测进度:%p%")
|
||
self.progress_xml = QProgressBar()
|
||
self.progress_xml.setFormat("XML生成进度:%p%")
|
||
layout.addWidget(self.progress_model)
|
||
layout.addWidget(self.progress_xml)
|
||
|
||
# 日志输出
|
||
layout.addWidget(QLabel("日志输出:"))
|
||
self.log_output = QTextEdit()
|
||
self.log_output.setReadOnly(True)
|
||
layout.addWidget(self.log_output)
|
||
|
||
self.setLayout(layout)
|
||
|
||
# 信号绑定
|
||
self.btn_load.clicked.connect(self.load_model)
|
||
self.btn_start.clicked.connect(self.start_detection)
|
||
|
||
self.model_loaded = False
|
||
|
||
def _add_path_selector(self, layout, label_text):
|
||
hlayout = QHBoxLayout()
|
||
label = QLabel(label_text)
|
||
edit = QLineEdit()
|
||
btn = QPushButton("浏览")
|
||
hlayout.addWidget(label)
|
||
hlayout.addWidget(edit)
|
||
hlayout.addWidget(btn)
|
||
layout.addLayout(hlayout)
|
||
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
|
||
return edit
|
||
|
||
def _select_path(self, edit, title):
|
||
if "模型" in title:
|
||
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
|
||
else:
|
||
path = QFileDialog.getExistingDirectory(self, title)
|
||
if path:
|
||
edit.setText(path)
|
||
|
||
def load_model(self):
|
||
model_path = self.model_edit.text().strip()
|
||
if not os.path.exists(model_path):
|
||
self._log("❌ 模型文件不存在")
|
||
return
|
||
self.model_loaded = True
|
||
self._log("✅ 模型路径加载成功,可开始检测")
|
||
|
||
def start_detection(self):
|
||
if not self.model_loaded:
|
||
self._log("⚠️ 请先加载模型路径")
|
||
return
|
||
|
||
model_path = self.model_edit.text().strip()
|
||
input_dir = self.input_edit.text().strip()
|
||
output_xml_dir = self.output_edit.text().strip()
|
||
conf = float(self.conf_slider.value() / 100)
|
||
classes = [11] # 可修改为自定义选择
|
||
|
||
if not os.path.exists(input_dir):
|
||
self._log("❌ 输入文件夹不存在")
|
||
return
|
||
|
||
self.progress_model.setValue(0)
|
||
self.progress_xml.setValue(0)
|
||
|
||
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, conf, classes)
|
||
self.thread.update_model_progress.connect(self.progress_model.setValue)
|
||
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
|
||
self.thread.log_message.connect(self._log)
|
||
self.thread.finished_signal.connect(self._finish)
|
||
self.thread.start()
|
||
|
||
self.btn_start.setEnabled(False)
|
||
self._log("开始执行检测任务...")
|
||
|
||
def _log(self, msg):
|
||
self.log_output.append(msg)
|
||
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
|
||
|
||
def _finish(self):
|
||
self.btn_start.setEnabled(True)
|
||
self._log("任务完成 ✅")
|
||
|
||
|
||
# ========== 主程序入口 ==========
|
||
if __name__ == "__main__":
|
||
app = QApplication(sys.argv)
|
||
window = MainWindow()
|
||
window.show()
|
||
sys.exit(app.exec_())
|