First
This commit is contained in:
5
.idea/.gitignore
generated
vendored
Normal file
5
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
10
.idea/AutoAnno.iml
generated
Normal file
10
.idea/AutoAnno.iml
generated
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.11 (AutoAnno)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (AutoAnno)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (AutoAnno)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/AutoAnno.iml" filepath="$PROJECT_DIR$/.idea/AutoAnno.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
129
autoanno.py
Normal file
129
autoanno.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import cv2
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
from ultralytics import YOLO
|
||||
|
||||
# ========== 配置部分 ==========
|
||||
model_path = "epoch220.pt"
|
||||
input_dir = "test_images/"
|
||||
output_xml_dir = "annotations/"
|
||||
temp_dir = "temp_rgb_images"
|
||||
|
||||
|
||||
# ===== 自定义类别名称映射 =====
|
||||
class_mapping = {
|
||||
0: "EM14",
|
||||
1: "EM18",
|
||||
2: "EM17",
|
||||
3: "EM170",
|
||||
4: "EM19",
|
||||
5: "EM190",
|
||||
6: "EM20",
|
||||
7: "EM200",
|
||||
8: "EM201",
|
||||
9: "EM202",
|
||||
10: "EM203",
|
||||
11: "EM180",
|
||||
12: "EM181"
|
||||
}
|
||||
|
||||
|
||||
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
os.makedirs(output_xml_dir, exist_ok=True)
|
||||
|
||||
# ========== 加载模型 ==========
|
||||
model = YOLO(model_path)
|
||||
|
||||
|
||||
|
||||
|
||||
# ========== 灰度转RGB ==========
|
||||
for file in os.listdir(input_dir):
|
||||
if not file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
|
||||
continue
|
||||
|
||||
input_path = os.path.join(input_dir, file)
|
||||
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
temp_path = os.path.join(temp_dir, file)
|
||||
cv2.imwrite(temp_path, img)
|
||||
|
||||
# ========== 执行检测 ==========
|
||||
results = model.predict(source=temp_dir, conf=0.2, save=False,classes=[11])
|
||||
|
||||
# ========== 生成标准化VOC XML ==========
|
||||
for result in results:
|
||||
file_name = os.path.basename(result.path)
|
||||
image_path = os.path.join(input_dir, file_name)
|
||||
img = cv2.imread(image_path)
|
||||
height, width, depth = img.shape
|
||||
|
||||
annotation = ET.Element("annotation")
|
||||
|
||||
folder = ET.SubElement(annotation, "folder")
|
||||
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||
|
||||
filename = ET.SubElement(annotation, "filename")
|
||||
filename.text = file_name
|
||||
|
||||
path = ET.SubElement(annotation, "path")
|
||||
path.text = os.path.abspath(image_path)
|
||||
|
||||
source = ET.SubElement(annotation, "source")
|
||||
database = ET.SubElement(source, "database")
|
||||
database.text = "Unknown"
|
||||
|
||||
size = ET.SubElement(annotation, "size")
|
||||
ET.SubElement(size, "width").text = str(width)
|
||||
ET.SubElement(size, "height").text = str(height)
|
||||
ET.SubElement(size, "depth").text = str(depth)
|
||||
|
||||
segmented = ET.SubElement(annotation, "segmented")
|
||||
segmented.text = "0"
|
||||
|
||||
# 遍历检测框
|
||||
for box in result.boxes:
|
||||
cls = int(box.cls[0])
|
||||
xyxy = box.xyxy[0].tolist()
|
||||
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||
|
||||
obj = ET.SubElement(annotation, "object")
|
||||
|
||||
name = ET.SubElement(obj, "name")
|
||||
name.text = class_mapping.get(cls, str(cls))
|
||||
|
||||
|
||||
pose = ET.SubElement(obj, "pose")
|
||||
pose.text = "Unspecified"
|
||||
|
||||
truncated = ET.SubElement(obj, "truncated")
|
||||
truncated.text = "0"
|
||||
|
||||
difficult = ET.SubElement(obj, "difficult")
|
||||
difficult.text = "0"
|
||||
|
||||
bndbox = ET.SubElement(obj, "bndbox")
|
||||
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||
|
||||
# ===== 用 minidom 格式化输出 =====
|
||||
rough_string = ET.tostring(annotation, 'utf-8')
|
||||
reparsed = minidom.parseString(rough_string)
|
||||
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||
|
||||
# 去掉多余的第一行空行
|
||||
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||
|
||||
# 保存文件
|
||||
xml_path = os.path.join(output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||
with open(xml_path, "w", encoding="utf-8") as f:
|
||||
f.write(xml_str)
|
||||
|
||||
print(f"[生成完成] {xml_path}")
|
||||
279
autoannogui.py
Normal file
279
autoannogui.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
from ultralytics import YOLO
|
||||
from PyQt5.QtWidgets import (
|
||||
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
|
||||
|
||||
# ========== 后台任务线程 ==========
|
||||
class DetectionThread(QThread):
|
||||
update_model_progress = pyqtSignal(int)
|
||||
update_xml_progress = pyqtSignal(int)
|
||||
log_message = pyqtSignal(str)
|
||||
finished_signal = pyqtSignal()
|
||||
|
||||
def __init__(self, model_path, input_dir, output_xml_dir, conf, classes):
|
||||
super().__init__()
|
||||
self.model_path = model_path
|
||||
self.input_dir = input_dir
|
||||
self.output_xml_dir = output_xml_dir
|
||||
self.conf = conf
|
||||
self.classes = classes
|
||||
|
||||
# 类别映射(可修改)
|
||||
self.class_mapping = {
|
||||
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||
10: "EM203", 11: "EM180", 12: "EM181"
|
||||
}
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
os.makedirs(self.output_xml_dir, exist_ok=True)
|
||||
temp_dir = os.path.join(self.output_xml_dir, "temp_rgb_images")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# 加载模型
|
||||
self.log_message.emit(f"加载模型:{self.model_path}")
|
||||
model = YOLO(self.model_path)
|
||||
self.log_message.emit("模型加载完成 ✅")
|
||||
|
||||
# 读取图片列表
|
||||
imgs = [f for f in os.listdir(self.input_dir)
|
||||
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
|
||||
total_imgs = len(imgs)
|
||||
if total_imgs == 0:
|
||||
self.log_message.emit("输入目录中未找到图片 ❌")
|
||||
self.finished_signal.emit()
|
||||
return
|
||||
|
||||
# 灰度转RGB并保存临时文件
|
||||
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
|
||||
for img_name in imgs:
|
||||
img_path = os.path.join(self.input_dir, img_name)
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
if img is None:
|
||||
continue
|
||||
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
cv2.imwrite(os.path.join(temp_dir, img_name), img)
|
||||
|
||||
# ========= 模型检测阶段 =========
|
||||
self.log_message.emit("开始执行YOLO检测...")
|
||||
results = []
|
||||
for i, file in enumerate(imgs, 1):
|
||||
temp_path = os.path.join(temp_dir, file)
|
||||
result = model.predict(
|
||||
source=temp_path, conf=self.conf, save=False, classes=self.classes
|
||||
)
|
||||
results.extend(result)
|
||||
progress = int(i / total_imgs * 100)
|
||||
self.update_model_progress.emit(progress)
|
||||
self.log_message.emit(f"[检测完成] {file}")
|
||||
|
||||
# ========= XML生成阶段 =========
|
||||
self.log_message.emit("开始生成VOC格式XML...")
|
||||
total_results = len(results)
|
||||
for j, result in enumerate(results, 1):
|
||||
file_name = os.path.basename(result.path)
|
||||
image_path = os.path.join(self.input_dir, file_name)
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
continue
|
||||
height, width, depth = img.shape
|
||||
|
||||
annotation = ET.Element("annotation")
|
||||
|
||||
folder = ET.SubElement(annotation, "folder")
|
||||
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||
|
||||
filename = ET.SubElement(annotation, "filename")
|
||||
filename.text = file_name
|
||||
|
||||
path = ET.SubElement(annotation, "path")
|
||||
path.text = os.path.abspath(image_path)
|
||||
|
||||
source = ET.SubElement(annotation, "source")
|
||||
database = ET.SubElement(source, "database")
|
||||
database.text = "Unknown"
|
||||
|
||||
size = ET.SubElement(annotation, "size")
|
||||
ET.SubElement(size, "width").text = str(width)
|
||||
ET.SubElement(size, "height").text = str(height)
|
||||
ET.SubElement(size, "depth").text = str(depth)
|
||||
|
||||
segmented = ET.SubElement(annotation, "segmented")
|
||||
segmented.text = "0"
|
||||
|
||||
for box in result.boxes:
|
||||
cls = int(box.cls[0])
|
||||
xyxy = box.xyxy[0].tolist()
|
||||
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||
|
||||
obj = ET.SubElement(annotation, "object")
|
||||
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
|
||||
ET.SubElement(obj, "pose").text = "Unspecified"
|
||||
ET.SubElement(obj, "truncated").text = "0"
|
||||
ET.SubElement(obj, "difficult").text = "0"
|
||||
|
||||
bndbox = ET.SubElement(obj, "bndbox")
|
||||
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||
|
||||
rough_string = ET.tostring(annotation, 'utf-8')
|
||||
reparsed = minidom.parseString(rough_string)
|
||||
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||
|
||||
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||
with open(xml_path, "w", encoding="utf-8") as f:
|
||||
f.write(xml_str)
|
||||
|
||||
progress = int(j / total_results * 100)
|
||||
self.update_xml_progress.emit(progress)
|
||||
self.log_message.emit(f"[生成完成] {xml_path}")
|
||||
|
||||
self.log_message.emit("✅ 所有任务完成!")
|
||||
except Exception as e:
|
||||
self.log_message.emit(f"❌ 错误: {e}")
|
||||
finally:
|
||||
self.finished_signal.emit()
|
||||
|
||||
|
||||
# ========== 主界面 ==========
|
||||
class MainWindow(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("YOLO 自动检测 + VOC生成工具")
|
||||
self.resize(700, 600)
|
||||
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 模型路径
|
||||
self.model_edit = self._add_path_selector(layout, "模型路径:")
|
||||
# 输入文件夹
|
||||
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
|
||||
# 输出文件夹
|
||||
self.output_edit = self._add_path_selector(layout, "输出XML文件夹:")
|
||||
|
||||
# 置信度滑块
|
||||
conf_layout = QHBoxLayout()
|
||||
conf_layout.addWidget(QLabel("置信度阈值:"))
|
||||
self.conf_slider = QSlider(Qt.Horizontal)
|
||||
self.conf_slider.setRange(0, 100)
|
||||
self.conf_slider.setValue(20)
|
||||
self.conf_value = QLabel("0.2")
|
||||
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
|
||||
conf_layout.addWidget(self.conf_slider)
|
||||
conf_layout.addWidget(self.conf_value)
|
||||
layout.addLayout(conf_layout)
|
||||
|
||||
# 控制按钮
|
||||
btn_layout = QHBoxLayout()
|
||||
self.btn_load = QPushButton("加载模型")
|
||||
self.btn_start = QPushButton("开始检测")
|
||||
btn_layout.addWidget(self.btn_load)
|
||||
btn_layout.addWidget(self.btn_start)
|
||||
layout.addLayout(btn_layout)
|
||||
|
||||
# 两个进度条
|
||||
self.progress_model = QProgressBar()
|
||||
self.progress_model.setFormat("模型检测进度:%p%")
|
||||
self.progress_xml = QProgressBar()
|
||||
self.progress_xml.setFormat("XML生成进度:%p%")
|
||||
layout.addWidget(self.progress_model)
|
||||
layout.addWidget(self.progress_xml)
|
||||
|
||||
# 日志输出
|
||||
layout.addWidget(QLabel("日志输出:"))
|
||||
self.log_output = QTextEdit()
|
||||
self.log_output.setReadOnly(True)
|
||||
layout.addWidget(self.log_output)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定
|
||||
self.btn_load.clicked.connect(self.load_model)
|
||||
self.btn_start.clicked.connect(self.start_detection)
|
||||
|
||||
self.model_loaded = False
|
||||
|
||||
def _add_path_selector(self, layout, label_text):
|
||||
hlayout = QHBoxLayout()
|
||||
label = QLabel(label_text)
|
||||
edit = QLineEdit()
|
||||
btn = QPushButton("浏览")
|
||||
hlayout.addWidget(label)
|
||||
hlayout.addWidget(edit)
|
||||
hlayout.addWidget(btn)
|
||||
layout.addLayout(hlayout)
|
||||
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
|
||||
return edit
|
||||
|
||||
def _select_path(self, edit, title):
|
||||
if "模型" in title:
|
||||
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
|
||||
else:
|
||||
path = QFileDialog.getExistingDirectory(self, title)
|
||||
if path:
|
||||
edit.setText(path)
|
||||
|
||||
def load_model(self):
|
||||
model_path = self.model_edit.text().strip()
|
||||
if not os.path.exists(model_path):
|
||||
self._log("❌ 模型文件不存在")
|
||||
return
|
||||
self.model_loaded = True
|
||||
self._log("✅ 模型路径加载成功,可开始检测")
|
||||
|
||||
def start_detection(self):
|
||||
if not self.model_loaded:
|
||||
self._log("⚠️ 请先加载模型路径")
|
||||
return
|
||||
|
||||
model_path = self.model_edit.text().strip()
|
||||
input_dir = self.input_edit.text().strip()
|
||||
output_xml_dir = self.output_edit.text().strip()
|
||||
conf = float(self.conf_slider.value() / 100)
|
||||
classes = [11] # 可修改为自定义选择
|
||||
|
||||
if not os.path.exists(input_dir):
|
||||
self._log("❌ 输入文件夹不存在")
|
||||
return
|
||||
|
||||
self.progress_model.setValue(0)
|
||||
self.progress_xml.setValue(0)
|
||||
|
||||
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, conf, classes)
|
||||
self.thread.update_model_progress.connect(self.progress_model.setValue)
|
||||
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
|
||||
self.thread.log_message.connect(self._log)
|
||||
self.thread.finished_signal.connect(self._finish)
|
||||
self.thread.start()
|
||||
|
||||
self.btn_start.setEnabled(False)
|
||||
self._log("开始执行检测任务...")
|
||||
|
||||
def _log(self, msg):
|
||||
self.log_output.append(msg)
|
||||
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
|
||||
|
||||
def _finish(self):
|
||||
self.btn_start.setEnabled(True)
|
||||
self._log("任务完成 ✅")
|
||||
|
||||
|
||||
# ========== 主程序入口 ==========
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
313
autoannoguinew.py
Normal file
313
autoannoguinew.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
from ultralytics import YOLO
|
||||
from PyQt5.QtWidgets import (
|
||||
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider,
|
||||
QCheckBox, QScrollArea, QGroupBox, QGridLayout
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
|
||||
|
||||
# ========== 后台任务线程 ==========
|
||||
class DetectionThread(QThread):
|
||||
update_rgb_progress = pyqtSignal(int)
|
||||
update_model_progress = pyqtSignal(int)
|
||||
update_xml_progress = pyqtSignal(int)
|
||||
log_message = pyqtSignal(str)
|
||||
finished_signal = pyqtSignal()
|
||||
|
||||
def __init__(self, model_path, input_dir, output_xml_dir, temp_dir, conf, classes):
|
||||
super().__init__()
|
||||
self.model_path = model_path
|
||||
self.input_dir = input_dir
|
||||
self.output_xml_dir = output_xml_dir
|
||||
self.temp_dir = temp_dir
|
||||
self.conf = conf
|
||||
self.classes = classes
|
||||
|
||||
# 类别映射
|
||||
self.class_mapping = {
|
||||
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||
10: "EM203", 11: "EM180", 12: "EM181"
|
||||
}
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
os.makedirs(self.output_xml_dir, exist_ok=True)
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
|
||||
# 加载模型
|
||||
self.log_message.emit(f"加载模型:{self.model_path}")
|
||||
model = YOLO(self.model_path)
|
||||
self.log_message.emit("模型加载完成 ✅")
|
||||
|
||||
# 读取图片
|
||||
imgs = [f for f in os.listdir(self.input_dir)
|
||||
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
|
||||
total_imgs = len(imgs)
|
||||
if total_imgs == 0:
|
||||
self.log_message.emit("输入目录中未找到图片 ❌")
|
||||
self.finished_signal.emit()
|
||||
return
|
||||
|
||||
# ========= 灰度转RGB阶段 =========
|
||||
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
|
||||
for i, img_name in enumerate(imgs, 1):
|
||||
img_path = os.path.join(self.input_dir, img_name)
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
if img is None:
|
||||
continue
|
||||
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
cv2.imwrite(os.path.join(self.temp_dir, img_name), img)
|
||||
self.update_rgb_progress.emit(int(i / total_imgs * 100))
|
||||
self.log_message.emit("RGB 转换完成 ✅")
|
||||
|
||||
# ========= 模型检测阶段 =========
|
||||
self.log_message.emit("开始执行YOLO检测...")
|
||||
results = []
|
||||
for i, file in enumerate(imgs, 1):
|
||||
temp_path = os.path.join(self.temp_dir, file)
|
||||
result = model.predict(
|
||||
source=temp_path, conf=self.conf, save=False,
|
||||
classes=self.classes if self.classes else None
|
||||
)
|
||||
results.extend(result)
|
||||
self.update_model_progress.emit(int(i / total_imgs * 100))
|
||||
self.log_message.emit(f"[检测完成] {file}")
|
||||
|
||||
# ========= XML生成阶段 =========
|
||||
self.log_message.emit("开始生成VOC格式XML...")
|
||||
total_results = len(results)
|
||||
for j, result in enumerate(results, 1):
|
||||
file_name = os.path.basename(result.path)
|
||||
image_path = os.path.join(self.input_dir, file_name)
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
continue
|
||||
height, width, depth = img.shape
|
||||
|
||||
annotation = ET.Element("annotation")
|
||||
|
||||
folder = ET.SubElement(annotation, "folder")
|
||||
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||
|
||||
filename = ET.SubElement(annotation, "filename")
|
||||
filename.text = file_name
|
||||
|
||||
path = ET.SubElement(annotation, "path")
|
||||
path.text = os.path.abspath(image_path)
|
||||
|
||||
source = ET.SubElement(annotation, "source")
|
||||
database = ET.SubElement(source, "database")
|
||||
database.text = "Unknown"
|
||||
|
||||
size = ET.SubElement(annotation, "size")
|
||||
ET.SubElement(size, "width").text = str(width)
|
||||
ET.SubElement(size, "height").text = str(height)
|
||||
ET.SubElement(size, "depth").text = str(depth)
|
||||
|
||||
segmented = ET.SubElement(annotation, "segmented")
|
||||
segmented.text = "0"
|
||||
|
||||
for box in result.boxes:
|
||||
cls = int(box.cls[0])
|
||||
xyxy = box.xyxy[0].tolist()
|
||||
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||
|
||||
obj = ET.SubElement(annotation, "object")
|
||||
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
|
||||
ET.SubElement(obj, "pose").text = "Unspecified"
|
||||
ET.SubElement(obj, "truncated").text = "0"
|
||||
ET.SubElement(obj, "difficult").text = "0"
|
||||
|
||||
bndbox = ET.SubElement(obj, "bndbox")
|
||||
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||
|
||||
xml_str = minidom.parseString(ET.tostring(annotation, 'utf-8')).toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||
|
||||
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||
with open(xml_path, "w", encoding="utf-8") as f:
|
||||
f.write(xml_str)
|
||||
|
||||
self.update_xml_progress.emit(int(j / total_results * 100))
|
||||
self.log_message.emit(f"[生成完成] {xml_path}")
|
||||
|
||||
self.log_message.emit("✅ 所有任务完成!")
|
||||
except Exception as e:
|
||||
self.log_message.emit(f"❌ 错误: {e}")
|
||||
finally:
|
||||
self.finished_signal.emit()
|
||||
|
||||
|
||||
# ========== 主界面 ==========
|
||||
class MainWindow(QWidget):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("YOLO 半自动标注工具(华东专用)")
|
||||
self.resize(750, 750)
|
||||
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 路径部分
|
||||
self.model_edit = self._add_path_selector(layout, "模型路径:")
|
||||
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
|
||||
self.output_edit = self._add_path_selector(layout, "输出XML文件夹:")
|
||||
self.temp_edit = self._add_path_selector(layout, "临时RGB图片文件夹:")
|
||||
|
||||
# 置信度
|
||||
conf_layout = QHBoxLayout()
|
||||
conf_layout.addWidget(QLabel("置信度阈值:"))
|
||||
self.conf_slider = QSlider(Qt.Horizontal)
|
||||
self.conf_slider.setRange(0, 100)
|
||||
self.conf_slider.setValue(20)
|
||||
self.conf_value = QLabel("0.2")
|
||||
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
|
||||
conf_layout.addWidget(self.conf_slider)
|
||||
conf_layout.addWidget(self.conf_value)
|
||||
layout.addLayout(conf_layout)
|
||||
|
||||
# 类别选择区域
|
||||
self.class_mapping = {
|
||||
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||
10: "EM203", 11: "EM180", 12: "EM181"
|
||||
}
|
||||
|
||||
layout.addWidget(QLabel("选择要检测的类别(可多选):"))
|
||||
self.class_checkboxes = {}
|
||||
class_group = QGroupBox()
|
||||
grid = QGridLayout()
|
||||
for idx, (key, name) in enumerate(self.class_mapping.items()):
|
||||
cb = QCheckBox(f"{key}: {name}")
|
||||
self.class_checkboxes[key] = cb
|
||||
grid.addWidget(cb, idx // 4, idx % 4)
|
||||
class_group.setLayout(grid)
|
||||
|
||||
scroll = QScrollArea()
|
||||
scroll.setWidget(class_group)
|
||||
scroll.setWidgetResizable(True)
|
||||
scroll.setFixedHeight(120)
|
||||
layout.addWidget(scroll)
|
||||
|
||||
# 控制按钮
|
||||
btn_layout = QHBoxLayout()
|
||||
self.btn_load = QPushButton("加载模型")
|
||||
self.btn_start = QPushButton("开始检测")
|
||||
btn_layout.addWidget(self.btn_load)
|
||||
btn_layout.addWidget(self.btn_start)
|
||||
layout.addLayout(btn_layout)
|
||||
|
||||
# 三个进度条
|
||||
self.progress_rgb = QProgressBar()
|
||||
self.progress_rgb.setFormat("RGB转换进度:%p%")
|
||||
self.progress_model = QProgressBar()
|
||||
self.progress_model.setFormat("模型检测进度:%p%")
|
||||
self.progress_xml = QProgressBar()
|
||||
self.progress_xml.setFormat("XML生成进度:%p%")
|
||||
layout.addWidget(self.progress_rgb)
|
||||
layout.addWidget(self.progress_model)
|
||||
layout.addWidget(self.progress_xml)
|
||||
|
||||
# 日志
|
||||
layout.addWidget(QLabel("日志输出:"))
|
||||
self.log_output = QTextEdit()
|
||||
self.log_output.setReadOnly(True)
|
||||
layout.addWidget(self.log_output)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定
|
||||
self.btn_load.clicked.connect(self.load_model)
|
||||
self.btn_start.clicked.connect(self.start_detection)
|
||||
self.model_loaded = False
|
||||
|
||||
def _add_path_selector(self, layout, label_text):
|
||||
hlayout = QHBoxLayout()
|
||||
label = QLabel(label_text)
|
||||
edit = QLineEdit()
|
||||
btn = QPushButton("浏览")
|
||||
hlayout.addWidget(label)
|
||||
hlayout.addWidget(edit)
|
||||
hlayout.addWidget(btn)
|
||||
layout.addLayout(hlayout)
|
||||
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
|
||||
return edit
|
||||
|
||||
def _select_path(self, edit, title):
|
||||
if "模型" in title:
|
||||
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
|
||||
else:
|
||||
path = QFileDialog.getExistingDirectory(self, title)
|
||||
if path:
|
||||
edit.setText(path)
|
||||
|
||||
def load_model(self):
|
||||
path = self.model_edit.text().strip()
|
||||
if not os.path.exists(path):
|
||||
self._log("❌ 模型文件不存在")
|
||||
return
|
||||
self.model_loaded = True
|
||||
self._log("✅ 模型路径加载成功,可开始检测")
|
||||
|
||||
def start_detection(self):
|
||||
if not self.model_loaded:
|
||||
self._log("⚠️ 请先加载模型路径")
|
||||
return
|
||||
|
||||
model_path = self.model_edit.text().strip()
|
||||
input_dir = self.input_edit.text().strip()
|
||||
output_xml_dir = self.output_edit.text().strip()
|
||||
temp_dir = self.temp_edit.text().strip()
|
||||
conf = float(self.conf_slider.value() / 100)
|
||||
|
||||
selected_classes = [cid for cid, cb in self.class_checkboxes.items() if cb.isChecked()]
|
||||
|
||||
self._log(f"选中类别:{selected_classes if selected_classes else '全部'}")
|
||||
|
||||
if not os.path.exists(input_dir):
|
||||
self._log("❌ 输入文件夹不存在")
|
||||
return
|
||||
if not temp_dir:
|
||||
self._log("⚠️ 请选择临时RGB图片文件夹")
|
||||
return
|
||||
|
||||
self.progress_rgb.setValue(0)
|
||||
self.progress_model.setValue(0)
|
||||
self.progress_xml.setValue(0)
|
||||
|
||||
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, temp_dir, conf, selected_classes)
|
||||
self.thread.update_rgb_progress.connect(self.progress_rgb.setValue)
|
||||
self.thread.update_model_progress.connect(self.progress_model.setValue)
|
||||
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
|
||||
self.thread.log_message.connect(self._log)
|
||||
self.thread.finished_signal.connect(self._finish)
|
||||
self.thread.start()
|
||||
|
||||
self.btn_start.setEnabled(False)
|
||||
self._log("开始执行检测任务...")
|
||||
|
||||
def _log(self, msg):
|
||||
self.log_output.append(msg)
|
||||
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
|
||||
|
||||
def _finish(self):
|
||||
self.btn_start.setEnabled(True)
|
||||
self._log("任务完成 ✅")
|
||||
|
||||
|
||||
# ========== 主程序入口 ==========
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
44
autoannoguinew.spec
Normal file
44
autoannoguinew.spec
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
|
||||
a = Analysis(
|
||||
['autoannoguinew.py'],
|
||||
pathex=['D:\PythonProject\AutoAnno\.venv\Lib\site-packages'],
|
||||
binaries=[],
|
||||
datas=[],
|
||||
hiddenimports=['os','cv2','ultralytics','PyQt5','sys','xml.etree.ElementTree','xml.dom'],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
noarchive=False,
|
||||
optimize=0,
|
||||
)
|
||||
pyz = PYZ(a.pure)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name='autoannoguinew',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
name='autoannoguinew',
|
||||
)
|
||||
BIN
epoch220(1).pt
Normal file
BIN
epoch220(1).pt
Normal file
Binary file not shown.
68
function-gpu.yaml
Normal file
68
function-gpu.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
metadata:
|
||||
name: emdetector
|
||||
namespace: cvat
|
||||
annotations:
|
||||
name: EM14 v1
|
||||
type: detector
|
||||
framework: pytorch
|
||||
spec: |
|
||||
[
|
||||
{ "id": 0, "name": "EM14", "type": "rectangle" },
|
||||
{ "id": 1, "name": "EM18", "type": "rectangle" },
|
||||
{ "id": 2, "name": "EM17", "type": "rectangle" },
|
||||
{ "id": 3, "name": "EM170", "type": "rectangle" },
|
||||
{ "id": 4, "name": "EM19", "type": "rectangle" },
|
||||
{ "id": 5, "name": "EM190", "type": "rectangle" },
|
||||
{ "id": 6, "name": "EM20", "type": "rectangle" },
|
||||
{ "id": 7, "name": "EM200", "type": "rectangle" },
|
||||
{ "id": 8, "name": "EM201", "type": "rectangle" },
|
||||
{ "id": 9, "name": "EM202", "type": "rectangle" },
|
||||
{ "id": 10, "name": "EM203", "type": "rectangle" }
|
||||
]
|
||||
|
||||
spec:
|
||||
description: 工位检测
|
||||
runtime: "python:3.9"
|
||||
handler: main:handler
|
||||
eventTimeout: 30s
|
||||
|
||||
build:
|
||||
image: cvat.pth.yolo8.emdetector:latest-gpu
|
||||
baseImage: python:3.9
|
||||
directives:
|
||||
preCopy:
|
||||
- kind: ENV
|
||||
value: DEBIAN_FRONTEND=noninteractive
|
||||
- kind: RUN
|
||||
value: apt-get update && apt-get install -y libgl1 libglib2.0-0 && apt-get clean
|
||||
- kind: RUN
|
||||
value: pip install ultralytics torch torchvision opencv-python-headless && pip cache purge
|
||||
|
||||
triggers:
|
||||
myHttpTrigger:
|
||||
numWorkers: 1
|
||||
kind: 'http'
|
||||
workerAvailabilityTimeoutMilliseconds: 10000
|
||||
attributes:
|
||||
# Set value from the calculation of tracking of 100 objects at the same time on a 4k image
|
||||
maxRequestBodySize: 268435456 # 256MB
|
||||
|
||||
volumes:
|
||||
- volume:
|
||||
name: model-volume
|
||||
hostPath:
|
||||
path: /DATA/wjl/cvat/models/best.pt # 可选:如果使用HostPath挂载模型
|
||||
volumeMount:
|
||||
name: model-volume
|
||||
mountPath: /opt/nuclio/best.pt
|
||||
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
|
||||
platform:
|
||||
attributes:
|
||||
restartPolicy:
|
||||
name: always
|
||||
maximumRetryCount: 3
|
||||
mountMode: volume
|
||||
38
localdetect.py
Normal file
38
localdetect.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 加载模型
|
||||
model = YOLO("best.pt")
|
||||
|
||||
# 输入文件夹路径
|
||||
input_dir = "test_images/"
|
||||
temp_dir = "temp_rgb_images"
|
||||
|
||||
# 创建临时文件夹用于保存转换后的RGB图片
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# 扫描文件夹,将灰度图转换为RGB
|
||||
for file in os.listdir(input_dir):
|
||||
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
|
||||
path = os.path.join(input_dir, file)
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# 如果是灰度图,转换为3通道
|
||||
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
# 保存到临时目录
|
||||
cv2.imwrite(os.path.join(temp_dir, file), img)
|
||||
|
||||
# 使用转换后的图像文件夹进行检测
|
||||
results = model.predict(source=temp_dir, conf=0.2, save=True)
|
||||
|
||||
# 输出检测信息
|
||||
for result in results:
|
||||
boxes = result.boxes # 检测框
|
||||
for box in boxes:
|
||||
cls = int(box.cls[0])
|
||||
conf = float(box.conf[0])
|
||||
xyxy = box.xyxy[0].tolist()
|
||||
print(f"类别: {model.names[cls]}, 置信度: {conf:.2f}, 坐标: {xyxy}")
|
||||
24
main.py
Normal file
24
main.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import json
|
||||
import base64
|
||||
# from PIL import Image
|
||||
import io
|
||||
from model_handler import ModelHandler
|
||||
|
||||
def init_context(context):
|
||||
context.logger.info("Init context... 0%")
|
||||
|
||||
context.logger.info("Initializing EMDetection model...")
|
||||
context.user_data.model_handler = ModelHandler()
|
||||
|
||||
context.logger.info("Init context...100%")
|
||||
|
||||
def handler(context, event):
|
||||
context.logger.info("Run EMDetection model")
|
||||
data = event.body
|
||||
image_data = base64.b64decode(data["image"])
|
||||
threshold = float(data.get("threshold", 0.5))
|
||||
|
||||
results = context.user_data.model_handler.infer(image_data, threshold)
|
||||
|
||||
return context.Response(body=json.dumps(results), headers={},
|
||||
content_type='application/json', status_code=200)
|
||||
33
model_handler.py
Normal file
33
model_handler.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import io
|
||||
from PIL import Image
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
class ModelHandler:
|
||||
def __init__(self):
|
||||
"""加载 YOLOv11 模型"""
|
||||
self.model = YOLO("/opt/nuclio/best.pt") # 确保路径正确
|
||||
|
||||
def infer(self, image_data, threshold=0.3):
|
||||
"""
|
||||
执行推理
|
||||
:param image_data: 图片的二进制数据
|
||||
:param threshold: 置信度阈值(默认0.3)
|
||||
:return: 符合阈值的检测结果
|
||||
"""
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
results = self.model(image)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
for box in result.boxes.data.tolist():
|
||||
x1, y1, x2, y2, score, class_id = box
|
||||
if score >= threshold: # 过滤低置信度目标
|
||||
detections.append({
|
||||
"confidence": score,
|
||||
"label": self.model.names[int(class_id)],
|
||||
"points": [x1, y1, x2, y2],
|
||||
"type": "rectangle",
|
||||
})
|
||||
|
||||
return detections
|
||||
16
runanno.py
Normal file
16
runanno.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
# 获取当前脚本所在的目录(即 AutoAnno 目录)
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 组装相对路径
|
||||
python_exe = os.path.join(base_dir, ".venv", "Scripts", "python.exe")
|
||||
script_path = os.path.join(base_dir, "autoannoguinew.py")
|
||||
|
||||
# 后台静默执行
|
||||
subprocess.Popen(
|
||||
[python_exe, script_path],
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
cwd=base_dir
|
||||
)
|
||||
Reference in New Issue
Block a user