First
This commit is contained in:
5
.idea/.gitignore
generated
vendored
Normal file
5
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
10
.idea/AutoAnno.iml
generated
Normal file
10
.idea/AutoAnno.iml
generated
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.11 (AutoAnno)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.11 (AutoAnno)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (AutoAnno)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/AutoAnno.iml" filepath="$PROJECT_DIR$/.idea/AutoAnno.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
129
autoanno.py
Normal file
129
autoanno.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from xml.dom import minidom
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# ========== 配置部分 ==========
|
||||||
|
model_path = "epoch220.pt"
|
||||||
|
input_dir = "test_images/"
|
||||||
|
output_xml_dir = "annotations/"
|
||||||
|
temp_dir = "temp_rgb_images"
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 自定义类别名称映射 =====
|
||||||
|
class_mapping = {
|
||||||
|
0: "EM14",
|
||||||
|
1: "EM18",
|
||||||
|
2: "EM17",
|
||||||
|
3: "EM170",
|
||||||
|
4: "EM19",
|
||||||
|
5: "EM190",
|
||||||
|
6: "EM20",
|
||||||
|
7: "EM200",
|
||||||
|
8: "EM201",
|
||||||
|
9: "EM202",
|
||||||
|
10: "EM203",
|
||||||
|
11: "EM180",
|
||||||
|
12: "EM181"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
os.makedirs(output_xml_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# ========== 加载模型 ==========
|
||||||
|
model = YOLO(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 灰度转RGB ==========
|
||||||
|
for file in os.listdir(input_dir):
|
||||||
|
if not file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_path = os.path.join(input_dir, file)
|
||||||
|
img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
temp_path = os.path.join(temp_dir, file)
|
||||||
|
cv2.imwrite(temp_path, img)
|
||||||
|
|
||||||
|
# ========== 执行检测 ==========
|
||||||
|
results = model.predict(source=temp_dir, conf=0.2, save=False,classes=[11])
|
||||||
|
|
||||||
|
# ========== 生成标准化VOC XML ==========
|
||||||
|
for result in results:
|
||||||
|
file_name = os.path.basename(result.path)
|
||||||
|
image_path = os.path.join(input_dir, file_name)
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
height, width, depth = img.shape
|
||||||
|
|
||||||
|
annotation = ET.Element("annotation")
|
||||||
|
|
||||||
|
folder = ET.SubElement(annotation, "folder")
|
||||||
|
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||||
|
|
||||||
|
filename = ET.SubElement(annotation, "filename")
|
||||||
|
filename.text = file_name
|
||||||
|
|
||||||
|
path = ET.SubElement(annotation, "path")
|
||||||
|
path.text = os.path.abspath(image_path)
|
||||||
|
|
||||||
|
source = ET.SubElement(annotation, "source")
|
||||||
|
database = ET.SubElement(source, "database")
|
||||||
|
database.text = "Unknown"
|
||||||
|
|
||||||
|
size = ET.SubElement(annotation, "size")
|
||||||
|
ET.SubElement(size, "width").text = str(width)
|
||||||
|
ET.SubElement(size, "height").text = str(height)
|
||||||
|
ET.SubElement(size, "depth").text = str(depth)
|
||||||
|
|
||||||
|
segmented = ET.SubElement(annotation, "segmented")
|
||||||
|
segmented.text = "0"
|
||||||
|
|
||||||
|
# 遍历检测框
|
||||||
|
for box in result.boxes:
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
xyxy = box.xyxy[0].tolist()
|
||||||
|
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||||
|
|
||||||
|
obj = ET.SubElement(annotation, "object")
|
||||||
|
|
||||||
|
name = ET.SubElement(obj, "name")
|
||||||
|
name.text = class_mapping.get(cls, str(cls))
|
||||||
|
|
||||||
|
|
||||||
|
pose = ET.SubElement(obj, "pose")
|
||||||
|
pose.text = "Unspecified"
|
||||||
|
|
||||||
|
truncated = ET.SubElement(obj, "truncated")
|
||||||
|
truncated.text = "0"
|
||||||
|
|
||||||
|
difficult = ET.SubElement(obj, "difficult")
|
||||||
|
difficult.text = "0"
|
||||||
|
|
||||||
|
bndbox = ET.SubElement(obj, "bndbox")
|
||||||
|
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||||
|
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||||
|
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||||
|
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||||
|
|
||||||
|
# ===== 用 minidom 格式化输出 =====
|
||||||
|
rough_string = ET.tostring(annotation, 'utf-8')
|
||||||
|
reparsed = minidom.parseString(rough_string)
|
||||||
|
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||||
|
|
||||||
|
# 去掉多余的第一行空行
|
||||||
|
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
xml_path = os.path.join(output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||||
|
with open(xml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(xml_str)
|
||||||
|
|
||||||
|
print(f"[生成完成] {xml_path}")
|
||||||
279
autoannogui.py
Normal file
279
autoannogui.py
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import sys
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from xml.dom import minidom
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from PyQt5.QtWidgets import (
|
||||||
|
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||||
|
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider
|
||||||
|
)
|
||||||
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 后台任务线程 ==========
|
||||||
|
class DetectionThread(QThread):
|
||||||
|
update_model_progress = pyqtSignal(int)
|
||||||
|
update_xml_progress = pyqtSignal(int)
|
||||||
|
log_message = pyqtSignal(str)
|
||||||
|
finished_signal = pyqtSignal()
|
||||||
|
|
||||||
|
def __init__(self, model_path, input_dir, output_xml_dir, conf, classes):
|
||||||
|
super().__init__()
|
||||||
|
self.model_path = model_path
|
||||||
|
self.input_dir = input_dir
|
||||||
|
self.output_xml_dir = output_xml_dir
|
||||||
|
self.conf = conf
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
# 类别映射(可修改)
|
||||||
|
self.class_mapping = {
|
||||||
|
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||||
|
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||||
|
10: "EM203", 11: "EM180", 12: "EM181"
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
os.makedirs(self.output_xml_dir, exist_ok=True)
|
||||||
|
temp_dir = os.path.join(self.output_xml_dir, "temp_rgb_images")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.log_message.emit(f"加载模型:{self.model_path}")
|
||||||
|
model = YOLO(self.model_path)
|
||||||
|
self.log_message.emit("模型加载完成 ✅")
|
||||||
|
|
||||||
|
# 读取图片列表
|
||||||
|
imgs = [f for f in os.listdir(self.input_dir)
|
||||||
|
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
|
||||||
|
total_imgs = len(imgs)
|
||||||
|
if total_imgs == 0:
|
||||||
|
self.log_message.emit("输入目录中未找到图片 ❌")
|
||||||
|
self.finished_signal.emit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 灰度转RGB并保存临时文件
|
||||||
|
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
|
||||||
|
for img_name in imgs:
|
||||||
|
img_path = os.path.join(self.input_dir, img_name)
|
||||||
|
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if img is None:
|
||||||
|
continue
|
||||||
|
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
cv2.imwrite(os.path.join(temp_dir, img_name), img)
|
||||||
|
|
||||||
|
# ========= 模型检测阶段 =========
|
||||||
|
self.log_message.emit("开始执行YOLO检测...")
|
||||||
|
results = []
|
||||||
|
for i, file in enumerate(imgs, 1):
|
||||||
|
temp_path = os.path.join(temp_dir, file)
|
||||||
|
result = model.predict(
|
||||||
|
source=temp_path, conf=self.conf, save=False, classes=self.classes
|
||||||
|
)
|
||||||
|
results.extend(result)
|
||||||
|
progress = int(i / total_imgs * 100)
|
||||||
|
self.update_model_progress.emit(progress)
|
||||||
|
self.log_message.emit(f"[检测完成] {file}")
|
||||||
|
|
||||||
|
# ========= XML生成阶段 =========
|
||||||
|
self.log_message.emit("开始生成VOC格式XML...")
|
||||||
|
total_results = len(results)
|
||||||
|
for j, result in enumerate(results, 1):
|
||||||
|
file_name = os.path.basename(result.path)
|
||||||
|
image_path = os.path.join(self.input_dir, file_name)
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
continue
|
||||||
|
height, width, depth = img.shape
|
||||||
|
|
||||||
|
annotation = ET.Element("annotation")
|
||||||
|
|
||||||
|
folder = ET.SubElement(annotation, "folder")
|
||||||
|
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||||
|
|
||||||
|
filename = ET.SubElement(annotation, "filename")
|
||||||
|
filename.text = file_name
|
||||||
|
|
||||||
|
path = ET.SubElement(annotation, "path")
|
||||||
|
path.text = os.path.abspath(image_path)
|
||||||
|
|
||||||
|
source = ET.SubElement(annotation, "source")
|
||||||
|
database = ET.SubElement(source, "database")
|
||||||
|
database.text = "Unknown"
|
||||||
|
|
||||||
|
size = ET.SubElement(annotation, "size")
|
||||||
|
ET.SubElement(size, "width").text = str(width)
|
||||||
|
ET.SubElement(size, "height").text = str(height)
|
||||||
|
ET.SubElement(size, "depth").text = str(depth)
|
||||||
|
|
||||||
|
segmented = ET.SubElement(annotation, "segmented")
|
||||||
|
segmented.text = "0"
|
||||||
|
|
||||||
|
for box in result.boxes:
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
xyxy = box.xyxy[0].tolist()
|
||||||
|
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||||
|
|
||||||
|
obj = ET.SubElement(annotation, "object")
|
||||||
|
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
|
||||||
|
ET.SubElement(obj, "pose").text = "Unspecified"
|
||||||
|
ET.SubElement(obj, "truncated").text = "0"
|
||||||
|
ET.SubElement(obj, "difficult").text = "0"
|
||||||
|
|
||||||
|
bndbox = ET.SubElement(obj, "bndbox")
|
||||||
|
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||||
|
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||||
|
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||||
|
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||||
|
|
||||||
|
rough_string = ET.tostring(annotation, 'utf-8')
|
||||||
|
reparsed = minidom.parseString(rough_string)
|
||||||
|
xml_str = reparsed.toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||||
|
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||||
|
|
||||||
|
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||||
|
with open(xml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(xml_str)
|
||||||
|
|
||||||
|
progress = int(j / total_results * 100)
|
||||||
|
self.update_xml_progress.emit(progress)
|
||||||
|
self.log_message.emit(f"[生成完成] {xml_path}")
|
||||||
|
|
||||||
|
self.log_message.emit("✅ 所有任务完成!")
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message.emit(f"❌ 错误: {e}")
|
||||||
|
finally:
|
||||||
|
self.finished_signal.emit()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主界面 ==========
|
||||||
|
class MainWindow(QWidget):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("YOLO 自动检测 + VOC生成工具")
|
||||||
|
self.resize(700, 600)
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# 模型路径
|
||||||
|
self.model_edit = self._add_path_selector(layout, "模型路径:")
|
||||||
|
# 输入文件夹
|
||||||
|
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
|
||||||
|
# 输出文件夹
|
||||||
|
self.output_edit = self._add_path_selector(layout, "输出XML文件夹:")
|
||||||
|
|
||||||
|
# 置信度滑块
|
||||||
|
conf_layout = QHBoxLayout()
|
||||||
|
conf_layout.addWidget(QLabel("置信度阈值:"))
|
||||||
|
self.conf_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.conf_slider.setRange(0, 100)
|
||||||
|
self.conf_slider.setValue(20)
|
||||||
|
self.conf_value = QLabel("0.2")
|
||||||
|
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
|
||||||
|
conf_layout.addWidget(self.conf_slider)
|
||||||
|
conf_layout.addWidget(self.conf_value)
|
||||||
|
layout.addLayout(conf_layout)
|
||||||
|
|
||||||
|
# 控制按钮
|
||||||
|
btn_layout = QHBoxLayout()
|
||||||
|
self.btn_load = QPushButton("加载模型")
|
||||||
|
self.btn_start = QPushButton("开始检测")
|
||||||
|
btn_layout.addWidget(self.btn_load)
|
||||||
|
btn_layout.addWidget(self.btn_start)
|
||||||
|
layout.addLayout(btn_layout)
|
||||||
|
|
||||||
|
# 两个进度条
|
||||||
|
self.progress_model = QProgressBar()
|
||||||
|
self.progress_model.setFormat("模型检测进度:%p%")
|
||||||
|
self.progress_xml = QProgressBar()
|
||||||
|
self.progress_xml.setFormat("XML生成进度:%p%")
|
||||||
|
layout.addWidget(self.progress_model)
|
||||||
|
layout.addWidget(self.progress_xml)
|
||||||
|
|
||||||
|
# 日志输出
|
||||||
|
layout.addWidget(QLabel("日志输出:"))
|
||||||
|
self.log_output = QTextEdit()
|
||||||
|
self.log_output.setReadOnly(True)
|
||||||
|
layout.addWidget(self.log_output)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
# 信号绑定
|
||||||
|
self.btn_load.clicked.connect(self.load_model)
|
||||||
|
self.btn_start.clicked.connect(self.start_detection)
|
||||||
|
|
||||||
|
self.model_loaded = False
|
||||||
|
|
||||||
|
def _add_path_selector(self, layout, label_text):
|
||||||
|
hlayout = QHBoxLayout()
|
||||||
|
label = QLabel(label_text)
|
||||||
|
edit = QLineEdit()
|
||||||
|
btn = QPushButton("浏览")
|
||||||
|
hlayout.addWidget(label)
|
||||||
|
hlayout.addWidget(edit)
|
||||||
|
hlayout.addWidget(btn)
|
||||||
|
layout.addLayout(hlayout)
|
||||||
|
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
|
||||||
|
return edit
|
||||||
|
|
||||||
|
def _select_path(self, edit, title):
|
||||||
|
if "模型" in title:
|
||||||
|
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
|
||||||
|
else:
|
||||||
|
path = QFileDialog.getExistingDirectory(self, title)
|
||||||
|
if path:
|
||||||
|
edit.setText(path)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
model_path = self.model_edit.text().strip()
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
self._log("❌ 模型文件不存在")
|
||||||
|
return
|
||||||
|
self.model_loaded = True
|
||||||
|
self._log("✅ 模型路径加载成功,可开始检测")
|
||||||
|
|
||||||
|
def start_detection(self):
|
||||||
|
if not self.model_loaded:
|
||||||
|
self._log("⚠️ 请先加载模型路径")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = self.model_edit.text().strip()
|
||||||
|
input_dir = self.input_edit.text().strip()
|
||||||
|
output_xml_dir = self.output_edit.text().strip()
|
||||||
|
conf = float(self.conf_slider.value() / 100)
|
||||||
|
classes = [11] # 可修改为自定义选择
|
||||||
|
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
self._log("❌ 输入文件夹不存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.progress_model.setValue(0)
|
||||||
|
self.progress_xml.setValue(0)
|
||||||
|
|
||||||
|
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, conf, classes)
|
||||||
|
self.thread.update_model_progress.connect(self.progress_model.setValue)
|
||||||
|
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
|
||||||
|
self.thread.log_message.connect(self._log)
|
||||||
|
self.thread.finished_signal.connect(self._finish)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
self.btn_start.setEnabled(False)
|
||||||
|
self._log("开始执行检测任务...")
|
||||||
|
|
||||||
|
def _log(self, msg):
|
||||||
|
self.log_output.append(msg)
|
||||||
|
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
|
||||||
|
|
||||||
|
def _finish(self):
|
||||||
|
self.btn_start.setEnabled(True)
|
||||||
|
self._log("任务完成 ✅")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主程序入口 ==========
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
window = MainWindow()
|
||||||
|
window.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
313
autoannoguinew.py
Normal file
313
autoannoguinew.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import sys
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from xml.dom import minidom
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from PyQt5.QtWidgets import (
|
||||||
|
QApplication, QWidget, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||||
|
QLineEdit, QLabel, QFileDialog, QProgressBar, QTextEdit, QSlider,
|
||||||
|
QCheckBox, QScrollArea, QGroupBox, QGridLayout
|
||||||
|
)
|
||||||
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 后台任务线程 ==========
|
||||||
|
class DetectionThread(QThread):
|
||||||
|
update_rgb_progress = pyqtSignal(int)
|
||||||
|
update_model_progress = pyqtSignal(int)
|
||||||
|
update_xml_progress = pyqtSignal(int)
|
||||||
|
log_message = pyqtSignal(str)
|
||||||
|
finished_signal = pyqtSignal()
|
||||||
|
|
||||||
|
def __init__(self, model_path, input_dir, output_xml_dir, temp_dir, conf, classes):
|
||||||
|
super().__init__()
|
||||||
|
self.model_path = model_path
|
||||||
|
self.input_dir = input_dir
|
||||||
|
self.output_xml_dir = output_xml_dir
|
||||||
|
self.temp_dir = temp_dir
|
||||||
|
self.conf = conf
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
# 类别映射
|
||||||
|
self.class_mapping = {
|
||||||
|
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||||
|
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||||
|
10: "EM203", 11: "EM180", 12: "EM181"
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
try:
|
||||||
|
os.makedirs(self.output_xml_dir, exist_ok=True)
|
||||||
|
os.makedirs(self.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.log_message.emit(f"加载模型:{self.model_path}")
|
||||||
|
model = YOLO(self.model_path)
|
||||||
|
self.log_message.emit("模型加载完成 ✅")
|
||||||
|
|
||||||
|
# 读取图片
|
||||||
|
imgs = [f for f in os.listdir(self.input_dir)
|
||||||
|
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
|
||||||
|
total_imgs = len(imgs)
|
||||||
|
if total_imgs == 0:
|
||||||
|
self.log_message.emit("输入目录中未找到图片 ❌")
|
||||||
|
self.finished_signal.emit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# ========= 灰度转RGB阶段 =========
|
||||||
|
self.log_message.emit("开始灰度图像转RGB...此过程时间较长,请等待")
|
||||||
|
for i, img_name in enumerate(imgs, 1):
|
||||||
|
img_path = os.path.join(self.input_dir, img_name)
|
||||||
|
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if img is None:
|
||||||
|
continue
|
||||||
|
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
cv2.imwrite(os.path.join(self.temp_dir, img_name), img)
|
||||||
|
self.update_rgb_progress.emit(int(i / total_imgs * 100))
|
||||||
|
self.log_message.emit("RGB 转换完成 ✅")
|
||||||
|
|
||||||
|
# ========= 模型检测阶段 =========
|
||||||
|
self.log_message.emit("开始执行YOLO检测...")
|
||||||
|
results = []
|
||||||
|
for i, file in enumerate(imgs, 1):
|
||||||
|
temp_path = os.path.join(self.temp_dir, file)
|
||||||
|
result = model.predict(
|
||||||
|
source=temp_path, conf=self.conf, save=False,
|
||||||
|
classes=self.classes if self.classes else None
|
||||||
|
)
|
||||||
|
results.extend(result)
|
||||||
|
self.update_model_progress.emit(int(i / total_imgs * 100))
|
||||||
|
self.log_message.emit(f"[检测完成] {file}")
|
||||||
|
|
||||||
|
# ========= XML生成阶段 =========
|
||||||
|
self.log_message.emit("开始生成VOC格式XML...")
|
||||||
|
total_results = len(results)
|
||||||
|
for j, result in enumerate(results, 1):
|
||||||
|
file_name = os.path.basename(result.path)
|
||||||
|
image_path = os.path.join(self.input_dir, file_name)
|
||||||
|
img = cv2.imread(image_path)
|
||||||
|
if img is None:
|
||||||
|
continue
|
||||||
|
height, width, depth = img.shape
|
||||||
|
|
||||||
|
annotation = ET.Element("annotation")
|
||||||
|
|
||||||
|
folder = ET.SubElement(annotation, "folder")
|
||||||
|
folder.text = os.path.basename(os.path.dirname(image_path))
|
||||||
|
|
||||||
|
filename = ET.SubElement(annotation, "filename")
|
||||||
|
filename.text = file_name
|
||||||
|
|
||||||
|
path = ET.SubElement(annotation, "path")
|
||||||
|
path.text = os.path.abspath(image_path)
|
||||||
|
|
||||||
|
source = ET.SubElement(annotation, "source")
|
||||||
|
database = ET.SubElement(source, "database")
|
||||||
|
database.text = "Unknown"
|
||||||
|
|
||||||
|
size = ET.SubElement(annotation, "size")
|
||||||
|
ET.SubElement(size, "width").text = str(width)
|
||||||
|
ET.SubElement(size, "height").text = str(height)
|
||||||
|
ET.SubElement(size, "depth").text = str(depth)
|
||||||
|
|
||||||
|
segmented = ET.SubElement(annotation, "segmented")
|
||||||
|
segmented.text = "0"
|
||||||
|
|
||||||
|
for box in result.boxes:
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
xyxy = box.xyxy[0].tolist()
|
||||||
|
xmin, ymin, xmax, ymax = map(int, xyxy)
|
||||||
|
|
||||||
|
obj = ET.SubElement(annotation, "object")
|
||||||
|
ET.SubElement(obj, "name").text = self.class_mapping.get(cls, str(cls))
|
||||||
|
ET.SubElement(obj, "pose").text = "Unspecified"
|
||||||
|
ET.SubElement(obj, "truncated").text = "0"
|
||||||
|
ET.SubElement(obj, "difficult").text = "0"
|
||||||
|
|
||||||
|
bndbox = ET.SubElement(obj, "bndbox")
|
||||||
|
ET.SubElement(bndbox, "xmin").text = str(xmin)
|
||||||
|
ET.SubElement(bndbox, "ymin").text = str(ymin)
|
||||||
|
ET.SubElement(bndbox, "xmax").text = str(xmax)
|
||||||
|
ET.SubElement(bndbox, "ymax").text = str(ymax)
|
||||||
|
|
||||||
|
xml_str = minidom.parseString(ET.tostring(annotation, 'utf-8')).toprettyxml(indent="\t", encoding="utf-8").decode("utf-8")
|
||||||
|
xml_str = "\n".join([line for line in xml_str.splitlines() if line.strip()])
|
||||||
|
|
||||||
|
xml_path = os.path.join(self.output_xml_dir, os.path.splitext(file_name)[0] + ".xml")
|
||||||
|
with open(xml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(xml_str)
|
||||||
|
|
||||||
|
self.update_xml_progress.emit(int(j / total_results * 100))
|
||||||
|
self.log_message.emit(f"[生成完成] {xml_path}")
|
||||||
|
|
||||||
|
self.log_message.emit("✅ 所有任务完成!")
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message.emit(f"❌ 错误: {e}")
|
||||||
|
finally:
|
||||||
|
self.finished_signal.emit()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主界面 ==========
|
||||||
|
class MainWindow(QWidget):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("YOLO 半自动标注工具(华东专用)")
|
||||||
|
self.resize(750, 750)
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
|
# 路径部分
|
||||||
|
self.model_edit = self._add_path_selector(layout, "模型路径:")
|
||||||
|
self.input_edit = self._add_path_selector(layout, "输入图片文件夹:")
|
||||||
|
self.output_edit = self._add_path_selector(layout, "输出XML文件夹:")
|
||||||
|
self.temp_edit = self._add_path_selector(layout, "临时RGB图片文件夹:")
|
||||||
|
|
||||||
|
# 置信度
|
||||||
|
conf_layout = QHBoxLayout()
|
||||||
|
conf_layout.addWidget(QLabel("置信度阈值:"))
|
||||||
|
self.conf_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.conf_slider.setRange(0, 100)
|
||||||
|
self.conf_slider.setValue(20)
|
||||||
|
self.conf_value = QLabel("0.2")
|
||||||
|
self.conf_slider.valueChanged.connect(lambda v: self.conf_value.setText(str(v / 100)))
|
||||||
|
conf_layout.addWidget(self.conf_slider)
|
||||||
|
conf_layout.addWidget(self.conf_value)
|
||||||
|
layout.addLayout(conf_layout)
|
||||||
|
|
||||||
|
# 类别选择区域
|
||||||
|
self.class_mapping = {
|
||||||
|
0: "EM14", 1: "EM18", 2: "EM17", 3: "EM170", 4: "EM19",
|
||||||
|
5: "EM190", 6: "EM20", 7: "EM200", 8: "EM201", 9: "EM202",
|
||||||
|
10: "EM203", 11: "EM180", 12: "EM181"
|
||||||
|
}
|
||||||
|
|
||||||
|
layout.addWidget(QLabel("选择要检测的类别(可多选):"))
|
||||||
|
self.class_checkboxes = {}
|
||||||
|
class_group = QGroupBox()
|
||||||
|
grid = QGridLayout()
|
||||||
|
for idx, (key, name) in enumerate(self.class_mapping.items()):
|
||||||
|
cb = QCheckBox(f"{key}: {name}")
|
||||||
|
self.class_checkboxes[key] = cb
|
||||||
|
grid.addWidget(cb, idx // 4, idx % 4)
|
||||||
|
class_group.setLayout(grid)
|
||||||
|
|
||||||
|
scroll = QScrollArea()
|
||||||
|
scroll.setWidget(class_group)
|
||||||
|
scroll.setWidgetResizable(True)
|
||||||
|
scroll.setFixedHeight(120)
|
||||||
|
layout.addWidget(scroll)
|
||||||
|
|
||||||
|
# 控制按钮
|
||||||
|
btn_layout = QHBoxLayout()
|
||||||
|
self.btn_load = QPushButton("加载模型")
|
||||||
|
self.btn_start = QPushButton("开始检测")
|
||||||
|
btn_layout.addWidget(self.btn_load)
|
||||||
|
btn_layout.addWidget(self.btn_start)
|
||||||
|
layout.addLayout(btn_layout)
|
||||||
|
|
||||||
|
# 三个进度条
|
||||||
|
self.progress_rgb = QProgressBar()
|
||||||
|
self.progress_rgb.setFormat("RGB转换进度:%p%")
|
||||||
|
self.progress_model = QProgressBar()
|
||||||
|
self.progress_model.setFormat("模型检测进度:%p%")
|
||||||
|
self.progress_xml = QProgressBar()
|
||||||
|
self.progress_xml.setFormat("XML生成进度:%p%")
|
||||||
|
layout.addWidget(self.progress_rgb)
|
||||||
|
layout.addWidget(self.progress_model)
|
||||||
|
layout.addWidget(self.progress_xml)
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
layout.addWidget(QLabel("日志输出:"))
|
||||||
|
self.log_output = QTextEdit()
|
||||||
|
self.log_output.setReadOnly(True)
|
||||||
|
layout.addWidget(self.log_output)
|
||||||
|
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
# 信号绑定
|
||||||
|
self.btn_load.clicked.connect(self.load_model)
|
||||||
|
self.btn_start.clicked.connect(self.start_detection)
|
||||||
|
self.model_loaded = False
|
||||||
|
|
||||||
|
def _add_path_selector(self, layout, label_text):
|
||||||
|
hlayout = QHBoxLayout()
|
||||||
|
label = QLabel(label_text)
|
||||||
|
edit = QLineEdit()
|
||||||
|
btn = QPushButton("浏览")
|
||||||
|
hlayout.addWidget(label)
|
||||||
|
hlayout.addWidget(edit)
|
||||||
|
hlayout.addWidget(btn)
|
||||||
|
layout.addLayout(hlayout)
|
||||||
|
btn.clicked.connect(lambda: self._select_path(edit, "选择文件夹" if "文件夹" in label_text else "选择模型文件"))
|
||||||
|
return edit
|
||||||
|
|
||||||
|
def _select_path(self, edit, title):
|
||||||
|
if "模型" in title:
|
||||||
|
path, _ = QFileDialog.getOpenFileName(self, title, "", "Model Files (*.pt *.onnx)")
|
||||||
|
else:
|
||||||
|
path = QFileDialog.getExistingDirectory(self, title)
|
||||||
|
if path:
|
||||||
|
edit.setText(path)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
path = self.model_edit.text().strip()
|
||||||
|
if not os.path.exists(path):
|
||||||
|
self._log("❌ 模型文件不存在")
|
||||||
|
return
|
||||||
|
self.model_loaded = True
|
||||||
|
self._log("✅ 模型路径加载成功,可开始检测")
|
||||||
|
|
||||||
|
def start_detection(self):
|
||||||
|
if not self.model_loaded:
|
||||||
|
self._log("⚠️ 请先加载模型路径")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = self.model_edit.text().strip()
|
||||||
|
input_dir = self.input_edit.text().strip()
|
||||||
|
output_xml_dir = self.output_edit.text().strip()
|
||||||
|
temp_dir = self.temp_edit.text().strip()
|
||||||
|
conf = float(self.conf_slider.value() / 100)
|
||||||
|
|
||||||
|
selected_classes = [cid for cid, cb in self.class_checkboxes.items() if cb.isChecked()]
|
||||||
|
|
||||||
|
self._log(f"选中类别:{selected_classes if selected_classes else '全部'}")
|
||||||
|
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
self._log("❌ 输入文件夹不存在")
|
||||||
|
return
|
||||||
|
if not temp_dir:
|
||||||
|
self._log("⚠️ 请选择临时RGB图片文件夹")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.progress_rgb.setValue(0)
|
||||||
|
self.progress_model.setValue(0)
|
||||||
|
self.progress_xml.setValue(0)
|
||||||
|
|
||||||
|
self.thread = DetectionThread(model_path, input_dir, output_xml_dir, temp_dir, conf, selected_classes)
|
||||||
|
self.thread.update_rgb_progress.connect(self.progress_rgb.setValue)
|
||||||
|
self.thread.update_model_progress.connect(self.progress_model.setValue)
|
||||||
|
self.thread.update_xml_progress.connect(self.progress_xml.setValue)
|
||||||
|
self.thread.log_message.connect(self._log)
|
||||||
|
self.thread.finished_signal.connect(self._finish)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
self.btn_start.setEnabled(False)
|
||||||
|
self._log("开始执行检测任务...")
|
||||||
|
|
||||||
|
def _log(self, msg):
|
||||||
|
self.log_output.append(msg)
|
||||||
|
self.log_output.verticalScrollBar().setValue(self.log_output.verticalScrollBar().maximum())
|
||||||
|
|
||||||
|
def _finish(self):
|
||||||
|
self.btn_start.setEnabled(True)
|
||||||
|
self._log("任务完成 ✅")
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 主程序入口 ==========
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
window = MainWindow()
|
||||||
|
window.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
44
autoannoguinew.spec
Normal file
44
autoannoguinew.spec
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
|
a = Analysis(
|
||||||
|
['autoannoguinew.py'],
|
||||||
|
pathex=['D:\PythonProject\AutoAnno\.venv\Lib\site-packages'],
|
||||||
|
binaries=[],
|
||||||
|
datas=[],
|
||||||
|
hiddenimports=['os','cv2','ultralytics','PyQt5','sys','xml.etree.ElementTree','xml.dom'],
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=[],
|
||||||
|
noarchive=False,
|
||||||
|
optimize=0,
|
||||||
|
)
|
||||||
|
pyz = PYZ(a.pure)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name='autoannoguinew',
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.datas,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
upx_exclude=[],
|
||||||
|
name='autoannoguinew',
|
||||||
|
)
|
||||||
BIN
epoch220(1).pt
Normal file
BIN
epoch220(1).pt
Normal file
Binary file not shown.
68
function-gpu.yaml
Normal file
68
function-gpu.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
metadata:
|
||||||
|
name: emdetector
|
||||||
|
namespace: cvat
|
||||||
|
annotations:
|
||||||
|
name: EM14 v1
|
||||||
|
type: detector
|
||||||
|
framework: pytorch
|
||||||
|
spec: |
|
||||||
|
[
|
||||||
|
{ "id": 0, "name": "EM14", "type": "rectangle" },
|
||||||
|
{ "id": 1, "name": "EM18", "type": "rectangle" },
|
||||||
|
{ "id": 2, "name": "EM17", "type": "rectangle" },
|
||||||
|
{ "id": 3, "name": "EM170", "type": "rectangle" },
|
||||||
|
{ "id": 4, "name": "EM19", "type": "rectangle" },
|
||||||
|
{ "id": 5, "name": "EM190", "type": "rectangle" },
|
||||||
|
{ "id": 6, "name": "EM20", "type": "rectangle" },
|
||||||
|
{ "id": 7, "name": "EM200", "type": "rectangle" },
|
||||||
|
{ "id": 8, "name": "EM201", "type": "rectangle" },
|
||||||
|
{ "id": 9, "name": "EM202", "type": "rectangle" },
|
||||||
|
{ "id": 10, "name": "EM203", "type": "rectangle" }
|
||||||
|
]
|
||||||
|
|
||||||
|
spec:
|
||||||
|
description: 工位检测
|
||||||
|
runtime: "python:3.9"
|
||||||
|
handler: main:handler
|
||||||
|
eventTimeout: 30s
|
||||||
|
|
||||||
|
build:
|
||||||
|
image: cvat.pth.yolo8.emdetector:latest-gpu
|
||||||
|
baseImage: python:3.9
|
||||||
|
directives:
|
||||||
|
preCopy:
|
||||||
|
- kind: ENV
|
||||||
|
value: DEBIAN_FRONTEND=noninteractive
|
||||||
|
- kind: RUN
|
||||||
|
value: apt-get update && apt-get install -y libgl1 libglib2.0-0 && apt-get clean
|
||||||
|
- kind: RUN
|
||||||
|
value: pip install ultralytics torch torchvision opencv-python-headless && pip cache purge
|
||||||
|
|
||||||
|
triggers:
|
||||||
|
myHttpTrigger:
|
||||||
|
numWorkers: 1
|
||||||
|
kind: 'http'
|
||||||
|
workerAvailabilityTimeoutMilliseconds: 10000
|
||||||
|
attributes:
|
||||||
|
# Set value from the calculation of tracking of 100 objects at the same time on a 4k image
|
||||||
|
maxRequestBodySize: 268435456 # 256MB
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
- volume:
|
||||||
|
name: model-volume
|
||||||
|
hostPath:
|
||||||
|
path: /DATA/wjl/cvat/models/best.pt # 可选:如果使用HostPath挂载模型
|
||||||
|
volumeMount:
|
||||||
|
name: model-volume
|
||||||
|
mountPath: /opt/nuclio/best.pt
|
||||||
|
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: 1
|
||||||
|
|
||||||
|
platform:
|
||||||
|
attributes:
|
||||||
|
restartPolicy:
|
||||||
|
name: always
|
||||||
|
maximumRetryCount: 3
|
||||||
|
mountMode: volume
|
||||||
38
localdetect.py
Normal file
38
localdetect.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = YOLO("best.pt")
|
||||||
|
|
||||||
|
# 输入文件夹路径
|
||||||
|
input_dir = "test_images/"
|
||||||
|
temp_dir = "temp_rgb_images"
|
||||||
|
|
||||||
|
# 创建临时文件夹用于保存转换后的RGB图片
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 扫描文件夹,将灰度图转换为RGB
|
||||||
|
for file in os.listdir(input_dir):
|
||||||
|
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
|
||||||
|
path = os.path.join(input_dir, file)
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
# 如果是灰度图,转换为3通道
|
||||||
|
if len(img.shape) == 2 or img.shape[2] == 1:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# 保存到临时目录
|
||||||
|
cv2.imwrite(os.path.join(temp_dir, file), img)
|
||||||
|
|
||||||
|
# 使用转换后的图像文件夹进行检测
|
||||||
|
results = model.predict(source=temp_dir, conf=0.2, save=True)
|
||||||
|
|
||||||
|
# 输出检测信息
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes # 检测框
|
||||||
|
for box in boxes:
|
||||||
|
cls = int(box.cls[0])
|
||||||
|
conf = float(box.conf[0])
|
||||||
|
xyxy = box.xyxy[0].tolist()
|
||||||
|
print(f"类别: {model.names[cls]}, 置信度: {conf:.2f}, 坐标: {xyxy}")
|
||||||
24
main.py
Normal file
24
main.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import json
|
||||||
|
import base64
|
||||||
|
# from PIL import Image
|
||||||
|
import io
|
||||||
|
from model_handler import ModelHandler
|
||||||
|
|
||||||
|
def init_context(context):
|
||||||
|
context.logger.info("Init context... 0%")
|
||||||
|
|
||||||
|
context.logger.info("Initializing EMDetection model...")
|
||||||
|
context.user_data.model_handler = ModelHandler()
|
||||||
|
|
||||||
|
context.logger.info("Init context...100%")
|
||||||
|
|
||||||
|
def handler(context, event):
|
||||||
|
context.logger.info("Run EMDetection model")
|
||||||
|
data = event.body
|
||||||
|
image_data = base64.b64decode(data["image"])
|
||||||
|
threshold = float(data.get("threshold", 0.5))
|
||||||
|
|
||||||
|
results = context.user_data.model_handler.infer(image_data, threshold)
|
||||||
|
|
||||||
|
return context.Response(body=json.dumps(results), headers={},
|
||||||
|
content_type='application/json', status_code=200)
|
||||||
33
model_handler.py
Normal file
33
model_handler.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
|
||||||
|
class ModelHandler:
|
||||||
|
def __init__(self):
|
||||||
|
"""加载 YOLOv11 模型"""
|
||||||
|
self.model = YOLO("/opt/nuclio/best.pt") # 确保路径正确
|
||||||
|
|
||||||
|
def infer(self, image_data, threshold=0.3):
|
||||||
|
"""
|
||||||
|
执行推理
|
||||||
|
:param image_data: 图片的二进制数据
|
||||||
|
:param threshold: 置信度阈值(默认0.3)
|
||||||
|
:return: 符合阈值的检测结果
|
||||||
|
"""
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
results = self.model(image)
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
for result in results:
|
||||||
|
for box in result.boxes.data.tolist():
|
||||||
|
x1, y1, x2, y2, score, class_id = box
|
||||||
|
if score >= threshold: # 过滤低置信度目标
|
||||||
|
detections.append({
|
||||||
|
"confidence": score,
|
||||||
|
"label": self.model.names[int(class_id)],
|
||||||
|
"points": [x1, y1, x2, y2],
|
||||||
|
"type": "rectangle",
|
||||||
|
})
|
||||||
|
|
||||||
|
return detections
|
||||||
16
runanno.py
Normal file
16
runanno.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import subprocess
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 获取当前脚本所在的目录(即 AutoAnno 目录)
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# 组装相对路径
|
||||||
|
python_exe = os.path.join(base_dir, ".venv", "Scripts", "python.exe")
|
||||||
|
script_path = os.path.join(base_dir, "autoannoguinew.py")
|
||||||
|
|
||||||
|
# 后台静默执行
|
||||||
|
subprocess.Popen(
|
||||||
|
[python_exe, script_path],
|
||||||
|
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||||
|
cwd=base_dir
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user