ML Documentation

グラフニューラルネットワークによる暗号資産の価格相関分析

1. はじめに

1.1 なぜグラフニューラルネットワークか

暗号資産市場は複雑な相互依存関係を持つネットワークです:

これらの関係性を捉えるには、従来の時系列モデルでは限界があり、グラフ構造を扱えるGNN(Graph Neural Networks)が有効です。

1.2 GNNの利点

  1. 非ユークリッド構造: 暗号資産間の複雑な関係性を自然に表現
  2. 情報伝播: ネットワーク効果を直接モデル化
  3. 動的グラフ: 時間とともに変化する関係性を捉える
  4. マルチスケール: ローカルとグローバルな相関を同時に学習

2. グラフ構築手法

2.1 相関グラフの構築

import numpy as np
import pandas as pd
import networkx as nx
from scipy.stats import pearsonr, spearmanr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GraphConv
from torch_geometric.data import Data, DataLoader

class CryptoCorrelationGraph:
    """暗号資産の相関グラフ構築"""

    def __init__(self, price_data: pd.DataFrame, window_size: int = 30):
        self.price_data = price_data
        self.window_size = window_size
        self.graph = None

    def build_correlation_graph(self, threshold: float = 0.5, method: str = 'pearson'):
        """相関に基づくグラフ構築"""
        # リターンの計算
        returns = self.price_data.pct_change().dropna()

        # 相関行列の計算
        if method == 'pearson':
            corr_matrix = returns.corr(method='pearson')
        elif method == 'spearman':
            corr_matrix = returns.corr(method='spearman')
        elif method == 'dynamic':
            corr_matrix = self._calculate_dynamic_correlation(returns)

        # グラフの構築
        self.graph = nx.Graph()
        assets = corr_matrix.columns

        # ノードの追加
        for asset in assets:
            self.graph.add_node(asset, 
                              price=self.price_data[asset].iloc[-1],
                              returns_mean=returns[asset].mean(),
                              returns_std=returns[asset].std(),
                              volume=self._get_volume(asset))

        # エッジの追加(閾値以上の相関)
        for i, asset1 in enumerate(assets):
            for j, asset2 in enumerate(assets):
                if i < j:  # 重複を避ける
                    correlation = corr_matrix.loc[asset1, asset2]
                    if abs(correlation) > threshold:
                        self.graph.add_edge(asset1, asset2, 
                                          weight=correlation,
                                          abs_weight=abs(correlation))

        return self.graph

    def _calculate_dynamic_correlation(self, returns, method='dcc'):
        """動的条件付き相関(DCC)の計算"""
        from arch import arch_model

        n_assets = len(returns.columns)
        T = len(returns)

        # 各資産のGARCHモデルをフィット
        standardized_residuals = pd.DataFrame()

        for asset in returns.columns:
            model = arch_model(returns[asset], vol='GARCH', p=1, q=1)
            result = model.fit(disp='off')
            standardized_residuals[asset] = result.resid / result.conditional_volatility

        # DCCパラメータの推定(簡略化版)
        Q_bar = standardized_residuals.corr()
        a, b = 0.01, 0.95  # DCCパラメータ(通常は最尤推定)

        # 動的相関の計算
        Q_t = Q_bar.copy()
        dynamic_corr = []

        for t in range(1, T):
            # DCC更新式
            epsilon_t = standardized_residuals.iloc[t].values.reshape(-1, 1)
            Q_t = (1 - a - b) * Q_bar + a * (epsilon_t @ epsilon_t.T) + b * Q_t

            # 相関行列に変換
            D_t = np.diag(1 / np.sqrt(np.diag(Q_t)))
            R_t = D_t @ Q_t @ D_t
            dynamic_corr.append(R_t)

        # 平均相関を返す
        avg_corr = np.mean(dynamic_corr, axis=0)
        return pd.DataFrame(avg_corr, index=returns.columns, columns=returns.columns)

    def build_transaction_graph(self, transaction_data: pd.DataFrame):
        """トランザクションデータからグラフ構築"""
        self.tx_graph = nx.DiGraph()

        # アドレス間の送金をエッジとして追加
        for _, tx in transaction_data.iterrows():
            from_addr = tx['from_address']
            to_addr = tx['to_address']
            amount = tx['amount']

            # ノードの追加(存在しない場合)
            if from_addr not in self.tx_graph:
                self.tx_graph.add_node(from_addr)
            if to_addr not in self.tx_graph:
                self.tx_graph.add_node(to_addr)

            # エッジの追加または更新
            if self.tx_graph.has_edge(from_addr, to_addr):
                self.tx_graph[from_addr][to_addr]['weight'] += amount
                self.tx_graph[from_addr][to_addr]['count'] += 1
            else:
                self.tx_graph.add_edge(from_addr, to_addr, weight=amount, count=1)

        return self.tx_graph

    def build_defi_protocol_graph(self, protocol_data: dict):
        """DeFiプロトコル間の依存関係グラフ"""
        self.defi_graph = nx.DiGraph()

        # プロトコル間の流動性フローを追加
        for protocol, dependencies in protocol_data.items():
            self.defi_graph.add_node(protocol, 
                                   tvl=dependencies.get('tvl', 0),
                                   type=dependencies.get('type', 'unknown'))

            for dep_protocol, flow_amount in dependencies.get('flows', {}).items():
                self.defi_graph.add_edge(protocol, dep_protocol, 
                                       weight=flow_amount)

        return self.defi_graph

2.2 グラフ特徴量の抽出

class GraphFeatureExtractor:
    """グラフ構造からの特徴量抽出"""

    def __init__(self, graph: nx.Graph):
        self.graph = graph

    def extract_node_features(self, node):
        """ノード特徴量の抽出"""
        features = []

        # 基本的なグラフ統計量
        features.append(self.graph.degree(node))
        features.append(nx.clustering(self.graph, node))
        features.append(nx.closeness_centrality(self.graph)[node])
        features.append(nx.betweenness_centrality(self.graph)[node])

        # 加重特徴量
        weighted_degree = sum([self.graph[node][neighbor]['abs_weight'] 
                              for neighbor in self.graph.neighbors(node)])
        features.append(weighted_degree)

        # 近傍の統計量
        neighbor_degrees = [self.graph.degree(n) for n in self.graph.neighbors(node)]
        if neighbor_degrees:
            features.extend([
                np.mean(neighbor_degrees),
                np.std(neighbor_degrees),
                np.max(neighbor_degrees),
                np.min(neighbor_degrees)
            ])
        else:
            features.extend([0, 0, 0, 0])

        # PageRank
        pagerank = nx.pagerank(self.graph, weight='abs_weight')
        features.append(pagerank.get(node, 0))

        # コミュニティ検出
        communities = nx.community.greedy_modularity_communities(self.graph)
        for i, community in enumerate(communities):
            if node in community:
                features.append(i)
                features.append(len(community))
                break
        else:
            features.extend([0, 1])

        return np.array(features)

    def extract_edge_features(self, node1, node2):
        """エッジ特徴量の抽出"""
        if not self.graph.has_edge(node1, node2):
            return np.zeros(10)  # デフォルト特徴量

        features = []

        # 基本的なエッジ属性
        edge_data = self.graph[node1][node2]
        features.append(edge_data.get('weight', 0))
        features.append(edge_data.get('abs_weight', 0))

        # 共通近傍
        common_neighbors = len(list(nx.common_neighbors(self.graph, node1, node2)))
        features.append(common_neighbors)

        # Adamic-Adar指標
        aa_index = sum([1 / np.log(self.graph.degree(n)) 
                       for n in nx.common_neighbors(self.graph, node1, node2)])
        features.append(aa_index)

        # Jaccard係数
        neighbors1 = set(self.graph.neighbors(node1))
        neighbors2 = set(self.graph.neighbors(node2))
        if neighbors1 or neighbors2:
            jaccard = len(neighbors1 & neighbors2) / len(neighbors1 | neighbors2)
        else:
            jaccard = 0
        features.append(jaccard)

        # 最短経路長(エッジを除いて)
        temp_graph = self.graph.copy()
        temp_graph.remove_edge(node1, node2)
        try:
            shortest_path_length = nx.shortest_path_length(temp_graph, node1, node2)
        except nx.NetworkXNoPath:
            shortest_path_length = float('inf')
        features.append(1 / (1 + shortest_path_length))

        return np.array(features)

3. GNNモデルの実装

3.1 基本的なGCN(Graph Convolutional Network)

class CryptoGCN(nn.Module):
    """暗号資産価格予測用GCN"""

    def __init__(self, num_features: int, hidden_dim: int = 64, 
                 output_dim: int = 1, num_layers: int = 3):
        super(CryptoGCN, self).__init__()

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        # 入力層
        self.convs.append(GCNConv(num_features, hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))

        # 隠れ層
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # 出力層
        self.convs.append(GCNConv(hidden_dim, output_dim))

        # 追加の全結合層
        self.fc = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, edge_index, edge_weight=None):
        # グラフ畳み込み
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index, edge_weight)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)

        # 最終層
        x = self.convs[-1](x, edge_index, edge_weight)

        # 追加の変換
        x = self.fc(x)

        return x

3.2 グラフアテンションネットワーク(GAT)

class CryptoGAT(nn.Module):
    """注意機構を持つグラフニューラルネットワーク"""

    def __init__(self, num_features: int, hidden_dim: int = 64,
                 output_dim: int = 1, num_heads: int = 8):
        super(CryptoGAT, self).__init__()

        # マルチヘッドアテンション層
        self.conv1 = GATConv(num_features, hidden_dim, heads=num_heads, 
                            dropout=0.6, concat=True)
        self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, 
                            heads=num_heads, dropout=0.6, concat=True)
        self.conv3 = GATConv(hidden_dim * num_heads, output_dim, 
                            heads=1, concat=False, dropout=0.6)

        # 時系列特徴量との結合層
        self.temporal_encoder = nn.LSTM(
            input_size=num_features,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.5
        )

        # 最終予測層
        self.predictor = nn.Sequential(
            nn.Linear(output_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, edge_index, temporal_features=None):
        # グラフアテンション
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # 時系列特徴量の処理
        if temporal_features is not None:
            _, (h_n, _) = self.temporal_encoder(temporal_features)
            temporal_encoding = h_n[-1]  # 最後の隠れ状態

            # グラフ特徴量と時系列特徴量の結合
            combined = torch.cat([x, temporal_encoding.unsqueeze(0).expand(x.size(0), -1)], dim=1)
            x = self.predictor(combined)

        return x

3.3 時空間グラフニューラルネットワーク

class SpatioTemporalGNN(nn.Module):
    """時空間グラフニューラルネットワーク"""

    def __init__(self, num_features: int, hidden_dim: int = 64,
                 temporal_length: int = 24, num_nodes: int = 100):
        super(SpatioTemporalGNN, self).__init__()

        self.temporal_length = temporal_length
        self.num_nodes = num_nodes

        # 空間的グラフ畳み込み
        self.spatial_conv1 = GraphConv(num_features, hidden_dim)
        self.spatial_conv2 = GraphConv(hidden_dim, hidden_dim)

        # 時間的畳み込み
        self.temporal_conv = nn.Conv1d(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            kernel_size=3,
            padding=1
        )

        # ゲート機構
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # 出力層
        self.output_layer = nn.Linear(hidden_dim, 1)

    def forward(self, x_seq, edge_index_seq):
        """
        Args:
            x_seq: (batch_size, temporal_length, num_nodes, num_features)
            edge_index_seq: 各時点でのエッジインデックス
        """
        batch_size, T, N, F = x_seq.shape

        # 各時点で空間的畳み込み
        spatial_features = []
        for t in range(T):
            x_t = x_seq[:, t, :, :].reshape(batch_size * N, F)
            edge_index_t = edge_index_seq[t]

            # 空間的特徴抽出
            h = F.relu(self.spatial_conv1(x_t, edge_index_t))
            h = self.spatial_conv2(h, edge_index_t)

            spatial_features.append(h.reshape(batch_size, N, -1))

        # (batch_size, num_nodes, hidden_dim, temporal_length)
        spatial_features = torch.stack(spatial_features, dim=3)

        # 時間的畳み込み
        temporal_features = []
        for n in range(N):
            node_features = spatial_features[:, n, :, :]  # (batch, hidden, T)
            temporal_out = self.temporal_conv(node_features)
            temporal_features.append(temporal_out)

        temporal_features = torch.stack(temporal_features, dim=1)

        # ゲート機構で空間・時間特徴を統合
        combined = torch.cat([
            spatial_features[:, :, :, -1],  # 最新の空間特徴
            temporal_features[:, :, :, -1]  # 最新の時間特徴
        ], dim=2)

        gate_values = self.gate(combined)
        gated_features = gate_values * spatial_features[:, :, :, -1] + \
                        (1 - gate_values) * temporal_features[:, :, :, -1]

        # 最終予測
        output = self.output_layer(gated_features)

        return output

4. 動的グラフ学習

4.1 動的グラフの構築と更新

class DynamicCryptoGraph:
    """時間とともに変化するグラフ構造"""

    def __init__(self, initial_graph: nx.Graph, update_interval: int = 3600):
        self.graphs = [initial_graph]
        self.update_interval = update_interval
        self.current_time = 0

    def update_graph(self, new_data: dict, timestamp: int):
        """グラフの動的更新"""
        # 最新のグラフをコピー
        current_graph = self.graphs[-1].copy()

        # 価格相関の更新
        if 'price_updates' in new_data:
            self._update_correlations(current_graph, new_data['price_updates'])

        # 新しいトランザクションの追加
        if 'transactions' in new_data:
            self._add_transactions(current_graph, new_data['transactions'])

        # エッジの減衰(古い関係性を弱める)
        self._decay_edges(current_graph, timestamp)

        # 新しいノードの追加(新規上場など)
        if 'new_assets' in new_data:
            self._add_new_nodes(current_graph, new_data['new_assets'])

        self.graphs.append(current_graph)
        self.current_time = timestamp

    def _update_correlations(self, graph, price_updates):
        """価格相関の動的更新"""
        # 指数移動平均で相関を更新
        alpha = 0.1  # 学習率

        for edge in graph.edges():
            node1, node2 = edge
            if node1 in price_updates and node2 in price_updates:
                # 新しい相関の計算
                new_corr = np.corrcoef(
                    price_updates[node1][-30:],
                    price_updates[node2][-30:]
                )[0, 1]

                # 既存の相関との加重平均
                old_corr = graph[node1][node2].get('weight', 0)
                updated_corr = alpha * new_corr + (1 - alpha) * old_corr

                graph[node1][node2]['weight'] = updated_corr
                graph[node1][node2]['abs_weight'] = abs(updated_corr)

    def _decay_edges(self, graph, current_timestamp):
        """時間経過によるエッジの減衰"""
        decay_rate = 0.99  # 1時間ごとの減衰率
        hours_passed = (current_timestamp - self.current_time) / 3600

        edges_to_remove = []
        for edge in graph.edges():
            # 重みを減衰
            graph[edge[0]][edge[1]]['weight'] *= (decay_rate ** hours_passed)
            graph[edge[0]][edge[1]]['abs_weight'] *= (decay_rate ** hours_passed)

            # 閾値以下になったエッジを削除
            if graph[edge[0]][edge[1]]['abs_weight'] < 0.1:
                edges_to_remove.append(edge)

        graph.remove_edges_from(edges_to_remove)

4.2 時間認識GNN

class TimeAwareGNN(nn.Module):
    """時間情報を考慮したGNN"""

    def __init__(self, num_features: int, hidden_dim: int = 64):
        super(TimeAwareGNN, self).__init__()

        # 時間エンコーディング
        self.time_encoder = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # グラフ畳み込み(時間条件付き)
        self.conv1 = GCNConv(num_features + hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        # 時間ゲート
        self.time_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # 出力層
        self.output = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_time, current_time):
        """
        Args:
            x: ノード特徴量
            edge_index: エッジインデックス
            edge_time: 各エッジの生成時刻
            current_time: 現在時刻
        """
        # 時間差の計算
        time_diff = current_time - edge_time
        time_encoding = self.time_encoder(time_diff.unsqueeze(-1))

        # ノード特徴量に時間エンコーディングを結合
        x_with_time = torch.cat([x, time_encoding.mean(dim=0).expand(x.size(0), -1)], dim=1)

        # グラフ畳み込み
        h = F.relu(self.conv1(x_with_time, edge_index))
        h = self.conv2(h, edge_index)

        # 時間ゲートによる調整
        gate = self.time_gate(torch.cat([h, time_encoding.mean(dim=0).expand(h.size(0), -1)], dim=1))
        h = gate * h

        # 予測
        out = self.output(h)

        return out

5. クロスチェーン分析

5.1 マルチチェーングラフの構築

class CrossChainGraph:
    """クロスチェーン相互作用のグラフ"""

    def __init__(self):
        self.multi_graph = nx.MultiDiGraph()
        self.chain_metadata = {}

    def add_chain(self, chain_name: str, metadata: dict):
        """ブロックチェーンの追加"""
        self.chain_metadata[chain_name] = metadata

    def add_cross_chain_transaction(self, tx_data: dict):
        """クロスチェーントランザクションの追加"""
        source_chain = tx_data['source_chain']
        target_chain = tx_data['target_chain']
        asset = tx_data['asset']
        amount = tx_data['amount']

        # マルチグラフにエッジを追加
        self.multi_graph.add_edge(
            f"{source_chain}:{asset}",
            f"{target_chain}:{asset}",
            amount=amount,
            timestamp=tx_data['timestamp'],
            tx_hash=tx_data['tx_hash'],
            bridge=tx_data.get('bridge', 'unknown')
        )

    def analyze_bridge_flows(self, time_window: int = 86400):
        """ブリッジ経由のフロー分析"""
        current_time = time.time()
        bridge_flows = {}

        for edge in self.multi_graph.edges(data=True):
            source, target, data = edge
            if current_time - data['timestamp'] <= time_window:
                bridge = data['bridge']
                if bridge not in bridge_flows:
                    bridge_flows[bridge] = {
                        'total_volume': 0,
                        'transaction_count': 0,
                        'unique_assets': set(),
                        'chain_pairs': set()
                    }

                bridge_flows[bridge]['total_volume'] += data['amount']
                bridge_flows[bridge]['transaction_count'] += 1
                bridge_flows[bridge]['unique_assets'].add(source.split(':')[1])
                bridge_flows[bridge]['chain_pairs'].add((source.split(':')[0], 
                                                         target.split(':')[0]))

        return bridge_flows

    def detect_arbitrage_opportunities(self):
        """アービトラージ機会の検出"""
        opportunities = []

        # 各アセットについて
        assets = set()
        for node in self.multi_graph.nodes():
            assets.add(node.split(':')[1])

        for asset in assets:
            # 各チェーンでの価格を取得
            chain_prices = {}
            for chain in self.chain_metadata.keys():
                node = f"{chain}:{asset}"
                if node in self.multi_graph.nodes():
                    # 実際の実装では外部APIから価格を取得
                    chain_prices[chain] = self._get_asset_price(chain, asset)

            # 価格差を計算
            if len(chain_prices) >= 2:
                max_chain = max(chain_prices, key=chain_prices.get)
                min_chain = min(chain_prices, key=chain_prices.get)

                price_diff_pct = (chain_prices[max_chain] - chain_prices[min_chain]) / \
                                chain_prices[min_chain] * 100

                if price_diff_pct > 0.5:  # 0.5%以上の価格差
                    opportunities.append({
                        'asset': asset,
                        'buy_chain': min_chain,
                        'sell_chain': max_chain,
                        'price_diff_pct': price_diff_pct,
                        'estimated_profit': self._estimate_profit(
                            asset, min_chain, max_chain, price_diff_pct
                        )
                    })

        return opportunities

5.2 クロスチェーンGNNモデル

class CrossChainGNN(nn.Module):
    """クロスチェーン相互作用を考慮したGNN"""

    def __init__(self, num_chains: int, num_assets: int, 
                 chain_embedding_dim: int = 32, asset_embedding_dim: int = 64):
        super(CrossChainGNN, self).__init__()

        # チェーンとアセットのエンベディング
        self.chain_embedding = nn.Embedding(num_chains, chain_embedding_dim)
        self.asset_embedding = nn.Embedding(num_assets, asset_embedding_dim)

        # インターチェーン相互作用層
        self.inter_chain_attention = nn.MultiheadAttention(
            embed_dim=chain_embedding_dim + asset_embedding_dim,
            num_heads=8,
            dropout=0.1
        )

        # グラフ畳み込み層
        feature_dim = chain_embedding_dim + asset_embedding_dim
        self.conv1 = GATConv(feature_dim, 128, heads=4, concat=True)
        self.conv2 = GATConv(128 * 4, 128, heads=4, concat=True)
        self.conv3 = GATConv(128 * 4, 64, heads=1, concat=False)

        # 予測層
        self.predictor = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, 1)
        )

    def forward(self, chain_ids, asset_ids, edge_index, bridge_features=None):
        # エンベディングの取得
        chain_emb = self.chain_embedding(chain_ids)
        asset_emb = self.asset_embedding(asset_ids)

        # 特徴量の結合
        node_features = torch.cat([chain_emb, asset_emb], dim=1)

        # クロスチェーン注意機構
        if bridge_features is not None:
            attended_features, _ = self.inter_chain_attention(
                node_features.unsqueeze(0),
                node_features.unsqueeze(0),
                node_features.unsqueeze(0)
            )
            node_features = node_features + attended_features.squeeze(0)

        # グラフ畳み込み
        x = F.dropout(node_features, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # 予測
        output = self.predictor(x)

        return output

6. 実装例:統合システム

6.1 エンドツーエンドのGNN予測システム

class CryptoGNNPredictionSystem:
    """暗号資産のGNN予測システム"""

    def __init__(self, config: dict):
        self.config = config

        # グラフ構築器
        self.graph_builder = CryptoCorrelationGraph(
            price_data=self._load_price_data(),
            window_size=config['correlation_window']
        )

        # 動的グラフ管理
        self.dynamic_graph = DynamicCryptoGraph(
            initial_graph=self.graph_builder.build_correlation_graph(),
            update_interval=config['update_interval']
        )

        # モデル
        self.model = self._build_model()

        # 特徴抽出器
        self.feature_extractor = GraphFeatureExtractor(self.dynamic_graph.graphs[-1])

    def _build_model(self):
        """モデルの構築"""
        if self.config['model_type'] == 'gcn':
            return CryptoGCN(
                num_features=self.config['num_features'],
                hidden_dim=self.config['hidden_dim'],
                output_dim=1,
                num_layers=self.config['num_layers']
            )
        elif self.config['model_type'] == 'gat':
            return CryptoGAT(
                num_features=self.config['num_features'],
                hidden_dim=self.config['hidden_dim'],
                output_dim=1,
                num_heads=self.config['num_heads']
            )
        elif self.config['model_type'] == 'stgnn':
            return SpatioTemporalGNN(
                num_features=self.config['num_features'],
                hidden_dim=self.config['hidden_dim'],
                temporal_length=self.config['temporal_length'],
                num_nodes=self.config['num_nodes']
            )

    def prepare_graph_data(self, timestamp: int):
        """グラフデータの準備"""
        current_graph = self.dynamic_graph.graphs[-1]

        # ノード特徴量
        node_features = []
        node_mapping = {}

        for i, node in enumerate(current_graph.nodes()):
            features = self.feature_extractor.extract_node_features(node)
            node_features.append(features)
            node_mapping[node] = i

        # エッジインデックス
        edge_index = []
        edge_attr = []

        for edge in current_graph.edges():
            source_idx = node_mapping[edge[0]]
            target_idx = node_mapping[edge[1]]
            edge_index.append([source_idx, target_idx])
            edge_index.append([target_idx, source_idx])  # 無向グラフ

            edge_features = self.feature_extractor.extract_edge_features(edge[0], edge[1])
            edge_attr.append(edge_features)
            edge_attr.append(edge_features)  # 両方向で同じ特徴量

        # PyTorch Geometric形式に変換
        x = torch.FloatTensor(node_features)
        edge_index = torch.LongTensor(edge_index).t().contiguous()
        edge_attr = torch.FloatTensor(edge_attr)

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

        return data, node_mapping

    def train(self, train_data: list, val_data: list, epochs: int = 100):
        """モデルの訓練"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        criterion = nn.MSELoss()

        best_val_loss = float('inf')

        for epoch in range(epochs):
            # 訓練
            self.model.train()
            train_loss = 0

            for batch in train_data:
                optimizer.zero_grad()

                # グラフデータの準備
                graph_data, node_mapping = self.prepare_graph_data(batch['timestamp'])

                # 予測
                predictions = self.model(graph_data.x, graph_data.edge_index)

                # ターゲットの準備
                targets = self._prepare_targets(batch, node_mapping)

                loss = criterion(predictions, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            # 検証
            self.model.eval()
            val_loss = 0

            with torch.no_grad():
                for batch in val_data:
                    graph_data, node_mapping = self.prepare_graph_data(batch['timestamp'])
                    predictions = self.model(graph_data.x, graph_data.edge_index)
                    targets = self._prepare_targets(batch, node_mapping)
                    val_loss += criterion(predictions, targets).item()

            # ログ出力
            print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_data):.4f}, "
                  f"Val Loss = {val_loss/len(val_data):.4f}")

            # ベストモデルの保存
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_gnn_model.pth')

    def predict(self, current_timestamp: int):
        """予測の実行"""
        self.model.eval()

        # 最新のグラフデータを準備
        graph_data, node_mapping = self.prepare_graph_data(current_timestamp)

        with torch.no_grad():
            predictions = self.model(graph_data.x, graph_data.edge_index)

        # 予測結果をアセット名にマッピング
        results = {}
        for asset, idx in node_mapping.items():
            results[asset] = {
                'predicted_return': predictions[idx].item(),
                'confidence': self._calculate_prediction_confidence(
                    graph_data, idx, predictions[idx]
                )
            }

        return results

    def _calculate_prediction_confidence(self, graph_data, node_idx, prediction):
        """予測の信頼度計算"""
        # ノードの中心性に基づく信頼度
        node_degree = (graph_data.edge_index[0] == node_idx).sum().item()
        max_degree = graph_data.edge_index.shape[1] / graph_data.x.shape[0]

        degree_confidence = min(node_degree / max_degree, 1.0)

        # 予測値の大きさに基づく信頼度調整
        prediction_magnitude = abs(prediction.item())
        magnitude_confidence = 1 / (1 + np.exp(-prediction_magnitude))

        # 総合的な信頼度
        confidence = 0.7 * degree_confidence + 0.3 * magnitude_confidence

        return confidence

6.2 バックテストとパフォーマンス評価

class GNNBacktester:
    """GNNモデルのバックテスター"""

    def __init__(self, model, graph_builder):
        self.model = model
        self.graph_builder = graph_builder

    def backtest(self, historical_data: pd.DataFrame, 
                 start_date: str, end_date: str,
                 rebalance_frequency: str = 'daily'):
        """バックテストの実行"""
        results = {
            'dates': [],
            'portfolio_value': [],
            'positions': [],
            'trades': []
        }

        # 初期資本
        capital = 10000
        positions = {}

        # 時系列でループ
        dates = pd.date_range(start=start_date, end=end_date, freq=rebalance_frequency)

        for date in dates:
            # その時点でのグラフを構築
            current_data = historical_data[historical_data.index <= date]
            graph = self.graph_builder.build_correlation_graph(
                current_data.tail(30)  # 直近30日のデータ
            )

            # 予測を実行
            predictions = self._get_predictions(graph, date)

            # ポートフォリオの最適化
            target_positions = self._optimize_portfolio(
                predictions, capital, positions
            )

            # 取引の実行
            trades = self._execute_trades(
                positions, target_positions, 
                current_data.iloc[-1]
            )

            # ポートフォリオ価値の計算
            portfolio_value = self._calculate_portfolio_value(
                positions, current_data.iloc[-1]
            )

            # 結果の記録
            results['dates'].append(date)
            results['portfolio_value'].append(portfolio_value)
            results['positions'].append(positions.copy())
            results['trades'].append(trades)

        return self._analyze_results(results)

    def _optimize_portfolio(self, predictions, capital, current_positions):
        """予測に基づくポートフォリオ最適化"""
        # 予測リターンでソート
        sorted_assets = sorted(
            predictions.items(), 
            key=lambda x: x[1]['predicted_return'] * x[1]['confidence'],
            reverse=True
        )

        # 上位N個のアセットに均等配分(簡略化)
        n_assets = min(10, len(sorted_assets))
        target_positions = {}

        for i in range(n_assets):
            asset = sorted_assets[i][0]
            weight = 1.0 / n_assets
            target_positions[asset] = weight * capital

        return target_positions

    def _analyze_results(self, results):
        """バックテスト結果の分析"""
        df = pd.DataFrame({
            'date': results['dates'],
            'portfolio_value': results['portfolio_value']
        })

        # リターンの計算
        df['returns'] = df['portfolio_value'].pct_change()

        # パフォーマンス指標
        total_return = (df['portfolio_value'].iloc[-1] - df['portfolio_value'].iloc[0]) / \
                      df['portfolio_value'].iloc[0]

        sharpe_ratio = df['returns'].mean() / df['returns'].std() * np.sqrt(252)

        max_drawdown = (df['portfolio_value'].cummax() - df['portfolio_value']).max() / \
                      df['portfolio_value'].cummax().max()

        # 取引統計
        total_trades = sum(len(trades) for trades in results['trades'])

        return {
            'total_return': total_return,
            'sharpe_ratio': sharpe_ratio,
            'max_drawdown': max_drawdown,
            'total_trades': total_trades,
            'portfolio_history': df
        }

7. 最適化とスケーラビリティ

7.1 グラフサンプリング

class GraphSampler:
    """大規模グラフのサンプリング"""

    def __init__(self, graph: nx.Graph):
        self.graph = graph

    def neighbor_sampling(self, target_nodes: list, num_samples: list):
        """近傍サンプリング(GraphSAGE方式)"""
        sampled_nodes = set(target_nodes)
        current_layer = target_nodes

        for num_sample in num_samples:
            next_layer = []
            for node in current_layer:
                neighbors = list(self.graph.neighbors(node))
                if len(neighbors) > num_sample:
                    sampled = np.random.choice(neighbors, num_sample, replace=False)
                else:
                    sampled = neighbors
                next_layer.extend(sampled)
                sampled_nodes.update(sampled)

            current_layer = list(set(next_layer))

        # サブグラフの抽出
        subgraph = self.graph.subgraph(sampled_nodes)

        return subgraph

    def importance_sampling(self, num_nodes: int):
        """重要度に基づくサンプリング"""
        # PageRankで重要度を計算
        pagerank = nx.pagerank(self.graph)

        # 重要度に比例した確率でサンプリング
        nodes = list(self.graph.nodes())
        probs = [pagerank[node] for node in nodes]
        probs = np.array(probs) / sum(probs)

        sampled_nodes = np.random.choice(
            nodes, size=num_nodes, replace=False, p=probs
        )

        return self.graph.subgraph(sampled_nodes)

7.2 分散処理

import torch.distributed as dist
from torch_geometric.data import DataLoader
from torch_geometric.nn import DataParallel

class DistributedGNNTrainer:
    """分散GNN訓練"""

    def __init__(self, model, world_size: int, rank: int):
        self.model = model
        self.world_size = world_size
        self.rank = rank

        # 分散処理の初期化
        dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)

        # モデルを分散化
        self.model = DataParallel(model)
        self.model = self.model.to(f'cuda:{rank}')

    def train_epoch(self, data_loader):
        """分散訓練の1エポック"""
        self.model.train()

        for batch in data_loader:
            # 各GPUにデータを分配
            batch = batch.to(f'cuda:{self.rank}')

            # フォワードパス
            out = self.model(batch.x, batch.edge_index)
            loss = F.mse_loss(out, batch.y)

            # バックワードパス
            loss.backward()

            # 勾配の同期
            for param in self.model.parameters():
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                param.grad.data /= self.world_size

            # パラメータ更新
            self.optimizer.step()
            self.optimizer.zero_grad()

8. まとめとベストプラクティス

8.1 GNN適用のガイドライン

  1. グラフ構築
    - 相関閾値は0.3-0.5が適切
    - 動的グラフは1時間ごとに更新
    - エッジの減衰率は市場の性質に応じて調整

  2. モデル選択
    - 小規模(<100ノード): GCN
    - 中規模(100-1000ノード): GAT
    - 大規模(>1000ノード): GraphSAGE with sampling

  3. 特徴量エンジニアリング
    - グラフ統計量は必須
    - 時系列特徴量との組み合わせが効果的
    - クロスチェーンデータの統合で精度向上

  4. 最適化
    - バッチ正規化とドロップアウトは必須
    - 学習率は0.001から開始
    - Early stoppingで過学習を防ぐ

8.2 実装上の注意点

# 実装チェックリスト
def gnn_implementation_checklist():
    checklist = {
        "データ準備": [
            "価格データの正規化",
            "欠損値の処理",
            "外れ値の除去"
        ],
        "グラフ構築": [
            "適切な相関指標の選択",
            "エッジ重みの正規化",
            "孤立ノードの処理"
        ],
        "モデル設計": [
            "適切な層数(3-5層)",
            "過学習対策",
            "スキップ接続の検討"
        ],
        "訓練": [
            "学習率スケジューリング",
            "勾配クリッピング",
            "検証データでの評価"
        ],
        "運用": [
            "リアルタイム更新の実装",
            "異常検知の組み込み",
            "モデルの定期的な再学習"
        ]
    }
    return checklist

GNNは暗号資産の複雑な相互依存関係を捉える強力なツールです。適切な実装により、従来の手法では見逃されていたパターンを発見し、より精度の高い予測が可能になります。