ML Documentation

GRU-D(GRU with Decay)の詳細解説と実装ガイド

1. GRU-Dとは何か

1.1 背景と動機

GRU-D(Gated Recurrent Unit with Decay) は、不規則にサンプリングされた時系列データ、特に欠損値を含むデータを効果的に処理するために設計された深層学習モデルです。2016年にChe et al.によって提案されました。

従来のRNN/LSTM/GRUの問題点:
- 固定時間間隔を前提としている
- 欠損値に対する明示的な処理メカニズムがない
- 観測間の時間間隔の情報を活用できない

GRU-Dの解決策:
- 時間減衰メカニズム: 観測間の時間間隔に基づいて情報を減衰
- 欠損値の明示的処理: マスキングと入力減衰
- 時間情報の活用: 時間間隔を追加の入力として使用

1.2 数学的基礎

標準的なGRUの更新式:

r_t = σ(W_r[x_t, h_{t-1}] + b_r)  # リセットゲート
z_t = σ(W_z[x_t, h_{t-1}] + b_z)  # 更新ゲート
h̃_t = tanh(W_h[x_t, r_t ⊙ h_{t-1}] + b_h)  # 候補隠れ状態
h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t  # 最終隠れ状態

GRU-Dの拡張:

γ_t = exp(-max(0, W_γδ_t + b_γ))  # 時間減衰
h'_{t-1} = γ_t ⊙ h_{t-1}  # 減衰した隠れ状態

x'_t = m_t ⊙ x_t + (1 - m_t) ⊙ (γ_x ⊙ x̂_{t-1} + (1 - γ_x) ⊙ x̄)  # 入力減衰

ここで:
- δ_t: 前回の観測からの経過時間
- m_t: マスキングベクトル(1: 観測あり、0: 欠損)
- x̂_{t-1}: 前回の観測値
- : 訓練データの平均値

2. GRU-Dの詳細実装

2.1 基本的なGRU-Dセル

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GRUDCell(nn.Module):
    """GRU-Dの単一セル実装"""

    def __init__(self, input_size, hidden_size, x_mean=None):
        super(GRUDCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 標準的なGRUパラメータ
        self.W_ir = nn.Linear(input_size, hidden_size, bias=True)
        self.W_hr = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_iz = nn.Linear(input_size, hidden_size, bias=True)
        self.W_hz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_in = nn.Linear(input_size, hidden_size, bias=True)
        self.W_hn = nn.Linear(hidden_size, hidden_size, bias=False)

        # GRU-D特有のパラメータ
        # 時間減衰パラメータ
        self.W_gamma_h = nn.Linear(1, hidden_size, bias=True)
        self.W_gamma_x = nn.Linear(1, input_size, bias=True)

        # 欠損値処理用の平均値
        if x_mean is not None:
            self.register_buffer('x_mean', torch.tensor(x_mean, dtype=torch.float32))
        else:
            self.register_buffer('x_mean', torch.zeros(input_size))

        self.reset_parameters()

    def reset_parameters(self):
        """パラメータの初期化"""
        std = 1.0 / np.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-std, std)

    def forward(self, x, h_prev, mask, delta_t, x_last_observed):
        """
        Args:
            x: 現在の入力 (batch_size, input_size)
            h_prev: 前の隠れ状態 (batch_size, hidden_size)
            mask: マスク (batch_size, input_size) - 1: 観測あり, 0: 欠損
            delta_t: 前回からの経過時間 (batch_size, 1)
            x_last_observed: 最後に観測された値 (batch_size, input_size)

        Returns:
            h: 新しい隠れ状態 (batch_size, hidden_size)
            (gamma_h, gamma_x): デバッグ用の減衰率
        """
        batch_size = x.size(0)

        # 1. 時間減衰の計算
        gamma_h = torch.exp(-F.relu(self.W_gamma_h(delta_t)))
        gamma_x = torch.exp(-F.relu(self.W_gamma_x(delta_t)))

        # 2. 隠れ状態の減衰
        h_prev = gamma_h * h_prev

        # 3. 入力の減衰と補完
        # 欠損値を前回の観測値と平均値の組み合わせで補完
        x_complement = gamma_x * x_last_observed + (1 - gamma_x) * self.x_mean
        x = mask * x + (1 - mask) * x_complement

        # 4. 標準的なGRUの計算
        # リセットゲート
        r = torch.sigmoid(self.W_ir(x) + self.W_hr(h_prev))

        # 更新ゲート
        z = torch.sigmoid(self.W_iz(x) + self.W_hz(h_prev))

        # 候補隠れ状態
        n = torch.tanh(self.W_in(x) + self.W_hn(r * h_prev))

        # 新しい隠れ状態
        h = (1 - z) * h_prev + z * n

        return h, (gamma_h, gamma_x)

2.2 完全なGRU-Dモデル

class GRUD(nn.Module):
    """完全なGRU-Dモデル"""

    def __init__(self, input_size, hidden_size, output_size, 
                 num_layers=1, dropout=0.0, x_mean=None):
        super(GRUD, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout = dropout

        # GRU-Dセルのスタック
        self.cells = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                cell = GRUDCell(input_size, hidden_size, x_mean)
            else:
                cell = GRUDCell(hidden_size, hidden_size)
            self.cells.append(cell)

        # ドロップアウト層
        if dropout > 0:
            self.dropout_layer = nn.Dropout(dropout)

        # 出力層
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x, mask, delta_t, x_last_observed=None, h_0=None):
        """
        Args:
            x: 入力シーケンス (batch_size, seq_len, input_size)
            mask: マスクシーケンス (batch_size, seq_len, input_size)
            delta_t: 時間間隔 (batch_size, seq_len, 1)
            x_last_observed: 最後の観測値 (batch_size, seq_len, input_size)
            h_0: 初期隠れ状態 (num_layers, batch_size, hidden_size)

        Returns:
            output: 出力シーケンス (batch_size, seq_len, output_size)
            h_n: 最終隠れ状態 (num_layers, batch_size, hidden_size)
        """
        batch_size, seq_len, _ = x.size()

        # 初期隠れ状態
        if h_0 is None:
            h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, 
                            device=x.device)

        # 最後の観測値の初期化
        if x_last_observed is None:
            x_last_observed = torch.zeros_like(x)
            # 最初の非欠損値で初期化
            for b in range(batch_size):
                for t in range(seq_len):
                    if mask[b, t].sum() > 0:
                        x_last_observed[b, 0] = x[b, t] * mask[b, t]
                        break

        outputs = []
        h = list(h_0)

        for t in range(seq_len):
            # 現在の入力
            x_t = x[:, t, :]
            mask_t = mask[:, t, :]
            delta_t_current = delta_t[:, t, :]

            # 最後の観測値の更新
            if t > 0:
                x_last_obs = x_last_observed[:, t, :]
            else:
                x_last_obs = x_last_observed[:, 0, :]

            # 各層を通過
            input_t = x_t
            for layer in range(self.num_layers):
                h[layer], _ = self.cells[layer](
                    input_t, h[layer], mask_t if layer == 0 else torch.ones_like(input_t),
                    delta_t_current, x_last_obs if layer == 0 else input_t
                )
                input_t = h[layer]

                # ドロップアウト(最後の層以外)
                if self.dropout > 0 and layer < self.num_layers - 1:
                    input_t = self.dropout_layer(input_t)

            outputs.append(input_t)

        # 出力の生成
        outputs = torch.stack(outputs, dim=1)
        output = self.output_layer(outputs)
        h_n = torch.stack(h)

        return output, h_n

    def predict_next(self, x_last, h_last, delta_t, mask=None):
        """次の時点の予測(推論用)"""
        if mask is None:
            mask = torch.ones_like(x_last)

        # 単一時点の予測
        x = x_last.unsqueeze(1)
        mask = mask.unsqueeze(1)
        delta_t = delta_t.unsqueeze(1)

        output, h_n = self.forward(x, mask, delta_t, x, h_last)

        return output[:, 0, :], h_n

2.3 学習用のヘルパー関数

class GRUDTrainer:
    """GRU-Dモデルの学習を管理するクラス"""

    def __init__(self, model, learning_rate=0.001, weight_decay=1e-5):
        self.model = model
        self.optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=learning_rate, 
            weight_decay=weight_decay
        )
        self.criterion = nn.MSELoss()

    def compute_loss(self, predictions, targets, mask):
        """マスクを考慮した損失計算"""
        # 観測されたデータポイントのみで損失を計算
        masked_predictions = predictions * mask
        masked_targets = targets * mask

        # マスクされた要素数で正規化
        n_observed = mask.sum()
        if n_observed > 0:
            loss = torch.sum((masked_predictions - masked_targets) ** 2) / n_observed
        else:
            loss = torch.tensor(0.0, device=predictions.device)

        return loss

    def train_epoch(self, dataloader):
        """1エポックの学習"""
        self.model.train()
        total_loss = 0.0
        n_batches = 0

        for batch in dataloader:
            x, mask, delta_t, x_last_obs, targets = batch

            # 勾配リセット
            self.optimizer.zero_grad()

            # 予測
            predictions, _ = self.model(x, mask, delta_t, x_last_obs)

            # 損失計算
            loss = self.compute_loss(predictions, targets, mask)

            # バックプロパゲーション
            loss.backward()

            # 勾配クリッピング
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # パラメータ更新
            self.optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        return total_loss / n_batches

3. データ前処理

3.1 不規則時系列データの準備

class IrregularTimeSeriesProcessor:
    """不規則時系列データの前処理"""

    def __init__(self, regular_interval='1min'):
        self.regular_interval = regular_interval

    def process_irregular_data(self, timestamps, values, start_time, end_time):
        """
        不規則データを処理してGRU-D用の形式に変換

        Args:
            timestamps: タイムスタンプのリスト
            values: 値のリスト
            start_time: 開始時刻
            end_time: 終了時刻

        Returns:
            regular_times: 規則的な時刻グリッド
            interpolated_values: 補間された値
            mask: マスク(1: 実測値, 0: 欠損)
            delta_t: 時間間隔
            last_observed_values: 最後の観測値
        """
        # 規則的な時刻グリッドを作成
        regular_times = pd.date_range(
            start=start_time, 
            end=end_time, 
            freq=self.regular_interval
        )

        n_times = len(regular_times)
        n_features = values.shape[1] if len(values.shape) > 1 else 1

        # 初期化
        interpolated_values = np.zeros((n_times, n_features))
        mask = np.zeros((n_times, n_features))
        delta_t = np.zeros((n_times, 1))
        last_observed_values = np.zeros((n_times, n_features))

        # 各時刻での処理
        last_obs_idx = -1
        last_obs_val = np.zeros(n_features)

        for i, t in enumerate(regular_times):
            # 最も近い観測を見つける
            time_diffs = np.abs(timestamps - t)
            closest_idx = np.argmin(time_diffs)

            # 閾値内であれば実測値として使用
            if time_diffs[closest_idx] < pd.Timedelta(self.regular_interval) / 2:
                interpolated_values[i] = values[closest_idx]
                mask[i] = 1
                last_obs_idx = i
                last_obs_val = values[closest_idx]

            # 時間間隔の計算
            if last_obs_idx >= 0:
                delta_t[i] = (i - last_obs_idx) * pd.Timedelta(self.regular_interval).total_seconds() / 60
            else:
                delta_t[i] = 0

            # 最後の観測値
            last_observed_values[i] = last_obs_val

        return regular_times, interpolated_values, mask, delta_t, last_observed_values

3.2 バッチデータの作成

class GRUDDataset(torch.utils.data.Dataset):
    """GRU-D用のデータセット"""

    def __init__(self, data, sequence_length=60, prediction_horizon=5):
        """
        Args:
            data: 処理済みデータ(values, mask, delta_t, last_observed)
            sequence_length: 入力シーケンスの長さ
            prediction_horizon: 予測ホライズン
        """
        self.values = data['values']
        self.mask = data['mask']
        self.delta_t = data['delta_t']
        self.last_observed = data['last_observed']
        self.sequence_length = sequence_length
        self.prediction_horizon = prediction_horizon

        # 有効なインデックスを計算
        self.valid_indices = []
        total_length = len(self.values)
        for i in range(total_length - sequence_length - prediction_horizon + 1):
            # 少なくとも一部のデータが観測されているシーケンスのみ
            if self.mask[i:i+sequence_length].sum() > 0:
                self.valid_indices.append(i)

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_indices[idx]
        end_idx = start_idx + self.sequence_length
        target_idx = end_idx + self.prediction_horizon - 1

        # 入力シーケンス
        x = torch.FloatTensor(self.values[start_idx:end_idx])
        mask = torch.FloatTensor(self.mask[start_idx:end_idx])
        delta_t = torch.FloatTensor(self.delta_t[start_idx:end_idx])
        last_obs = torch.FloatTensor(self.last_observed[start_idx:end_idx])

        # ターゲット
        if target_idx < len(self.values):
            target = torch.FloatTensor(self.values[target_idx])
        else:
            target = torch.FloatTensor(self.values[-1])

        return x, mask, delta_t, last_obs, target

4. 実践的な使用例

4.1 モデルの学習

def train_grud_model(train_data, val_data, config):
    """GRU-Dモデルの学習"""

    # データの平均値を計算(欠損値補完用)
    x_mean = train_data['values'][train_data['mask'] == 1].mean(axis=0)

    # モデルの初期化
    model = GRUD(
        input_size=config['input_size'],
        hidden_size=config['hidden_size'],
        output_size=config['output_size'],
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        x_mean=x_mean
    )

    # データローダーの作成
    train_dataset = GRUDDataset(train_data, sequence_length=60)
    val_dataset = GRUDDataset(val_data, sequence_length=60)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=32, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=32, shuffle=False
    )

    # トレーナーの初期化
    trainer = GRUDTrainer(model, learning_rate=0.001)

    # 学習ループ
    best_val_loss = float('inf')
    patience = 0

    for epoch in range(100):
        # 学習
        train_loss = trainer.train_epoch(train_loader)

        # 検証
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                x, mask, delta_t, last_obs, targets = batch
                predictions, _ = model(x, mask, delta_t, last_obs)
                loss = trainer.compute_loss(predictions[:, -1, :], targets, mask[:, -1, :])
                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_grud_model.pth')
            patience = 0
        else:
            patience += 1
            if patience >= 10:
                print("Early stopping")
                break

    return model

4.2 リアルタイム予測

class GRUDPredictor:
    """リアルタイム予測のためのクラス"""

    def __init__(self, model, buffer_size=60):
        self.model = model
        self.buffer_size = buffer_size

        # バッファの初期化
        self.value_buffer = deque(maxlen=buffer_size)
        self.mask_buffer = deque(maxlen=buffer_size)
        self.time_buffer = deque(maxlen=buffer_size)
        self.last_observation = None
        self.hidden_state = None

    def update(self, value, timestamp, is_observed=True):
        """新しいデータポイントでバッファを更新"""
        # バッファに追加
        self.value_buffer.append(value)
        self.mask_buffer.append(1.0 if is_observed else 0.0)
        self.time_buffer.append(timestamp)

        # 最後の観測値を更新
        if is_observed:
            self.last_observation = value

    def predict_next(self, future_timestamp):
        """次の時点の予測"""
        if len(self.value_buffer) < 10:  # 最小限のデータが必要
            return None

        # 時間間隔の計算
        if self.last_observation is not None:
            delta_t = (future_timestamp - self.time_buffer[-1]).total_seconds() / 60
        else:
            delta_t = 0

        # テンソルに変換
        x = torch.FloatTensor(list(self.value_buffer)).unsqueeze(0)
        mask = torch.FloatTensor(list(self.mask_buffer)).unsqueeze(0)
        delta_t_tensor = torch.FloatTensor([[delta_t]])

        # 予測
        self.model.eval()
        with torch.no_grad():
            if self.hidden_state is None:
                # 全シーケンスを処理
                output, self.hidden_state = self.model(
                    x.unsqueeze(-1), 
                    mask.unsqueeze(-1), 
                    delta_t_tensor.unsqueeze(1),
                    x.unsqueeze(-1)
                )
                prediction = output[:, -1, :]
            else:
                # 増分更新
                prediction, self.hidden_state = self.model.predict_next(
                    x[:, -1:], 
                    self.hidden_state, 
                    delta_t_tensor,
                    mask[:, -1:]
                )

        return prediction.squeeze().numpy()

5. GRU-Dの利点と制限

5.1 利点

  1. 欠損値の自然な処理: 補間や削除なしに欠損値を直接扱える
  2. 時間情報の活用: 観測間隔の情報を有効活用
  3. メモリ効率: 固定長バッファで不規則データを処理
  4. 解釈可能性: 減衰メカニズムが直感的

5.2 制限

  1. 計算コスト: 標準GRUより計算量が多い
  2. ハイパーパラメータ: 減衰率の調整が必要
  3. 長期依存: 非常に長い時間間隔では性能低下

5.3 改良版

class ImprovedGRUD(nn.Module):
    """改良版GRU-D with attention"""

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.grud = GRUD(input_size, hidden_size, hidden_size)

        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=0.1
        )

        # Output projection
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, output_size)
        )

    def forward(self, x, mask, delta_t, x_last_observed):
        # GRU-D encoding
        grud_out, _ = self.grud(x, mask, delta_t, x_last_observed)

        # Self-attention
        attended, _ = self.attention(
            grud_out.transpose(0, 1),
            grud_out.transpose(0, 1),
            grud_out.transpose(0, 1),
            key_padding_mask=(mask.sum(dim=-1) == 0)
        )
        attended = attended.transpose(0, 1)

        # Output
        output = self.output_projection(attended)

        return output

6. まとめ

GRU-Dは不規則時系列データ、特に欠損値を含むデータの処理において強力なツールです。時間減衰メカニズムにより、観測間隔の情報を自然に活用でき、医療データ、金融データ、センサーデータなど様々な分野で応用可能です。

実装時の重要なポイント:
1. データの前処理で適切な時間間隔を計算
2. 欠損値パターンに応じた減衰率の調整
3. 十分な学習データと適切な正則化
4. タスクに応じた出力層の設計