フリーランチ食べたい

No Free Lunch in ML and Life. Pythonや機械学習のことを書きます。

時系列モデル(ARIMA/Prophet/NNなど)を統一的なAPIで扱えるPythonライブラリ「Darts」がかなり便利

時系列モデルを扱う上でデファクトスタンダードになりそうなPythonライブラリが出てきました。

f:id:mergyi:20200825003540p:plain

時系列モデルを扱うPythonライブラリは、 scikit-learn のようなデファクトスタンダードなものがありません。そのため時系列モデルを用いて実装を行うためには、様々なライブラリのAPIなどの仕様を理解しつつ、それに合わせてデータ整形を行い、評価する必要があり、これはなかなか辛い作業でした。

スイスの企業 Unit8 が今年(2020年)6月末に公開した Darts はまさにこういった課題を解決するライブラリです。時系列に関する様々なモデルを scikit-learn ベースのAPIで統一的に扱うことができます。

github.com

Darts は現在、下記のモデルに対応しています。内側では statsmodelsProphet(stan)Pytorch などを使っていて、 Darts はラッパーライブラリになっています。

  • Exponential smoothing,
  • ARIMA & auto-ARIMA,
  • Facebook Prophet,
  • Theta method,
  • FFT (Fast Fourier Transform),
  • Recurrent neural networks (vanilla RNNs, GRU, and LSTM variants),
  • Temporal convolutional network.

また時系列モデル予測における、欠損値保管といった前処理や、BacktestingやGrid Searchといったモデル選択・評価で用いるユーティリティも兼ね備えており、こちらもかなり便利です。

使ってみる

簡単にウォークスルーをして、 Darts の外観を見ていきます。

インストール

まずインストールは通常通り pip install でできます。

pip install u8darts 

現在バージョンは 0.2.2 です。

print(darts.__version__) # -> '0.2.2'

データ読み込み

今回は公式のExample) でも使われている Air Passengers) のデータを使いたいと思います。Darts を使うためにはまず pandas.DataFrame として読み込みます。

import pandas as pd

df = pd.read_csv('AirPassengers.csv', delimiter=",")
df 
# Month    #Passengers
# 1949-01  112
# 1949-02  118
# 1949-03  132
# 1949-04  129
# 1949-05  121

pandas.DataFrame から darts.TimeSeries で下記のように変換します。直感的ですね。

from darts import TimeSeries
series = TimeSeries.from_dataframe(df, time_col='Month', value_cols='#Passengers')

モデル学習・予測

それではモデルを学習し予測してみます。試しに ExponentialSmoothing を使ってみます。

下記のように scikit-learn のAPIに極めて近い形で書くことができます。注意点としては predict には日付データを渡すのではなく、「trainモデルで最新の日付から未来の何時点まで予測するか」を渡します。

# 1959/01/01以前と以後に分割
train, val = series.split_after(pd.Timestamp('19590101'))

from darts.models import ExponentialSmoothing

# モデル生成
model = ExponentialSmoothing()

# 学習
model.fit(train) 

# 予測 (predictには予測数を入れることに注意)
prediction = model.predict(len(val))

可視化

可視化も pandas.DataFrame.plot と同じように darts.TimeSeries.plot を呼ぶことで簡単に行うことができます。

import matplotlib.pyplot as plt

series.plot(label='actual', lw=3)
prediction.plot(label='forecast', lw=3)
plt.legend()
plt.xlabel('Year');

f:id:mergyi:20200825004220p:plain

ほとんど pandasscikit-learn を使って行うワークフローと同じように書けることが感じていただけたかと思います。

他のモデルとの交換も楽々

Darts の素晴らしさを知るために、他のモデルも試してみます。下記のようにAPIが完全に統一されているので、scikit-learn のように簡単に多数のモデルを比較することが可能です。

from darts.models import (
    NaiveSeasonal,
    NaiveDrift,
    Prophet,
    ExponentialSmoothing,
    ARIMA,
    AutoARIMA,
    StandardRegressionModel,
    Theta,
    FFT
)

for model in (
    NaiveSeasonal,
    NaiveDrift,
    Prophet,
    ExponentialSmoothing,
    ARIMA,
    AutoARIMA,
    # StandardRegressionModel, -> 初期化時にtrain_n_points が必要
    Theta,
    FFT
):
    m = model()
    m.fit(train)
    pred = m.predict(len(val))

Deep Learningを使ったモデルの RNNModelTCNModel に関してはデータの加工が必要なのでご注意ください。

Backtesting

時系列モデルではCross Validationの仕方が一般的なモデルとは異なる時系列を過去から未来に向かって少しずつ予測・評価していくBacktestingという手法が使われますが、これに関しても便利に使える関数が用意されています。下記の例は3つのモデルに対してそれぞれBacktestingして評価する例です。

from darts.backtesting import backtest_forecasting

models = [ExponentialSmoothing, AutoARIMA, Prophet]
backtests = [
    backtest_forecasting(series=series,
                         model=model(),
                         start=pd.Timestamp('19540101'),
                         fcast_horizon_n=3)
    for model in models
]

mapeで精度を測りつつ、可視化してみます。

from darts.metrics import mape

series.plot(label="series")
for i, m in enumerate(models):
    err = mape(backtests[i], series)
    backtests[i].plot(lw=3, label=f'{m.__name__}, MAPE={err:.2f}%')
plt.legend();

f:id:mergyi:20200825004240p:plain

今回のデータとハイパーパラメータ(デフォルト)では ExponentialSmoothing が最も精度が良かったようです。

最後に

簡単な例でしたが、Darts の便利さが伝われば幸いです。また公開されたばかりで、バージョンも0.2.2 ということでまだ成熟したライブラリではないですが、現状でも十分実用性はありますし、今後みんなが使うことで開発も活発になり成熟が期待できるでしょう。また、コントリビューションも積極的に受け入れているようなのでIssueを見て取り組んでいきましょう。

また今回使ったコードは公式のExampleをベースにしていて、こちらに更に様々なAPIが紹介されているので興味ある方はぜひチェックしてみてください。

unit8co.github.io

補足: PyFlux

先日こちらの記事で様々な時系列モデル用のPythonライブラリが紹介されていて、その中にPyFlux) がありました。

zakopilo.hatenablog.jp

PyFluxはARIMAモデルを始めとして確率分布をベースにした様々な時系列モデルを扱える Darts と同様非常にライブラリですが、残念ながらPython3.6以降に対応しておらず、作者の方も「メンテする意思がなく他のライブラリを使って欲しい」と言っている) ので、少なくとも本番環境では使わない方が良いと思われます。