Skip to content

Commit 66cb98a

Browse files
add Binary Log-Loss (Cross-Entropy) metric (#262)
1 parent e415c86 commit 66cb98a

File tree

8 files changed

+74
-1
lines changed

8 files changed

+74
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ test/.test_coverage.dart
1414

1515
.DS_Store
1616
*/**/.DS_Store
17+
flutter-sdk/

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.17.15
4+
- Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic
5+
binary classifiers
6+
37
## 16.17.13
48
- Added Decision Tree web demo using Web Assembly
59

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
2+
import 'package:ml_algo/src/metric/metric.dart';
3+
import 'package:ml_linalg/matrix.dart';
4+
import 'dart:math' as math;
5+
6+
class LogLossMetric implements Metric {
7+
const LogLossMetric({this.eps = 1e-15});
8+
9+
final double eps;
10+
11+
double _clip(double p) => p < eps ? eps : (p > 1.0 - eps ? 1.0 - eps : p);
12+
13+
@override
14+
double getScore(Matrix predictedLabels, Matrix origLabels) {
15+
validateMatrixColumns([predictedLabels, origLabels]);
16+
17+
final preds = predictedLabels.toVector();
18+
final orig = origLabels.toVector();
19+
20+
var sum = 0.0;
21+
for (var i = 0; i < preds.length; i++) {
22+
final p = _clip(preds[i]);
23+
final y = orig[i];
24+
sum += y == 1 ? -math.log(p) : -math.log(1.0 - p);
25+
}
26+
return sum / preds.length;
27+
}
28+
}

lib/src/metric/metric_factory_impl.dart

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'package:ml_algo/src/metric/classification/accuracy.dart';
22
import 'package:ml_algo/src/metric/classification/precision.dart';
33
import 'package:ml_algo/src/metric/classification/recall.dart';
4+
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
45
import 'package:ml_algo/src/metric/metric.dart';
56
import 'package:ml_algo/src/metric/metric_factory.dart';
67
import 'package:ml_algo/src/metric/metric_type.dart';
@@ -28,6 +29,9 @@ class MetricFactoryImpl implements MetricFactory {
2829
case MetricType.recall:
2930
return const RecallMetric();
3031

32+
case MetricType.logLoss:
33+
return const LogLossMetric();
34+
3135
default:
3236
throw UnsupportedError('Unsupported metric type $type');
3337
}

lib/src/metric/metric_type.dart

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,7 @@ enum MetricType {
100100
/// better the prediction's quality is. The metric produces scores within the
101101
/// range [0, 1]
102102
recall,
103+
104+
/// Binary cross-entropy (a.k.a. log-loss)
105+
logLoss,
103106
}

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.17.13
3+
version: 16.17.15
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
2+
import 'package:ml_linalg/matrix.dart';
3+
import 'package:test/test.dart';
4+
5+
void main() {
6+
group('LogLossMetric', () {
7+
const metric = LogLossMetric();
8+
9+
test('perfect predictions → loss ≈ 0', () {
10+
final yTrue = Matrix.column([1, 0, 1, 0]);
11+
final yPred = Matrix.column([1.0, 0.0, 1.0, 0.0]);
12+
expect(metric.getScore(yPred, yTrue), closeTo(0.0, 1e-12));
13+
});
14+
15+
test('typical predictions', () {
16+
final yTrue = Matrix.column([1, 0]);
17+
final yPred = Matrix.column([0.9, 0.1]);
18+
expect(metric.getScore(yPred, yTrue),
19+
closeTo(0.10536051565782628, 1e-6)); // -ln(0.9)
20+
});
21+
22+
test('probabilities are clipped', () {
23+
final yTrue = Matrix.column([1, 0]);
24+
final yPred = Matrix.column([0.0, 1.0]);
25+
expect(metric.getScore(yPred, yTrue).isFinite, isTrue);
26+
});
27+
});
28+
}

test/metric/metric_factory_impl_test.dart

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'package:ml_algo/src/metric/classification/accuracy.dart';
22
import 'package:ml_algo/src/metric/classification/precision.dart';
33
import 'package:ml_algo/src/metric/classification/recall.dart';
4+
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
45
import 'package:ml_algo/src/metric/metric_factory_impl.dart';
56
import 'package:ml_algo/src/metric/metric_type.dart';
67
import 'package:ml_algo/src/metric/regression/mape.dart';
@@ -31,5 +32,9 @@ void main() {
3132
test('should create RecallMetric instance', () {
3233
expect(factory.createByType(MetricType.recall), isA<RecallMetric>());
3334
});
35+
36+
test('should create LogLossMetric instance', () {
37+
expect(factory.createByType(MetricType.logLoss), isA<LogLossMetric>());
38+
});
3439
});
3540
}

0 commit comments

Comments
 (0)