Files
autoanno/autoannoguinew.py
wangjialiang e0626adfb6 First
2025-11-12 17:04:47 +08:00

314 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import sys
import xml.etree.ElementTree as ET
from xml.dom import minidom
from ultralytics import YOLO
from PyQt5.QtWidgets import (
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider,
QCheckBox, QScrollArea, QGroupBox, QGridLayout
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
# ========== 后台任务线程 ==========
class DetectionThread(QThread):
update_rgb_progress = pyqtSignal(int)
update_model_progress = pyqtSignal(int)
update_xml_progress = pyqtSignal(int)
log_message = pyqtSignal(str)
finished_signal = pyqtSignal()
def __init__(self, model_path, input_dir, output_xml_dir, temp_dir, conf, classes):
super().__init__()
self.model_path = model_path
self.input_dir = input_dir
self.output_xml_dir = output_xml_dir
self.temp_dir = temp_dir
self.conf = conf
self.classes = classes
# 类别映射
self.class_mapping = {
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
10: "EM203", 11: "EM180", 12: "EM181"
}
def run(self):
try:
os.makedirs(self.output_xml_dir, exist_ok=True)
os.makedirs(self.temp_dir, exist_ok=True)
# 加载模型
self.log_message.emit(f"加载模型:{self.model_path}")
model = YOLO(self.model_path)
self.log_message.emit("模型加载完成 ✅")
# 读取图片
imgs = [f for f in os.listdir(self.input_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
total_imgs = len(imgs)
if total_imgs == 0:
self.log_message.emit("输入目录中未找到图片 ❌")
self.finished_signal.emit()
return
# ========= 灰度转RGB阶段 =========
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
for i, img_name in enumerate(imgs, 1):
img_path = os.path.join(self.input_dir, img_name)
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if img is None:
continue
if len(img.shape) == 2 or img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
cv2.imwrite(os.path.join(self.temp_dir, img_name), img)
self.update_rgb_progress.emit(int(i / total_imgs * 100))
self.log_message.emit("RGB 转换完成 ✅")
# ========= 模型检测阶段 =========
self.log_message.emit("开始执行YOLO检测...")
results = []
for i, file in enumerate(imgs, 1):
temp_path = os.path.join(self.temp_dir, file)
result = model.predict(
source=temp_path, conf=self.conf, save=False,
classes=self.classes if self.classes else None
)
results.extend(result)
self.update_model_progress.emit(int(i / total_imgs * 100))
self.log_message.emit(f"[检测完成] {file}")
# ========= XML生成阶段 =========
self.log_message.emit("开始生成VOC格式XML...")
total_results = len(results)
for j, result in enumerate(results, 1):
file_name = os.path.basename(result.path)
image_path = os.path.join(self.input_dir, file_name)
img = cv2.imread(image_path)
if img is None:
continue
height, width, depth = img.shape
annotation = ET.Element("annotation")
folder = ET.SubElement(annotation, "folder")
folder.text = os.path.basename(os.path.dirname(image_path))
filename = ET.SubElement(annotation, "filename")
filename.text = file_name
path = ET.SubElement(annotation, "path")
path.text = os.path.abspath(image_path)
source = ET.SubElement(annotation, "source")
database = ET.SubElement(source, "database")
database.text = "Unknown"
size = ET.SubElement(annotation, "size")
ET.SubElement(size, "width").text = str(width)
ET.SubElement(size, "height").text = str(height)
ET.SubElement(size, "depth").text = str(depth)
segmented = ET.SubElement(annotation, "segmented")
segmented.text = "0"
for box in result.boxes:
cls = int(box.cls[0])
xyxy = box.xyxy[0].tolist()
xmin, ymin, xmax, ymax = map(int, xyxy)
obj = ET.SubElement(annotation, "object")
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
ET.SubElement(obj, "pose").text = "Unspecified"
ET.SubElement(obj, "truncated").text = "0"
ET.SubElement(obj, "difficult").text = "0"
bndbox = ET.SubElement(obj, "bndbox")
ET.SubElement(bndbox, "xmin").text = str(xmin)
ET.SubElement(bndbox, "ymin").text = str(ymin)
ET.SubElement(bndbox, "xmax").text = str(xmax)
ET.SubElement(bndbox, "ymax").text = str(ymax)
xml_str = minidom.parseString(ET.tostring(annotation, 'utf-8')).toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
with open(xml_path, "w", encoding="utf-8") as f:
f.write(xml_str)
self.update_xml_progress.emit(int(j / total_results * 100))
self.log_message.emit(f"[生成完成] {xml_path}")
self.log_message.emit("✅ 所有任务完成!")
except Exception as e:
self.log_message.emit(f"❌ 错误: {e}")
finally:
self.finished_signal.emit()
# ========== 主界面 ==========
class MainWindow(QWidget):
def __init__(self):
super().__init__()
self.setWindowTitle("YOLO 半自动标注工具(华东专用)")
self.resize(750, 750)
layout = QVBoxLayout()
# 路径部分
self.model_edit = self._add_path_selector(layout, "模型路径:")
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
self.output_edit = self._add_path_selector(layout, "输出XML文件夹")
self.temp_edit = self._add_path_selector(layout, "临时RGB图片文件夹")
# 置信度
conf_layout = QHBoxLayout()
conf_layout.addWidget(QLabel("置信度阈值:"))
self.conf_slider = QSlider(Qt.Horizontal)
self.conf_slider.setRange(0, 100)
self.conf_slider.setValue(20)
self.conf_value = QLabel("0.2")
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
conf_layout.addWidget(self.conf_slider)
conf_layout.addWidget(self.conf_value)
layout.addLayout(conf_layout)
# 类别选择区域
self.class_mapping = {
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
10: "EM203", 11: "EM180", 12: "EM181"
}
layout.addWidget(QLabel("选择要检测的类别(可多选):"))
self.class_checkboxes = {}
class_group = QGroupBox()
grid = QGridLayout()
for idx, (key, name) in enumerate(self.class_mapping.items()):
cb = QCheckBox(f"{key}: {name}")
self.class_checkboxes[key] = cb
grid.addWidget(cb, idx // 4, idx % 4)
class_group.setLayout(grid)
scroll = QScrollArea()
scroll.setWidget(class_group)
scroll.setWidgetResizable(True)
scroll.setFixedHeight(120)
layout.addWidget(scroll)
# 控制按钮
btn_layout = QHBoxLayout()
self.btn_load = QPushButton("加载模型")
self.btn_start = QPushButton("开始检测")
btn_layout.addWidget(self.btn_load)
btn_layout.addWidget(self.btn_start)
layout.addLayout(btn_layout)
# 三个进度条
self.progress_rgb = QProgressBar()
self.progress_rgb.setFormat("RGB转换进度%p%")
self.progress_model = QProgressBar()
self.progress_model.setFormat("模型检测进度:%p%")
self.progress_xml = QProgressBar()
self.progress_xml.setFormat("XML生成进度%p%")
layout.addWidget(self.progress_rgb)
layout.addWidget(self.progress_model)
layout.addWidget(self.progress_xml)
# 日志
layout.addWidget(QLabel("日志输出:"))
self.log_output = QTextEdit()
self.log_output.setReadOnly(True)
layout.addWidget(self.log_output)
self.setLayout(layout)
# 信号绑定
self.btn_load.clicked.connect(self.load_model)
self.btn_start.clicked.connect(self.start_detection)
self.model_loaded = False
def _add_path_selector(self, layout, label_text):
hlayout = QHBoxLayout()
label = QLabel(label_text)
edit = QLineEdit()
btn = QPushButton("浏览")
hlayout.addWidget(label)
hlayout.addWidget(edit)
hlayout.addWidget(btn)
layout.addLayout(hlayout)
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
return edit
def _select_path(self, edit, title):
if "模型" in title:
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
else:
path = QFileDialog.getExistingDirectory(self, title)
if path:
edit.setText(path)
def load_model(self):
path = self.model_edit.text().strip()
if not os.path.exists(path):
self._log("❌ 模型文件不存在")
return
self.model_loaded = True
self._log("✅ 模型路径加载成功,可开始检测")
def start_detection(self):
if not self.model_loaded:
self._log("⚠️ 请先加载模型路径")
return
model_path = self.model_edit.text().strip()
input_dir = self.input_edit.text().strip()
output_xml_dir = self.output_edit.text().strip()
temp_dir = self.temp_edit.text().strip()
conf = float(self.conf_slider.value() / 100)
selected_classes = [cid for cid, cb in self.class_checkboxes.items() if cb.isChecked()]
self._log(f"选中类别:{selected_classes if selected_classes else '全部'}")
if not os.path.exists(input_dir):
self._log("❌ 输入文件夹不存在")
return
if not temp_dir:
self._log("⚠️ 请选择临时RGB图片文件夹")
return
self.progress_rgb.setValue(0)
self.progress_model.setValue(0)
self.progress_xml.setValue(0)
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, temp_dir, conf, selected_classes)
self.thread.update_rgb_progress.connect(self.progress_rgb.setValue)
self.thread.update_model_progress.connect(self.progress_model.setValue)
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
self.thread.log_message.connect(self._log)
self.thread.finished_signal.connect(self._finish)
self.thread.start()
self.btn_start.setEnabled(False)
self._log("开始执行检测任务...")
def _log(self, msg):
self.log_output.append(msg)
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
def _finish(self):
self.btn_start.setEnabled(True)
self._log("任务完成 ✅")
# ========== 主程序入口 ==========
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())