add xml parse

This commit is contained in:
2026-03-19 23:49:16 +08:00
parent 8b279df333
commit 230d3d4204
38 changed files with 2736 additions and 9 deletions

View File

@@ -0,0 +1,133 @@
//! sunhpc config check 命令实现
//! 功能:加载配置文件 → 解析XML → 生成Linux Kickstart脚本
use anyhow::{Context, Result};
use clap::Args;
use std::env;
use std::path::PathBuf;
// 导入内部模块
use crate::internal::config::ksgen::KickstartGenerator;
use crate::internal::config::loader::ConfigLoader;
/// 配置检查命令参数
#[derive(Args, Debug)]
#[command(about = "检查配置文件并生成Kickstart脚本", long_about = None)]
pub struct CheckArgs {
/// 配置文件路径默认config/main.toml
#[arg(
short = 'c',
long = "config",
default_value = "config/main.toml",
help = "指定主配置文件路径TOML/YAML格式"
)]
pub config: PathBuf,
/// 输出格式text/json默认text
#[arg(
short = 'f',
long = "format",
default_value = "text",
help = "Kickstart输出格式text纯文本/json调试用"
)]
pub format: String,
/// 详细输出模式
#[arg(
long = "verbose",
help = "显示更多调试信息(配置加载过程、节点变量等)"
)]
pub verbose: bool,
/// 目标节点名称默认compute
#[arg(
short = 'n',
long = "node",
default_value = "compute",
help = "指定要生成Kickstart的节点名称对应配置中的nodes字段"
)]
pub node: String,
/// Linux发行版过滤可选覆盖配置中的值
#[arg(
long = "distro",
help = "指定Linux发行版如centos/ubuntu/rhel优先级高于配置文件"
)]
pub distro: Option<String>,
/// Linux版本过滤可选覆盖配置中的值
#[arg(
long = "version",
help = "指定Linux版本如8/9/22.04),优先级高于配置文件"
)]
pub version: Option<String>,
}
/// 执行check命令的核心逻辑
pub fn run(args: CheckArgs) -> Result<()> {
let cwd = get_current_working_dir()?;
// 1. 打印基础信息
println!("========================================");
println!("🔍 SUNHPC 配置检查工具Linux专属");
println!("========================================");
println!("配置文件路径: {}", args.config.display());
println!("当前路径: {}", cwd.display());
println!("目标节点: {}", args.node);
println!("输出格式: {}", args.format);
println!("详细模式: {}", if args.verbose { "开启" } else { "关闭" });
if let Some(distro) = &args.distro {
println!("指定发行版: {}", distro);
}
if let Some(version) = &args.version {
println!("指定版本: {}", version);
}
println!("----------------------------------------");
// 2. 加载配置文件TOML/YAML
let mut config_loader = ConfigLoader::new(&args.config)
.with_context(|| format!("加载配置文件失败: {}", args.config.display()))?;
// 覆盖发行版/版本配置(如果命令行指定)
if let Some(distro) = &args.distro {
config_loader.global.linux_distro = distro.clone();
}
if let Some(version) = &args.version {
config_loader.global.linux_version = version.clone();
}
// 详细模式:打印加载的配置信息
if args.verbose {
println!("📄 配置加载成功!");
println!(" - 全局变量数: {}", config_loader.global.variables.len());
println!(" - 节点数: {}", config_loader.global.nodes.len());
println!(" - XML路径数: {}", config_loader.global.xml_paths.len());
println!(" - Linux发行版: {}", config_loader.global.linux_distro);
println!(" - Linux版本: {}", config_loader.global.linux_version);
println!("----------------------------------------");
}
// 3. 初始化Kickstart生成器
let mut generator = KickstartGenerator::new(config_loader, &args.node)
.with_context(|| format!("初始化Kickstart生成器失败节点: {}", args.node))?;
// 4. 解析XML配置并构建Kickstart上下文
generator.parse()
.with_context(|| "解析XML配置文件失败")?;
// 5. 生成Kickstart脚本内容
let kickstart_content = generator.generate(&args.format)
.with_context(|| format!("生成Kickstart失败格式: {}", args.format))?;
// 6. 输出结果
println!("✅ Kickstart脚本生成成功");
println!("----------------------------------------");
println!("{}", kickstart_content);
Ok(())
}
/// 获取当前路径
fn get_current_working_dir() -> Result<PathBuf, std::io::Error> {
env::current_dir()
}

View File

@@ -0,0 +1,25 @@
use clap::Subcommand;
use anyhow::Result;
// 引入具体的命令逻辑文件
mod check;
mod sync;
/// Server 子命令集的具体命令
#[derive(Subcommand)]
pub enum ConfigCommands {
/// 检查配置文件语法正确性
Check(check::CheckArgs),
/// 分发配置文件到所有节点
Sync(sync::StopArgs),
}
/// 执行 server 命令集的入口函数
pub fn execute(cmd: ConfigCommands) -> Result<(), anyhow::Error> {
match cmd {
ConfigCommands::Check(args) => check::run(args),
ConfigCommands::Sync(args) => sync::run(args),
}
}

View File

@@ -0,0 +1,13 @@
use anyhow::Result;
#[derive(clap::Args)]
pub struct StopArgs {
/// 强制停止
#[arg(short, long)]
pub force: bool,
}
pub fn run(args: StopArgs) -> Result<()> {
println!("🛑 正在停止服务器... (force: {})", args.force);
Ok(())
}

49
src/commands/db/init.rs Normal file
View File

@@ -0,0 +1,49 @@
use clap::Args;
use anyhow::Result;
// 引入核心逻辑
use crate::internal::database::DbConfig;
use crate::internal::database::init_db;
#[derive(Args)]
pub struct InitArgs {
/// 数据库连接 URL
#[arg(short, long, env = "DATABASE_URL", default_value = "postgres://localhost/mydb")]
pub url: String,
/// 最大连接池大小
#[arg(long, default_value_t = 10)]
pub max_connections: u32,
/// 是否跳过确认提示
#[arg(long, short)]
pub yes: bool,
}
/// CLI 层的执行函数
pub async fn run(args: InitArgs) -> Result<()> {
if !args.yes {
print!("⚠️ 即将初始化数据库 '{}',数据可能会被清空。继续吗?(y/N): ", args.url);
use std::io::{self, Write};
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
if !input.trim().eq_ignore_ascii_case("y") {
println!("操作已取消。");
return Ok(());
}
}
// 1. 将 CLI 参数转换为 核心逻辑需要的结构体
let config = DbConfig {
url: args.url,
max_connections: args.max_connections,
};
// 2. 调用 internal 层的纯逻辑
// 注意:这里调用的是 async 函数
init_db(&config).await?;
Ok(())
}

View File

@@ -2,15 +2,20 @@ use clap::Subcommand;
use anyhow::Result;
mod migrate;
mod init;
#[derive(Subcommand)]
pub enum DbCommands {
/// 运行数据库迁移
Migrate(migrate::MigrateArgs),
/// 数据库初始化
Init(init::InitArgs),
}
pub fn execute(cmd: DbCommands) -> Result<()> {
pub async fn execute(cmd: DbCommands) -> Result<()> {
match cmd {
DbCommands::Migrate(args) => migrate::run(args),
DbCommands::Init(args) => init::run(args).await,
}
}

27
src/commands/db/status.rs Normal file
View File

@@ -0,0 +1,27 @@
// src/commands/db/status.rs
use clap::Args;
use anyhow::Result;
use crate::internal::database::check_connection;
#[derive(Args)]
pub struct StatusArgs {
/// 数据库 URL (默认从环境变量读取)
#[arg(short, long, env = "DATABASE_URL", default_value = "postgres://localhost/mydb")]
pub url: String,
}
pub async fn run(args: StatusArgs) -> Result<()> {
println!("🏓 正在检查数据库状态...");
match check_connection(&args.url).await? {
true => {
println!("✅ 数据库连接正常: {}", args.url);
Ok(())
}
false => {
eprintln!("❌ 无法连接到数据库: {}", args.url);
// 返回错误码
std::process::exit(1);
}
}
}

View File

@@ -1,6 +1,8 @@
use clap::Subcommand;
// 声明子模块,对应 src/commands/ 下的目录
pub mod config;
pub mod report;
pub mod server;
pub mod db;
// 未来扩展pub mod new_feature;
@@ -13,7 +15,15 @@ pub enum CliCommands {
#[command(subcommand)]
Server(server::ServerCommands),
/// 数据库管理相关命令 (db migrate, db seed)
/// 服务器配置文件解析相关命令
#[command(subcommand)]
Config(config::ConfigCommands),
/// 打印集群配置信息
#[command(subcommand)]
Report(report::ReportCommands),
/// 数据库管理相关命令 (db init, db migrate)
#[command(subcommand)]
Db(db::DbCommands),

View File

@@ -0,0 +1,21 @@
use clap::Subcommand;
use anyhow::Result;
// 引入具体的命令逻辑文件
mod nextip;
/// Server 子命令集的具体命令
#[derive(Subcommand)]
pub enum ReportCommands {
/// 生成下一个可用 IP 地址
#[command(name = "nextip")] // 强制子命令名为 nextip
NextIP(nextip::NextIPArgs), // 默认字母有大小写自动转换成驼峰命令法next-ip.
}
/// 执行 report nextip 命令集的入口函数
pub fn execute(cmd: ReportCommands) -> Result<(), anyhow::Error> {
match cmd {
ReportCommands::NextIP(args) => nextip::run(args),
}
}

View File

@@ -0,0 +1,71 @@
#[warn(unused_imports)]
use crate::internal::network::IPGenerator;
use anyhow::Result;
#[derive(clap::Args, Debug)]
pub struct NextIPArgs {
/// 网络地址(如 10.1.1.0
#[arg(short = 'n', long = "network", default_value = "10.1.1.0")]
network: String,
/// 子网掩码(如 255.255.255.128,不指定则自动推断)
#[arg(short = 'm', long = "netmask", default_value = "255.255.255.0")]
netmask: Option<String>,
/// IP 偏移步长(仅允许正数,从第一个可以用 IP 向后偏移)
#[arg(short = 's', long = "step", default_value = "1")]
step: u32,
}
pub fn run(args: NextIPArgs) -> Result<()> {
// 调用 IP 生成逻辑
generate_next_ip(args)?;
Ok(())
}
/// 核心逻辑:生成指定偏移的 IP 地址
fn generate_next_ip(opts: NextIPArgs) -> Result<()> {
// 创建 IP 生成器实例(用 anyhow 包装错误)
let mut generator = IPGenerator::new(&opts.network, opts.netmask.as_deref())
.map_err(|e| anyhow::anyhow!("Failed to create IP generator: {}", e))?;
// 获取网段最大可用步长,提前校验
let max_step = generator.get_max_available_step();
if opts.step > max_step {
return Err(anyhow::anyhow!(
"Step {} exceeds maximum available step {} for this network (mask: {})",
opts.step,
max_step,
generator.get_netmask_str()
));
}
// 获取初始 IP、新 IP、网络地址等信息
let initial_ip = generator.curr()
.map_err(|e| anyhow::anyhow!("Failed to get initial IP: {}", e))?;
let new_ip = generator.next(opts.step as i32)
.map_err(|e| anyhow::anyhow!("Failed to move IP address: {}", e))?;
let network_addr = generator.get_network();
let netmask_str = generator.get_netmask_str();
// ========== 核心:格式化对齐输出 ==========
// 定义左侧字段的宽度(根据最长字段调整,这里选 25 足够覆盖所有字段名)
const FIELD_WIDTH: usize = 30;
println!("{:>FIELD_WIDTH$}: {}", "Initial IP (first available)", initial_ip);
println!("{:>FIELD_WIDTH$}: {}", "IP after step", format!("{} (step: {})", new_ip, opts.step));
println!("{:>FIELD_WIDTH$}: {}", "Network address", network_addr);
println!("{:>FIELD_WIDTH$}: {}", "Subnet mask", netmask_str);
println!("{:>FIELD_WIDTH$}: {}", "Max available step", max_step);
println!("\nIP generation completed successfully");
// 扩展点:写入配置/同步到集群的业务逻辑
// 示例write_ip_to_config(&new_ip)?;
Ok(())
}

View File

@@ -0,0 +1,595 @@
//! Kickstart生成器模块
//! 功能:
//! 1. 基于配置和XML解析结果生成Linux Kickstart脚本
//! 2. 支持不同Linux发行版/版本的适配
//! 3. 输出纯文本或JSON格式的Kickstart内容
use anyhow::{Context, Result};
use serde::Serialize;
use std::collections::{HashMap, HashSet};
use std::collections::hash_map::RandomState;
use std::path::PathBuf;
use roxmltree::Node;
// 导入内部模块
use super::loader::ConfigLoader;
use super::xml_parser::XmlLoader;
/// Kickstart上下文存储生成脚本所需的所有数据
/// 对应gen.py的Generator_linux的ks字段
#[derive(Debug, Default, Serialize)]
pub struct KickstartContext {
/// 节点遍历顺序file, roll
pub order: Vec<(String, String)>,
/// debug信息列表
pub debug: Vec<String>,
/// main部分内容Kickstart核心配置
pub main: Vec<String>,
/// 启用的RPM包列表
pub rpms_on: Vec<String>,
/// 禁用的RPM包列表
pub rpms_off: Vec<String>,
/// pre脚本[参数, 内容]
pub pre: Vec<Vec<String>>,
/// post脚本[参数, 解释器, 内容]
pub post: Vec<Vec<String>>,
/// boot pre脚本
pub boot_pre: Vec<String>,
/// boot post脚本
pub boot_post: Vec<String>,
/// RCS文件配置路径: (所有者, 权限)
pub rcs_files: HashMap<PathBuf, (String, String)>,
}
/// 节点数据结构(用于避免借用冲突)
#[derive(Debug)]
struct NodeData {
/// 节点名称
pub node_name: String,
/// 文件属性
pub file: String,
/// roll属性
pub roll: String,
/// 节点文本内容
pub text: String,
/// 是否禁用package节点
pub disable: bool,
/// 是否元包package节点
pub meta_type: bool,
/// 解释器pre/post节点
pub interpreter: String,
/// 参数pre/post节点
pub arg: String,
/// 顺序boot节点
pub order: String,
/// 文件名file节点
pub file_name: String,
/// 所有者file节点
pub owner: String,
/// 权限file节点
pub perms: String,
/// main节点的子节点数据
pub children_data: Vec<MainChildData>,
}
/// main节点的子节点数据
#[derive(Debug)]
struct MainChildData {
/// 子节点名称
pub name: String,
/// 子节点文本
pub text: String,
/// partition属性clearpart节点
pub partition: String,
}
/// Linux Kickstart生成器
pub struct KickstartGenerator {
/// 配置加载器
config_loader: ConfigLoader,
/// XML加载器
xml_loader: XmlLoader,
/// Kickstart上下文
context: KickstartContext,
/// 目标节点名称
node_name: String,
}
impl KickstartGenerator {
/// 创建Kickstart生成器
/// 参数:
/// - config_loader: 配置加载器实例
/// - node_name: 目标节点名称
pub fn new(config_loader: ConfigLoader, node_name: &str) -> Result<Self> {
// 1. 获取节点变量用于XML过滤
let node_vars = config_loader.get_node_vars(node_name)
.with_context(|| format!("获取节点{}的变量失败", node_name))?;
// 2. 创建XML加载器使用节点变量作为过滤条件
let xml_loader = XmlLoader::new(node_vars, &config_loader.global.xml_paths)
.with_context(|| "初始化XML加载器失败")?;
// 3. 初始化上下文
let context = KickstartContext::default();
Ok(Self {
config_loader,
xml_loader,
context,
node_name: node_name.to_string(),
})
}
/// 解析XML并构建Kickstart上下文
pub fn parse(&mut self) -> Result<()> {
println!("📝 开始解析XML配置生成Kickstart...");
// 提取节点数据(避免借用冲突)
let node_data: Vec<NodeData> = self.xml_loader.iter_linux_nodes()
.map(|node| {
let node_name = node.tag_name().name().to_string();
let file = node.attribute("file").unwrap_or_default().to_string();
let roll = node.attribute("roll").unwrap_or_default().to_string();
let text = self.xml_loader.get_node_text(&node);
// 提取常用属性
let disable = node.attribute("disable").is_some();
let meta_type = node.attribute("type").map(|s| s == "meta").unwrap_or(false);
let interpreter = node.attribute("interpreter").unwrap_or("/bin/bash").to_string();
let arg = node.attribute("arg").unwrap_or_default().to_string();
let order = node.attribute("order").unwrap_or("pre").to_string();
let file_name = node.attribute("name").unwrap_or_default().to_string();
let owner = node.attribute("owner").unwrap_or("root").to_string();
let perms = node.attribute("perms").unwrap_or("0644").to_string();
// 对于main节点提取子节点数据
let children_data = if node_name == "main" {
Self::extract_main_children(&node, &self.xml_loader)
} else {
Vec::new()
};
NodeData {
node_name,
file,
roll,
text,
disable,
meta_type,
interpreter,
arg,
order,
file_name,
owner,
perms,
children_data,
}
})
.collect();
// 遍历节点数据此时不再持有xml_loader的借用
for data in node_data {
// 记录遍历顺序
self.context.order.push((data.file.clone(), data.roll.clone()));
// 根据节点类型处理
match data.node_name.as_str() {
"main" => self.handle_main_node_data(&data)?,
"package" => self.handle_package_node_data(&data)?,
"pre" => self.handle_pre_node_data(&data)?,
"post" => self.handle_post_node_data(&data)?,
"configure" => self.handle_configure_node_data(&data)?,
"boot" => self.handle_boot_node_data(&data)?,
"debug" => self.context.debug.push(data.text),
"file" => self.handle_file_node_data(&data)?,
_ => {
// 其他节点记录到debug
self.context.debug.push(format!(
"未处理的节点类型: {} (文本: {})",
data.node_name, data.text
));
}
}
}
println!("✅ XML解析完成共处理{}个节点", self.context.order.len());
Ok(())
}
/// 提取main节点的子节点数据
fn extract_main_children<'a, 'b>(node: &Node<'a, 'b>, xml_loader: &XmlLoader) -> Vec<MainChildData> {
let mut children_data = Vec::new();
for child in node.children() {
if child.node_type() != roxmltree::NodeType::Element {
continue;
}
let child_name = child.tag_name().name().to_string();
let child_text = xml_loader.get_node_text(&child);
let partition = child.attribute("partition").unwrap_or_default().to_string();
children_data.push(MainChildData {
name: child_name,
text: child_text,
partition,
});
}
children_data
}
/// 生成Kickstart脚本内容
/// 参数format - 输出格式text/json
pub fn generate(&self, format: &str) -> Result<String> {
match format.to_lowercase().as_str() {
"text" => self.generate_text(),
"json" => self.generate_json(),
_ => Err(anyhow::anyhow!("不支持的输出格式: {}仅支持text/json", format)),
}
}
/// 生成纯文本格式的Kickstart脚本
fn generate_text(&self) -> Result<String> {
let mut content = String::new();
// 1. 生成头部注释
content.push_str(&format!(
"# SUNHPC Kickstart脚本自动生成
# 节点: {}
# Linux发行版: {}
# Linux版本: {}
# 生成时间: autoGen
\n",
self.node_name,
self.config_loader.global.linux_distro,
self.config_loader.global.linux_version
));
// 2. 生成main部分Kickstart核心配置
content.push_str("# ========== 核心配置 ==========\n");
for line in &self.context.main {
content.push_str(line);
content.push('\n');
}
content.push('\n');
// 3. 生成包管理部分
content.push_str("# ========== 包配置 ==========\n");
if !self.context.rpms_on.is_empty() || !self.context.rpms_off.is_empty() {
content.push_str("%packages --ignoremissing\n");
// 启用的包
for rpm in &self.context.rpms_on {
content.push_str(rpm);
content.push('\n');
}
// 禁用的包(前缀-
for rpm in &self.context.rpms_off {
content.push_str(&format!("-{}\n", rpm));
}
content.push_str("%end\n\n");
}
// 4. 生成pre脚本部分
content.push_str("# ========== Pre安装脚本 ==========\n");
for pre in &self.context.pre {
content.push_str(&format!("%pre --log=/tmp/ks-pre.log {}\n", pre[0]));
content.push_str(&pre[1]);
content.push_str("\n%end\n\n");
}
// 5. 生成post脚本部分
content.push_str("# ========== Post安装脚本 ==========\n");
let log_path = "/mnt/sysimage/var/log/sunhpc-install.log";
for post in &self.context.post {
content.push_str(&format!("%post --log={} {}\n", log_path, post[0]));
content.push_str(&post[1]); // 解释器(#!/bin/bash
content.push_str("\n");
content.push_str(&post[2]); // 脚本内容
content.push_str("\n%end\n\n");
}
// 6. 生成boot脚本部分
content.push_str("# ========== Boot配置脚本 ==========\n");
content.push_str(&format!("%post --log={}\n", log_path));
content.push_str("cat >> /etc/sysconfig/sunhpc-pre << EOF\n");
for line in &self.context.boot_pre {
content.push_str(line);
content.push('\n');
}
content.push_str("EOF\n\n");
content.push_str("cat >> /etc/sysconfig/sunhpc-post << EOF\n");
for line in &self.context.boot_post {
content.push_str(line);
content.push('\n');
}
content.push_str("EOF\n%end\n");
// 7. 生成debug信息如果有
if !self.context.debug.is_empty() {
content.push_str("\n# ========== Debug信息 ==========\n");
content.push_str("# 以下为调试信息,实际部署时可删除\n");
for debug in &self.context.debug {
content.push_str(&format!("# {}\n", debug));
}
}
Ok(content)
}
/// 生成JSON格式的Kickstart上下文调试用
fn generate_json(&self) -> Result<String> {
let json = serde_json::to_string_pretty(&self.context)
.with_context(|| "将Kickstart上下文序列化为JSON失败")?;
Ok(json)
}
// ------------------------------
// 以下为节点处理方法按XML节点类型
// ------------------------------
/// 处理<main>节点Kickstart核心配置
fn handle_main_node_data(&mut self, data: &NodeData) -> Result<()> {
for child in &data.children_data {
// 根据不同子节点类型生成配置
match child.name.as_str() {
"clearpart" => {
// clearpart配置示例clearpart --all --initlabel
let mut clearpart_line = "clearpart".to_string();
if !child.partition.is_empty() {
clearpart_line.push_str(&format!(" --partition={}", child.partition));
}
if !child.text.is_empty() {
clearpart_line.push_str(&format!(" {}", child.text));
}
self.context.main.push(clearpart_line);
}
"bootloader" | "lilo" => {
// 引导加载器配置
self.context.main.push(format!("bootloader {}", child.text));
}
"lang" => {
// 语言配置
self.context.main.push(format!("lang {}", child.text));
}
"keyboard" => {
// 键盘配置
self.context.main.push(format!("keyboard {}", child.text));
}
"timezone" => {
// 时区配置
self.context.main.push(format!("timezone {}", child.text));
}
"rootpw" => {
// root密码配置
self.context.main.push(format!("rootpw {}", child.text));
}
"network" => {
// 网络配置
self.context.main.push(format!("network {}", child.text));
}
"part" | "volgroup" | "logvol" | "raid" => {
// 分区配置
self.context.main.push(format!("{} {}", child.name, child.text));
}
"reboot" | "poweroff" => {
// 安装后动作
self.context.main.push(child.name.clone());
}
_ => {
// 其他主节点:直接拼接
self.context.main.push(format!("{} {}", child.name, child.text));
}
}
}
Ok(())
}
/// 处理<package>节点RPM包配置
fn handle_package_node_data(&mut self, data: &NodeData) -> Result<()> {
let package_name = data.text.trim();
if package_name.is_empty() {
return Ok(());
}
// 构建包名(元包前缀@
let rpm_name = if data.meta_type {
format!("@{}", package_name)
} else {
package_name.to_string()
};
// 去重集合(避免重复添加)
let mut rpms_on_set: HashSet<String, RandomState> = self.context.rpms_on.clone()
.into_iter()
.collect();
let mut rpms_off_set: HashSet<String, RandomState> = self.context.rpms_off.clone()
.into_iter()
.collect();
if data.disable {
// 禁用的包:不在启用列表才添加到禁用列表
if !rpms_on_set.contains(&rpm_name) {
rpms_off_set.insert(rpm_name.clone());
}
rpms_on_set.remove(&rpm_name);
} else {
// 启用的包:添加到启用列表,移除禁用列表
rpms_on_set.insert(rpm_name.clone());
rpms_off_set.remove(&rpm_name);
}
// 将临时HashSet转回Vec覆盖原数据可变借用
self.context.rpms_on = rpms_on_set.into_iter().collect();
self.context.rpms_off = rpms_off_set.into_iter().collect();
Ok(())
}
/// 处理<pre>节点(安装前脚本)
fn handle_pre_node_data(&mut self, data: &NodeData) -> Result<()> {
// 构建pre脚本参数
let pre_args = format!(
"--interpreter {} {}",
data.interpreter,
data.arg
).trim().to_string();
// 添加到pre列表
self.context.pre.push(vec![pre_args, data.text.clone()]);
Ok(())
}
/// 处理<post>节点(安装后脚本)
fn handle_post_node_data(&mut self, data: &NodeData) -> Result<()> {
// 构建post脚本项参数、解释器、内容
let post_item = vec![
data.arg.clone(),
format!("#!{}", data.interpreter),
data.text.clone(),
];
// 添加到post列表
self.context.post.push(post_item);
Ok(())
}
/// 处理<configure>节点配置脚本转为post脚本
fn handle_configure_node_data(&mut self, data: &NodeData) -> Result<()> {
// configure节点等价于post节点
self.handle_post_node_data(data)
}
/// 处理<configure>节点配置脚本转为post脚本- 保留以兼容
#[allow(dead_code)]
fn handle_configure_node<'a, 'b>(&mut self, node: &Node<'a, 'b>) -> Result<()> {
// configure节点等价于post节点保留这个方法以兼容
let interpreter = node.attribute("interpreter").unwrap_or("/bin/bash");
let arg = node.attribute("arg").unwrap_or_default();
let content = self.xml_loader.get_node_text(node);
// 构建post脚本项参数、解释器、内容
let post_item = vec![
arg.to_string(),
format!("#!{}", interpreter),
content,
];
// 添加到post列表
self.context.post.push(post_item);
Ok(())
}
/// 处理<boot>节点(启动脚本)
fn handle_boot_node_data(&mut self, data: &NodeData) -> Result<()> {
match data.order.as_str() {
"pre" => self.context.boot_pre.push(data.text.clone()),
"post" => self.context.boot_post.push(data.text.clone()),
_ => {
// 未知顺序同时添加到pre和post
self.context.boot_pre.push(data.text.clone());
self.context.boot_post.push(data.text.clone());
}
}
Ok(())
}
/// 处理<debug>节点(调试信息)
#[allow(dead_code)]
fn handle_debug_node<'a, 'b>(&mut self, node: &Node<'a, 'b>) -> Result<()> {
// 保留这个方法以兼容
let content = self.xml_loader.get_node_text(node);
self.context.debug.push(content);
Ok(())
}
/// 处理<file>节点(文件配置)
fn handle_file_node_data(&mut self, data: &NodeData) -> Result<()> {
if data.file_name.is_empty() {
return Ok(());
}
// 记录RCS文件配置
let file_path = PathBuf::from(&data.file_name);
self.context.rcs_files.insert(file_path, (data.owner.clone(), data.perms.clone()));
// 生成文件创建脚本添加到post
let post_content = format!(
r#"
# 创建文件: {}
mkdir -p $(dirname {})
cat > {} << EOF
->{}
EOF
chown {} {}
chmod {} {}
"#,
data.file_name, data.file_name, data.file_name, data.text, data.owner, data.file_name, data.perms, data.file_name
);
// 添加到post脚本
self.context.post.push(vec![
"".to_string(),
"#!/bin/bash".to_string(),
post_content,
]);
Ok(())
}
/// 处理<file>节点(文件配置)
#[allow(dead_code)]
fn handle_file_node<'a, 'b>(&mut self, node: &Node<'a, 'b>) -> Result<()> {
// 提取文件属性
let file_name = node.attribute("name").unwrap_or_default();
let owner = node.attribute("owner").unwrap_or("root");
let perms = node.attribute("perms").unwrap_or("0644");
let content = self.xml_loader.get_node_text(node);
if file_name.is_empty() {
return Ok(());
}
// 记录RCS文件配置
let file_path = PathBuf::from(file_name);
self.context.rcs_files.insert(file_path, (owner.to_string(), perms.to_string()));
// 生成文件创建脚本添加到post
let post_content = format!(
r#"
# 创建文件: {}
mkdir -p $(dirname {})
cat > {} << EOF
{}
EOF
chown {} {}
chmod {} {}
"#,
file_name, file_name, file_name, content, owner, file_name, perms, file_name
);
// 添加到post脚本
self.context.post.push(vec![
"".to_string(),
"#!/bin/bash".to_string(),
post_content,
]);
Ok(())
}
}

View File

@@ -0,0 +1,322 @@
//! 配置加载器模块
//! 功能:
//! 1. 解析TOML/YAML格式的主配置+分组配置
//! 2. 处理节点继承关系(子节点合并父节点变量)
//! 3. 管理Linux发行版/版本配置
//! 4. 收集XML文件路径转为绝对路径
use anyhow::{Context, Result};
use serde::{Deserialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use yaml_rust::{Yaml, YamlLoader};
/// 全局配置结构体对应main.toml的根节点
#[derive(Debug, Deserialize, Clone)]
pub struct GlobalConfig {
/// 全局通用变量(所有节点共享)
#[serde(default)]
pub variables: HashMap<String, String>,
/// 节点配置key=节点名value=节点具体配置)
#[serde(default)]
pub nodes: HashMap<String, NodeConfig>,
/// XML配置文件/目录路径列表
#[serde(default)]
pub xml_paths: Vec<PathBuf>,
/// Linux发行版如centos/rhel/ubuntu
#[serde(default = "default_linux_distro")]
pub linux_distro: String,
/// Linux版本如8/9/22.04
#[serde(default = "default_linux_version")]
pub linux_version: String,
}
/// 节点配置结构体(支持继承)
#[derive(Debug, Deserialize, Clone)]
pub struct NodeConfig {
/// 继承的父节点名称(可选)
#[serde(default)]
pub inherit: Option<String>,
/// 节点架构x86_64/aarch64等
#[serde(default = "default_arch")]
pub arch: String,
/// Linux发行版覆盖全局配置
#[serde(default)]
pub linux_distro: String,
/// Linux版本覆盖全局配置
#[serde(default)]
pub linux_version: String,
/// 节点私有变量(覆盖全局变量)
#[serde(default)]
pub variables: HashMap<String, String>,
}
/// 配置加载器核心结构体
pub struct ConfigLoader {
/// 加载完成的全局配置
pub global: GlobalConfig,
/// 原始配置文件路径
#[allow(dead_code)]
config_path: PathBuf,
}
// 默认值函数
fn default_arch() -> String {
"x86_64".to_string()
}
fn default_linux_distro() -> String {
"centos".to_string()
}
fn default_linux_version() -> String {
"8".to_string()
}
impl ConfigLoader {
/// 创建配置加载器并加载配置文件
/// 参数config_path - 主配置文件路径TOML/YAML
pub fn new(config_path: &Path) -> Result<Self> {
// 1. 读取配置文件内容
let config_str = fs::read_to_string(config_path)
.with_context(|| format!("无法读取配置文件: {}", config_path.display()))?;
// 2. 解析配置优先TOML兼容YAML
let mut global: GlobalConfig = match toml::from_str(&config_str) {
Ok(cfg) => cfg,
Err(toml_err) => {
eprintln!("TOML解析失败: {}", toml_err);
// TOML解析失败尝试YAML
let yaml_docs = YamlLoader::load_from_str(&config_str)
.with_context(|| "配置文件不是有效的TOML或YAML格式")?;
if yaml_docs.is_empty() {
return Err(anyhow::anyhow!("YAML配置文件为空"))
.context("解析Yaml后未获取到任何信息");
}
Self::yaml_to_global(&yaml_docs[0])?
}
};
// 3. 合并分组配置groups目录下的compute/storage等
Self::merge_group_configs(&mut global, config_path.parent().unwrap())?;
// 4. 处理节点继承关系(子节点合并父节点变量)
Self::process_node_inheritance(&mut global)?;
// 5. 标准化XML路径转为绝对路径
global.xml_paths = global.xml_paths
.iter()
.map(|p| {
if p.is_absolute() {
p.clone()
} else {
config_path.parent().unwrap().join(p)
}
})
.collect();
Ok(Self {
global,
config_path: config_path.to_path_buf(),
})
}
/// 将YAML对象转换为GlobalConfig兼容逻辑
fn yaml_to_global(yaml: &Yaml) -> Result<GlobalConfig> {
let mut global = GlobalConfig {
variables: HashMap::new(),
nodes: HashMap::new(),
xml_paths: Vec::new(),
linux_distro: default_linux_distro(),
linux_version: default_linux_version(),
};
// 解析全局变量
if let Yaml::Hash(vars_hash) = &yaml["variables"] {
for (k, v) in vars_hash {
if let (Yaml::String(k_str), Yaml::String(v_str)) = (k, v) {
global.variables.insert(k_str.clone(), v_str.clone());
}
}
}
// 解析节点配置
if let Yaml::Hash(nodes_hash) = &yaml["nodes"] {
for (node_name, node_yaml) in nodes_hash {
let node_name_str = node_name
.as_str()
.with_context(|| "节点名称必须是字符串")?
.to_string();
// 解析单个节点配置
let inherit = node_yaml["inherit"].as_str().map(|s| s.to_string());
let arch = node_yaml["arch"].as_str().unwrap_or(&default_arch()).to_string();
let linux_distro = node_yaml["linux_distro"]
.as_str()
.unwrap_or(&default_linux_distro())
.to_string();
let linux_version = node_yaml["linux_version"]
.as_str()
.unwrap_or(&default_linux_version())
.to_string();
// 解析节点私有变量
let mut node_vars = HashMap::new();
if let Yaml::Hash(vars_hash) = &node_yaml["variables"] {
for (k, v) in vars_hash {
if let (Yaml::String(k_str), Yaml::String(v_str)) = (k, v) {
node_vars.insert(k_str.clone(), v_str.clone());
}
}
}
// 添加到节点列表
global.nodes.insert(
node_name_str,
NodeConfig {
inherit,
arch,
linux_distro,
linux_version,
variables: node_vars,
},
);
}
}
// 解析XML路径
if let Yaml::Array(xml_paths_arr) = &yaml["xml_paths"] {
for p in xml_paths_arr {
if let Yaml::String(p_str) = p {
global.xml_paths.push(PathBuf::from(p_str));
}
}
}
// 解析Linux发行版/版本
if let Yaml::String(distro) = &yaml["linux_distro"] {
global.linux_distro = distro.clone();
}
if let Yaml::String(version) = &yaml["linux_version"] {
global.linux_version = version.clone();
}
Ok(global)
}
/// 合并groups目录下的分组配置compute/storage/gpu/common
fn merge_group_configs(global: &mut GlobalConfig, parent_dir: &Path) -> Result<()> {
let groups_dir = parent_dir.join("groups");
if !groups_dir.exists() || !groups_dir.is_dir() {
// 没有groups目录直接返回
return Ok(());
}
// 遍历groups目录下的所有TOML/YAML文件
for entry in fs::read_dir(groups_dir)? {
let entry = entry?;
let path = entry.path();
// 只处理文件且后缀为toml/yaml/yml
if !path.is_file() {
continue;
}
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
if !["toml", "yaml", "yml"].contains(&ext) {
continue;
}
// 读取并解析分组配置
let group_str = fs::read_to_string(&path)?;
let group_config: GlobalConfig = if ext == "toml" {
toml::from_str(&group_str)?
} else {
let yaml_docs = YamlLoader::load_from_str(&group_str)?;
Self::yaml_to_global(&yaml_docs[0])?
};
// 合并配置(分组配置覆盖全局配置)
global.variables.extend(group_config.variables);
global.nodes.extend(group_config.nodes);
global.xml_paths.extend(group_config.xml_paths);
// 合并Linux发行版/版本(如果分组配置有值)
if !group_config.linux_distro.is_empty() {
global.linux_distro = group_config.linux_distro;
}
if !group_config.linux_version.is_empty() {
global.linux_version = group_config.linux_version;
}
}
Ok(())
}
/// 处理节点继承关系(子节点合并父节点的变量/配置)
fn process_node_inheritance(global: &mut GlobalConfig) -> Result<()> {
let nodes_clone = global.nodes.clone(); // 克隆一份用于读取父节点
for (node_name, node) in global.nodes.iter_mut() {
// 如果节点有继承的父节点
if let Some(parent_name) = &node.inherit {
let parent_node = nodes_clone
.get(parent_name)
.with_context(|| format!("节点{}继承的父节点{}不存在", node_name, parent_name))?;
// 1. 合并变量:父节点变量 → 子节点变量(子节点覆盖父节点)
let mut parent_vars = parent_node.variables.clone();
parent_vars.extend(node.variables.clone());
node.variables = parent_vars;
// 2. 继承架构(子节点未设置则使用父节点)
if node.arch == default_arch() && parent_node.arch != default_arch() {
node.arch = parent_node.arch.clone();
}
// 3. 继承Linux发行版子节点未设置则使用父节点
if node.linux_distro == default_linux_distro() && parent_node.linux_distro != default_linux_distro() {
node.linux_distro = parent_node.linux_distro.clone();
}
// 4. 继承Linux版本子节点未设置则使用父节点
if node.linux_version == default_linux_version() && parent_node.linux_version != default_linux_version() {
node.linux_version = parent_node.linux_version.clone();
}
}
}
Ok(())
}
/// 获取指定节点的完整变量(全局变量 + 节点私有变量)
/// 参数node_name - 节点名称
pub fn get_node_vars(&self, node_name: &str) -> Result<HashMap<String, String>> {
let node = self.global.nodes.get(node_name)
.with_context(|| format!("节点{}不存在于配置文件中", node_name))?;
// 1. 复制全局变量
let mut vars = self.global.variables.clone();
// 2. 合并节点私有变量(覆盖全局变量)
vars.extend(node.variables.clone());
// 3. 注入节点内置变量
vars.insert("arch".to_string(), node.arch.clone());
vars.insert("linux_distro".to_string(), node.linux_distro.clone());
vars.insert("linux_version".to_string(), node.linux_version.clone());
vars.insert("node_name".to_string(), node_name.to_string());
Ok(vars)
}
}

View File

@@ -0,0 +1,5 @@
// 配置相关模块:加载、解析、合并 TOML 配置文件
pub mod loader;
pub mod xml_parser;
pub mod ksgen;

View File

@@ -0,0 +1,226 @@
//! XML解析器模块
//! 功能:
//! 1. 加载指定目录/文件下的所有XML文件
//! 2. 实现NodeFilter逻辑过滤符合Linux发行版/版本/架构的节点)
//! 3. 提供XML节点遍历和文本提取功能
use anyhow::{Context, Result};
use roxmltree::{Document, Node, NodeType};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::{Path, PathBuf};
/// XML节点过滤器复刻gen.py的NodeFilter
/// 用于过滤符合条件的XML节点架构/发行版/版本/阶段)
#[derive(Debug, Clone)]
pub struct XmlNodeFilter {
/// 过滤条件arch/linux_distro/linux_version等
filter_attrs: HashMap<String, String>,
/// 阶段过滤pre/post默认全部
phases: HashSet<String>,
}
impl XmlNodeFilter {
/// 创建过滤器
/// 参数filter_attrs - 过滤属性arch/linux_distro/linux_version
pub fn new(filter_attrs: HashMap<String, String>) -> Self {
// 默认包含pre和post阶段
let mut phases = HashSet::new();
phases.insert("pre".to_string());
phases.insert("post".to_string());
Self {
filter_attrs,
phases,
}
}
/// 设置阶段过滤(仅保留指定阶段)
/// 参数phases - 阶段列表(如["pre"]或["post"]
#[allow(dead_code)]
pub fn set_phases(&mut self, phases: &[&str]) {
self.phases = phases.iter().map(|s| s.to_string()).collect();
}
/// 检查节点是否符合过滤条件(核心逻辑)
/// 对应gen.py的isCorrectCond
pub fn is_node_match(&self, node: &Node) -> bool {
// 非元素节点直接过滤
if node.node_type() != NodeType::Element {
return false;
}
// 1. 提取节点属性
let node_arch = node.attribute("arch").unwrap_or_default();
let node_distro = node.attribute("linux_distro").unwrap_or_default();
let node_version = node.attribute("linux_version").unwrap_or_default();
let node_phase = node.attribute("phase").unwrap_or_default();
// 2. 阶段过滤(支持逗号分隔的多个阶段)
let node_phases: HashSet<String> = node_phase
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect();
// 如果节点指定了阶段,但与过滤器阶段无交集 → 过滤
if !node_phases.is_empty() && node_phases.intersection(&self.phases).next().is_none() {
return false;
}
// 3. 架构过滤节点指定了arch但不匹配 → 过滤)
if !node_arch.is_empty() && node_arch != self.filter_attrs.get("arch").map(|s| s.as_str()).unwrap_or("") {
return false;
}
// 4. 发行版过滤节点指定了linux_distro但不匹配 → 过滤)
if !node_distro.is_empty() && node_distro != self.filter_attrs.get("linux_distro").map(|s|s.as_str()).unwrap_or("") {
return false;
}
// 5. 版本过滤节点指定了linux_version但不匹配 → 过滤)
if !node_version.is_empty() && node_version != self.filter_attrs.get("linux_version").map(|s|s.as_str()).unwrap_or("") {
return false;
}
// 所有条件都满足
true
}
/// 检查节点是否是Linux Kickstart允许的主节点
/// 对应gen.py的MainNodeFilter_linux
pub fn is_linux_main_node(&self, node: &Node) -> bool {
if !self.is_node_match(node) {
return false;
}
// 允许的主节点列表Kickstart核心节点
let allowed_main_nodes = [
"kickstart", "include", "main", "auth", "clearpart", "device", "driverdisk",
"install", "nfs", "cdrom", "interactive", "harddrive", "url", "keyboard",
"lang", "langsupport", "lilo", "bootloader", "mouse", "network", "part",
"volgroup", "logvol", "raid", "reboot", "rootpw", "skipx", "text",
"timezone", "upgrade", "xconfig", "zerombr"
];
allowed_main_nodes.contains(&node.tag_name().name())
}
/// 检查节点是否是Linux Kickstart允许的其他节点
/// 对应gen.py的OtherNodeFilter_linux
pub fn is_linux_other_node(&self, node: &Node) -> bool {
if !self.is_node_match(node) {
return false;
}
// 允许的其他节点列表(包/脚本/配置相关)
let allowed_other_nodes = [
"attributes", "debug", "description", "package", "pre", "post",
"boot", "configure", "file"
];
allowed_other_nodes.contains(&node.tag_name().name())
}
}
/// XML加载器加载所有XML文件并提供遍历接口
#[derive(Debug)]
pub struct XmlLoader{
/// 加载的所有XML文档
pub docs: Vec<Document<'static>>,
/// 节点过滤器
pub filter: XmlNodeFilter,
}
impl XmlLoader {
/// 创建XML加载器
/// 参数:
/// - filter_attrs: 过滤属性arch/linux_distro/linux_version
/// - xml_paths: XML文件/目录路径列表
pub fn new(filter_attrs: HashMap<String, String>, xml_paths: &[PathBuf]) -> Result<Self> {
let mut docs = Vec::new();
// 遍历所有XML路径文件/目录)
for path in xml_paths {
if path.is_dir() {
// 目录:遍历所有.xml文件
docs.extend(Self::load_xml_dir(path)?);
} else if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("xml") {
// 文件:直接加载
let doc = Self::load_xml_file(path)?;
docs.push(doc);
}
}
// 创建过滤器
let filter = XmlNodeFilter::new(filter_attrs);
Ok(Self { docs, filter})
}
/// 加载目录下的所有XML文件
fn load_xml_dir(dir: &Path) -> Result<Vec<Document<'static>>> {
let mut docs = Vec::new();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
// 只处理.xml文件
if path.is_file() && path.extension().and_then(|e| e.to_str()) == Some("xml") {
let doc = Self::load_xml_file(&path)?;
docs.push(doc);
}
}
Ok(docs)
}
/// 加载单个XML文件
fn load_xml_file(path: &Path) -> Result<Document<'static>> {
// 读取文件内容
let xml_str = fs::read_to_string(path)
.with_context(|| format!("读取XML文件失败: {}", path.display()))?;
// 解析XML为DOM文档
let xml_static: &'static str = Box::leak(xml_str.into_boxed_str());
let doc = Document::parse(xml_static)
.with_context(|| format!("解析XML文件失败: {}", path.display()))?;
Ok(doc)
}
/// 遍历所有符合条件的Linux节点
pub fn iter_linux_nodes(&self) -> impl Iterator<Item = Node<'_, '_>> {
self.docs.iter()
.flat_map(|doc| doc.descendants()) // 遍历所有后代节点
.filter(|node| {
// 过滤出主节点或其他节点
self.filter.is_linux_main_node(node) || self.filter.is_linux_other_node(node)
})
}
/// 获取节点的所有文本内容(包括子节点)
pub fn get_node_text<'a, 'b>(&self, node: &Node<'a, 'b>) -> String {
let mut text = String::new();
for child in node.children() {
match child.node_type() {
// 文本节点:直接追加
NodeType::Text => {
if let Some(t) = child.text() {
text.push_str(t);
}
}
// 元素节点:递归获取文本
NodeType::Element => {
text.push_str(&self.get_node_text(&child));
}
// 其他节点:忽略
_ => {}
}
}
// 去除首尾空白,替换多空格为单空格
text.trim().replace(&['\n', '\r', '\t'][..], " ").replace(" ", " ")
}
}

View File

@@ -0,0 +1,52 @@
use anyhow::{Result};
use std::time::Duration;
/// 模拟数据库配置结构
#[derive(Debug, Clone)]
pub struct DbConfig {
pub url: String,
pub max_connections: u32,
}
/// 核心逻辑:初始化数据库
/// 这个函数不关心参数从哪里来CLI、配置文件、环境变量
pub async fn init_db(config: &DbConfig) -> Result<()> {
println!("🔌 正在连接数据库: {}", config.url);
// 模拟网络延迟
tokio::time::sleep(Duration::from_millis(500)).await;
// 模拟连接逻辑
if config.url.contains("error") {
anyhow::bail!("无法连接到数据库URL 无效");
}
println!("✅ 连接成功!最大连接数设置为: {}", config.max_connections);
// 模拟执行初始化脚本
run_migrations().await?;
println!("🎉 数据库初始化完成!");
Ok(())
}
/// 核心逻辑:检查连接状态
#[allow(dead_code)]
pub async fn check_connection(url: &str) -> Result<bool> {
println!("🏓 正在 Ping 数据库: {}", url);
tokio::time::sleep(Duration::from_millis(200)).await;
if url.contains("error") {
return Ok(false);
}
Ok(true)
}
/// 内部辅助函数:运行迁移
async fn run_migrations() -> Result<()> {
println!("🚀 正在运行数据迁移脚本...");
tokio::time::sleep(Duration::from_millis(300)).await;
println!(" - 创建表 users");
println!(" - 创建表 logs");
Ok(())
}

View File

@@ -0,0 +1,8 @@
// 公开具体的实现文件
pub mod database;
// 重新导出常用函数,方便 commands 层调用
// 例如use crate::internal::database::init_db;
pub use database::init_db;
//pub use database::check_connection;
pub use database::DbConfig;

5
src/internal/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
// 公开 database 模块供外部commands使用
pub mod database;
pub mod config;
pub mod network;

238
src/internal/network/ip.rs Normal file
View File

@@ -0,0 +1,238 @@
use std::fmt;
use std::error::Error;
// 自定义错误类型(实现 Error trait 方便上层处理)
#[derive(Debug)]
pub enum IPError {
InvalidAddressFormat,
OutOfRange(String),
NonUnicastAddress,
ParseError(String),
}
impl fmt::Display for IPError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IPError::InvalidAddressFormat => write!(f, "Invalid IP address format (must be x.x.x.x)"),
IPError::OutOfRange(msg) => write!(f, "IP address out of range: {}", msg),
IPError::NonUnicastAddress => write!(f, "Not a unicast IP address (invalid network type)"),
IPError::ParseError(msg) => write!(f, "IP parse error: {}", msg),
}
}
}
impl Error for IPError {}
// IP 地址核心结构体IPv4
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct IPAddr(u32);
impl IPAddr {
/// 从字符串创建 IP 地址(如 "192.168.1.1"
pub fn from_str(s: &str) -> Result<Self, IPError> {
let parts: Vec<&str> = s.split('.').collect();
if parts.len() != 4 {
return Err(IPError::InvalidAddressFormat);
}
let mut octets = [0u8; 4];
for (i, part) in parts.iter().enumerate() {
octets[i] = part.parse().map_err(|e| {
IPError::ParseError(format!("Failed to parse octet {}: {}", part, e))
})?;
}
// 转换为 u32大端序第一个 octet 是最高位)
let addr = ((octets[0] as u32) << 24)
| ((octets[1] as u32) << 16)
| ((octets[2] as u32) << 8)
| (octets[3] as u32);
Ok(Self(addr))
}
/// 从 u32 数值创建 IP 地址
#[allow(dead_code)]
pub fn from_u32(addr: u32) -> Self {
Self(addr)
}
/// 获取 u32 格式的地址
pub fn to_u32(&self) -> u32 {
self.0
}
/// 位运算:与
pub fn bitwise_and(&self, other: &Self) -> Self {
Self(self.0 & other.0)
}
/// 位运算:或
#[allow(dead_code)]
pub fn bitwise_or(&self, other: &Self) -> Self {
Self(self.0 | other.0)
}
/// 位运算:非
pub fn bitwise_not(&self) -> Self {
Self(!self.0)
}
/// 地址加法(支持负数)
pub fn add(&self, n: i32) -> Self {
if n >= 0 {
Self(self.0 + n as u32)
} else {
Self(self.0 - n.abs() as u32)
}
}
}
// 实现 Display 以便格式化输出
impl fmt::Display for IPAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}.{}.{}.{}",
(self.0 >> 24) & 0xFF,
(self.0 >> 16) & 0xFF,
(self.0 >> 8) & 0xFF,
self.0 & 0xFF
)
}
}
// IP 生成器核心结构体
#[derive(Debug)]
pub struct IPGenerator {
network: IPAddr,
netmask: IPAddr,
addr: IPAddr,
}
impl IPGenerator {
/// 创建 IP 生成器实例
/// - network: 网络地址(如 "10.1.1.0"
/// - netmask: 子网掩码None 则自动推断 A/B/C 类)
pub fn new(network: &str, netmask: Option<&str>) -> Result<Self, IPError> {
let network_addr = IPAddr::from_str(network)?;
let netmask_addr = match netmask {
Some(addr_str) => IPAddr::from_str(addr_str)?,
None => Self::infer_netmask(&network_addr)?,
};
// 初始IP = 网络地址 + 1(第一个可用地址)
let initial_addr = network_addr.add(1);
// 效验初始 IP 是否在合法范围
Self::validate_ip_in_range(&initial_addr, &network_addr, &netmask_addr)?;
Ok(Self {
network: network_addr,
netmask: netmask_addr,
addr: initial_addr,
})
}
/// 自动推断子网掩码A/B/C 类)
fn infer_netmask(addr: &IPAddr) -> Result<IPAddr, IPError> {
let first_octet = (addr.0 >> 24) & 0xFF;
match first_octet {
// A 类: 0-127
0..=127 => Ok(IPAddr::from_str("255.0.0.0")?),
// B 类: 128-191
128..=191 => Ok(IPAddr::from_str("255.255.0.0")?),
// C 类: 192-223
192..=223 => Ok(IPAddr::from_str("255.255.255.0")?),
// 非单播地址
_ => Err(IPError::NonUnicastAddress),
}
}
// 校验 IP 是否在网段合法范围(非网络/广播地址)
fn validate_ip_in_range(ip: &IPAddr, network: &IPAddr, netmask: &IPAddr) -> Result<(), IPError> {
let inverted_netmask = netmask.bitwise_not();
let net_addr = network.bitwise_and(netmask).to_u32();
let ip_addr = ip.to_u32();
let broadcast_addr = net_addr | inverted_netmask.to_u32();
// 检查是否是网络地址
if ip_addr == net_addr {
return Err(IPError::OutOfRange("IP is network address".to_string()));
}
// 检查是否是广播地址
if ip_addr == broadcast_addr {
return Err(IPError::OutOfRange("IP is broadcast address".to_string()));
}
// 检查是否在网段内
if (ip_addr & netmask.to_u32()) != net_addr {
return Err(IPError::OutOfRange("IP is not in network range".to_string()));
}
Ok(())
}
// 获取当前 IP带校验
pub fn curr(&self) -> Result<IPAddr, IPError> {
Self::validate_ip_in_range(&self.addr, &self.network, &self.netmask)?;
Ok(self.addr)
}
// 动态计算最大可用步长(根据子网掩码)
pub fn get_max_available_step(&self) -> u32 {
let inverted_netmask = self.netmask.bitwise_not().to_u32();
// 可用 IP 数量 = 主机位总数 - 2排除网络/广播地址)
let total_available = inverted_netmask - 1; // 减 1 是因为初始 IP 已经是 +1 了
total_available as u32
}
// 获取子网掩码的字符串形式
pub fn get_netmask_str(&self) -> String {
format!("{}", self.netmask)
}
/// 获取网络地址(如 10.1.1.0/25 对应 10.1.1.0
pub fn get_network(&self) -> String {
format!("{}", self.addr.bitwise_and(&self.netmask))
}
// 偏移 IP仅支持正数步长
pub fn next(&mut self, n: i32) -> Result<IPAddr, IPError> {
// 强制转为正数(禁用负数偏移)
let step = if n < 0 { 0 } else { n as u32 };
let new_addr = self.addr.add(step as i32);
// 校验新 IP 是否合法
Self::validate_ip_in_range(&new_addr, &self.network, &self.netmask)?;
self.addr = new_addr;
Ok(self.addr)
}
/// 地址递减(等价于 next(-1)
#[allow(dead_code)]
pub fn dec(&mut self) -> Result<IPAddr, IPError> {
self.next(-1)
}
}
// 测试用例(可选)
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_from_str() {
let ip = IPAddr::from_str("10.1.1.0").unwrap();
assert_eq!(ip.to_u32(), 0x0a010100);
}
#[test]
fn test_ip_generator() {
let mut generator = IPGenerator::new("10.1.1.0", Some("255.255.255.128")).unwrap();
assert_eq!(generator.curr().unwrap().to_string(), "10.1.1.127");
generator.next(-126).unwrap();
assert_eq!(generator.curr().unwrap().to_string(), "10.1.1.1");
}
}

View File

@@ -0,0 +1,3 @@
pub mod ip;
pub use ip::IPGenerator;

View File

@@ -1,13 +1,12 @@
use clap::Parser;
use anyhow::Result;
// 引入 commands 模块
mod commands;
mod utils; // 假设有通用工具
mod commands; // 引入 commands 模块
mod internal; // 引入 internal 模块
mod utils; // 引入 utils 通用工具模块
use commands::CliCommands;
/// 我的超级 CLI 工具
#[derive(Parser)]
#[command(name = "sunhpc")]
#[command(author = "Qichao.Sun")]
@@ -23,7 +22,8 @@ struct Cli {
command: CliCommands,
}
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
if cli.debug {
@@ -33,7 +33,9 @@ fn main() -> Result<()> {
// 根据子命令分发逻辑
match cli.command {
CliCommands::Server(args) => commands::server::execute(args)?,
CliCommands::Db(args) => commands::db::execute(args)?,
CliCommands::Db(args) => commands::db::execute(args).await?,
CliCommands::Config(args) => commands::config::execute(args)?,
CliCommands::Report(args) => commands::report::execute(args)?,
// 未来扩展新命令时,只需在这里添加新的匹配臂
// CliCommands::NewFeature(args) => commands::new_feature::execute(args)?,
}