Skip to content

Commit

Permalink
samples: tflite-micro: update samples for latest tflite-micro
Browse files Browse the repository at this point in the history
tflite-micro now uses MicroPrintf instead of MicroErrorReporter. Update
the samples to use this function instead. AllOpsResolver is now removed
from tflite-micro. AllOpsResolver was also removed in the latest
tflite-micro. Use MicroMutableOpResolver and only include the kernels
used instead.

Signed-off-by: Ryan McClelland <ryanmcclelland@meta.com>
  • Loading branch information
XenuIsWatching committed Jul 6, 2023
1 parent 33a9794 commit fc65780
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 81 deletions.
1 change: 1 addition & 0 deletions samples/modules/tflite-micro/hello_world/prj.conf
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# limitations under the License.
# ==============================================================================s
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_TENSORFLOW_LITE_MICRO=y
CONFIG_MAIN_STACK_SIZE=2048
34 changes: 12 additions & 22 deletions samples/modules/tflite-micro/hello_world/src/main_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@

#include "main_functions.h"

#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include "constants.h"
#include "model.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/system_setup.h>
#include <tensorflow/lite/schema/schema_generated.h>

/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
Expand All @@ -41,40 +40,32 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::MicroErrorReporter micro_error_reporter;

error_reporter = &micro_error_reporter;

/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}

/* This pulls in all the operation implementations we need.
/* This pulls in the operation implementations we need.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::AllOpsResolver resolver;
static tflite::MicroMutableOpResolver <1> resolver;
resolver.AddFullyConnected();

/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;

/* Allocate memory from the tensor_arena for the model's tensors. */
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
MicroPrintf("AllocateTensors() failed");
return;
}

Expand Down Expand Up @@ -106,8 +97,7 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
static_cast < double > (x));
MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
return;
}

Expand All @@ -119,7 +109,7 @@ void loop(void)
/* Output the results. A custom HandleOutput function can be implemented
* for each supported hardware target.
*/
HandleOutput(error_reporter, x, y);
HandleOutput(x, y);

/* Increment the inference_counter, and reset it if we have reached
* the total number per cycle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

#include "output_handler.hpp"

void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value)
void HandleOutput(float x_value, float y_value)
{
/* Log the current X and Y values */
TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
MicroPrintf("x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>

/* Called by the main loop to produce some output based on the x and y values */
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value);
void HandleOutput(float x_value, float y_value);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_ */
1 change: 1 addition & 0 deletions samples/modules/tflite-micro/magic_wand/prj.conf
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_NEWLIB_LIBC_FLOAT_PRINTF=y
CONFIG_SENSOR=y
CONFIG_NETWORKING=n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,32 @@ float bufz[BUFLEN] = { 0.0f };

bool initial = true;

TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
TfLiteStatus SetupAccelerometer()
{
if (!device_is_ready(sensor)) {
printk("%s: device not ready.\n", sensor->name);
return kTfLiteApplicationError;
}

if (sensor == NULL) {
TF_LITE_REPORT_ERROR(error_reporter,
"Failed to get accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Failed to get accelerometer, name: %s\n",
sensor->name);
} else {
TF_LITE_REPORT_ERROR(error_reporter, "Got accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Got accelerometer, name: %s\n",
sensor->name);
}
return kTfLiteOk;
}

bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
int length)
bool ReadAccelerometer(float *input, int length)
{
int rc;
struct sensor_value accel[3];
int samples_count;

rc = sensor_sample_fetch(sensor);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "Fetch failed\n");
MicroPrintf("Fetch failed\n");
return false;
}
/* Skip if there is no data */
Expand All @@ -72,7 +70,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
for (int i = 0; i < samples_count; i++) {
rc = sensor_channel_get(sensor, SENSOR_CHAN_ACCEL_XYZ, accel);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "ERROR: Update failed: %d\n", rc);
MicroPrintf("ERROR: Update failed: %d\n", rc);
return false;
}
bufx[begin_index] = (float)sensor_value_to_double(&accel[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
#define kChannelNumber 3

#include <tensorflow/lite/c/c_api_types.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>

extern int begin_index;
extern TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter);
extern bool ReadAccelerometer(tflite::ErrorReporter *error_reporter,
float *input, int length);
extern TfLiteStatus SetupAccelerometer();
extern bool ReadAccelerometer(float *input, int length);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_ACCELEROMETER_HANDLER_H_ */
33 changes: 11 additions & 22 deletions samples/modules/tflite-micro/magic_wand/src/main_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
#include "gesture_predictor.hpp"
#include "magic_wand_model_data.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <tensorflow/lite/schema/schema_generated.h>

/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *model_input = nullptr;
Expand All @@ -45,22 +44,14 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
*/
static tflite::MicroErrorReporter micro_error_reporter; /* NOLINT */

error_reporter = &micro_error_reporter;

/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}

Expand All @@ -79,7 +70,7 @@ void setup(void)

/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;

/* Allocate memory from the tensor_arena for the model's tensors. */
Expand All @@ -91,24 +82,23 @@ void setup(void)
(model_input->dims->data[1] != 128) ||
(model_input->dims->data[2] != kChannelNumber) ||
(model_input->type != kTfLiteFloat32)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
MicroPrintf("Bad input tensor parameters in model");
return;
}

input_length = model_input->bytes / sizeof(float);

TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
TfLiteStatus setup_status = SetupAccelerometer();
if (setup_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Set up failed\n");
MicroPrintf("Set up failed\n");
}
}

void loop(void)
{
/* Attempt to read new data from the accelerometer. */
bool got_data =
ReadAccelerometer(error_reporter, model_input->data.f, input_length);
ReadAccelerometer(model_input->data.f, input_length);

/* If there was no new data, wait until next time. */
if (!got_data) {
Expand All @@ -118,13 +108,12 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n",
begin_index);
MicroPrintf("Invoke failed on index: %d\n", begin_index);
return;
}
/* Analyze the results to obtain a prediction */
int gesture_index = PredictGesture(interpreter->output(0)->data.f);

/* Produce an output */
HandleOutput(error_reporter, gesture_index);
HandleOutput(gesture_index);
}
11 changes: 4 additions & 7 deletions samples/modules/tflite-micro/magic_wand/src/output_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,21 @@

#include "output_handler.hpp"

void HandleOutput(tflite::ErrorReporter *error_reporter, int kind)
void HandleOutput(int kind)
{
/* light (red: wing, blue: ring, green: slope) */
if (kind == 0) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"WING:\n\r* * *\n\r * * * "
"*\n\r * * * *\n\r * * * *\n\r * * "
"* *\n\r * *\n\r");
} else if (kind == 1) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"RING:\n\r *\n\r * *\n\r * *\n\r "
" * *\n\r * *\n\r * *\n\r "
" *\n\r");
} else if (kind == 2) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"SLOPE:\n\r *\n\r *\n\r *\n\r *\n\r "
"*\n\r *\n\r *\n\r * * * * * * * *\n\r");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>

void HandleOutput(tflite::ErrorReporter *error_reporter, int kind);
void HandleOutput(int kind);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_ */
1 change: 1 addition & 0 deletions samples/modules/tflite-micro/tflm_ethosu/prj.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#application default configuration
# include TFLM based on CMSIS NN optimization and ETHOSU acceleration
CONFIG_CPP=y
CONFIG_STD_CPP17=y
CONFIG_TENSORFLOW_LITE_MICRO=y
CONFIG_ARM_ETHOS_U=y
CONFIG_HEAP_MEM_POOL_SIZE=16384
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

#include "inference_process.hpp"

#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include <tensorflow/lite/micro/micro_log.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_profiler.h>
#include <tensorflow/lite/schema/schema_generated.h>
Expand Down Expand Up @@ -118,11 +118,10 @@ bool InferenceProcess::runJob(InferenceJob &job)
}

/* Create the TFL micro interpreter */
tflite::AllOpsResolver resolver;
tflite::MicroErrorReporter errorReporter;
tflite::MicroMutableOpResolver <1> resolver;
resolver.AddEthosU();

tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize,
&errorReporter);
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize);

/* Allocate tensors */
TfLiteStatus allocate_status = interpreter.AllocateTensors();
Expand Down

0 comments on commit fc65780

Please sign in to comment.