Skip to content

Commit

Permalink
coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
uecker committed Sep 13, 2024
1 parent c83e544 commit dbe0f06
Showing 1 changed file with 82 additions and 42 deletions.
124 changes: 82 additions & 42 deletions src/networks/nlinvnet.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023. Institute of Biomedical Imaging. TU Graz.
/* Copyright 2023-2024. Institute of Biomedical Imaging. TU Graz.
* All rights reserved. Use of this source code is governed by
* a BSD-style license which can be found in the LICENSE file.
*
Expand Down Expand Up @@ -145,7 +145,7 @@ void nlinvnet_init(struct nlinvnet_s* nlinvnet, int N,
nlinvnet->iter_conf->tol = nlinvnet->cgtol;

if (NULL == get_loss_from_option())
nlinvnet->train_loss->weighting_mse = 1.;
nlinvnet->train_loss->weighting_mse = 1.;

if (NULL == get_val_loss_from_option())
nlinvnet->valid_loss = &loss_image_valid;
Expand All @@ -164,34 +164,34 @@ void nlinvnet_init(struct nlinvnet_s* nlinvnet, int N,

long tcol_dims[N];
md_copy_dims(N, tcol_dims, col_dims);

for (int i = 0; i < 3; i++)
if ((0 < nlinvnet->senssize) && 1 < tcol_dims[i])
if ((0 < nlinvnet->senssize) && (1 < tcol_dims[i]))
tcol_dims[i] = nlinvnet->senssize;

nlinvnet->model = noir2_net_config_create(N, trj_dims, wgh_dims, bas_dims, basis, NULL, NULL, ksp_dims, cim_dims, img_dims, tcol_dims, TIME_FLAG, &model_conf);
}

static nn_t nlinvnet_sort_args_F(nn_t net)
{

const char* data_names[] =
{
"ref",
"ksp",
"pat",
"trj",
"loss_mask",
"prev_frames"
};
const char* data_names[] = {
"ref",
"ksp",
"pat",
"trj",
"loss_mask",
"prev_frames"
};

int N = nn_get_nr_named_in_args(net);
const char* sorted_names[N + (int)ARRAY_SIZE(data_names) + 3];

nn_get_in_names_copy(N, sorted_names + ARRAY_SIZE(data_names) + 3, net);

for (int i = 0; i < (int)ARRAY_SIZE(data_names); i++)
sorted_names[i] = data_names[i];

sorted_names[ARRAY_SIZE(data_names)] = "lam";
sorted_names[ARRAY_SIZE(data_names) + 0] = "lam";
sorted_names[ARRAY_SIZE(data_names) + 1] = "lam_sens";
sorted_names[ARRAY_SIZE(data_names) + 2] = "alp";

Expand Down Expand Up @@ -287,6 +287,7 @@ static nn_t nlinvnet_network_create(const struct nlinvnet_s* nlinvnet, int N, co

network = nn_chain2_FF(network, 0, NULL, nn_from_linop_F(linop_slice_one_create(N, TIME_DIM, pos, nn_generic_codomain(network, 0, NULL)->dims)), 0, NULL);
network = nn_chain2_FF(network, 0, NULL, nn_from_linop_F(linop_reshape2_create(N, BATCH_FLAG | TIME_FLAG , _img_dims, nn_generic_codomain(network, 0, NULL)->dims)), 0, NULL);

} else {

network = network_create(nlinvnet->network, N, _img_dims, N, _img_dims, status);
Expand Down Expand Up @@ -328,12 +329,14 @@ static nn_t nlinvnet_get_network_step(const struct nlinvnet_s* nlinvnet, struct
for (int i = 0; i < N_in_names; i++) {

network = nn_append_singleton_dim_in_F(network, 0, in_names[i]);

xfree(in_names[i]);
}

for (int i = 0; i < N_out_names; i++) {

network = nn_append_singleton_dim_out_F(network, 0, out_names[i]);

xfree(out_names[i]);
}

Expand All @@ -345,6 +348,7 @@ static nn_t nlinvnet_get_network_step(const struct nlinvnet_s* nlinvnet, struct
network = nn_chain2_FF(network, 0, NULL, nn_from_linop_F(linop_slice_one_create(N, COIL_DIM, 0, img_dims)), 0, NULL);
network = nn_chain2_FF(nn_from_nlop_F(nlop_stack_create(N, img_dims, img_one_dims, img_one_dims, COIL_DIM)), 0, NULL, network, 0, NULL);
network = nn_set_input_name_F(network, 0, "ref_img");

} else {

auto dummy = nn_from_nlop_F(nlop_del_out_create(N, img_one_dims));
Expand All @@ -368,6 +372,7 @@ static nn_t nlinvnet_get_network_step(const struct nlinvnet_s* nlinvnet, struct
nn_shift = nn_set_input_name_F(nn_shift, 0, "ref_col");

join = nn_chain2_FF(nn_shift, 0, NULL, join, 1, NULL);

} else {

auto dom = nn_generic_domain(join, 1, NULL);
Expand Down Expand Up @@ -542,6 +547,7 @@ static nn_t nlinvnet_create(const struct nlinvnet_s* nlinvnet, struct noir2_net_

for (int i = 0; i < N_in_names; i++)
xfree(in_names[i]);

for (int i = 0; i < N_out_names; i++)
xfree(out_names[i]);
}
Expand Down Expand Up @@ -610,6 +616,7 @@ static nn_t nlinvnet_create(const struct nlinvnet_s* nlinvnet, struct noir2_net_
if (0 > nlinvnet->scaling) {

nlop_scale = nlop_norm_znorm_create(N, cim_dims, BATCH_FLAG);

} else {

complex float one[1] = { 1. };
Expand Down Expand Up @@ -646,13 +653,15 @@ static nn_t nlinvnet_create(const struct nlinvnet_s* nlinvnet, struct noir2_net_
result = nn_chain2_swap_FF(nn_from_nlop_F(nlop_adj), 0, NULL , result, 0, NULL);
result = nn_set_input_name_F(result, 0, "ksp");
result = nn_set_input_name_F(result, 0, "pat");

if (nlinvnet->conf->noncart)
result = nn_set_input_name_F(result, 0, "trj");

// normalize output
auto cod = nn_generic_codomain(result, 0, NULL);
long cdims[2] = { cod->dims[0] * cod->dims[1] / sdims[BATCH_DIM], sdims[BATCH_DIM]};
long tdims[2] = { cod->dims[0], cod->dims[1]};

result = nn_reshape_out_F(result, 0, NULL, 2, cdims);
result = nn_chain2_FF(result, 0, NULL, nn_from_nlop_F(nlop_tenmul_create(2, cdims, cdims, (long[2]){ 1, sdims[BATCH_DIM] })), 0, NULL);
result = nn_link_F(result, 0, "scale_sqrt", 0, NULL);
Expand Down Expand Up @@ -688,6 +697,7 @@ static nn_t nlinvnet_apply_op_create(const struct nlinvnet_s* nlinvnet, int Nb)
N_weights++;

complex float zero[1] = { 0 };

if (nlinvnet->weights->N + 1 == N_weights)
nn_apply = nn_set_input_const_F(nn_apply, 0, "lam_sens", 1, MD_DIMS(1), true, zero);

Expand Down Expand Up @@ -738,11 +748,13 @@ static nn_t nlinvnet_train_loss_create(const struct nlinvnet_s* nlinvnet, int Nb
long time = nlop_generic_domain(nlop_reg, 1)->dims[TIME_DIM];

int N = nlop_generic_domain(nlop_reg, 1)->N;

long tdims[N];
md_singleton_dims(N, tdims);
tdims[TIME_DIM] = time;

complex float mask[time];

for (int i = 0; i < time; i++) {

if ((i >= nlinvnet->time_mask[0]) && ((-1 == nlinvnet->time_mask[1]) || (i < nlinvnet->time_mask[1])))
Expand Down Expand Up @@ -812,6 +824,7 @@ static nn_t nlinvnet_train_loss_create(const struct nlinvnet_s* nlinvnet, int Nb
long time = tdims[TIME_DIM];

complex float mask[time];

for (int i = 0; i < time; i++) {

if ((i >= nlinvnet->time_mask[0]) && ((-1 == nlinvnet->time_mask[1]) || (i < nlinvnet->time_mask[1])))
Expand Down Expand Up @@ -856,6 +869,7 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_
{
auto ref_iov = named_data_list_get_iovec(train_data, "ref");
long Nt = ref_iov->dims[BATCH_DIM];

iovec_free(ref_iov);

Nb = MIN(Nb, Nt);
Expand Down Expand Up @@ -892,16 +906,22 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_
long pat_dims[N];
md_copy_dims(N, pat_dims, dom->dims);

const complex float* use_reco = NULL;
long use_reco_dims[DIMS];
unsigned long use_reco_nontriv = 0UL;

if (NULL != nlinvnet->use_reco_file) {

long use_reco_dims[DIMS];
const complex float* use_reco = load_cfl(nlinvnet->use_reco_file, DIMS, use_reco_dims);
nlop_rand_split = nlop_rand_split_fixed_create(N, pat_dims, nlinvnet->ksp_shared_dims, BATCH_FLAG | TIME_FLAG, nlinvnet->ksp_split, md_nontriv_dims(DIMS, use_reco_dims), use_reco, nlinvnet->ksp_leaky);
unmap_cfl(DIMS, use_reco_dims, use_reco);
} else {
use_reco = load_cfl(nlinvnet->use_reco_file, DIMS, use_reco_dims);

nlop_rand_split = nlop_rand_split_fixed_create(N, pat_dims, nlinvnet->ksp_shared_dims, BATCH_FLAG | TIME_FLAG, nlinvnet->ksp_split, 0, NULL, nlinvnet->ksp_leaky);
use_reco_nontriv = md_nontriv_dims(DIMS, use_reco_dims);
}

nlop_rand_split = nlop_rand_split_fixed_create(N, pat_dims, nlinvnet->ksp_shared_dims, BATCH_FLAG | TIME_FLAG,
nlinvnet->ksp_split, use_reco_nontriv, use_reco, nlinvnet->ksp_leaky);

if (NULL != use_reco)
unmap_cfl(DIMS, use_reco_dims, use_reco);

auto split_op = nn_from_nlop_F(nlop_rand_split);
split_op = nn_set_output_name_F(split_op, 0, "pat_trn");
Expand All @@ -921,11 +941,15 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_

nlinvnet->weights = nn_weights_create_from_nn(nn_train);
nn_init(nn_train, nlinvnet->weights);

} else {

auto tmp_weights = nn_weights_create_from_nn(nn_train);

nn_weights_copy(tmp_weights, nlinvnet->weights);

nn_weights_free(nlinvnet->weights);

nlinvnet->weights = tmp_weights;
}

Expand Down Expand Up @@ -966,39 +990,44 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_

switch (in_type[i]) {

case IN_BATCH_GENERATOR:

src[i] = NULL;
break;

case IN_BATCH:
case IN_UNDEFINED:
error("Intype of arg %d not supported!\n", i);
break;

case IN_OPTIMIZE:
case IN_STATIC:
case IN_BATCHNORM:
{
auto iov_weight = nlinvnet->weights->iovs[weight_index];
auto iov_train_op = nlop_generic_domain(nn_get_nlop(nn_train), i);
assert(md_check_equal_dims(iov_weight->N, iov_weight->dims, iov_train_op->dims, ~0UL));
src[i] = (float*)nlinvnet->weights->tensors[weight_index];
weight_index++;
}
case IN_BATCH_GENERATOR:

src[i] = NULL;
break;

case IN_BATCH:
case IN_UNDEFINED:
error("Intype of arg %d not supported!\n", i);
break;

case IN_OPTIMIZE:
case IN_STATIC:
case IN_BATCHNORM:

auto iov_weight = nlinvnet->weights->iovs[weight_index];
auto iov_train_op = nlop_generic_domain(nn_get_nlop(nn_train), i);
assert(md_check_equal_dims(iov_weight->N, iov_weight->dims, iov_train_op->dims, ~0UL));
src[i] = (float*)nlinvnet->weights->tensors[weight_index];
weight_index++;
break;
}
}

int num_monitors = 0;
const struct monitor_value_s* value_monitors[3];

if (NULL != valid_data) {

auto nn_validation_loss = nlinvnet_valid_create(nlinvnet, valid_data);

const char* val_names[nn_get_nr_out_args(nn_validation_loss)];

for (int i = 0; i < nn_get_nr_out_args(nn_validation_loss); i++)
val_names[i] = nn_get_out_name_from_arg_index(nn_validation_loss, i, false);

value_monitors[num_monitors] = monitor_iter6_nlop_create(nn_get_nlop(nn_validation_loss), false, nn_get_nr_out_args(nn_validation_loss), val_names);
nn_free(nn_validation_loss);

num_monitors += 1;
}

Expand All @@ -1014,16 +1043,20 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_
lams[i] = lam;

auto destack_lambda = nlop_from_linop_F(linop_identity_create(2, MD_DIMS(1, num_lambda)));

for (int i = num_lambda - 1; 0 < i; i--)
destack_lambda = nlop_chain2_FF(destack_lambda, 0, nlop_destack_create(2, MD_DIMS(1, i), MD_DIMS(1, 1), MD_DIMS(1, i + 1), 1), 0);

for(int i = 0; i < index_lambda; i++)
destack_lambda = nlop_combine_FF(nlop_del_out_create(1, MD_DIMS(1)), destack_lambda);

for(int i = index_lambda + 1; i < NI; i++)
destack_lambda = nlop_combine_FF(destack_lambda, nlop_del_out_create(1, MD_DIMS(1)));

value_monitors[num_monitors] = monitor_iter6_nlop_create(destack_lambda, true, num_lambda, lams);

nlop_free(destack_lambda);

num_monitors += 1;
}

Expand All @@ -1039,16 +1072,20 @@ void train_nlinvnet(struct nlinvnet_s* nlinvnet, int Nb, struct named_data_list_
lams[i] = lam;

auto destack_lambda = nlop_from_linop_F(linop_identity_create(2, MD_DIMS(1, num_lambda)));

for (int i = num_lambda - 1; 0 < i; i--)
destack_lambda = nlop_chain2_FF(destack_lambda, 0, nlop_destack_create(2, MD_DIMS(1, i), MD_DIMS(1, 1), MD_DIMS(1, i + 1), 1), 0);

for(int i = 0; i < index_lambda; i++)
destack_lambda = nlop_combine_FF(nlop_del_out_create(1, MD_DIMS(1)), destack_lambda);

for(int i = index_lambda + 1; i < NI; i++)
destack_lambda = nlop_combine_FF(destack_lambda, nlop_del_out_create(1, MD_DIMS(1)));

value_monitors[num_monitors] = monitor_iter6_nlop_create(destack_lambda, true, num_lambda, lams);

nlop_free(destack_lambda);

num_monitors += 1;
}

Expand Down Expand Up @@ -1089,7 +1126,8 @@ void apply_nlinvnet(struct nlinvnet_s* nlinvnet, int N,
const struct nlop_s* nlop_apply = nlop_optimize_graph(nlop_clone(nn_apply->nlop));

nn_debug(DP_INFO, nn_apply);
unsigned long batch_flags = md_nontriv_dims(N, img_dims) & (~(md_nontriv_dims(N, nn_generic_codomain(nn_apply, 0, "img")->dims)));
unsigned long batch_flags = md_nontriv_dims(N, img_dims)
& ~md_nontriv_dims(N, nn_generic_codomain(nn_apply, 0, "img")->dims);

nn_free(nn_apply);

Expand All @@ -1102,10 +1140,12 @@ void apply_nlinvnet(struct nlinvnet_s* nlinvnet, int N,

long col_dims2[N];
md_select_dims(N, ~COIL_FLAG, col_dims2, col_dims);

complex float* tmp = md_alloc_sameplace(N, col_dims2, CFL_SIZE, img);
md_zrss(N, col_dims, COIL_FLAG, tmp, col);

md_zrss(N, col_dims, COIL_FLAG, tmp, col);
md_zmul2(N, img_dims, MD_STRIDES(N, img_dims, CFL_SIZE), img, MD_STRIDES(N, img_dims, CFL_SIZE), img, MD_STRIDES(N, col_dims2, CFL_SIZE), tmp);

md_free(tmp);
}
}
Expand Down

0 comments on commit dbe0f06

Please sign in to comment.