ML Documentation

リアルタイムデータ処理とストリーミング機械学習

概要

暗号通貨市場は24時間365日動き続けており、リアルタイムでのデータ処理と機械学習の適用が重要です。本ドキュメントでは、ストリーミングデータの処理とオンライン学習を組み合わせた実装方法を解説します。

1. アーキテクチャ概要

1.1 データフローアーキテクチャ

[取引所API/WebSocket] → [データ収集層] → [ストリーム処理層] → [ML処理層] → [アクション層]
     ↓                        ↓                 ↓                ↓              ↓
[Binance/Coinbase]    [Kafka/Redis]      [Spark/Flink]    [Online ML]   [Trading Bot]

1.2 主要コンポーネント

2. データ収集層の実装

2.1 WebSocket接続

import asyncio
import websockets
import json
from datetime import datetime

class CryptoWebSocketClient:
    def __init__(self, exchange, symbols):
        self.exchange = exchange
        self.symbols = symbols
        self.callbacks = []

    async def connect_binance(self):
        """Binance WebSocket接続"""
        streams = [f"{symbol.lower()}@ticker" for symbol in self.symbols]
        url = f"wss://stream.binance.com:9443/stream?streams={'/'.join(streams)}"

        async with websockets.connect(url) as websocket:
            while True:
                try:
                    data = await websocket.recv()
                    parsed_data = self.parse_binance_data(json.loads(data))
                    await self.process_data(parsed_data)
                except Exception as e:
                    print(f"Error: {e}")
                    await asyncio.sleep(5)

    def parse_binance_data(self, raw_data):
        """データパース"""
        if 'data' in raw_data:
            ticker = raw_data['data']
            return {
                'exchange': 'binance',
                'symbol': ticker['s'],
                'price': float(ticker['c']),
                'volume': float(ticker['v']),
                'timestamp': datetime.fromtimestamp(ticker['E'] / 1000),
                'bid': float(ticker['b']),
                'ask': float(ticker['a']),
                'price_change_24h': float(ticker['p'])
            }
        return None

    async def process_data(self, data):
        """データ処理コールバック実行"""
        if data:
            for callback in self.callbacks:
                await callback(data)

2.2 マルチエクスチェンジ対応

class MultiExchangeAggregator:
    def __init__(self):
        self.exchanges = {}
        self.aggregated_data = {}

    async def add_exchange(self, name, client):
        """取引所の追加"""
        self.exchanges[name] = client
        client.callbacks.append(self.aggregate_data)

    async def aggregate_data(self, data):
        """複数取引所のデータ集約"""
        key = f"{data['symbol']}_{data['exchange']}"
        self.aggregated_data[key] = data

        # アービトラージ機会の検出
        await self.check_arbitrage(data['symbol'])

    async def check_arbitrage(self, symbol):
        """取引所間の価格差検出"""
        prices = []
        for exchange in self.exchanges:
            key = f"{symbol}_{exchange}"
            if key in self.aggregated_data:
                prices.append({
                    'exchange': exchange,
                    'bid': self.aggregated_data[key]['bid'],
                    'ask': self.aggregated_data[key]['ask']
                })

        if len(prices) >= 2:
            # 最高買値と最安売値の差を計算
            max_bid = max(prices, key=lambda x: x['bid'])
            min_ask = min(prices, key=lambda x: x['ask'])

            spread_pct = ((max_bid['bid'] - min_ask['ask']) / min_ask['ask']) * 100

            if spread_pct > 0.1:  # 0.1%以上の差があれば通知
                await self.notify_arbitrage({
                    'symbol': symbol,
                    'buy_exchange': min_ask['exchange'],
                    'sell_exchange': max_bid['exchange'],
                    'profit_pct': spread_pct
                })

3. ストリーム処理層

3.1 Apache Kafkaを使用したデータパイプライン

from kafka import KafkaProducer, KafkaConsumer
import json

class KafkaStreamProcessor:
    def __init__(self, bootstrap_servers):
        self.bootstrap_servers = bootstrap_servers
        self.producer = KafkaProducer(
            bootstrap_servers=bootstrap_servers,
            value_serializer=lambda v: json.dumps(v).encode('utf-8')
        )

    async def publish_market_data(self, topic, data):
        """市場データのパブリッシュ"""
        self.producer.send(topic, value=data)

    def create_consumer(self, topics, group_id):
        """コンシューマーの作成"""
        return KafkaConsumer(
            *topics,
            bootstrap_servers=self.bootstrap_servers,
            group_id=group_id,
            value_deserializer=lambda m: json.loads(m.decode('utf-8')),
            auto_offset_reset='latest'
        )

3.2 ウィンドウ処理と集計

from collections import deque
from datetime import datetime, timedelta
import numpy as np

class StreamWindowProcessor:
    def __init__(self, window_size_seconds=60):
        self.window_size = timedelta(seconds=window_size_seconds)
        self.windows = {}

    def add_data_point(self, symbol, price, volume, timestamp):
        """データポイントの追加"""
        if symbol not in self.windows:
            self.windows[symbol] = deque()

        # 古いデータの削除
        cutoff_time = timestamp - self.window_size
        while self.windows[symbol] and self.windows[symbol][0]['timestamp'] < cutoff_time:
            self.windows[symbol].popleft()

        # 新しいデータの追加
        self.windows[symbol].append({
            'price': price,
            'volume': volume,
            'timestamp': timestamp
        })

    def calculate_window_statistics(self, symbol):
        """ウィンドウ内の統計情報計算"""
        if symbol not in self.windows or not self.windows[symbol]:
            return None

        prices = [d['price'] for d in self.windows[symbol]]
        volumes = [d['volume'] for d in self.windows[symbol]]

        return {
            'symbol': symbol,
            'window_size': len(prices),
            'mean_price': np.mean(prices),
            'std_price': np.std(prices),
            'min_price': np.min(prices),
            'max_price': np.max(prices),
            'total_volume': np.sum(volumes),
            'price_velocity': (prices[-1] - prices[0]) / len(prices) if len(prices) > 1 else 0,
            'volatility': np.std(np.diff(prices)) if len(prices) > 1 else 0
        }

4. オンライン機械学習

4.1 適応的学習アルゴリズム

import numpy as np
from sklearn.preprocessing import StandardScaler

class OnlineLearningModel:
    def __init__(self, learning_rate=0.01, feature_dim=10):
        self.learning_rate = learning_rate
        self.weights = np.random.randn(feature_dim) * 0.01
        self.bias = 0
        self.scaler = StandardScaler()
        self.feature_buffer = []
        self.min_samples_for_update = 100

    def extract_features(self, window_stats, technical_indicators):
        """特徴量抽出"""
        features = [
            window_stats['mean_price'],
            window_stats['std_price'],
            window_stats['price_velocity'],
            window_stats['volatility'],
            window_stats['total_volume'],
            technical_indicators.get('rsi', 50),
            technical_indicators.get('macd', 0),
            technical_indicators.get('bb_position', 0.5),
            technical_indicators.get('volume_ratio', 1),
            technical_indicators.get('price_momentum', 0)
        ]
        return np.array(features)

    def predict(self, features):
        """予測"""
        if len(self.feature_buffer) < self.min_samples_for_update:
            self.feature_buffer.append(features)
            return 0  # 十分なデータがない場合は中立

        # 特徴量の正規化
        features_scaled = self.scaler.transform([features])[0]

        # 線形予測
        prediction = np.dot(self.weights, features_scaled) + self.bias
        return np.tanh(prediction)  # -1から1の範囲に制限

    def update(self, features, actual_return):
        """オンライン更新(SGD)"""
        if len(self.feature_buffer) >= self.min_samples_for_update:
            # スケーラーの更新
            self.scaler.partial_fit([features])
            features_scaled = self.scaler.transform([features])[0]

            # 予測誤差
            prediction = np.dot(self.weights, features_scaled) + self.bias
            error = actual_return - prediction

            # 勾配降下法による更新
            self.weights += self.learning_rate * error * features_scaled
            self.bias += self.learning_rate * error

            # 学習率の適応的調整
            self.learning_rate *= 0.9999

4.2 アンサンブルオンライン学習

class EnsembleOnlineModel:
    def __init__(self, n_models=5):
        self.models = [
            OnlineLearningModel(learning_rate=0.01 * (i + 1))
            for i in range(n_models)
        ]
        self.model_weights = np.ones(n_models) / n_models
        self.performance_history = [[] for _ in range(n_models)]

    def predict(self, features):
        """アンサンブル予測"""
        predictions = []
        for i, model in enumerate(self.models):
            pred = model.predict(features)
            predictions.append(pred * self.model_weights[i])

        return np.sum(predictions)

    def update(self, features, actual_return):
        """各モデルの更新と重み調整"""
        for i, model in enumerate(self.models):
            # 個別予測
            pred = model.predict(features)
            error = abs(actual_return - pred)

            # パフォーマンス履歴の更新
            self.performance_history[i].append(error)
            if len(self.performance_history[i]) > 100:
                self.performance_history[i].pop(0)

            # モデルの更新
            model.update(features, actual_return)

        # モデル重みの更新(パフォーマンスに基づく)
        if all(len(hist) >= 10 for hist in self.performance_history):
            avg_errors = [np.mean(hist) for hist in self.performance_history]
            # エラーの逆数を重みとして使用
            weights = 1 / (np.array(avg_errors) + 1e-6)
            self.model_weights = weights / np.sum(weights)

5. リアルタイム品質管理

5.1 データ品質モニタリング

class DataQualityMonitor:
    def __init__(self):
        self.metrics = {
            'latency': deque(maxlen=1000),
            'missing_data': deque(maxlen=1000),
            'outliers': deque(maxlen=1000),
            'prediction_accuracy': deque(maxlen=1000)
        }
        self.alerts = []

    def check_latency(self, timestamp, received_time):
        """レイテンシーチェック"""
        latency = (received_time - timestamp).total_seconds()
        self.metrics['latency'].append(latency)

        if latency > 1.0:  # 1秒以上の遅延
            self.raise_alert('HIGH_LATENCY', f'Latency: {latency:.2f}s')

        return latency

    def check_data_integrity(self, data):
        """データ整合性チェック"""
        issues = []

        # 必須フィールドの確認
        required_fields = ['price', 'volume', 'timestamp']
        for field in required_fields:
            if field not in data or data[field] is None:
                issues.append(f'Missing field: {field}')

        # 異常値検出
        if 'price' in data:
            if data['price'] <= 0 or data['price'] > 1e9:
                issues.append(f'Invalid price: {data["price"]}')

        # ボリューム検証
        if 'volume' in data:
            if data['volume'] < 0:
                issues.append(f'Negative volume: {data["volume"]}')

        if issues:
            self.metrics['missing_data'].append(len(issues))
            self.raise_alert('DATA_INTEGRITY', ', '.join(issues))

        return len(issues) == 0

    def monitor_prediction_drift(self, predictions, actuals):
        """予測精度のドリフト検出"""
        if len(predictions) != len(actuals):
            return

        # ローリングウィンドウでの精度計算
        window_size = 100
        if len(predictions) >= window_size:
            recent_acc = np.mean([
                1 if np.sign(p) == np.sign(a) else 0
                for p, a in zip(predictions[-window_size:], actuals[-window_size:])
            ])

            self.metrics['prediction_accuracy'].append(recent_acc)

            # 精度が閾値を下回った場合
            if recent_acc < 0.45:
                self.raise_alert('MODEL_DRIFT', f'Accuracy dropped to {recent_acc:.2%}')

5.2 自動リカバリー機構

class AutoRecoverySystem:
    def __init__(self):
        self.connection_retries = {}
        self.max_retries = 5
        self.backoff_factor = 2

    async def reconnect_with_backoff(self, exchange, connect_func):
        """指数バックオフでの再接続"""
        if exchange not in self.connection_retries:
            self.connection_retries[exchange] = 0

        retry_count = self.connection_retries[exchange]

        if retry_count >= self.max_retries:
            raise Exception(f"Max retries exceeded for {exchange}")

        wait_time = self.backoff_factor ** retry_count
        print(f"Waiting {wait_time}s before reconnecting to {exchange}")
        await asyncio.sleep(wait_time)

        try:
            await connect_func()
            self.connection_retries[exchange] = 0  # リセット
        except Exception as e:
            self.connection_retries[exchange] += 1
            raise e

6. パフォーマンス最適化

6.1 メモリ効率的なデータ構造

import numpy as np
from collections import deque

class CircularBuffer:
    """固定サイズの循環バッファ"""
    def __init__(self, size, dtype=np.float32):
        self.size = size
        self.buffer = np.zeros(size, dtype=dtype)
        self.index = 0
        self.is_full = False

    def append(self, value):
        self.buffer[self.index] = value
        self.index = (self.index + 1) % self.size
        if self.index == 0:
            self.is_full = True

    def get_data(self):
        if self.is_full:
            return np.concatenate([
                self.buffer[self.index:],
                self.buffer[:self.index]
            ])
        else:
            return self.buffer[:self.index]

6.2 並列処理の実装

import asyncio
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp

class ParallelStreamProcessor:
    def __init__(self, n_workers=None):
        self.n_workers = n_workers or mp.cpu_count()
        self.executor = ProcessPoolExecutor(max_workers=self.n_workers)

    async def process_batch_parallel(self, data_batch, process_func):
        """バッチデータの並列処理"""
        # データを均等に分割
        chunk_size = len(data_batch) // self.n_workers
        chunks = [
            data_batch[i:i + chunk_size]
            for i in range(0, len(data_batch), chunk_size)
        ]

        # 並列実行
        loop = asyncio.get_event_loop()
        futures = [
            loop.run_in_executor(self.executor, process_func, chunk)
            for chunk in chunks
        ]

        results = await asyncio.gather(*futures)

        # 結果の結合
        return [item for sublist in results for item in sublist]

7. 実装例:統合システム

class RealTimeCryptoMLSystem:
    def __init__(self):
        self.ws_client = CryptoWebSocketClient('binance', ['BTCUSDT', 'ETHUSDT'])
        self.stream_processor = StreamWindowProcessor(window_size_seconds=60)
        self.ml_model = EnsembleOnlineModel(n_models=5)
        self.quality_monitor = DataQualityMonitor()
        self.recovery_system = AutoRecoverySystem()

    async def start(self):
        """システム起動"""
        # WebSocketコールバックの設定
        self.ws_client.callbacks.append(self.process_market_data)

        # 接続開始
        while True:
            try:
                await self.ws_client.connect_binance()
            except Exception as e:
                print(f"Connection error: {e}")
                await self.recovery_system.reconnect_with_backoff(
                    'binance',
                    self.ws_client.connect_binance
                )

    async def process_market_data(self, data):
        """市場データの処理"""
        # レイテンシーチェック
        latency = self.quality_monitor.check_latency(
            data['timestamp'],
            datetime.now()
        )

        # データ品質チェック
        if not self.quality_monitor.check_data_integrity(data):
            return

        # ウィンドウ処理
        self.stream_processor.add_data_point(
            data['symbol'],
            data['price'],
            data['volume'],
            data['timestamp']
        )

        # 統計情報の計算
        stats = self.stream_processor.calculate_window_statistics(data['symbol'])

        if stats and stats['window_size'] >= 30:
            # 技術指標の計算(別途実装)
            technical_indicators = self.calculate_technical_indicators(data['symbol'])

            # 特徴量抽出
            features = self.ml_model.models[0].extract_features(
                stats,
                technical_indicators
            )

            # 予測
            prediction = self.ml_model.predict(features)

            # アクション実行
            await self.execute_trading_action(data['symbol'], prediction)

    async def execute_trading_action(self, symbol, prediction):
        """予測に基づくアクション実行"""
        threshold = 0.7

        if prediction > threshold:
            print(f"BUY signal for {symbol}: {prediction:.3f}")
            # 買い注文のロジック
        elif prediction < -threshold:
            print(f"SELL signal for {symbol}: {prediction:.3f}")
            # 売り注文のロジック

まとめ

リアルタイムストリーミングMLシステムの実装には、以下の要素が重要です:

  1. 低レイテンシー: WebSocketとメモリ内処理による高速化
  2. スケーラビリティ: Kafka/Redisによる水平スケーリング
  3. 信頼性: 自動リカバリーと品質モニタリング
  4. 適応性: オンライン学習による市場変化への対応
  5. 効率性: 並列処理とメモリ最適化

これらを組み合わせることで、24/7稼働する堅牢なシステムを構築できます。