Skip to content

Commit

Permalink
[#66550] model: API improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Mikolaj Klikowicz <mklikowicz@antmicro.com>
  • Loading branch information
mikolaj-klikowicz authored and glatosinski committed Nov 29, 2024
1 parent 0ca31d7 commit 0ff8479
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 53 deletions.
35 changes: 3 additions & 32 deletions demo_app/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,6 @@ void postprocess_output(uint8_t *data_in, float *data_out, size_t model_output_s
*/
void format_output(uint8_t *buffer, const size_t buffer_size, float *model_output);

/**
* Initialize main loader table
*/
status_t prepare_main_ldr_table()
{
static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel));
g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec;
return STATUS_OK;
}

int main(void)
{
status_t status = STATUS_OK;
Expand All @@ -88,8 +78,6 @@ int main(void)
int64_t timer_start = 0;
int64_t timer_end = 0;

prepare_main_ldr_table();

do
{
// initialize model
Expand All @@ -100,25 +88,12 @@ int main(void)
break;
}

// retrieve loaders
struct msg_loader *msg_loader_model = g_ldr_tables[1][LOADER_TYPE_MODEL];
struct msg_loader *msg_loader_data = g_ldr_tables[1][LOADER_TYPE_DATA];
struct msg_loader *msg_loader_iospec = g_ldr_tables[0][LOADER_TYPE_IOSPEC];

// load model structure
msg_loader_iospec->reset(msg_loader_iospec, 0);
status = msg_loader_iospec->save(msg_loader_iospec, (uint8_t *)(&model_struct), sizeof(MlModel));
BREAK_ON_ERROR_LOG(status, "iospec loader failed: %d", status);

status = model_load_struct();
status = model_load_struct((uint8_t *)&model_struct, sizeof(MlModel));
BREAK_ON_ERROR_LOG(status, "Model struct load error 0x%x (%s)", status, get_status_str(status));

// load model weights
msg_loader_model->reset(msg_loader_model, 0);
status = msg_loader_model->save(msg_loader_model, (uint8_t *)model_data, model_data_len);
BREAK_ON_ERROR_LOG(status, "Model loader failed: %d", status);

status = model_load_weights();
status = model_load_weights(model_data, model_data_len);
BREAK_ON_ERROR_LOG(status, "Model weights load error 0x%x (%s)", status, get_status_str(status));

// allocate buffer for input
Expand All @@ -135,11 +110,7 @@ int main(void)
{
preprocess_input((float *)data[batch_index], model_input, model_input_size);

msg_loader_data->reset(msg_loader_data, 0);
status = msg_loader_data->save(msg_loader_data, model_input, model_input_size);
BREAK_ON_ERROR_LOG(status, "Data loader failed: %d", status);

status = model_load_input();
status = model_load_input(model_input, model_input_size);
BREAK_ON_ERROR_LOG(status, "Model input load error 0x%x (%s)", status, get_status_str(status));

status = model_run();
Expand Down
12 changes: 9 additions & 3 deletions include/kenning_inference_lib/core/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ status_t model_init();
*
* @returns status of the model
*/
status_t model_load_struct();
status_t model_load_struct(const uint8_t *model_struct_data, const size_t data_size);

/**
* Loads model weights from given buffer
Expand All @@ -67,7 +67,7 @@ status_t model_load_struct();
*
* @returns status of the model
*/
status_t model_load_weights();
status_t model_load_weights(const uint8_t *model_weights_data, const size_t data_size);

/**
* Calculates model input size based on data from model struct
Expand All @@ -86,7 +86,13 @@ status_t model_get_input_size(size_t *model_input_size);
*
* @returns status of the model
*/
status_t model_load_input();
status_t model_load_input(const uint8_t *model_input, const size_t model_input_size);

status_t model_load_struct_from_loader();

status_t model_load_weights_from_loader();

status_t model_load_input_from_loader();

/**
* Runs model inference
Expand Down
6 changes: 3 additions & 3 deletions lib/kenning_inference_lib/core/callbacks.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ status_t data_callback(message_hdr_t *hdr, resp_message_t *resp) // TODO

VALIDATE_HEADER(MESSAGE_TYPE_DATA, hdr);

status = model_load_input();
status = model_load_input_from_loader();

CHECK_STATUS_LOG(status, resp, "model_load_input returned 0x%x (%s)", status, get_status_str(status));

Expand All @@ -126,7 +126,7 @@ status_t model_callback(message_hdr_t *hdr, resp_message_t *resp)

VALIDATE_HEADER(MESSAGE_TYPE_MODEL, hdr);

status = model_load_weights();
status = model_load_weights_from_loader();

CHECK_STATUS_LOG(status, resp, "model_load_weights returned 0x%x (%s)", status, get_status_str(status));

Expand Down Expand Up @@ -222,7 +222,7 @@ status_t iospec_callback(message_hdr_t *hdr, resp_message_t *resp)

VALIDATE_HEADER(MESSAGE_TYPE_IOSPEC, hdr);

status = model_load_struct();
status = model_load_struct_from_loader();

CHECK_STATUS_LOG(status, resp, "model_load_struct returned 0x%x (%s)", status, get_status_str(status));

Expand Down
17 changes: 5 additions & 12 deletions lib/kenning_inference_lib/core/inference_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,19 @@ int reset_runtime_alloc(struct msg_loader *ldr, size_t n)
return 0;
}

#endif // defined(CONFIG_LLEXT)

status_t prepare_main_ldr_table()
status_t prepare_llext_loader()
{
#if defined(CONFIG_LLEXT)
static struct msg_loader msg_loader_llext = {.save = buf_save,
.save_one = buf_save_one,
.reset = reset_runtime_alloc,
.written = 0,
.max_size = 0,
.addr = NULL};
g_ldr_tables[0][LOADER_TYPE_RUNTIME] = &msg_loader_llext;
#endif

static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel));
g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec;
return STATUS_OK;
}

#endif // defined(CONFIG_LLEXT)

status_t init_server()
{
status_t status = STATUS_OK;
Expand All @@ -90,10 +84,9 @@ status_t init_server()
#if !defined(CONFIG_LLEXT)
status = model_init();
CHECK_INIT_STATUS_RET(status, "model_init returned 0x%x (%s)", status, get_status_str(status));
#else
prepare_llext_loader();
#endif

prepare_main_ldr_table();

LOG_INF("Inference server started");
return STATUS_OK;
}
Expand Down
51 changes: 48 additions & 3 deletions lib/kenning_inference_lib/core/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,30 @@ MODEL_STATE model_get_state() { return g_model_state; }

void model_reset_state() { g_model_state = MODEL_STATE_UNINITIALIZED; }

status_t prepare_iospec_loader()
{
static struct msg_loader msg_loader_iospec = MSG_LOADER_BUF((uint8_t *)(&g_model_struct), sizeof(MlModel));
g_ldr_tables[0][LOADER_TYPE_IOSPEC] = &msg_loader_iospec;
return STATUS_OK;
}

status_t model_init()
{
status_t status = STATUS_OK;

status = runtime_init();
RETURN_ON_ERROR(status, status);

if (STATUS_OK == status)
{
g_model_state = MODEL_STATE_INITIALIZED;
}

status = prepare_iospec_loader();
return status;
}

status_t model_load_struct()
status_t model_load_struct_from_loader()
{
status_t status = STATUS_OK;

Expand Down Expand Up @@ -122,7 +131,7 @@ status_t model_load_struct()
return status;
}

status_t model_load_weights()
status_t model_load_weights_from_loader()
{
status_t status = STATUS_OK;

Expand Down Expand Up @@ -164,7 +173,7 @@ status_t model_get_input_size(size_t *model_input_size)
return status;
}

status_t model_load_input()
status_t model_load_input_from_loader()
{
status_t status = STATUS_OK;

Expand All @@ -189,6 +198,42 @@ status_t model_load_input()
return status;
}

status_t model_load_weights(const uint8_t *model_weights_data, const size_t data_size)
{
status_t status = STATUS_OK;
struct msg_loader *msg_loader_model = g_ldr_tables[1][LOADER_TYPE_MODEL];

msg_loader_model->reset(msg_loader_model, 0);
status = msg_loader_model->save(msg_loader_model, (uint8_t *)model_weights_data, data_size);
RETURN_ON_ERROR_LOG(status, status, "Model loader failed: %d", status);

return model_load_weights_from_loader();
}

status_t model_load_struct(const uint8_t *model_struct_data, const size_t data_size)
{
status_t status = STATUS_OK;
struct msg_loader *msg_loader_iospec = g_ldr_tables[0][LOADER_TYPE_IOSPEC];

msg_loader_iospec->reset(msg_loader_iospec, 0);
status = msg_loader_iospec->save(msg_loader_iospec, model_struct_data, data_size);
RETURN_ON_ERROR_LOG(status, status, "iospec loader failed: %d", status);

return model_load_struct_from_loader();
}

status_t model_load_input(const uint8_t *model_input, const size_t model_input_size)
{
status_t status = STATUS_OK;
struct msg_loader *msg_loader_data = g_ldr_tables[1][LOADER_TYPE_DATA];

msg_loader_data->reset(msg_loader_data, 0);
status = msg_loader_data->save(msg_loader_data, model_input, model_input_size);
RETURN_ON_ERROR_LOG(status, status, "Data loader failed: %d", status);

return model_load_input_from_loader();
}

status_t model_run()
{
status_t status = STATUS_OK;
Expand Down

0 comments on commit 0ff8479

Please sign in to comment.