2026-02-20 18:44:43 +08:00
|
|
|
|
package database
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"database/sql"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"sync"
|
|
|
|
|
|
|
|
|
|
|
|
"sunhpc/pkg/config"
|
|
|
|
|
|
"sunhpc/pkg/logger"
|
|
|
|
|
|
|
|
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-23 18:54:54 +08:00
|
|
|
|
// =========================================================
|
|
|
|
|
|
// 全局变量
|
|
|
|
|
|
// =========================================================
|
2026-02-20 18:44:43 +08:00
|
|
|
|
var (
|
2026-02-23 18:54:54 +08:00
|
|
|
|
dbInstance *sql.DB
|
2026-02-20 18:44:43 +08:00
|
|
|
|
dbOnce sync.Once
|
2026-02-23 18:54:54 +08:00
|
|
|
|
dbMutex sync.RWMutex
|
2026-02-20 18:44:43 +08:00
|
|
|
|
dbErr error
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-06 00:29:34 +08:00
|
|
|
|
// =========================================================
|
|
|
|
|
|
// 封装数据库函数使用Go实现
|
|
|
|
|
|
// =========================================================
|
|
|
|
|
|
// MapCategory - 根据类别名称查ID
|
|
|
|
|
|
// 查询方式: globalID, err := db.MapCategory(conn, "global")
|
|
|
|
|
|
func MapCategory(conn *sql.DB, catname string) (int, error) {
|
|
|
|
|
|
var id int
|
|
|
|
|
|
query := "select id from categories where name = ?"
|
|
|
|
|
|
logger.Debugf("查询SQL: %s", query)
|
|
|
|
|
|
logger.Debugf("查询类别ID: %s", catname)
|
|
|
|
|
|
err := conn.QueryRow(query, catname).Scan(&id)
|
|
|
|
|
|
if err == sql.ErrNoRows {
|
|
|
|
|
|
logger.Debugf("未找到类别 %s, 返回ID=0", catname)
|
|
|
|
|
|
return 0, nil // 无匹配返回0
|
|
|
|
|
|
}
|
|
|
|
|
|
logger.Debugf("查询到类别 %s, ID=%d", catname, id)
|
|
|
|
|
|
return id, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// MapCategoryIndex - 根据类别名称 + 索引名称查ID
|
|
|
|
|
|
// 调用方式: linuxOSID, err := db.MapCategoryIndex(conn, "os", "linux")
|
|
|
|
|
|
func MapCategoryIndex(conn *sql.DB, catindexName, categoryIndex string) (int, error) {
|
|
|
|
|
|
var id int
|
|
|
|
|
|
query := `
|
|
|
|
|
|
select index_id from vmapCategoryIndex
|
|
|
|
|
|
where categoryName = ? and categoryIndex = ?
|
|
|
|
|
|
`
|
|
|
|
|
|
logger.Debugf("查询SQL: %s", query)
|
|
|
|
|
|
logger.Debugf("查询索引ID: %s, 类别: %s", catindexName, categoryIndex)
|
|
|
|
|
|
err := conn.QueryRow(query, catindexName, categoryIndex).Scan(&id)
|
|
|
|
|
|
if err == sql.ErrNoRows {
|
|
|
|
|
|
logger.Debugf("未找到索引 %s, 返回ID=0", catindexName)
|
|
|
|
|
|
return 0, nil // 无匹配返回0
|
|
|
|
|
|
}
|
|
|
|
|
|
logger.Debugf("查询到索引 %s, ID=%d", catindexName, id)
|
|
|
|
|
|
return id, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ResolveFirewalls - 解析指定主机的防火墙规则
|
|
|
|
|
|
// 返回解析后的防火墙规则(fwresolved表数据),临时表使用后自动清理
|
|
|
|
|
|
// 调用方式: rows, err := db.ResolveFirewalls(conn, "compute-0-1", "default")
|
|
|
|
|
|
func ResolveFirewalls(conn *sql.DB, hostname, chainname string) (*sql.Rows, error) {
|
|
|
|
|
|
// 步骤1: 创建临时表 fresolved1
|
|
|
|
|
|
_, err := conn.Exec(`
|
|
|
|
|
|
DROP TABLE IF EXISTS fresolved1;
|
|
|
|
|
|
CREATE TEMPORARY TABLE fresolved1 AS
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
? AS hostname,
|
|
|
|
|
|
? AS Resolver,
|
|
|
|
|
|
f.*,
|
|
|
|
|
|
r.precedence
|
|
|
|
|
|
FROM
|
|
|
|
|
|
resolvechain r
|
|
|
|
|
|
inner join hostselections hs on r.category = hs.category and r.name = ?
|
|
|
|
|
|
inner join firewalls f on hs.category = f.category and hs.selection = f.catindex
|
|
|
|
|
|
where hs.host = ?;
|
|
|
|
|
|
`, hostname, chainname, chainname, hostname)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("Create temporary table fresolved1 failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 步骤2:创建临时表 fresolved2
|
|
|
|
|
|
_, err = conn.Exec(`
|
|
|
|
|
|
DROP TABLE IF EXISTS fresolved2;
|
|
|
|
|
|
CREATE TEMPORARY TABLE fresolved2 AS
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
*
|
|
|
|
|
|
FROM
|
|
|
|
|
|
fresolved1;
|
|
|
|
|
|
`)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("Create temporary table fresolved2 failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 步骤3:创建最终结果表 fwresolved
|
|
|
|
|
|
_, err = conn.Exec(`
|
|
|
|
|
|
DROP TABLE IF EXISTS fwresolved;
|
|
|
|
|
|
CREATE TEMPORARY TABLE fwresolved AS
|
|
|
|
|
|
SELECT
|
|
|
|
|
|
r1.*,
|
|
|
|
|
|
cat.name AS categoryName
|
|
|
|
|
|
FROM
|
|
|
|
|
|
fresolved1 r1
|
|
|
|
|
|
inner join (
|
|
|
|
|
|
select Rulename, MAX(precedence) as precedence
|
|
|
|
|
|
from fresolved2
|
|
|
|
|
|
group by Rulename
|
|
|
|
|
|
) AS r2 on r1.Rulename = r2.Rulename and r1.precedence = r2.precedence
|
|
|
|
|
|
inner join categories cat on r1.category = cat.id;
|
|
|
|
|
|
`)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("Create temporary table fwresolved failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 步骤4:查询结果并返回
|
|
|
|
|
|
rows, err := conn.Query("SELECT * FROM fwresolved")
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("Query fwresolved failed: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
return rows, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-23 18:54:54 +08:00
|
|
|
|
// =========================================================
|
|
|
|
|
|
// GetDB - 获取数据库连接(单例模式)
|
|
|
|
|
|
// =========================================================
|
|
|
|
|
|
func GetDB() (*sql.DB, error) {
|
2026-03-06 00:29:34 +08:00
|
|
|
|
logger.Debug("获取数据库连接...")
|
|
|
|
|
|
|
2026-02-20 18:44:43 +08:00
|
|
|
|
dbOnce.Do(func() {
|
2026-02-23 18:54:54 +08:00
|
|
|
|
if dbInstance != nil {
|
|
|
|
|
|
return
|
2026-02-20 18:44:43 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-23 18:54:54 +08:00
|
|
|
|
// 确保配置已加载
|
|
|
|
|
|
cfg, err := config.LoadConfig()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
dbErr = fmt.Errorf("加载配置失败: %w", err)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 构建DSN
|
2026-02-23 18:54:54 +08:00
|
|
|
|
logger.Debugf("DSN: %s", cfg.Database.DSN)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
|
|
|
|
|
|
// 打开SQLite 连接
|
2026-02-23 18:54:54 +08:00
|
|
|
|
sqlDB, err := sql.Open("sqlite3", cfg.Database.DSN)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
if err != nil {
|
2026-02-23 18:54:54 +08:00
|
|
|
|
dbErr = fmt.Errorf("数据库打开失败: %w", err)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 设置连接池参数
|
2026-02-23 18:54:54 +08:00
|
|
|
|
sqlDB.SetMaxOpenConns(10) // 最大打开连接数
|
|
|
|
|
|
sqlDB.SetMaxIdleConns(5) // 保持空闲连接
|
2026-02-20 18:44:43 +08:00
|
|
|
|
sqlDB.SetConnMaxLifetime(0) // 禁用连接生命周期超时
|
|
|
|
|
|
sqlDB.SetConnMaxIdleTime(0) // 禁用空闲连接超时
|
|
|
|
|
|
|
|
|
|
|
|
// 测试数据库连接
|
|
|
|
|
|
if err := sqlDB.Ping(); err != nil {
|
|
|
|
|
|
sqlDB.Close()
|
2026-02-23 18:54:54 +08:00
|
|
|
|
dbErr = fmt.Errorf("数据库连接失败: %w", err)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-23 18:54:54 +08:00
|
|
|
|
logger.Debug("数据库连接成功")
|
|
|
|
|
|
dbInstance = sqlDB
|
2026-02-20 18:44:43 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
if dbErr != nil {
|
|
|
|
|
|
return nil, dbErr
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return dbInstance, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-23 18:54:54 +08:00
|
|
|
|
func InitTables(db *sql.DB, force bool) error {
|
2026-02-20 18:44:43 +08:00
|
|
|
|
|
|
|
|
|
|
// ✅ 调用 schema.go 中的函数
|
2026-03-06 00:29:34 +08:00
|
|
|
|
//for _, ddl := range CreateTableStatements() {
|
|
|
|
|
|
for _, ddl := range BaseTables() {
|
2026-02-23 18:54:54 +08:00
|
|
|
|
logger.Debugf("执行: %s", ddl)
|
|
|
|
|
|
if _, err := db.Exec(ddl); err != nil {
|
2026-02-20 18:44:43 +08:00
|
|
|
|
return fmt.Errorf("数据表创建失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-02-23 18:54:54 +08:00
|
|
|
|
logger.Info("数据库表创建成功")
|
2026-02-20 18:44:43 +08:00
|
|
|
|
/*
|
|
|
|
|
|
使用sqlite3命令 测试数据库是否存在表
|
|
|
|
|
|
✅ 查询所有表
|
|
|
|
|
|
sqlite3 /var/lib/sunhpc/sunhpc.db
|
|
|
|
|
|
.tables # 查看所有表
|
|
|
|
|
|
select * from sqlite_master where type='table'; # 查看表定义
|
|
|
|
|
|
PRAGMA integrity_check; # 检查数据库完整性
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
2026-03-06 00:29:34 +08:00
|
|
|
|
// 添加基础数据
|
|
|
|
|
|
if err := InitBaseData(db); err != nil {
|
|
|
|
|
|
return fmt.Errorf("初始化基础数据失败: %w", err)
|
2026-02-20 18:44:43 +08:00
|
|
|
|
}
|
2026-03-06 00:29:34 +08:00
|
|
|
|
logger.Info("基础数据初始化成功")
|
2026-02-20 18:44:43 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2026-02-23 18:54:54 +08:00
|
|
|
|
|
|
|
|
|
|
func CloseDB() error {
|
|
|
|
|
|
dbMutex.Lock()
|
|
|
|
|
|
defer dbMutex.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
if dbInstance == nil {
|
|
|
|
|
|
if err := dbInstance.Close(); err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
dbInstance = nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 使用事务回滚测试
|
|
|
|
|
|
func RunTestWithRollback(db *sql.DB, testFunc func(*sql.Tx) error) error {
|
|
|
|
|
|
tx, err := db.Begin()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 执行测试
|
|
|
|
|
|
if err := testFunc(tx); err != nil {
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 回滚事务,所有更改(包括 ID 递增)都会撤销
|
|
|
|
|
|
return tx.Rollback()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 使用示例
|
|
|
|
|
|
func TestNodeInsert(db *sql.DB) error {
|
|
|
|
|
|
logger.Debug("测试数据插入...")
|
|
|
|
|
|
return RunTestWithRollback(db, func(tx *sql.Tx) error {
|
|
|
|
|
|
// 插入测试数据
|
|
|
|
|
|
logger.Debug("执行插入测试数据...")
|
|
|
|
|
|
|
|
|
|
|
|
_, err := tx.Exec(`
|
|
|
|
|
|
INSERT INTO nodes (name, cpus, rack, rank)
|
|
|
|
|
|
VALUES (?, ?, ?, ?)
|
|
|
|
|
|
`, "test-node", 64, 1, 1)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 验证插入
|
|
|
|
|
|
var count int
|
|
|
|
|
|
logger.Debug("执行查询测试数据...")
|
|
|
|
|
|
err = tx.QueryRow(`
|
|
|
|
|
|
SELECT COUNT(*) FROM nodes WHERE name = ?
|
|
|
|
|
|
`, "test-node").Scan(&count)
|
|
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
logger.Infof("测试数据插入成功,共 %d 条", count)
|
|
|
|
|
|
|
|
|
|
|
|
// 不需要手动删除,回滚会自动撤销
|
|
|
|
|
|
return nil
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
2026-03-06 00:29:34 +08:00
|
|
|
|
|
|
|
|
|
|
// =========================================================
|
|
|
|
|
|
// 带事务执行 SQL 语句,自动提交/回滚
|
|
|
|
|
|
// =========================================================
|
|
|
|
|
|
|
|
|
|
|
|
// 执行单条SQL语句,带事务管理
|
|
|
|
|
|
func ExecSingleWithTransaction(sqlStr string) error {
|
|
|
|
|
|
// 复用批量函数,将单条SQL语句包装为数组执行
|
|
|
|
|
|
return ExecWithTransaction([]string{sqlStr})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 批量执行 DDL 语句,带事务管理
|
|
|
|
|
|
func ExecWithTransaction(ddl []string) error {
|
|
|
|
|
|
conn, err := GetDB()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
logger.Errorf("获取数据库连接失败: %v", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 开始事务
|
|
|
|
|
|
tx, err := conn.Begin()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
logger.Errorf("开始事务失败: %v", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 延迟处理:如果函数异常,回滚事务
|
|
|
|
|
|
defer func() {
|
|
|
|
|
|
if r := recover(); r != nil {
|
|
|
|
|
|
// 捕获 panic 并回滚事务
|
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
|
logger.Errorf("事务执行中发生 panic: %v", r)
|
|
|
|
|
|
}
|
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
|
|
// 遍历执行 DDL 语句
|
|
|
|
|
|
for idx, sql := range ddl {
|
|
|
|
|
|
logger.Debugf("执行 DDL 语句 %d: %s", idx+1, sql)
|
|
|
|
|
|
_, err = tx.Exec(sql)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
// 执行失败时,回滚事务
|
|
|
|
|
|
rollbackErr := tx.Rollback()
|
|
|
|
|
|
if rollbackErr != nil {
|
|
|
|
|
|
logger.Errorf("执行失败: 回滚失败: %v (原错误: %v, SQL: %s)", rollbackErr, err, sql)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
logger.Errorf("执行失败: 回滚事务: %v, SQL: %s", err, sql)
|
|
|
|
|
|
}
|
|
|
|
|
|
logger.Errorf("执行 %d 条, 失败: %w (SQL: %s)", idx+1, err, sql)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 所有SQL语句执行成功,提交事务
|
|
|
|
|
|
logger.Info("所有SQL语句执行成功,提交事务")
|
|
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
|
|
logger.Errorf("提交事务失败: %w", err)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
logger.Debugf("成功执行 %d 条 SQL 语句, 事务已提交.", len(ddl))
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|