Skip to content

Commit

Permalink
samples: tflite-micro: update samples to use MicroPrintf
Browse files Browse the repository at this point in the history
tflite-micro now uses MicroPrintf instead of MicroErrorReporter. Update
the samples to use this function instead.

Signed-off-by: Ryan McClelland <ryanmcclelland@meta.com>
  • Loading branch information
XenuIsWatching committed Apr 19, 2023
1 parent 01e4a21 commit 4abef35
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 76 deletions.
27 changes: 8 additions & 19 deletions samples/modules/tflite-micro/hello_world/src/main_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
#include "constants.h"
#include "model.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/system_setup.h>
#include <tensorflow/lite/schema/schema_generated.h>

/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
Expand All @@ -41,23 +40,14 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
* NOLINTNEXTLINE(runtime-global-variables)
*/
static tflite::MicroErrorReporter micro_error_reporter;

error_reporter = &micro_error_reporter;

/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_model);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}

Expand All @@ -68,13 +58,13 @@ void setup(void)

/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;

/* Allocate memory from the tensor_arena for the model's tensors. */
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
MicroPrintf("AllocateTensors() failed");
return;
}

Expand Down Expand Up @@ -106,8 +96,7 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n",
static_cast < double > (x));
MicroPrintf("Invoke failed on x: %f\n", static_cast < double > (x));
return;
}

Expand All @@ -119,7 +108,7 @@ void loop(void)
/* Output the results. A custom HandleOutput function can be implemented
* for each supported hardware target.
*/
HandleOutput(error_reporter, x, y);
HandleOutput(x, y);

/* Increment the inference_counter, and reset it if we have reached
* the total number per cycle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

#include "output_handler.hpp"

void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value)
void HandleOutput(float x_value, float y_value)
{
/* Log the current X and Y values */
TF_LITE_REPORT_ERROR(error_reporter, "x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
MicroPrintf("x_value: %f, y_value: %f\n",
static_cast < double > (x_value),
static_cast < double > (y_value));
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"

/* Called by the main loop to produce some output based on the x and y values */
void HandleOutput(tflite::ErrorReporter *error_reporter, float x_value,
float y_value);
void HandleOutput(float x_value, float y_value);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_ */
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,32 @@ float bufz[BUFLEN] = { 0.0f };

bool initial = true;

TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter)
TfLiteStatus SetupAccelerometer()
{
if (!device_is_ready(sensor)) {
printk("%s: device not ready.\n", sensor->name);
return kTfLiteApplicationError;
}

if (sensor == NULL) {
TF_LITE_REPORT_ERROR(error_reporter,
"Failed to get accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Failed to get accelerometer, name: %s\n",
sensor->name);
} else {
TF_LITE_REPORT_ERROR(error_reporter, "Got accelerometer, name: %s\n",
sensor->name);
MicroPrintf("Got accelerometer, name: %s\n",
sensor->name);
}
return kTfLiteOk;
}

bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
int length)
bool ReadAccelerometer(float *input, int length)
{
int rc;
struct sensor_value accel[3];
int samples_count;

rc = sensor_sample_fetch(sensor);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "Fetch failed\n");
MicroPrintf("Fetch failed\n");
return false;
}
/* Skip if there is no data */
Expand All @@ -72,7 +70,7 @@ bool ReadAccelerometer(tflite::ErrorReporter *error_reporter, float *input,
for (int i = 0; i < samples_count; i++) {
rc = sensor_channel_get(sensor, SENSOR_CHAN_ACCEL_XYZ, accel);
if (rc < 0) {
TF_LITE_REPORT_ERROR(error_reporter, "ERROR: Update failed: %d\n", rc);
MicroPrintf("ERROR: Update failed: %d\n", rc);
return false;
}
bufx[begin_index] = (float)sensor_value_to_double(&accel[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
#define kChannelNumber 3

#include <tensorflow/lite/c/c_api_types.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"

extern int begin_index;
extern TfLiteStatus SetupAccelerometer(tflite::ErrorReporter *error_reporter);
extern bool ReadAccelerometer(tflite::ErrorReporter *error_reporter,
float *input, int length);
extern TfLiteStatus SetupAccelerometer();
extern bool ReadAccelerometer(float *input, int length);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_ACCELEROMETER_HANDLER_H_ */
33 changes: 11 additions & 22 deletions samples/modules/tflite-micro/magic_wand/src/main_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
#include "gesture_predictor.hpp"
#include "magic_wand_model_data.hpp"
#include "output_handler.hpp"
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <tensorflow/lite/schema/schema_generated.h>

/* Globals, used for compatibility with Arduino-style sketches. */
namespace {
tflite::ErrorReporter *error_reporter = nullptr;
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *model_input = nullptr;
Expand All @@ -45,22 +44,14 @@ namespace {
/* The name of this function is important for Arduino compatibility. */
void setup(void)
{
/* Set up logging. Google style is to avoid globals or statics because of
* lifetime uncertainty, but since this has a trivial destructor it's okay.
*/
static tflite::MicroErrorReporter micro_error_reporter; /* NOLINT */

error_reporter = &micro_error_reporter;

/* Map the model into a usable data structure. This doesn't involve any
* copying or parsing, it's a very lightweight operation.
*/
model = tflite::GetModel(g_magic_wand_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
MicroPrintf("Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}

Expand All @@ -79,7 +70,7 @@ void setup(void)

/* Build an interpreter to run the model with. */
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize, error_reporter);
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;

/* Allocate memory from the tensor_arena for the model's tensors. */
Expand All @@ -91,24 +82,23 @@ void setup(void)
(model_input->dims->data[1] != 128) ||
(model_input->dims->data[2] != kChannelNumber) ||
(model_input->type != kTfLiteFloat32)) {
TF_LITE_REPORT_ERROR(error_reporter,
"Bad input tensor parameters in model");
MicroPrintf("Bad input tensor parameters in model");
return;
}

input_length = model_input->bytes / sizeof(float);

TfLiteStatus setup_status = SetupAccelerometer(error_reporter);
TfLiteStatus setup_status = SetupAccelerometer();
if (setup_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Set up failed\n");
MicroPrintf("Set up failed\n");
}
}

void loop(void)
{
/* Attempt to read new data from the accelerometer. */
bool got_data =
ReadAccelerometer(error_reporter, model_input->data.f, input_length);
ReadAccelerometer(model_input->data.f, input_length);

/* If there was no new data, wait until next time. */
if (!got_data) {
Expand All @@ -118,13 +108,12 @@ void loop(void)
/* Run inference, and report any error */
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n",
begin_index);
MicroPrintf("Invoke failed on index: %d\n", begin_index);
return;
}
/* Analyze the results to obtain a prediction */
int gesture_index = PredictGesture(interpreter->output(0)->data.f);

/* Produce an output */
HandleOutput(error_reporter, gesture_index);
HandleOutput(gesture_index);
}
11 changes: 4 additions & 7 deletions samples/modules/tflite-micro/magic_wand/src/output_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,21 @@

#include "output_handler.hpp"

void HandleOutput(tflite::ErrorReporter *error_reporter, int kind)
void HandleOutput(int kind)
{
/* light (red: wing, blue: ring, green: slope) */
if (kind == 0) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"WING:\n\r* * *\n\r * * * "
"*\n\r * * * *\n\r * * * *\n\r * * "
"* *\n\r * *\n\r");
} else if (kind == 1) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"RING:\n\r *\n\r * *\n\r * *\n\r "
" * *\n\r * *\n\r * *\n\r "
" *\n\r");
} else if (kind == 2) {
TF_LITE_REPORT_ERROR(
error_reporter,
MicroPrintf(
"SLOPE:\n\r *\n\r *\n\r *\n\r *\n\r "
"*\n\r *\n\r *\n\r * * * * * * * *\n\r");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#define TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_

#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"

void HandleOutput(tflite::ErrorReporter *error_reporter, int kind);
void HandleOutput(int kind);

#endif /* TENSORFLOW_LITE_MICRO_EXAMPLES_MAGIC_WAND_OUTPUT_HANDLER_H_ */
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <tensorflow/lite/micro/all_ops_resolver.h>
#include <tensorflow/lite/micro/cortex_m_generic/debug_log_callback.h>
#include <tensorflow/lite/micro/micro_error_reporter.h>
#include "tensorflow/lite/micro/micro_log.h"
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_profiler.h>
#include <tensorflow/lite/schema/schema_generated.h>
Expand Down Expand Up @@ -119,10 +119,8 @@ bool InferenceProcess::runJob(InferenceJob &job)

/* Create the TFL micro interpreter */
tflite::AllOpsResolver resolver;
tflite::MicroErrorReporter errorReporter;

tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize,
&errorReporter);
tflite::MicroInterpreter interpreter(model, resolver, tensorArena, tensorArenaSize);

/* Allocate tensors */
TfLiteStatus allocate_status = interpreter.AllocateTensors();
Expand Down

0 comments on commit 4abef35

Please sign in to comment.