@@ -323,8 +323,8 @@ class TestKoopmanPipelineScore:
323323
324324 @pytest .mark .parametrize (
325325 'X_predicted, X_expected, n_steps, discount_factor, '
326- 'regression_metric, error_score, min_samples, episode_feature , '
327- 'score_exp' ,
326+ 'regression_metric, regression_metric_kw, error_score, min_samples , '
327+ 'episode_feature, score_exp' ,
328328 [
329329 (
330330 np .array ([
@@ -338,11 +338,32 @@ class TestKoopmanPipelineScore:
338338 None ,
339339 1 ,
340340 'neg_mean_squared_error' ,
341+ None ,
341342 np .nan ,
342343 1 ,
343344 False ,
344345 0 ,
345346 ),
347+ (
348+ np .array ([
349+ [1 , 2 , 3 , 4 ],
350+ [2 , 3 , 3 , 2 ],
351+ ]).T ,
352+ np .array ([
353+ [1 , 2 , 3 , 4 ],
354+ [2 , 3 , 3 , 2 ],
355+ ]).T ,
356+ None ,
357+ 1 ,
358+ 'neg_mean_squared_error' ,
359+ {
360+ 'multioutput' : 'raw_values' ,
361+ },
362+ np .nan ,
363+ 1 ,
364+ False ,
365+ np .array ([0 , 0 ]),
366+ ),
346367 (
347368 np .array ([
348369 [1 , 2 ],
@@ -355,6 +376,27 @@ class TestKoopmanPipelineScore:
355376 None ,
356377 1 ,
357378 'neg_mean_squared_error' ,
379+ {
380+ 'multioutput' : 'raw_values' ,
381+ },
382+ np .nan ,
383+ 1 ,
384+ False ,
385+ - np .array ([2 ** 2 , 1 ]),
386+ ),
387+ (
388+ np .array ([
389+ [1 , 2 ],
390+ [2 , 3 ],
391+ ]).T ,
392+ np .array ([
393+ [1 , 4 ],
394+ [2 , 2 ],
395+ ]).T ,
396+ None ,
397+ 1 ,
398+ 'neg_mean_squared_error' ,
399+ None ,
358400 np .nan ,
359401 1 ,
360402 False ,
@@ -370,6 +412,7 @@ class TestKoopmanPipelineScore:
370412 None ,
371413 1 ,
372414 'neg_mean_squared_error' ,
415+ None ,
373416 np .nan ,
374417 1 ,
375418 False ,
@@ -385,6 +428,7 @@ class TestKoopmanPipelineScore:
385428 None ,
386429 1 ,
387430 'neg_mean_absolute_error' ,
431+ None ,
388432 np .nan ,
389433 1 ,
390434 False ,
@@ -400,6 +444,7 @@ class TestKoopmanPipelineScore:
400444 None ,
401445 1 ,
402446 'neg_mean_squared_error' ,
447+ None ,
403448 np .nan ,
404449 2 ,
405450 False ,
@@ -415,6 +460,7 @@ class TestKoopmanPipelineScore:
415460 2 ,
416461 1 ,
417462 'neg_mean_squared_error' ,
463+ None ,
418464 np .nan ,
419465 1 ,
420466 False ,
@@ -430,6 +476,7 @@ class TestKoopmanPipelineScore:
430476 None ,
431477 0.5 ,
432478 'neg_mean_squared_error' ,
479+ None ,
433480 np .nan ,
434481 1 ,
435482 False ,
@@ -448,6 +495,7 @@ class TestKoopmanPipelineScore:
448495 None ,
449496 1 ,
450497 'neg_mean_squared_error' ,
498+ None ,
451499 np .nan ,
452500 1 ,
453501 True ,
@@ -465,6 +513,7 @@ class TestKoopmanPipelineScore:
465513 1 ,
466514 1 ,
467515 'neg_mean_squared_error' ,
516+ None ,
468517 np .nan ,
469518 1 ,
470519 True ,
@@ -482,6 +531,7 @@ class TestKoopmanPipelineScore:
482531 None ,
483532 0.5 ,
484533 'neg_mean_squared_error' ,
534+ None ,
485535 np .nan ,
486536 1 ,
487537 True ,
@@ -499,6 +549,7 @@ class TestKoopmanPipelineScore:
499549 1 ,
500550 0.5 ,
501551 'neg_mean_squared_error' ,
552+ None ,
502553 np .nan ,
503554 1 ,
504555 True ,
@@ -516,6 +567,7 @@ class TestKoopmanPipelineScore:
516567 None ,
517568 1 ,
518569 'neg_mean_squared_error' ,
570+ None ,
519571 np .nan ,
520572 1 ,
521573 False ,
@@ -533,6 +585,7 @@ class TestKoopmanPipelineScore:
533585 None ,
534586 1 ,
535587 'neg_mean_squared_error' ,
588+ None ,
536589 - 100 ,
537590 1 ,
538591 False ,
@@ -550,6 +603,7 @@ class TestKoopmanPipelineScore:
550603 None ,
551604 1 ,
552605 'neg_mean_squared_error' ,
606+ None ,
553607 'raise' ,
554608 1 ,
555609 False ,
@@ -565,6 +619,7 @@ class TestKoopmanPipelineScore:
565619 None ,
566620 1 ,
567621 'neg_mean_squared_error' ,
622+ None ,
568623 - 100 ,
569624 1 ,
570625 False ,
@@ -580,6 +635,7 @@ class TestKoopmanPipelineScore:
580635 None ,
581636 1 ,
582637 'neg_mean_squared_error' ,
638+ None ,
583639 'raise' ,
584640 1 ,
585641 False ,
@@ -596,6 +652,7 @@ class TestKoopmanPipelineScore:
596652 None ,
597653 1 ,
598654 'neg_mean_squared_error' ,
655+ None ,
599656 - 100 ,
600657 1 ,
601658 False ,
@@ -610,6 +667,7 @@ def test_score_trajectory(
610667 n_steps ,
611668 discount_factor ,
612669 regression_metric ,
670+ regression_metric_kw ,
613671 error_score ,
614672 min_samples ,
615673 episode_feature ,
@@ -623,6 +681,7 @@ def test_score_trajectory(
623681 n_steps = n_steps ,
624682 discount_factor = discount_factor ,
625683 regression_metric = regression_metric ,
684+ regression_metric_kw = regression_metric_kw ,
626685 error_score = error_score ,
627686 min_samples = min_samples ,
628687 episode_feature = episode_feature ,
@@ -634,6 +693,7 @@ def test_score_trajectory(
634693 n_steps = n_steps ,
635694 discount_factor = discount_factor ,
636695 regression_metric = regression_metric ,
696+ regression_metric_kw = regression_metric_kw ,
637697 error_score = error_score ,
638698 min_samples = min_samples ,
639699 episode_feature = episode_feature ,
0 commit comments