Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ test/.test_coverage.dart

.DS_Store
*/**/.DS_Store
flutter-sdk/
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 16.17.15
- Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic
binary classifiers

## 16.17.13
- Added Decision Tree web demo using Web Assembly

Expand Down
28 changes: 28 additions & 0 deletions lib/src/metric/classification/log_loss_metric.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_linalg/matrix.dart';
import 'dart:math' as math;

class LogLossMetric implements Metric {
const LogLossMetric({this.eps = 1e-15});

final double eps;

double _clip(double p) => p < eps ? eps : (p > 1.0 - eps ? 1.0 - eps : p);

@override
double getScore(Matrix predictedLabels, Matrix origLabels) {
validateMatrixColumns([predictedLabels, origLabels]);

final preds = predictedLabels.toVector();
final orig = origLabels.toVector();

var sum = 0.0;
for (var i = 0; i < preds.length; i++) {
final p = _clip(preds[i]);
final y = orig[i];
sum += y == 1 ? -math.log(p) : -math.log(1.0 - p);
}
return sum / preds.length;
}
}
4 changes: 4 additions & 0 deletions lib/src/metric/metric_factory_impl.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_algo/src/metric/metric_factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
Expand Down Expand Up @@ -28,6 +29,9 @@ class MetricFactoryImpl implements MetricFactory {
case MetricType.recall:
return const RecallMetric();

case MetricType.logLoss:
return const LogLossMetric();

default:
throw UnsupportedError('Unsupported metric type $type');
}
Expand Down
3 changes: 3 additions & 0 deletions lib/src/metric/metric_type.dart
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,7 @@ enum MetricType {
/// better the prediction's quality is. The metric produces scores within the
/// range [0, 1]
recall,

/// Binary cross-entropy (a.k.a. log-loss)
logLoss,
}
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 16.17.13
version: 16.17.15
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChristianKleineidam I'm sorry, you got me wrong - I meant 16.18.0 (16.17.13 -> 16.18.0). We update minor version (number in the middle) when we add something new. No worries, I'll merge the branch, I'll fix the version before publishing

homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
28 changes: 28 additions & 0 deletions test/metric/classification/log_loss_metric_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

void main() {
group('LogLossMetric', () {
const metric = LogLossMetric();

test('perfect predictions → loss ≈ 0', () {
final yTrue = Matrix.column([1, 0, 1, 0]);
final yPred = Matrix.column([1.0, 0.0, 1.0, 0.0]);
expect(metric.getScore(yPred, yTrue), closeTo(0.0, 1e-12));
});

test('typical predictions', () {
final yTrue = Matrix.column([1, 0]);
final yPred = Matrix.column([0.9, 0.1]);
expect(metric.getScore(yPred, yTrue),
closeTo(0.10536051565782628, 1e-6)); // -ln(0.9)
});

test('probabilities are clipped', () {
final yTrue = Matrix.column([1, 0]);
final yPred = Matrix.column([0.0, 1.0]);
expect(metric.getScore(yPred, yTrue).isFinite, isTrue);
});
});
}
5 changes: 5 additions & 0 deletions test/metric/metric_factory_impl_test.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/classification/log_loss_metric.dart';
import 'package:ml_algo/src/metric/metric_factory_impl.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/metric/regression/mape.dart';
Expand Down Expand Up @@ -31,5 +32,9 @@ void main() {
test('should create RecallMetric instance', () {
expect(factory.createByType(MetricType.recall), isA<RecallMetric>());
});

test('should create LogLossMetric instance', () {
expect(factory.createByType(MetricType.logLoss), isA<LogLossMetric>());
});
});
}