Skip to content

Commit

Permalink
Subject: Add R2Score metric. (#8169) (#8353)
Browse files Browse the repository at this point in the history
Body:
FEATURE

Co-authored-by: Matthew Soulanille <msoulanille@google.com>
  • Loading branch information
lukonik and mattsoulanille authored Aug 22, 2024
1 parent 0677375 commit 936b448
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
19 changes: 19 additions & 0 deletions tfjs-layers/src/exports_metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,22 @@ export function MSE(yTrue: Tensor, yPred: Tensor): Tensor {
export function mse(yTrue: Tensor, yPred: Tensor): Tensor {
return losses.meanSquaredError(yTrue, yPred);
}

/**
* Computes R2 score.
*
* ```js
* const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
* const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
* const r2Score = tf.metrics.r2Score(yTrue, yPred);
* r2Score.print();
* ```
* @param yTrue Truth Tensor.
* @param yPred Prediction Tensor.
* @return R2 score Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
return metrics.r2Score(yTrue, yPred);
}
12 changes: 9 additions & 3 deletions tfjs-layers/src/metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import {Tensor, tidy} from '@tensorflow/tfjs-core';

import * as K from './backend/tfjs_backend';
import {NotImplementedError, ValueError} from './errors';
import {categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
import {binaryCrossentropy as lossBinaryCrossentropy} from './losses';
import {lossesMap} from './losses';
import {binaryCrossentropy as lossBinaryCrossentropy, categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, lossesMap, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
import {LossOrMetricFn} from './types';
import * as util from './utils/generic_utils';

Expand Down Expand Up @@ -112,6 +110,14 @@ export function sparseTopKCategoricalAccuracy(
throw new NotImplementedError();
}

export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
return tidy(() => {
const sumSquaresResiduals = yTrue.sub(yPred).square().sum();
const sumSquares = yTrue.sub(yTrue.mean()).square().sum();
return tfc.scalar(1).sub(sumSquaresResiduals.div(sumSquares));
});
}

// Aliases.
export const mse = meanSquaredError;
export const MSE = meanSquaredError;
Expand Down
23 changes: 22 additions & 1 deletion tfjs-layers/src/metrics_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {scalar, Tensor, tensor, tensor1d, tensor2d} from '@tensorflow/tfjs-core'

import {setEpsilon} from './backend/common';
import * as tfl from './index';
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName} from './metrics';
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName, r2Score} from './metrics';
import {LossOrMetricFn} from './types';
import {describeMathCPUAndGPU, describeMathCPUAndWebGL2, expectTensorsClose} from './utils/test_utils';

Expand Down Expand Up @@ -283,6 +283,27 @@ describeMathCPUAndGPU('recall metric', () => {
});
});

describeMathCPUAndGPU('r2Score', () => {
it('1D', () => {
const yTrue = tensor1d([3, -0.5, 2, 7, 4.2, 8.5, 1.3, 2.8, 6.7, 9.0]);
const yPred = tensor1d([2.5, 0.0, 2.1, 7.8, 4.0, 8.2, 1.4, 2.9, 6.5, 9.1]);
const score = r2Score(yTrue, yPred);
expectTensorsClose(score, scalar(0.985));
});
it('2D', () => {
const yTrue = tensor2d([
[3, 2.5], [-0.5, 3.2], [2, 1.9], [7, 5.1], [4.2, 3.8], [8.5, 7.4],
[1.3, 0.6], [2.8, 2.1], [6.7, 5.3], [9.0, 8.7]
]);
const yPred = tensor2d([
[2.7, 2.3], [0.0, 3.1], [2.1, 1.8], [6.8, 5.0], [4.1, 3.7], [8.4, 7.2],
[1.4, 0.7], [2.9, 2.2], [6.6, 5.2], [9.2, 8.9]
]);
const score = r2Score(yTrue, yPred);
expectTensorsClose(score, scalar(0.995));
});
});

describe('metrics.get', () => {
it('valid name, not alias', () => {
expect(get('binaryAccuracy') === get('categoricalAccuracy')).toEqual(false);
Expand Down

0 comments on commit 936b448

Please sign in to comment.