import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax.lax import scan
from jax.scipy.optimize import minimize
import jaxopt
from functools import partial
import pandas as pd
"ggplot")
matplotlib.style.use("figure.figsize"] = [10, 4]
plt.rcParams[
= jax.random.PRNGKey(seed=42) rng_key
In this notebook I demonstrate my JAX implementation of the ETS variant commonly denoted as \((A_d, A)\), which has an additive damped trend and an additive seasonality component. JAX is a high performance tensor library that offers functionalities like automatic differntiation, JIT compilation while having a familiar NumPy-like interface. However, there are some caveats, for example it only allows pure functions (so no side effects are allowed).
The main goal of this notebook is to explore JAX functionalities, especially lax.scan, which offers a convenient way around writing for-loops.
A much more detailed introduction to ETS models can be found in Forecasting: Principles and Practice.
Imports and Setup
Data Loading
We will use the AirPassangers dataset.
= "https://raw.githubusercontent.com/selva86/datasets/master/AirPassengers.csv"
url = pd.read_csv(
y
url,=["date"],
parse_dates"date").sort_index()["value"]
).set_index(= jnp.array(y)[:-12]
y_train = jnp.array(y)[-12:]
y_test
y.plot()
The scan function
Let’s start with a quick run down on how the scan function works. A crude Python implementation of it would look like this:
def scan(f, init, xs, length=None):
if xs is None:
= [None] * length
xs = init
carry = []
ys for x in xs:
= f(carry, x)
carry, y
ys.append(y)return carry, np.stack(ys)
Here, f
takes carry
, which can be thought of as a state and an element x
of xs
, and returns a transformed carry
and y
. The scan function loops through the sequence xs
, applying f
to each element and updating the carry
accordingly.
As an example we will calculate the square of cumulative sum of an integer sequence. Note that this implementation is not a very efficient one, we do it this way only for the sake of presentation.
= jax.random.randint(rng_key, (100,), minval=-100, maxval=100)
xs
def cumsum_squared(xs):
def transition_fn(carry, x):
= carry
previous_cumsum = previous_cumsum + x # Update carry
cumsum = cumsum ** 2 # Calculate y
y return cumsum, y
return scan(
transition_fn,=0,
init=xs,
xs
)
assert jnp.all(jnp.cumsum(xs) ** 2 == cumsum_squared(xs)[1])
Implementing the model
We are ready to start implementing the model, which is given by
\[\begin{aligned} y_t &= l_{t-1} + \varphi b_{t-1} + s_{t - m} + \varepsilon_{t}, \\ l_t &= l_{t - 1} + \varphi b_{t - 1} + \alpha \varepsilon_t, \\ b_t &= \varphi b_{t-1} + \beta \varepsilon_t, \\ s_{t} &= s_{t - m} + \gamma \varepsilon_t. \end{aligned}\] The conditional one step ahead expectation is given by \[\begin{aligned} \mu_{y, t} = l_{t-1} + \varphi b_{t-1} + s_{t - m}, \end{aligned}\]and the conditional \(h\) steps ahead expectation by
\[\begin{aligned} \hat{y}_{t + h} = l_{t} + \sum_{i=0}^h\varphi^i b_{t-1} + s_{t - \bigl\lceil \frac{h}{m} \bigr\rceil m + h}. \end{aligned}\]The later one might seem a bit obnoxious, but it can be derived from the model using \(\mathbb{E}(\varepsilon) = 0\). Luckily, in our approach it suffices to always forecast one step ahead.
def transition_fn(carry, obs, alpha, beta, gamma, phi):
= carry
epsilon, level, trend, season
= level + trend + season[0]
forecast = jax.lax.cond(
new_epsilon
jnp.isnan(obs),lambda: 0.0,
lambda: obs - forecast,
)= level + trend + alpha * epsilon
new_level = phi * trend + beta * epsilon
new_trend = jnp.roll(season.at[0].set(season[0] + gamma * epsilon), -1)
new_season
return (new_epsilon, new_level, new_trend, new_season), forecast
Two things to mention:
- When
obs
isnan
, we want to return the conditional one step ahead expectation. To do this, we used thejax.lax.cond
condition handler. This is equivalent to0.0 if jnp.isnan(obs) else obs - forecast
, however if-else statements are not allowed injit
-compiled functions. - Since JAX requires functions to be pure, mutation of variables is not allowed, i.e.
0] = season[0] + gamma * epsilon season[
would yield an error. Instead, we update the array using the syntax .at[idx].set(y)
.
Prediction and optimization
Predicting is now as easy as scanning using some initial_state
. For optimization of parameters I used LBFGS from the jaxopt library with mean squared error as a loss function, however other loss functions can also be considered.
We remark that constraints on the parameters could (and should) be imposed and the initial state should be optimized, instead of setting it heuristically, however we won’t bother with that, since that is not the focus of this notebook. Again, I refer the reader to Forecasting: Principles and Practice Chapter 8.6.
@jax.jit
def predict(y, coeffs, initial_state):
return scan(
partial(
transition_fn,=coeffs[0],
alpha=coeffs[1],
beta=coeffs[2],
gamma=coeffs[3],
phi
),
initial_state,
y,
)
@jax.jit
def squared_error(y, yhat):
return (y - yhat) ** 2
@jax.jit
def opt_coeffs(
y,
n_periods,
initial_state,
):def loss_func(coeffs):
return squared_error(y, predict(y, coeffs, initial_state)[1]).mean()
= jaxopt.LBFGS(fun=loss_func, maxiter=500).run(jnp.ones(4) / 10)
x_star
return x_star.params
Note that on the first run, opt_coeffs
function is compiled, which slows things down.
%%time
= 12
n_periods = (1.0, .0, .0, y_train[:12].astype("float32"))
initial_state
= opt_coeffs(y_train, n_periods, initial_state) coeffs
However, on later runs it should be pretty fast, especially on inputs of the same size, which sounds restrictive, but it can be solved using padding.
%%timeit
= opt_coeffs(y_train, n_periods, initial_state) coeffs
54.7 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
= plt.subplots(figsize=(20, 8))
fig, ax
="True")
ax.plot(y.index, y, label
ax.plot(
y.index,
predict(
jnp.concat(
[
y_train,* [jnp.nan])
jnp.array(n_periods
]
),
coeffs,
initial_state,1],
)[="ETS",
label
)
ax.axvline(-12],
y.index[="red",
color="dashed",
linestyle="Train-test split",
label
) ax.legend()