Skip to content

Commit 1c804f5

Browse files
Add Ranking AutoML Sample (#852)
* Initial add of project * Update ranking sample * Get sample working * Updates based on feedback * Add refitting on validation and test data sets * Update console headers * Iteration print improvements * Correct validationData * Printing NDCG@1,3,10 & DCG@10 * Printing NDCG@1,3,10 & DCG@10 * Add readme * Update based on feedback * Use new DcgTruncation property * Update to latest AutoML package * Review feedback * Wording for 1st refit step * Update to include original label in output Co-authored-by: Justin Ormont <[email protected]>
1 parent 2990008 commit 1c804f5

File tree

8 files changed

+559
-7
lines changed

8 files changed

+559
-7
lines changed

samples/csharp/common/AutoML/ConsoleHelper.cs

+26-7
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,26 @@ public static void PrintBinaryClassificationMetrics(string name, BinaryClassific
4444
public static void PrintMulticlassClassificationMetrics(string name, MulticlassClassificationMetrics metrics)
4545
{
4646
Console.WriteLine($"************************************************************");
47-
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
47+
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
4848
Console.WriteLine($"*-----------------------------------------------------------");
49-
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
50-
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
49+
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
50+
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value from 0 and 1, where closer to 1.0 is better");
5151
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
5252
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
5353
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
5454
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
5555
Console.WriteLine($"************************************************************");
5656
}
5757

58+
public static void PrintRankingMetrics(string name, RankingMetrics metrics, uint optimizationMetricTruncationLevel)
59+
{
60+
Console.WriteLine($"************************************************************");
61+
Console.WriteLine($"* Metrics for {name} ranking model ");
62+
Console.WriteLine($"*-----------------------------------------------------------");
63+
Console.WriteLine($" Normalized Discounted Cumulative Gain (NDCG@{optimizationMetricTruncationLevel}) = {metrics?.NormalizedDiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}, a value from 0 and 1, where closer to 1.0 is better");
64+
Console.WriteLine($" Discounted Cumulative Gain (DCG@{optimizationMetricTruncationLevel}) = {metrics?.DiscountedCumulativeGains?[(int)optimizationMetricTruncationLevel - 1] ?? double.NaN:0.####}");
65+
}
66+
5867
public static void ShowDataViewInConsole(MLContext mlContext, IDataView dataView, int numberOfRows = 4)
5968
{
6069
string msg = string.Format("Show data in DataView: Showing {0} rows with the columns", numberOfRows.ToString());
@@ -89,6 +98,11 @@ internal static void PrintIterationMetrics(int iteration, string trainerName, Re
8998
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", Width);
9099
}
91100

101+
internal static void PrintIterationMetrics(int iteration, string trainerName, RankingMetrics metrics, double? runtimeInSeconds)
102+
{
103+
CreateRow($"{iteration,-4} {trainerName,-15} {metrics?.NormalizedDiscountedCumulativeGains[0] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[2] ?? double.NaN,9:F4} {metrics?.NormalizedDiscountedCumulativeGains[9] ?? double.NaN,9:F4} {metrics?.DiscountedCumulativeGains[9] ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1}", Width);
104+
}
105+
92106
internal static void PrintIterationException(Exception ex)
93107
{
94108
Console.WriteLine($"Exception during AutoML iteration: {ex}");
@@ -109,6 +123,11 @@ internal static void PrintRegressionMetricsHeader()
109123
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", Width);
110124
}
111125

126+
internal static void PrintRankingMetricsHeader()
127+
{
128+
CreateRow($"{"",-4} {"Trainer",-15} {"NDCG@1",9} {"NDCG@3",9} {"NDCG@10",9} {"DCG@10",9} {"Duration",9}", Width);
129+
}
130+
112131
private static void CreateRow(string message, int width)
113132
{
114133
Console.WriteLine("|" + message.PadRight(width - 2) + "|");
@@ -239,10 +258,10 @@ private void AppendTableRow(ICollection<string[]> tableRows,
239258

240259
tableRows.Add(new[]
241260
{
242-
columnName,
243-
GetColumnDataType(columnName),
244-
columnPurpose
245-
});
261+
columnName,
262+
GetColumnDataType(columnName),
263+
columnPurpose
264+
});
246265
}
247266

248267
private void AppendTableRows(ICollection<string[]> tableRows,

samples/csharp/common/AutoML/ProgressHandlers.cs

+23
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,27 @@ public void Report(RunDetail<RegressionMetrics> iterationResult)
8181
}
8282
}
8383
}
84+
85+
public class RankingExperimentProgressHandler : IProgress<RunDetail<RankingMetrics>>
86+
{
87+
private int _iterationIndex;
88+
89+
public void Report(RunDetail<RankingMetrics> iterationResult)
90+
{
91+
if (_iterationIndex++ == 0)
92+
{
93+
ConsoleHelper.PrintRankingMetricsHeader();
94+
}
95+
96+
if (iterationResult.Exception != null)
97+
{
98+
ConsoleHelper.PrintIterationException(iterationResult.Exception);
99+
}
100+
else
101+
{
102+
ConsoleHelper.PrintIterationMetrics(_iterationIndex, iterationResult.TrainerName,
103+
iterationResult.ValidationMetrics, iterationResult.RuntimeInSeconds);
104+
}
105+
}
106+
}
84107
}

0 commit comments

Comments
 (0)