Skip to content

Add stratified k-fold cross-validation #121

@cardmagic

Description

@cardmagic

Summary

Implement stratified k-fold cross-validation for evaluating classifier performance.

Motivation

From classifier-reborn#151:

Cross-validation is essential for:

  • Evaluating classifier accuracy before deployment
  • Comparing different classifier types (Bayes vs LogisticRegression)
  • Tuning hyperparameters
  • Detecting overfitting

Stratified k-fold ensures each fold maintains the same class distribution as the original dataset, which is important for imbalanced datasets.

Proposed API

# Basic k-fold cross-validation
results = Classifier::CrossValidation.kfold(
  classifier_class: Classifier::Bayes,
  categories: [:spam, :ham],
  data: {
    spam: spam_documents,
    ham: ham_documents
  },
  k: 5,  # 5-fold
  stratified: true  # maintain class proportions
)

# Results
results.accuracy      # => 0.94
results.precision     # => {spam: 0.92, ham: 0.96}
results.recall        # => {spam: 0.95, ham: 0.93}
results.f1_score      # => {spam: 0.93, ham: 0.94}
results.confusion_matrix
# =>        spam  ham
#    spam   [95,   5]
#    ham    [ 7,  93]

# Per-fold results
results.folds.each do |fold|
  puts "Fold #{fold.index}: accuracy=#{fold.accuracy}"
end

Additional Features

  • Support for different split strategies (random, stratified)
  • Optional progress callback for long-running validations
  • Integration with streaming training for large datasets
  • Support for all classifier types

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions