-
Notifications
You must be signed in to change notification settings - Fork 1
/
README.Rmd
149 lines (105 loc) · 5.79 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# kantime
<!-- badges: start -->
<!-- badges: end -->
{kantime} is a minimal wrapper for Time Series Kolmogorov Arnold Networks in R. By binding nixtla's {neuralforecast} engine, KAN, from Python to use in R. Additional steps were made to bridged to {modeltime}, hence the name {kantime}!
## Installation
You can install the development version of kantime like so:
``` r
devtools::install_github("frankiethull/kantime")
```
## Example
This is a basic example showing the barebones {reticulate} Python wrappers. These are the core bindings before registering can configuring to {parsnip} and then bridging to {modeltime}.
#### setup {kantime}
kantime requires python and neuralforecast.
```{r}
# load the R library -------------------
library(kantime)
# setup Python environment -------------
# 1) create a virtual env ~
# kantime::create_neuralforecast_env()
# 2) use the virtual env ~
kantime::use_neuralforecast_env()
# 3) install neuralforecast ~
#kantime::install_neuralforecast()
```
##### example data for testing the library:
```{r}
## data comes from sister repo
air_passengers_df <- readr::read_csv("https://raw.githubusercontent.com/frankiethull/nixtla-r-tutorial/refs/heads/main/airpassengersDF.csv", show_col_types = FALSE)
Sys.sleep(1)
train_df <- air_passengers_df |> dplyr::slice(1:132)
test_df <- air_passengers_df |> dplyr::anti_join(train_df)
```
#### barebone internals, no xregs, with conformal prediction
note that nixtla's design requires a *unique_id* for each unique time series, *ds* for the time column, and *y* for the outcome variable. The internals work as minimal python wrappers to **fit** and **predict** on a time series data like so:
```r
# kan specs --
kan_model_specs <- kantime:::kan_spec(h = 12L)
# fit
kan_model_fit <- kantime:::conformal_fit(model_spec = kan_model_specs, df = train_df)
# predict
conf_kan_preds <- kantime:::conformal_predict(
model_spec = kan_model_specs,
model_fit = kan_model_fit,
level = 90L)
```
These internals are built thanks to {reticulate} and of using Nixtla's {neuralforecast} via R. While the idea works ok, the results and workflow can be clunky for R, especially when building for many unique IDs (many models for different time series loses pandas row IDs). Additionally, handling parallel processes in R via {reticulate} for cv, tuning and training in Python requires additional robustness.
{kantime} is super experimental in it's design. Which is also why this is not a full {neuralforecast} binding. This binding in particular is going to leverage the root KAN model, a few helper utils, then bind the base KAN model to {parsnip} and bridge to {modeltime}. This loses pieces of the underlying Nixtla functionality but replacing with {tidymodels} & {modeltime} functionality.
#### {neuralforecast}'s kan with a {modeltime} bridge
given our underlying internals, we can bind these functions to {modeltime} which is similar to registering a parsnip model but requires an additional bridge. The bridge implementation is shown below, but we have to add predict methods and there's quite a bit of underlying work that makes it fully functioning.
```r
kan_bridge <- kantime:::kan_bridge_fit_impl(
x = train_df |> dplyr::select(-y),
y = train_df |> dplyr::pull(y),
h = 12L,
input_size = 24L,
max_steps = 10L,
freq = "ME"
)
kan_bridge |> predict()
```
#### {kantime} workflow within {modeltime}
Remembering that once the model has been bridged to {modeltime}, we lose some of the underlying nixtla utilities, but get access to tidymodels + modeltime utilities in R. What's this mean? We can now use this KAN binding with tidymodels tools like `initial_time_split` and we can use leverage `modeltime`'s toolkit to further calibrate & validate KAN models. In fact, we went through these additional bridging steps to have a full suite of backtesting, ensembling, calibrating, and scoring tools that will be very familiar if already using {modeltime} and/or {tidymodels}.
```{r}
library(parsnip)
library(modeltime)
# time split
splits <- rsample::initial_time_split(air_passengers_df, prop = .92)
training <- rsample::training(splits)
testing <- rsample::testing(splits)
# kantime fit
kantime_fit <- kan(h = 12L,
input_size = 24L,
max_steps = 10L,
freq = "ME") |>
set_engine("kan") |>
fit(y ~ ds + unique_id, data = training)
# kantime predict
kan_point_predictions <-
kantime_fit |>
modeltime_table() |>
modeltime_forecast(actual_data = training, new_data = testing)
kan_point_predictions |>
plot_modeltime_forecast(.conf_interval_show = FALSE, .interactive = FALSE)
```
##### To Do's
1) *KAN:* KAN is still univariate without xreg support. (I wasn't sure if this would work so started with only bridging a few KAN parameters.)
2) *KAN* could use more informative info when printed.
3) predict(...*type = "prob"*), by design, kantime wraps the conformal method of nixtla so prediction intervals could be mapped to predict() calls, by-passing the need to calibrate in R, depending on the situation.
4) add in sample predictions to the modeltime bridge
5) *UX:* handle nixtla requirements, i.e. **y, ds and unique_id**, could be tidied within internals
6) *UI:* pandas/reticulate arg types! internals should handle numeric inputs as integers and pass to python for R programmer.
7) handle warnings like `NIXTLA_ID_AS_COL` and
8) `.nested.col = purrr::map2(...)` in final `modeltime_forecast`