Skip to content

Commit 503ea1c

Browse files
authored
Merge pull request #179 from decargroup/bugfix/164-pykoopscore_trajectory-does-not-work-with-multioutput=raw_values
Make `KoopmanPipeline.score_trajectory()` compatible with `multioutput='raw_values'` regression metric keyword argument
2 parents 8c798d6 + f14d589 commit 503ea1c

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

pykoop/koopman_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3425,7 +3425,7 @@ def score_trajectory(
34253425
else:
34263426
score = regression_metric(**regression_metric_args)
34273427
# Return error score if score is not finite
3428-
if not np.isfinite(score):
3428+
if not np.all(np.isfinite(score)):
34293429
if isinstance(error_score, str):
34303430
raise ValueError(
34313431
'Prediction diverged or error occured while scoring.')

tests/test_koopman_pipeline.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ class TestKoopmanPipelineScore:
323323

324324
@pytest.mark.parametrize(
325325
'X_predicted, X_expected, n_steps, discount_factor, '
326-
'regression_metric, error_score, min_samples, episode_feature, '
327-
'score_exp',
326+
'regression_metric, regression_metric_kw, error_score, min_samples, '
327+
'episode_feature, score_exp',
328328
[
329329
(
330330
np.array([
@@ -338,11 +338,32 @@ class TestKoopmanPipelineScore:
338338
None,
339339
1,
340340
'neg_mean_squared_error',
341+
None,
341342
np.nan,
342343
1,
343344
False,
344345
0,
345346
),
347+
(
348+
np.array([
349+
[1, 2, 3, 4],
350+
[2, 3, 3, 2],
351+
]).T,
352+
np.array([
353+
[1, 2, 3, 4],
354+
[2, 3, 3, 2],
355+
]).T,
356+
None,
357+
1,
358+
'neg_mean_squared_error',
359+
{
360+
'multioutput': 'raw_values',
361+
},
362+
np.nan,
363+
1,
364+
False,
365+
np.array([0, 0]),
366+
),
346367
(
347368
np.array([
348369
[1, 2],
@@ -355,6 +376,27 @@ class TestKoopmanPipelineScore:
355376
None,
356377
1,
357378
'neg_mean_squared_error',
379+
{
380+
'multioutput': 'raw_values',
381+
},
382+
np.nan,
383+
1,
384+
False,
385+
-np.array([2**2, 1]),
386+
),
387+
(
388+
np.array([
389+
[1, 2],
390+
[2, 3],
391+
]).T,
392+
np.array([
393+
[1, 4],
394+
[2, 2],
395+
]).T,
396+
None,
397+
1,
398+
'neg_mean_squared_error',
399+
None,
358400
np.nan,
359401
1,
360402
False,
@@ -370,6 +412,7 @@ class TestKoopmanPipelineScore:
370412
None,
371413
1,
372414
'neg_mean_squared_error',
415+
None,
373416
np.nan,
374417
1,
375418
False,
@@ -385,6 +428,7 @@ class TestKoopmanPipelineScore:
385428
None,
386429
1,
387430
'neg_mean_absolute_error',
431+
None,
388432
np.nan,
389433
1,
390434
False,
@@ -400,6 +444,7 @@ class TestKoopmanPipelineScore:
400444
None,
401445
1,
402446
'neg_mean_squared_error',
447+
None,
403448
np.nan,
404449
2,
405450
False,
@@ -415,6 +460,7 @@ class TestKoopmanPipelineScore:
415460
2,
416461
1,
417462
'neg_mean_squared_error',
463+
None,
418464
np.nan,
419465
1,
420466
False,
@@ -430,6 +476,7 @@ class TestKoopmanPipelineScore:
430476
None,
431477
0.5,
432478
'neg_mean_squared_error',
479+
None,
433480
np.nan,
434481
1,
435482
False,
@@ -448,6 +495,7 @@ class TestKoopmanPipelineScore:
448495
None,
449496
1,
450497
'neg_mean_squared_error',
498+
None,
451499
np.nan,
452500
1,
453501
True,
@@ -465,6 +513,7 @@ class TestKoopmanPipelineScore:
465513
1,
466514
1,
467515
'neg_mean_squared_error',
516+
None,
468517
np.nan,
469518
1,
470519
True,
@@ -482,6 +531,7 @@ class TestKoopmanPipelineScore:
482531
None,
483532
0.5,
484533
'neg_mean_squared_error',
534+
None,
485535
np.nan,
486536
1,
487537
True,
@@ -499,6 +549,7 @@ class TestKoopmanPipelineScore:
499549
1,
500550
0.5,
501551
'neg_mean_squared_error',
552+
None,
502553
np.nan,
503554
1,
504555
True,
@@ -516,6 +567,7 @@ class TestKoopmanPipelineScore:
516567
None,
517568
1,
518569
'neg_mean_squared_error',
570+
None,
519571
np.nan,
520572
1,
521573
False,
@@ -533,6 +585,7 @@ class TestKoopmanPipelineScore:
533585
None,
534586
1,
535587
'neg_mean_squared_error',
588+
None,
536589
-100,
537590
1,
538591
False,
@@ -550,6 +603,7 @@ class TestKoopmanPipelineScore:
550603
None,
551604
1,
552605
'neg_mean_squared_error',
606+
None,
553607
'raise',
554608
1,
555609
False,
@@ -565,6 +619,7 @@ class TestKoopmanPipelineScore:
565619
None,
566620
1,
567621
'neg_mean_squared_error',
622+
None,
568623
-100,
569624
1,
570625
False,
@@ -580,6 +635,7 @@ class TestKoopmanPipelineScore:
580635
None,
581636
1,
582637
'neg_mean_squared_error',
638+
None,
583639
'raise',
584640
1,
585641
False,
@@ -596,6 +652,7 @@ class TestKoopmanPipelineScore:
596652
None,
597653
1,
598654
'neg_mean_squared_error',
655+
None,
599656
-100,
600657
1,
601658
False,
@@ -610,6 +667,7 @@ def test_score_trajectory(
610667
n_steps,
611668
discount_factor,
612669
regression_metric,
670+
regression_metric_kw,
613671
error_score,
614672
min_samples,
615673
episode_feature,
@@ -623,6 +681,7 @@ def test_score_trajectory(
623681
n_steps=n_steps,
624682
discount_factor=discount_factor,
625683
regression_metric=regression_metric,
684+
regression_metric_kw=regression_metric_kw,
626685
error_score=error_score,
627686
min_samples=min_samples,
628687
episode_feature=episode_feature,
@@ -634,6 +693,7 @@ def test_score_trajectory(
634693
n_steps=n_steps,
635694
discount_factor=discount_factor,
636695
regression_metric=regression_metric,
696+
regression_metric_kw=regression_metric_kw,
637697
error_score=error_score,
638698
min_samples=min_samples,
639699
episode_feature=episode_feature,

0 commit comments

Comments
 (0)