# 苹果MLX AI训练应用1 **Repository Path**: nbsstudio/apple-mlx-ai-training-app-1 ## Basic Information - **Project Name**: 苹果MLX AI训练应用1 - **Description**: 用来训练一下识别动物、物品之类的,构建模型的基本操作,基于苹果的MLX框架,实用苹果的芯片进行加速 - **Primary Language**: Python - **License**: WTFPL - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-08-02 - **Last Updated**: 2025-08-02 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 猫咪识别系统 (基于MLX) 这个项目使用Apple的MLX框架实现了一个简单的猫咪识别系统。该系统使用卷积神经网络(CNN)来区分猫咪和非猫咪图像。 ## 环境要求 - macOS 15.0 或更高版本 (MLX需要Apple Silicon) - Python 3.9 或更高版本 - Conda 环境 - 安装的依赖库: mlx, numpy, pillow, tqdm ## 环境设置 1. 创建并激活conda环境: ```bash conda create -n ai python=3.9 -y conda activate ai ``` 2. 安装依赖库: ```bash pip install mlx numpy pillow tqdm ``` ## 数据集准备 1. 创建一个数据集目录,结构如下: ``` dataset/ ├── cats/ │ ├── cat1.jpg │ ├── cat2.jpg │ └── ... └── non_cats/ ├── non_cat1.jpg ├── non_cat2.jpg └── ... ``` 2. 在`cats`目录中放入猫咪图像,在`non_cats`目录中放入非猫咪图像。 3. 图像尺寸会被自动调整为64x64像素,并转换为灰度图。 ## 训练模型 1. 运行训练脚本: ```bash python cat_recognition_mlx.py --data_dir /path/to/your/dataset ``` 你可以通过命令行参数指定数据集路径,而不需要修改代码。 2. 可用的命令行参数: - `--data_dir`: 数据集路径 (必填) - `--batch_size`: 批次大小,默认为32 - `--epochs`: 训练轮数,默认为10 - `--learning_rate`: 学习率,默认为0.001 - `--image_size`: 图像尺寸,默认为(64 64) 3. 示例: ```bash python cat_recognition_mlx.py --data_dir /Users/cody/Documents/Dataset --batch_size 64 --epochs 20 ``` ## 模型评估 训练过程中会自动在验证集上评估模型性能,并打印出训练损失、训练准确率、验证损失和验证准确率。 ### 评估函数实现 以下是训练代码中使用的评估逻辑: ```python # 验证过程 if len(val_loader) > 0: model.eval() val_loss = 0.0 correct = 0 total = 0 with mx.no_grad(): for images, labels in val_loader: images = images.transpose(0, 2, 3, 1) outputs = model.forward(images) loss = criterion(outputs, labels) val_loss += loss.item() predicted = (outputs > 0.5).astype(mx.float32) total += labels.shape[0] correct += (predicted == labels).sum().item() val_loss /= len(val_loader) val_acc = correct / total print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}") else: print("跳过验证: 验证集大小为0") ``` ## 保存模型 训练完成后,模型参数会通过`mlx.utils.tree_flatten`和`mx.savez`保存为`cat_classifier_model.npz`文件。代码如下: ```python from mlx.utils import tree_flatten # 展平参数 flat_params = tree_flatten(model.parameters()) # 保存模型参数 mx.savez("cat_classifier_model.npz", **dict(flat_params)) print("模型参数保存成功!") ``` ### 代码说明 1. `tree_flatten`函数将模型参数从嵌套字典结构展平为键值对列表 2. `dict(flat_params)`将展平的参数转换为字典 3. `mx.savez`将参数字典保存为NumPy压缩文件格式 ## 使用模型进行预测 我们提供了一个专门的推理脚本 `inference.py`,用于加载训练好的模型并进行猫咪识别。 ### 推理脚本使用方法 1. **基本使用**: ```bash python inference.py --image_path /path/to/your/image.jpg ``` 2. **可选参数**: - `--model_path`: 模型参数文件路径,默认为 `cat_classifier_model.npz` - `--image_size`: 图像大小,默认为 `(64 64)`,需要与训练时使用的图像大小保持一致 3. **示例**: ```bash python inference.py --image_path /Users/cody/Documents/Dataset/test/cat.jpg --model_path cat_classifier_model.npz --image_size 64 64 ``` ### 使用测试脚本 我们还提供了一个便捷的测试脚本 `test_inference.sh`,用法如下: 1. 给脚本添加执行权限: ```bash chmod +x test_inference.sh ``` 2. 运行脚本,指定图像路径: ```bash ./test_inference.sh /path/to/your/image.jpg ``` ### 推理脚本说明 `inference.py` 脚本包含以下主要功能: 1. **模型加载**: 加载与训练时相同结构的 `CatClassifier` 模型 2. **参数导入**: 导入训练好的模型参数 3. **图像预处理**: 调整图像大小、转换为灰度图、归一化并调整维度顺序 4. **预测**: 在无梯度模式下进行前向传播,输出猫咪概率 5. **结果解析**: 判断图像是否为猫咪并输出结果 ## 代码结构 - `cat_recognition_mlx.py`: 主代码文件,包含数据集加载、模型定义、训练和评估函数 - `inference.py`: 推理脚本,用于加载模型并进行猫咪识别 - `test_inference.sh`: 测试脚本,提供便捷的推理功能 - `simple_training.py`: 简单训练示例,展示MLX的基本使用方法 - `run_training.sh`: 训练脚本 (如果存在) - `README.md`: 项目说明文档 ## 注意事项 1. MLX框架目前仅支持Apple Silicon芯片(M1/M2/M3等)。 2. 确保你的数据集包含足够多的猫咪和非猫咪图像,以获得较好的训练效果。 3. 可以通过调整模型结构和训练参数来提高识别准确率。 4. 训练时间取决于你的数据集大小和硬件性能。 ## 参考资料 - [MLX官方文档](https://ml-explore.github.io/mlx/) - [Apple MLX: 机器学习框架](https://github.com/ml-explore/mlx)