2829浏览
查看: 2829|回复: 0

[讨论] 如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...

[复制链接]

介绍

随着物联网和嵌入式系统的普及,如何在资源受限的硬件上实现高效的机器学习推理成为一个重要课题。行空板 K10 AI 硬件结合 TensorFlow Lite Micro (TFLM),为嵌入式设备上的 TinyML 开发 提供了轻量级的解决方案。从模型训练到最终部署,本文将详细讲解如何通过 K10 硬件 和 Arduino TensorFlow 开发环境,快速实现 嵌入式机器学习 应用。同时,这一教程也适用于类似的硬件平台,例如 ESP32 S3 AI 设备。

本文将结合 DFRobot 行空板K10 硬件 和 TensorFlow-Lite-Mirco 库,介绍如何从模型训练到最终部署,实现一个完整的 TinyML 开发过程,帮助开发者快速上手资源受限设备上的机器学习推理。代码下载链接:https://github.com/polamaxu/TFLM

行空板K10

行空板 K10是一款专为信息科技课程中的编程学习、物联网及AI项目教学需求而开发的学习板。

核心

  • MCU:ESP32-S3,32位双核处理器,主频为240MHz.

无线通信

  • Wi-Fi:支持2.4G Wi-Fi
  • 蓝牙:支持蓝牙5.0

显示模块

  • 显示屏:2.8英寸彩色LCD屏

板载元件

  • 摄像头:集成摄像头,支持图像采集和检测,可以用于人脸识别、物体检测等人工智能应用.
  • 麦克风:配备两个麦克风,能够捕捉声音信号,支持语音识别和语音交互功能.
  • 扬声器:内置扬声器,可以播放音频,用于语音提示和音乐播放等.
  • 传感器:包括数字环境光传感器、温湿度传感器和加速度传感器,能够检测环境参数和运动状态,为物联网应用提供数据支持.
  • RGB灯:有三个RGB灯,可以显示多种颜色
  • 实体按键:有两个实体按键,可以用于输入和控制操作.

Tensorflow vs TensorFlow-Lite-Mirco

TensorFlow 是一个开源机器学习框架。TensorFlow 提供了丰富的 API,使得开发者可以轻松地构建、训练和部署机器学习模型。
TensorFlow Lite Micro 是 TensorFlow 的一个子集,专门为微控制器等资源受限的设备设计。它允许开发者将机器学习模型部署到这些设备上,使得设备能够直接在本地进行数据处理和推理,而不需要依赖云端计算。

  • TensorFlow:适合在桌面或服务器环境中进行复杂的模型训练与开发,支持高度定制化和复杂操作。
  • TensorFlow Lite for Microcontrollers (TFLM):为嵌入式设备设计,优化了存储与计算效率,适用于资源受限环境中快速运行机器学习模型。
    如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图1

TinyML开发流程(软件Arduino IDE、硬件行空板K10)

数据采集:

  • 创建一个 Arduino Sketch 用于从传感器或其他输入设备收集数据,以形成模型训练所需的数据集。

定义与训练模型:

  • 在 TensorFlow 开发环境中(如 Google Colab)定义一个深度神经网络模型(DNN),并使用采集到的数据对模型进行训练。

模型转换与保存:

  • 将训练好的模型转换为 TensorFlow Lite 格式,并保存为 model.h 文件,其中包含模型的二进制表示(Flat Buffer 格式)。

部署推理代码:

  • 在 Arduino IDE 中准备推理代码,包含以下步骤:
  • 引入必要头文件(如 TensorFlow Lite Micro 和模型文件 model.h)。
  • 定义 TensorArena(内存缓冲区)。
  • 初始化模型。
  • 设置输入数据并运行推理。
  • 读取推理输出。

测试与优化:

  • 通过串口调试工具查看推理结果,并根据需要优化模型性能。

MicroTFLite介绍

MicroTFLite 是一个为 Arduino 设计的 TensorFlow Lite Micro 库,旨在简化在 Arduino 板上使用 TensorFlow Lite Micro 的过程。MicroTFLite库适用于各种任务,如分类、回归、预测。

1. 特点

  • Arduino 风格的 API:MicroTFLite 提供了典型的 Arduino 风格的 API,避免了在 Arduino 代码中使用指针或其他 C++ 语法构造,这使得它更适合 Arduino 开发者使用。
  • 支持量化数据和浮点数据:该库能够处理量化数据和原始浮点值,根据模型的元数据自动检测适当的处理方式。
  • 调试功能:提供了多个函数来帮助开发者了解模型部署的过程,并在调试模型问题时提供帮助。

2. 目标平台

MicroTFLite 适用于多种嵌入式设备,包括:

  • Arduino Nano 系列(如 Nano 33 BLE、Nano ESP32)
  • Arduino Nicla、Portenta 系列
  • ESP32 和 Arduino Giga R1 WiFi

3. MicroTFLite支持的算子

MicroTFLite 支持多种常见的机器学习算子:

  • 卷积与池化:
    • 二维卷积(CONV_2D)、深度可分离卷积(DEPTHWISE_CONV_2D)
    • 最大池化(MAX_POOL_2D)、平均池化(AVERAGE_POOL_2D)
  • 全连接层:FULLY_CONNECTED
  • 激活函数:
    • ReLU、ReLU6、TANH、SIGMOID

4. MicroTFLite API一览

  • 模型初始化和部署
    • 初始化模型:通过 ModelInit() 函数,可以初始化 TensorFlow Lite 模型和解释器。这包括加载模型数据和分配必要的内存空间(如 tensorArena).
    • 加载模型文件:支持从二进制文件(如model.h)中加载预训练的 TensorFlow Lite 模型,使其能够在 Arduino 设备上运行.
  • 输入和输出处理
    • 设置输入数据:使用 ModelSetInput() 函数可以将输入数据设置到模型的输入张量中。该函数还支持量化处理,能够根据模型的量化参数自动调整输入值.
    • 获取输出结果:通过 ModelGetOutput() 函数可以读取模型推理后的输出结果。这使得开发者能够获取模型的预测值或分类结果.
  • 推理和调试
    • 运行推理:调用 ModelRunInference() 函数可以启动模型的推理过程。该函数执行模型的前向传播,生成输出结果.
    • 打印张量信息:提供了多个函数来打印模型的张量信息,如 ModelPrintInputTensorDimensions()ModelPrintOutputTensorDimensions(),这些函数可以帮助开发者了解输入和输出张量的维度.
    • 调试和调试信息:ModelPrintTensorQuantizationParams() 函数可以打印输入和输出张量的量化参数,这对于调试量化模型非常有用。ModelPrintMetadata()函数可以打印模型的元数据信息,如描述和版本,帮助开发者了解模型的基本信息.

MicroTFLite安装

1. arduinoIDE 1.8.19 preferences安装行空板K10环境
https://downloadcd.dfrobot.com.cn/UNIHIKER/package_unihiker_index.json

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图2

2. Library Manager搜索MicroTFLite安装
如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图3

3. 示例测试:通过传感器数据实时预测工业设备故障
ArduinoLite_preventive_maintenance项目使用的是一个前馈神经网络架构,包含输入层、两个 Dropout 层、一个隐藏层和输出层,处理二元分类任务。使用 binary_crossentropy 作为损失函数和 adam 优化器进行训练,支持类别权重调整和早停回调,确保在类别不平衡时仍能有效训练。示例中使用了随机数模拟的传感器数据,比如转速、温度、振动、电流。为了减少模型大小,此模型量化成int8。

这个程序的主要功能是:

  • 模拟采集设备的各项参数(转速、温度、振动、电流)
  • 使用这些参数进行故障预测
  • 同时通过预设的阈值进行实际故障判断
  • 比较预测结果和实际故障情况,统计预测准确度

量化作用和影响:TensorFlow Lite for Microcontrollers (TFLM) 的量化是将模型中的数据从浮点数(float32)转换为整数(int8)的过程。

  • 减少模型大小:量化将模型中的浮点数转换为整数,显著减少了模型的存储需求。
  • 提高计算效率:在资源受限的设备上,整数计算通常比浮点计算更高效。整数运算器(如 ARM Cortex-M 的 DSP 扩展)可以更快地执行整数乘法和加法等操作,从而提高模型的推理速度.
  • 精度损失:量化会引入一定的精度损失,因为整数表示无法完全精确地表示浮点数。然而,通过合理的量化策略和模型优化,可以在保持较高精度的同时实现显著的性能提升.
/* Copyright 2024 John O'Sullivan, TensorFlow Authors. All Rights Reserved.

这是一个使用 MicroTFLite 库运行 TensorFlow Lite 模型的示例程序
主要用于设备的预防性维护,通过传感器数据预测可能的故障

更多信息请参考库文档:
https://github.com/johnosbb/MicroTFLite

Licensed under the Apache License, Version 2.0 (the "License");
... 许可证信息 ...
==============================================================================*/

#include <MicroTFLite.h>
#include "model.h"

// 特征数据的统计信息(缩放和平衡前):
// 转速(RPM) - 平均值: 1603.866, 标准差: 195.843
// 温度(°C) - 平均值: 24.354, 标准差: 4.987
// 振动(g) - 平均值: 0.120, 标准差: 0.020
// 电流(A) - 平均值: 3.494, 标准差: 0.308
constexpr float tMean = 24.354f;      // 温度平均值
constexpr float rpmMean = 1603.866f;  // 转速平均值
constexpr float vMean = 0.120f;       // 振动平均值
constexpr float cMean = 3.494f;       // 电流平均值
constexpr float tStd = 4.987f;        // 温度标准差
constexpr float rpmStd = 195.843f;    // 转速标准差
constexpr float vStd = 0.020f;        // 振动标准差
constexpr float cStd = 0.308f;        // 电流标准差

// 定义故障条件的阈值
const float highTempThreshold = 30.0f;            // 温度过高阈值(摄氏度)
const float lowRpmThreshold = 1500.0f;            // 转速过低阈值
const float highVibrationThreshold = 0.60f;       // 振动过高阈值(g)
const float abnormalCurrentLowThreshold = 0.2f;   // 电流过低阈值(安培)
const float abnormalCurrentHighThreshold = 10.8f; // 电流过高阈值(安培)

// 预测统计计数器
int totalPredictions = 0;    
int truePositives = 0;       
int falsePositives = 0;      
int trueNegatives = 0;       
int falseNegatives = 0;      
float rollingAccuracy = 0.0f;// 滚动计算的准确率
bool showStatistics = false; // 是否显示统计信息

// 为TensorFlow Lite分配内存
constexpr int kTensorArenaSize = 4 * 1024;
alignas(16) uint8_t tensorArena[kTensorArenaSize];

void setup() {
    // 初始化串口通信并等待串口监视器打开
    Serial.begin(115200);
    while (!Serial);
    delay(5000);
    Serial.println("Preventative Maintenance Example.");
    Serial.println("Initializing TensorFlow Lite Micro Interpreter...");

    // 初始化TensorFlow Lite模型
    if (!ModelInit(model, tensorArena, kTensorArenaSize)) {
        Serial.println("Model initialization failed!");
        while (true);
    }
    Serial.println("Model initialization done.");
    ModelPrintMetadata();
    ModelPrintTensorQuantizationParams();
    ModelPrintTensorInfo();
}

// 使用Box-Muller变换生成正态分布的随机值
float GenerateRandomValue(float mean, float stddev) {
    float u1 = random(0, 10000) / 10000.0f;
    float u2 = random(0, 10000) / 10000.0f;
    // Box-Muller变换生成正态分布值
    float z0 = sqrt(-2.0f * log(u1)) * cos(2.0f * PI * u2);
    float value = mean + z0 * stddev;
    return value;
}

// 模拟读取传感器数据的函数
float ReadRpm() {
    return GenerateRandomValue(rpmMean, rpmStd);
}

float ReadVibration() {
    return GenerateRandomValue(vMean, vStd);
}

float ReadTemperature() {
    return GenerateRandomValue(tMean, tStd);
}

float ReadCurrent() {
    return GenerateRandomValue(cMean, cStd);
}

// 检查是否存在故障条件
bool CheckFailureConditions(float temperature, float rpm, float vibration, float current) {
    bool conditionMet = false;
    String failureReason = "";

    // 检查各项参数是否超过阈值
    if (temperature > highTempThreshold) {
        conditionMet = true;
        failureReason += "温度过高; ";
    }
    if (rpm < lowRpmThreshold) {
        conditionMet = true;
        failureReason += "转速过低; ";
    }
    if (vibration > highVibrationThreshold) {
        conditionMet = true;
        failureReason += "振动过大; ";
    }
    if (current < abnormalCurrentLowThreshold) {
        conditionMet = true;
        failureReason += "电流过低; ";
    }
    if (current > abnormalCurrentHighThreshold) {
        conditionMet = true;
        failureReason += "电流过高; ";
    }

    if (conditionMet) {
        Serial.print("注意:传感器读数表明可能存在故障。原因: ");
        Serial.println(failureReason);
    }

    return conditionMet;
}

void loop() {
    // 读取传感器数据(这里使用模拟数据)
    float rpm = ReadRpm();
    float temperature = ReadTemperature();
    float current = ReadCurrent();
    float vibration = ReadVibration();
    bool inputSetFailed = false;

    // 检查实际故障条件
    bool actualFailure = CheckFailureConditions(temperature, rpm, vibration, current);

    // 数据标准化处理
    float temperatureN = (temperature - tMean) / tStd;
    float rpmN = (rpm - rpmMean) / rpmStd;
    float currentN = (current - cMean) / cStd;
    float vibrationN = (vibration - vMean) / vStd;

    // 将数据输入到模型中
    if (!ModelSetInput(temperatureN, 0))
        inputSetFailed = true;
    if (!ModelSetInput(rpmN, 1))
        inputSetFailed = true;
    if (!ModelSetInput(currentN, 2))
        inputSetFailed = true;
    if (!ModelSetInput(vibrationN, 3))
        inputSetFailed = true;

    // 运行模型推理
    if (!ModelRunInference()) {
        Serial.println("模型推理失败!");
        return;
    }

    // 获取模型输出结果
    float prediction = ModelGetOutput(0);
    bool predictedFailure = (prediction > 0.50f); // 预测值>0.5表示可能发生故障

    // 更新预测统计信息
    if (predictedFailure && actualFailure) {
        truePositives++;
    } else if (predictedFailure && !actualFailure) {
        falsePositives++;
        showStatistics = true;
    } else if (!predictedFailure && actualFailure) {
        falseNegatives++;
        showStatistics = true;
    } else if (!predictedFailure && !actualFailure) {
        trueNegatives++;
    }

    totalPredictions++;

    // 计算滚动准确率
    rollingAccuracy = (float)(truePositives + trueNegatives) / totalPredictions * 100.0f;

    // 当出现假阳性或假阴性时显示统计信息
    if (showStatistics) {
        Serial.println("-----------------------------");
        Serial.print("预测置信度: ");
        Serial.println(prediction);
        Serial.print("预测是否故障: ");
        Serial.println(predictedFailure);
        Serial.print("实际是否故障: ");
        Serial.println(actualFailure);
        Serial.print("转速: ");
        Serial.print(rpm);
        Serial.print(" | 温度: ");
        Serial.print(temperature);
        Serial.print(" °C");
        Serial.print(" | 电流: ");
        Serial.print(current);
        Serial.print(" A");
        Serial.print(" | 振动: ");
        Serial.print(vibration);
        Serial.print(" m/s^2\n");
        Serial.println("总预测次数: " + String(totalPredictions) + 
                      ", 真阳性: " + String(truePositives) + 
                      ", 假阳性: " + String(falsePositives) + 
                      ", 真阴性: " + String(trueNegatives) + 
                      ", 假阴性: " + String(falseNegatives) + 
                      ", 准确率(%): " + String(rollingAccuracy));

        showStatistics = false;
    }

    delay(10000); // 等待10秒后进行下一次检测
}

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图4

4. 电脑连接k10,上传程序:
如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图5

5. k10编译结果:程序占了很少的存储空间和动态内存
Sketch uses 539913 bytes (10%) of program storage space. Maximum is 5242880 bytes.
Global variables use 29740 bytes (9%) of dynamic memory, leaving 297940 bytes for local variables.

6. 串口监视器输出:

  • 转速(RPM):1493.15
  • 温度(Temperature):32.01°C
  • 电流(Current):3.52 A
  • 振动(Vibration):0.09 m/s^2
    7:59:07.295,串口输出指出传感器读数显示潜在的故障条件,原因是高温(High Temperature)。在 7:59:37.328,串口输出再次指出潜在的故障条件,原因包括高温和低转速。
    如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图6

自己训练模型并部署

1. PC上安装python环境
从 Python 官网下载并安装 Python 3.6 或更高版本:Python 官网

2. PC上安装TensorFlow和NumPy库
打开终端(Windows 上为命令提示符,macOS 或 Linux 上为终端),然后运行以下命令来安装所需的库:

# 安装 TensorFlow 2.x(默认为最新版本) 
pip install tensorflow  
# 安装 NumPy 
pip install numpy

3.  定义简单的回归网络,生成模型后再转换成C文件
目标:输入为1到5的数字,输出为其平方值
终端上运行test.py文件:

python test.py

生成的模型文件下载链接:https://github.com/polamaxu/TFLM
代码详情:

"""
环境要求:
- Python 3.6+
- TensorFlow 2.x
- NumPy

功能说明:
这个程序演示了一个简单的回归模型,用于预测数字的平方值。
包含模型训练、保存、转换为TFLite格式,以及生成C数组三个主要步骤。
"""

import tensorflow as tf
import numpy as np

# 1. 准备训练数据
# 创建一个简单的数据集:输入为1到5的数字,输出为其平方值
x_train = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y_train = np.square(x_train)  # 计算平方值作为目标输出

# 2. 定义模型
# 使用Sequential API创建一个简单的模型
# Lambda层直接实现了平方运算,无需训练参数
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,), name='input_layer'),
    tf.keras.layers.Lambda(lambda x: x ** 2, name='square_layer')
])

# 3. 编译模型
# 注:由于使用Lambda层直接计算平方,实际上不需要训练
model.compile(optimizer='sgd', loss='mean_squared_error')

# 4. 测试模型性能
print("测试结果:")
for x in x_train:
    y_pred = model.predict([x])
    print(f"输入:{x[0]}, 预测输出:{y_pred[0][0]}")

# 5. 保存模型
# 将模型保存为HDF5格式
model.save('./saved_model.h5')

# 6. 转换为TFLite格式
# TFLite格式更适合在嵌入式设备上运行
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# 保存TFLite模型
with open('./model.tflite', 'wb') as f:
    f.write(tflite_model)

# 7. 将TFLite模型转换为C数组
def convert_tflite_to_c_array(tflite_model_path, c_file_path):
    """
    将TFLite模型文件转换为C语言数组格式

    参数:
        tflite_model_path: TFLite模型文件路径
        c_file_path: 输出的C头文件路径
    """
    with open(tflite_model_path, 'rb') as file:
        content = file.read()

    with open(c_file_path, 'w') as file:
        file.write('const unsigned char model_data[] = {')
        for i, val in enumerate(content):
            if i % 12 == 0:  # 每12个数字换行,提高可读性
                file.write('\n  ')
            file.write('0x{:02x}, '.format(val))
        file.write('\n};\n')
        file.write('const int model_data_len = {};\n'.format(len(content)))

if __name__ == '__main__':
    convert_tflite_to_c_array('./model.tflite', './model_data.h'

4. 将生成的模型文件上传至K10硬件,通过MicroTFLite库进行推理操作。
将训练生成的model.h与ino文件放在同一个文件夹中,在arduinoIDE软件中上传程序

#include <MicroTFLite.h>
#include "model.h" // Import model data

// Define memory area for storing intermediate tensors
constexpr int kTensorArenaSize = 2 * 1024; // Adjust based on model size
alignas(16) uint8_t tensorArena[kTensorArenaSize];

void setup() {
  // Initialize serial communication
  Serial.begin(9600);

  // Initialize model
  if (!ModelInit(model, tensorArena, kTensorArenaSize)) {
    Serial.println("Model initialization failed!");
    while (true); // Stop execution if initialization fails
  }

  Serial.println("Model initialization done.");
  ModelPrintMetadata();
  ModelPrintTensorInfo();
  ModelPrintInputTensorDimensions(); // Print input tensor dimensions
  ModelPrintOutputTensorDimensions(); // Print output tensor dimensions
}

void loop() {
  float input_value = 3.0; // Test input value

  // Display the input value
  Serial.print("Input value: ");
  Serial.println(input_value);

  // Set input value to the model
  if (!ModelSetInput(input_value, 0)) { // Set the first input
    Serial.println("Failed to set input value!");
    return;
  }

  // Run inference
  if (!ModelRunInference()) {
    Serial.println("Inference failed!");
    return;
  }

  // Get the model's output
  float prediction = ModelGetOutput(0);

  // Print the output result (predicted square value)
  Serial.print("Predicted output: ");
  Serial.println(prediction);

  delay(2000); // Run every 2 seconds
}

5. k10编译结果:程序占了很少的存储空间和动态内存
Sketch uses 530373 bytes (10%) of program storage space. Maximum is 5242880 bytes.
Global variables use 28492 bytes (8%) of dynamic memory, leaving 299188 bytes for local variables. Maximum is 327680 bytes.

6. k10输出:

如何通过Arduino IDE在行空板K10 (ESP32 S3) 上使用TensorFlow Li...图7

结论

MicroTFLite 库为嵌入式开发者提供了高效的工具,使得在资源受限平台(如 K10)上运行机器学习模型变得简单易行。从模型训练到部署推理,开发者可以快速完成 TinyML 项目,推动智能嵌入式设备的发展。
如果你有想要了解的,也请告诉我们,我们会推出更多 tinyml 项目教程。

参考

  1. https://github.com/tensorflow/tflite-micro
  2. https://github.com/johnosbb/MicroTFLite
  3. https://github.com/polamaxu/TFLM

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

为本项目制作心愿单
购买心愿单
心愿单 编辑
[[wsData.name]]

硬件清单

  • [[d.name]]
btnicon
我也要做!
点击进入购买页面
上海智位机器人股份有限公司 沪ICP备09038501号-4 备案 沪公网安备31011502402448

© 2013-2025 Comsenz Inc. Powered by Discuz! X3.4 Licensed

mail