diff --git a/cmd/metricscollector/v1beta1/file-metricscollector/main.go b/cmd/metricscollector/v1beta1/file-metricscollector/main.go index a385d20d2ca..8f7c113366f 100644 --- a/cmd/metricscollector/v1beta1/file-metricscollector/main.go +++ b/cmd/metricscollector/v1beta1/file-metricscollector/main.go @@ -44,7 +44,6 @@ import ( "fmt" "os" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -141,21 +140,15 @@ func printMetricsFile(mFile string) { } func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, fileFormat commonv1beta1.FileFormat) { + // First metric is objective in metricNames array. + objMetric := strings.Split(*metricNames, ";")[0] + objType := commonv1beta1.ObjectiveType(*objectiveType) - // metricStartStep is the dict where key = metric name, value = start step. - // We should apply early stopping rule only if metric is reported at least "start_step" times. - metricStartStep := make(map[string]int) - for _, stopRule := range stopRules { - if stopRule.StartStep != 0 { - metricStartStep[stopRule.Name] = stopRule.StartStep - } + rules, err := filemc.NewRuleSet(objMetric, objType, stopRules) + if err != nil { + klog.Fatalf("NewRuleSet failed: %v", err) } - // For objective metric we calculate best optimal value from the recorded metrics. - // This is workaround for Median Stop algorithm. - // TODO (andreyvelich): Think about it, maybe define latest, max or min strategy type in stop-rule as well ? - var optimalObjValue *float64 - // Check that metric file exists. checkMetricFile(mFile) @@ -171,6 +164,11 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f klog.Fatalf("Failed to create new Process from pid %v, error: %v", mainProcPid, err) } + // Get list of regural expressions from filters. + metricRegList := filemc.GetFilterRegexpList(filters) + + liveRuleMetrics := rules.LiveMetrics() + // Start watch log lines. t, _ := tail.TailFile(mFile, tail.Config{Follow: true}) for line := range t.Lines { @@ -180,14 +178,10 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f switch fileFormat { case commonv1beta1.TextFormat: - // Get list of regural expressions from filters. - var metricRegList []*regexp.Regexp - metricRegList = filemc.GetFilterRegexpList(filters) - // Check if log line contains metric from stop rules. isRuleLine := false - for _, rule := range stopRules { - if strings.Contains(logText, rule.Name) { + for _, name := range liveRuleMetrics { + if strings.Contains(logText, name) { isRuleLine = true break } @@ -211,53 +205,46 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, metricName) } - // stopRules contains array of EarlyStoppingRules that has not been reached yet. - // After rule is reached we delete appropriate element from the array. - for idx, rule := range stopRules { - if metricName != rule.Name { + // liveRuleMetrics contains array of EarlyStoppingRules Name that has not been reached yet. + for _, name := range liveRuleMetrics { + if metricName != name { continue } - stopRules, optimalObjValue = updateStopRules(stopRules, optimalObjValue, metricValue, metricStartStep, rule, idx) + err = rules.UpdateMetric(name, metricValue) + if err != nil { + klog.Fatalf("Unable to UpdateMetric %s %v", name, err) + } } } } case commonv1beta1.JsonFormat: - var logJsonObj map[string]interface{} + var logJsonObj map[string]any if err = json.Unmarshal([]byte(logText), &logJsonObj); err != nil { klog.Fatalf("Failed to unmarshal logs in %v format, log: %s, error: %v", commonv1beta1.JsonFormat, logText, err) } - // Check if log line contains metric from stop rules. - isRuleLine := false - for _, rule := range stopRules { - if _, exist := logJsonObj[rule.Name]; exist { - isRuleLine = true - break - } - } - // If log line doesn't contain appropriate metric, continue track file. - if !isRuleLine { - continue - } - // stopRules contains array of EarlyStoppingRules that has not been reached yet. - // After rule is reached we delete appropriate element from the array. - for idx, rule := range stopRules { - value, exist := logJsonObj[rule.Name].(string) + // liveRuleMetrics contains array of EarlyStoppingRules Name that has not been reached yet. + for _, name := range liveRuleMetrics { + value, exist := logJsonObj[name].(string) if !exist { continue } metricValue, err := strconv.ParseFloat(strings.TrimSpace(value), 64) if err != nil { - klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, rule.Name) + klog.Fatalf("Unable to parse value %v to float for metric %v", metricValue, name) + } + err = rules.UpdateMetric(name, metricValue) + if err != nil { + klog.Fatalf("Unable to UpdateMetric %s %v", name, err) } - stopRules, optimalObjValue = updateStopRules(stopRules, optimalObjValue, metricValue, metricStartStep, rule, idx) } default: klog.Fatalf("Format must be set to %v or %v", commonv1beta1.TextFormat, commonv1beta1.JsonFormat) } - // If stopRules array is empty, Trial is early stopped. - if len(stopRules) == 0 { + liveRuleMetrics = rules.LiveMetrics() + // If liveRuleMetrics array is empty, Trial is early stopped. + if len(liveRuleMetrics) == 0 { klog.Info("Training container is early stopped") isEarlyStopped = true @@ -329,67 +316,6 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, f } } -func updateStopRules( - stopRules []commonv1beta1.EarlyStoppingRule, - optimalObjValue *float64, - metricValue float64, - metricStartStep map[string]int, - rule commonv1beta1.EarlyStoppingRule, - ruleIdx int, -) ([]commonv1beta1.EarlyStoppingRule, *float64) { - - // First metric is objective in metricNames array. - objMetric := strings.Split(*metricNames, ";")[0] - objType := commonv1beta1.ObjectiveType(*objectiveType) - - // Calculate optimalObjValue. - if rule.Name == objMetric { - if optimalObjValue == nil { - optimalObjValue = &metricValue - } else if objType == commonv1beta1.ObjectiveTypeMaximize && metricValue > *optimalObjValue { - optimalObjValue = &metricValue - } else if objType == commonv1beta1.ObjectiveTypeMinimize && metricValue < *optimalObjValue { - optimalObjValue = &metricValue - } - // Assign best optimal value to metric value. - metricValue = *optimalObjValue - } - - // Reduce steps if appropriate metric is reported. - // Once rest steps are empty we apply early stopping rule. - if _, ok := metricStartStep[rule.Name]; ok { - metricStartStep[rule.Name]-- - if metricStartStep[rule.Name] != 0 { - return stopRules, optimalObjValue - } - } - - ruleValue, err := strconv.ParseFloat(rule.Value, 64) - if err != nil { - klog.Fatalf("Unable to parse value %v to float for rule metric %v", rule.Value, rule.Name) - } - - // Metric value can be equal, less or greater than stop rule. - // Deleting suitable stop rule from the array. - if rule.Comparison == commonv1beta1.ComparisonTypeEqual && metricValue == ruleValue { - return deleteStopRule(stopRules, ruleIdx), optimalObjValue - } else if rule.Comparison == commonv1beta1.ComparisonTypeLess && metricValue < ruleValue { - return deleteStopRule(stopRules, ruleIdx), optimalObjValue - } else if rule.Comparison == commonv1beta1.ComparisonTypeGreater && metricValue > ruleValue { - return deleteStopRule(stopRules, ruleIdx), optimalObjValue - } - return stopRules, optimalObjValue -} - -func deleteStopRule(stopRules []commonv1beta1.EarlyStoppingRule, idx int) []commonv1beta1.EarlyStoppingRule { - if idx >= len(stopRules) { - klog.Fatalf("Index %v out of range stopRules: %v", idx, stopRules) - } - stopRules[idx] = stopRules[len(stopRules)-1] - stopRules[len(stopRules)-1] = commonv1beta1.EarlyStoppingRule{} - return stopRules[:len(stopRules)-1] -} - func main() { flag.Var(&stopRules, "stop-rule", "The list of early stopping stop rules") flag.Parse() diff --git a/pkg/metricscollector/v1beta1/file-metricscollector/rules.go b/pkg/metricscollector/v1beta1/file-metricscollector/rules.go new file mode 100644 index 00000000000..2578f3d46a9 --- /dev/null +++ b/pkg/metricscollector/v1beta1/file-metricscollector/rules.go @@ -0,0 +1,147 @@ +package sidecarmetricscollector + +import ( + "fmt" + "math" + "strconv" + + commonv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" +) + +type RuleSet struct { + spec []commonv1beta1.EarlyStoppingRule + status []struct { + pruner earlyStoppingPruner + reach bool + } +} + +func NewRuleSet( + objMetric string, + objType commonv1beta1.ObjectiveType, + spec []commonv1beta1.EarlyStoppingRule, +) (*RuleSet, error) { + s := &RuleSet{ + spec: spec, + status: make([]struct { + pruner earlyStoppingPruner + reach bool + }, len(spec)), + } + + for i, rule := range spec { + pruner, err := defaultFactory(rule) + if err != nil { + return nil, err + } + if objMetric == rule.Name { + pruner = &objPruner{ + objType: objType, + optimalObjValue: math.NaN(), + sub: pruner, + } + } + s.status[i].pruner = pruner + } + + return s, nil +} + +func (s *RuleSet) LiveMetrics() []string { + ls := make([]string, 0, len(s.spec)) + for i, rule := range s.spec { + if !s.status[i].reach { + ls = append(ls, rule.Name) + } + } + return ls +} + +func (s *RuleSet) UpdateMetric(name string, metricValue float64) error { + for i := range s.spec { + rule := &s.spec[i] + status := &s.status[i] + if rule.Name != name || status.reach { + continue + } + + reach, err := status.pruner.Pruner(metricValue) + if err != nil { + return err + } + if reach { + status.reach = true + } + } + return nil +} + +type earlyStoppingPruner interface { + Pruner(metricValue float64) (bool, error) +} + +func defaultFactory(rule commonv1beta1.EarlyStoppingRule) (earlyStoppingPruner, error) { + r := rule + switch rule.Comparison { + case commonv1beta1.ComparisonTypeGreater, commonv1beta1.ComparisonTypeLess, commonv1beta1.ComparisonTypeEqual: + value, err := strconv.ParseFloat(r.Value, 64) + if err != nil { + return nil, fmt.Errorf("unable to parse value to float for rule metric %s: %w", r.Name, err) + } + return &basicPruner{ + target: value, + startStep: r.StartStep, + cmp: r.Comparison, + }, nil + default: + return nil, fmt.Errorf("unknown rule comparison: %s", r.Comparison) + } +} + +type basicPruner struct { + target float64 + step int + startStep int + cmp commonv1beta1.ComparisonType +} + +func (p *basicPruner) Pruner(metricValue float64) (bool, error) { + p.step++ + if p.startStep > 0 && p.step < p.startStep { + return false, nil + } + switch p.cmp { + case commonv1beta1.ComparisonTypeLess: + return metricValue < p.target, nil + case commonv1beta1.ComparisonTypeGreater: + return metricValue > p.target, nil + case commonv1beta1.ComparisonTypeEqual: + return metricValue == p.target, nil + default: + return false, fmt.Errorf("unknown rule comparison: %s", p.cmp) + } +} + +type objPruner struct { + objType commonv1beta1.ObjectiveType + optimalObjValue float64 + sub earlyStoppingPruner +} + +func (p *objPruner) Pruner(metricValue float64) (bool, error) { + // For objective metric we calculate best optimal value from the recorded metrics. + // This is workaround for Median Stop algorithm. + // TODO (andreyvelich): Think about it, maybe define latest, max or min strategy type in stop-rule as well ? + + if math.IsNaN(p.optimalObjValue) { + p.optimalObjValue = metricValue + } else if p.objType == commonv1beta1.ObjectiveTypeMaximize && metricValue > p.optimalObjValue { + p.optimalObjValue = metricValue + } else if p.objType == commonv1beta1.ObjectiveTypeMinimize && metricValue < p.optimalObjValue { + p.optimalObjValue = metricValue + } + // Assign best optimal value to metric value. + metricValue = p.optimalObjValue + + return p.sub.Pruner(metricValue) +} diff --git a/pkg/metricscollector/v1beta1/file-metricscollector/rules_test.go b/pkg/metricscollector/v1beta1/file-metricscollector/rules_test.go new file mode 100644 index 00000000000..6a933e8c24d --- /dev/null +++ b/pkg/metricscollector/v1beta1/file-metricscollector/rules_test.go @@ -0,0 +1,137 @@ +package sidecarmetricscollector + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + commonv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/common/v1beta1" +) + +func TestRuleSet(t *testing.T) { + testCases := []struct { + name string + objMetric string + objType commonv1beta1.ObjectiveType + spec []commonv1beta1.EarlyStoppingRule + action func(t *testing.T, s *RuleSet) + }{ + { + name: "simple", + objMetric: "obj", + objType: commonv1beta1.ObjectiveTypeMinimize, + spec: []commonv1beta1.EarlyStoppingRule{ + { + Name: "a", + Value: "0.2", + Comparison: commonv1beta1.ComparisonTypeGreater, + StartStep: 2, + }, + { + Name: "b", + Value: "0.5", + Comparison: commonv1beta1.ComparisonTypeLess, + StartStep: 3, + }, + { + Name: "c", + Value: "1", + Comparison: commonv1beta1.ComparisonTypeEqual, + StartStep: 0, + }, + }, + action: func(t *testing.T, s *RuleSet) { + diff(t, []string{"a", "b", "c"}, s.LiveMetrics()) + err := s.UpdateMetric("c", 1) + if err != nil { + t.Error(err) + } + diff(t, []string{"a", "b"}, s.LiveMetrics()) + err = s.UpdateMetric("a", 1) + if err != nil { + t.Error(err) + } + err = s.UpdateMetric("b", 0) + if err != nil { + t.Error(err) + } + err = s.UpdateMetric("b", 0) + if err != nil { + t.Error(err) + } + diff(t, []string{"a", "b"}, s.LiveMetrics()) + err = s.UpdateMetric("a", 0.1) + if err != nil { + t.Error(err) + } + diff(t, []string{"a", "b"}, s.LiveMetrics()) + err = s.UpdateMetric("a", 0.21) + if err != nil { + t.Error(err) + } + diff(t, []string{"b"}, s.LiveMetrics()) + err = s.UpdateMetric("b", 0.2) + if err != nil { + t.Error(err) + } + diff(t, []string{}, s.LiveMetrics()) + }, + }, + { + name: "obj", + objMetric: "obj", + objType: commonv1beta1.ObjectiveTypeMaximize, + spec: []commonv1beta1.EarlyStoppingRule{ + { + Name: "obj", + Value: "0.8", + Comparison: commonv1beta1.ComparisonTypeGreater, + StartStep: 2, + }, + { + Name: "a", + Value: "0.5", + Comparison: commonv1beta1.ComparisonTypeLess, + StartStep: 2, + }, + }, + action: func(t *testing.T, s *RuleSet) { + diff(t, []string{"obj", "a"}, s.LiveMetrics()) + err := s.UpdateMetric("obj", 1) + if err != nil { + t.Error(err) + } + err = s.UpdateMetric("a", 0.6) + if err != nil { + t.Error(err) + } + diff(t, []string{"obj", "a"}, s.LiveMetrics()) + err = s.UpdateMetric("obj", 0.7) + if err != nil { + t.Error(err) + } + err = s.UpdateMetric("a", 0.6) + if err != nil { + t.Error(err) + } + diff(t, []string{"a"}, s.LiveMetrics()) + }, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + s, err := NewRuleSet(tt.objMetric, tt.objType, tt.spec) + if err != nil { + t.Fatalf("failed to NewRuleSet: %v", err) + } + + tt.action(t, s) + }) + } +} + +func diff(t *testing.T, want, got any) { + t.Helper() + if diff := cmp.Diff(want, got); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } +}