Skip to content

Commit

Permalink
fix(lib): Lower tensor memory usage (#228)
Browse files Browse the repository at this point in the history
  • Loading branch information
isair authored Mar 25, 2021
1 parent 80ef53c commit 0399060
Showing 1 changed file with 29 additions and 25 deletions.
54 changes: 29 additions & 25 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,33 +76,37 @@ const loadCsv = (
);
}

let features = tf.tensor(tables.features);
let testFeatures = tf.tensor(tables.testFeatures);

const labels = tf.tensor(tables.labels);
const testLabels = tf.tensor(tables.testLabels);

if (columnsToStandardise.length > 0) {
const result = standardise(
return tf.tidy(() => {
let features = tf.tensor(tables.features);
let testFeatures = tf.tensor(tables.testFeatures);

const labels = tf.tensor(tables.labels);
const testLabels = tf.tensor(tables.testLabels);

if (columnsToStandardise.length > 0) {
const result = standardise(
features,
testFeatures,
featureColumnNames.map((c) => columnsToStandardise.includes(c))
);
features = result.features;
testFeatures = result.testFeatures;
}

if (prependOnes) {
features = tf.ones([features.shape[0], 1]).concat(features, 1);
testFeatures = tf
.ones([testFeatures.shape[0], 1])
.concat(testFeatures, 1);
}

return {
features,
labels,
testFeatures,
featureColumnNames.map((c) => columnsToStandardise.includes(c))
);
features = result.features;
testFeatures = result.testFeatures;
}

if (prependOnes) {
features = tf.ones([features.shape[0], 1]).concat(features, 1);
testFeatures = tf.ones([testFeatures.shape[0], 1]).concat(testFeatures, 1);
}

return {
features,
labels,
testFeatures,
testLabels,
};
testLabels,
};
});
};

export default loadCsv;

0 comments on commit 0399060

Please sign in to comment.