Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ test/.test_coverage.dart

.DS_Store
*/**/.DS_Store
flutter-sdk/
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 16.17.14
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChristianKleineidam could you please change the version here as well?

- Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic
binary classifiers

## 16.17.13
- Added Decision Tree web demo using Web Assembly

Expand Down
28 changes: 28 additions & 0 deletions lib/src/metric/classification/log_loss_metric.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_linalg/matrix.dart';
import 'dart:math' as math;

class LogLossMetric implements Metric {
const LogLossMetric({this.eps = 1e-15});

final double eps;

double _clip(double p) => p < eps ? eps : (p > 1.0 - eps ? 1.0 - eps : p);

@override
double getScore(Matrix predictedLabels, Matrix origLabels) {
validateMatrixColumns([predictedLabels, origLabels]);

final preds = predictedLabels.toVector();
final orig = origLabels.toVector();

var sum = 0.0;
for (var i = 0; i < preds.length; i++) {
final p = _clip(preds[i]);
final y = orig[i];
sum += y == 1 ? -math.log(p) : -math.log(1.0 - p);
}
return sum / preds.length;
}
}
4 changes: 4 additions & 0 deletions lib/src/metric/metric_factory_impl.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_algo/src/metric/metric_factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
Expand Down Expand Up @@ -28,6 +29,9 @@ class MetricFactoryImpl implements MetricFactory {
case MetricType.recall:
return const RecallMetric();

case MetricType.logLoss:
return const LogLossMetric();

default:
throw UnsupportedError('Unsupported metric type $type');
}
Expand Down
3 changes: 3 additions & 0 deletions lib/src/metric/metric_type.dart
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,7 @@ enum MetricType {
/// better the prediction's quality is. The metric produces scores within the
/// range [0, 1]
recall,

/// Binary cross-entropy (a.k.a. log-loss)
logLoss,
}
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 16.17.13
version: 16.17.14
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChristianKleineidam let's update minor version instead of patch, since you've added new functionality:

version: 16.18.0

homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
28 changes: 28 additions & 0 deletions test/metric/classification/log_loss_metric_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

void main() {
group('LogLossMetric', () {
const metric = LogLossMetric();

test('perfect predictions → loss ≈ 0', () {
final yTrue = Matrix.column([1, 0, 1, 0]);
final yPred = Matrix.column([1.0, 0.0, 1.0, 0.0]);
expect(metric.getScore(yPred, yTrue), closeTo(0.0, 1e-12));
});

test('typical predictions', () {
final yTrue = Matrix.column([1, 0]);
final yPred = Matrix.column([0.9, 0.1]);
expect(metric.getScore(yPred, yTrue),
closeTo(0.10536051565782628, 1e-6)); // -ln(0.9)
});

test('probabilities are clipped', () {
final yTrue = Matrix.column([1, 0]);
final yPred = Matrix.column([0.0, 1.0]);
expect(metric.getScore(yPred, yTrue).isFinite, isTrue);
});
});
}
5 changes: 5 additions & 0 deletions test/metric/metric_factory_impl_test.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_algo/src/metric/metric_factory_impl.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/metric/regression/mape.dart';
Expand Down Expand Up @@ -31,5 +32,9 @@ void main() {
test('should create RecallMetric instance', () {
expect(factory.createByType(MetricType.recall), isA<RecallMetric>());
});

test('should create LogLossMetric instance', () {
expect(factory.createByType(MetricType.logLoss), isA<LogLossMetric>());
});
});
}