目次
リアルタイムデータ処理とストリーミング機械学習
概要
暗号通貨市場は24時間365日動き続けており、リアルタイムでのデータ処理と機械学習の適用が重要です。本ドキュメントでは、ストリーミングデータの処理とオンライン学習を組み合わせた実装方法を解説します。
1. アーキテクチャ概要
1.1 データフローアーキテクチャ
[取引所API/WebSocket] → [データ収集層] → [ストリーム処理層] → [ML処理層] → [アクション層]
↓ ↓ ↓ ↓ ↓
[Binance/Coinbase] [Kafka/Redis] [Spark/Flink] [Online ML] [Trading Bot]
1.2 主要コンポーネント
- データ収集: WebSocket接続によるリアルタイムデータ取得
- メッセージキュー: Apache Kafka/Redis Streamsによるバッファリング
- ストリーム処理: Apache Spark Streaming/Flinkによる前処理
- ML処理: オンライン学習アルゴリズムによる予測
- アクション: 自動取引/アラート生成
2. データ収集層の実装
2.1 WebSocket接続
import asyncio
import websockets
import json
from datetime import datetime
class CryptoWebSocketClient:
def __init__(self, exchange, symbols):
self.exchange = exchange
self.symbols = symbols
self.callbacks = []
async def connect_binance(self):
"""Binance WebSocket接続"""
streams = [f"{symbol.lower()}@ticker" for symbol in self.symbols]
url = f"wss://stream.binance.com:9443/stream?streams={'/'.join(streams)}"
async with websockets.connect(url) as websocket:
while True:
try:
data = await websocket.recv()
parsed_data = self.parse_binance_data(json.loads(data))
await self.process_data(parsed_data)
except Exception as e:
print(f"Error: {e}")
await asyncio.sleep(5)
def parse_binance_data(self, raw_data):
"""データパース"""
if 'data' in raw_data:
ticker = raw_data['data']
return {
'exchange': 'binance',
'symbol': ticker['s'],
'price': float(ticker['c']),
'volume': float(ticker['v']),
'timestamp': datetime.fromtimestamp(ticker['E'] / 1000),
'bid': float(ticker['b']),
'ask': float(ticker['a']),
'price_change_24h': float(ticker['p'])
}
return None
async def process_data(self, data):
"""データ処理コールバック実行"""
if data:
for callback in self.callbacks:
await callback(data)
2.2 マルチエクスチェンジ対応
class MultiExchangeAggregator:
def __init__(self):
self.exchanges = {}
self.aggregated_data = {}
async def add_exchange(self, name, client):
"""取引所の追加"""
self.exchanges[name] = client
client.callbacks.append(self.aggregate_data)
async def aggregate_data(self, data):
"""複数取引所のデータ集約"""
key = f"{data['symbol']}_{data['exchange']}"
self.aggregated_data[key] = data
# アービトラージ機会の検出
await self.check_arbitrage(data['symbol'])
async def check_arbitrage(self, symbol):
"""取引所間の価格差検出"""
prices = []
for exchange in self.exchanges:
key = f"{symbol}_{exchange}"
if key in self.aggregated_data:
prices.append({
'exchange': exchange,
'bid': self.aggregated_data[key]['bid'],
'ask': self.aggregated_data[key]['ask']
})
if len(prices) >= 2:
# 最高買値と最安売値の差を計算
max_bid = max(prices, key=lambda x: x['bid'])
min_ask = min(prices, key=lambda x: x['ask'])
spread_pct = ((max_bid['bid'] - min_ask['ask']) / min_ask['ask']) * 100
if spread_pct > 0.1: # 0.1%以上の差があれば通知
await self.notify_arbitrage({
'symbol': symbol,
'buy_exchange': min_ask['exchange'],
'sell_exchange': max_bid['exchange'],
'profit_pct': spread_pct
})
3. ストリーム処理層
3.1 Apache Kafkaを使用したデータパイプライン
from kafka import KafkaProducer, KafkaConsumer
import json
class KafkaStreamProcessor:
def __init__(self, bootstrap_servers):
self.bootstrap_servers = bootstrap_servers
self.producer = KafkaProducer(
bootstrap_servers=bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
async def publish_market_data(self, topic, data):
"""市場データのパブリッシュ"""
self.producer.send(topic, value=data)
def create_consumer(self, topics, group_id):
"""コンシューマーの作成"""
return KafkaConsumer(
*topics,
bootstrap_servers=self.bootstrap_servers,
group_id=group_id,
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
auto_offset_reset='latest'
)
3.2 ウィンドウ処理と集計
from collections import deque
from datetime import datetime, timedelta
import numpy as np
class StreamWindowProcessor:
def __init__(self, window_size_seconds=60):
self.window_size = timedelta(seconds=window_size_seconds)
self.windows = {}
def add_data_point(self, symbol, price, volume, timestamp):
"""データポイントの追加"""
if symbol not in self.windows:
self.windows[symbol] = deque()
# 古いデータの削除
cutoff_time = timestamp - self.window_size
while self.windows[symbol] and self.windows[symbol][0]['timestamp'] < cutoff_time:
self.windows[symbol].popleft()
# 新しいデータの追加
self.windows[symbol].append({
'price': price,
'volume': volume,
'timestamp': timestamp
})
def calculate_window_statistics(self, symbol):
"""ウィンドウ内の統計情報計算"""
if symbol not in self.windows or not self.windows[symbol]:
return None
prices = [d['price'] for d in self.windows[symbol]]
volumes = [d['volume'] for d in self.windows[symbol]]
return {
'symbol': symbol,
'window_size': len(prices),
'mean_price': np.mean(prices),
'std_price': np.std(prices),
'min_price': np.min(prices),
'max_price': np.max(prices),
'total_volume': np.sum(volumes),
'price_velocity': (prices[-1] - prices[0]) / len(prices) if len(prices) > 1 else 0,
'volatility': np.std(np.diff(prices)) if len(prices) > 1 else 0
}
4. オンライン機械学習
4.1 適応的学習アルゴリズム
import numpy as np
from sklearn.preprocessing import StandardScaler
class OnlineLearningModel:
def __init__(self, learning_rate=0.01, feature_dim=10):
self.learning_rate = learning_rate
self.weights = np.random.randn(feature_dim) * 0.01
self.bias = 0
self.scaler = StandardScaler()
self.feature_buffer = []
self.min_samples_for_update = 100
def extract_features(self, window_stats, technical_indicators):
"""特徴量抽出"""
features = [
window_stats['mean_price'],
window_stats['std_price'],
window_stats['price_velocity'],
window_stats['volatility'],
window_stats['total_volume'],
technical_indicators.get('rsi', 50),
technical_indicators.get('macd', 0),
technical_indicators.get('bb_position', 0.5),
technical_indicators.get('volume_ratio', 1),
technical_indicators.get('price_momentum', 0)
]
return np.array(features)
def predict(self, features):
"""予測"""
if len(self.feature_buffer) < self.min_samples_for_update:
self.feature_buffer.append(features)
return 0 # 十分なデータがない場合は中立
# 特徴量の正規化
features_scaled = self.scaler.transform([features])[0]
# 線形予測
prediction = np.dot(self.weights, features_scaled) + self.bias
return np.tanh(prediction) # -1から1の範囲に制限
def update(self, features, actual_return):
"""オンライン更新(SGD)"""
if len(self.feature_buffer) >= self.min_samples_for_update:
# スケーラーの更新
self.scaler.partial_fit([features])
features_scaled = self.scaler.transform([features])[0]
# 予測誤差
prediction = np.dot(self.weights, features_scaled) + self.bias
error = actual_return - prediction
# 勾配降下法による更新
self.weights += self.learning_rate * error * features_scaled
self.bias += self.learning_rate * error
# 学習率の適応的調整
self.learning_rate *= 0.9999
4.2 アンサンブルオンライン学習
class EnsembleOnlineModel:
def __init__(self, n_models=5):
self.models = [
OnlineLearningModel(learning_rate=0.01 * (i + 1))
for i in range(n_models)
]
self.model_weights = np.ones(n_models) / n_models
self.performance_history = [[] for _ in range(n_models)]
def predict(self, features):
"""アンサンブル予測"""
predictions = []
for i, model in enumerate(self.models):
pred = model.predict(features)
predictions.append(pred * self.model_weights[i])
return np.sum(predictions)
def update(self, features, actual_return):
"""各モデルの更新と重み調整"""
for i, model in enumerate(self.models):
# 個別予測
pred = model.predict(features)
error = abs(actual_return - pred)
# パフォーマンス履歴の更新
self.performance_history[i].append(error)
if len(self.performance_history[i]) > 100:
self.performance_history[i].pop(0)
# モデルの更新
model.update(features, actual_return)
# モデル重みの更新(パフォーマンスに基づく)
if all(len(hist) >= 10 for hist in self.performance_history):
avg_errors = [np.mean(hist) for hist in self.performance_history]
# エラーの逆数を重みとして使用
weights = 1 / (np.array(avg_errors) + 1e-6)
self.model_weights = weights / np.sum(weights)
5. リアルタイム品質管理
5.1 データ品質モニタリング
class DataQualityMonitor:
def __init__(self):
self.metrics = {
'latency': deque(maxlen=1000),
'missing_data': deque(maxlen=1000),
'outliers': deque(maxlen=1000),
'prediction_accuracy': deque(maxlen=1000)
}
self.alerts = []
def check_latency(self, timestamp, received_time):
"""レイテンシーチェック"""
latency = (received_time - timestamp).total_seconds()
self.metrics['latency'].append(latency)
if latency > 1.0: # 1秒以上の遅延
self.raise_alert('HIGH_LATENCY', f'Latency: {latency:.2f}s')
return latency
def check_data_integrity(self, data):
"""データ整合性チェック"""
issues = []
# 必須フィールドの確認
required_fields = ['price', 'volume', 'timestamp']
for field in required_fields:
if field not in data or data[field] is None:
issues.append(f'Missing field: {field}')
# 異常値検出
if 'price' in data:
if data['price'] <= 0 or data['price'] > 1e9:
issues.append(f'Invalid price: {data["price"]}')
# ボリューム検証
if 'volume' in data:
if data['volume'] < 0:
issues.append(f'Negative volume: {data["volume"]}')
if issues:
self.metrics['missing_data'].append(len(issues))
self.raise_alert('DATA_INTEGRITY', ', '.join(issues))
return len(issues) == 0
def monitor_prediction_drift(self, predictions, actuals):
"""予測精度のドリフト検出"""
if len(predictions) != len(actuals):
return
# ローリングウィンドウでの精度計算
window_size = 100
if len(predictions) >= window_size:
recent_acc = np.mean([
1 if np.sign(p) == np.sign(a) else 0
for p, a in zip(predictions[-window_size:], actuals[-window_size:])
])
self.metrics['prediction_accuracy'].append(recent_acc)
# 精度が閾値を下回った場合
if recent_acc < 0.45:
self.raise_alert('MODEL_DRIFT', f'Accuracy dropped to {recent_acc:.2%}')
5.2 自動リカバリー機構
class AutoRecoverySystem:
def __init__(self):
self.connection_retries = {}
self.max_retries = 5
self.backoff_factor = 2
async def reconnect_with_backoff(self, exchange, connect_func):
"""指数バックオフでの再接続"""
if exchange not in self.connection_retries:
self.connection_retries[exchange] = 0
retry_count = self.connection_retries[exchange]
if retry_count >= self.max_retries:
raise Exception(f"Max retries exceeded for {exchange}")
wait_time = self.backoff_factor ** retry_count
print(f"Waiting {wait_time}s before reconnecting to {exchange}")
await asyncio.sleep(wait_time)
try:
await connect_func()
self.connection_retries[exchange] = 0 # リセット
except Exception as e:
self.connection_retries[exchange] += 1
raise e
6. パフォーマンス最適化
6.1 メモリ効率的なデータ構造
import numpy as np
from collections import deque
class CircularBuffer:
"""固定サイズの循環バッファ"""
def __init__(self, size, dtype=np.float32):
self.size = size
self.buffer = np.zeros(size, dtype=dtype)
self.index = 0
self.is_full = False
def append(self, value):
self.buffer[self.index] = value
self.index = (self.index + 1) % self.size
if self.index == 0:
self.is_full = True
def get_data(self):
if self.is_full:
return np.concatenate([
self.buffer[self.index:],
self.buffer[:self.index]
])
else:
return self.buffer[:self.index]
6.2 並列処理の実装
import asyncio
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
class ParallelStreamProcessor:
def __init__(self, n_workers=None):
self.n_workers = n_workers or mp.cpu_count()
self.executor = ProcessPoolExecutor(max_workers=self.n_workers)
async def process_batch_parallel(self, data_batch, process_func):
"""バッチデータの並列処理"""
# データを均等に分割
chunk_size = len(data_batch) // self.n_workers
chunks = [
data_batch[i:i + chunk_size]
for i in range(0, len(data_batch), chunk_size)
]
# 並列実行
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(self.executor, process_func, chunk)
for chunk in chunks
]
results = await asyncio.gather(*futures)
# 結果の結合
return [item for sublist in results for item in sublist]
7. 実装例:統合システム
class RealTimeCryptoMLSystem:
def __init__(self):
self.ws_client = CryptoWebSocketClient('binance', ['BTCUSDT', 'ETHUSDT'])
self.stream_processor = StreamWindowProcessor(window_size_seconds=60)
self.ml_model = EnsembleOnlineModel(n_models=5)
self.quality_monitor = DataQualityMonitor()
self.recovery_system = AutoRecoverySystem()
async def start(self):
"""システム起動"""
# WebSocketコールバックの設定
self.ws_client.callbacks.append(self.process_market_data)
# 接続開始
while True:
try:
await self.ws_client.connect_binance()
except Exception as e:
print(f"Connection error: {e}")
await self.recovery_system.reconnect_with_backoff(
'binance',
self.ws_client.connect_binance
)
async def process_market_data(self, data):
"""市場データの処理"""
# レイテンシーチェック
latency = self.quality_monitor.check_latency(
data['timestamp'],
datetime.now()
)
# データ品質チェック
if not self.quality_monitor.check_data_integrity(data):
return
# ウィンドウ処理
self.stream_processor.add_data_point(
data['symbol'],
data['price'],
data['volume'],
data['timestamp']
)
# 統計情報の計算
stats = self.stream_processor.calculate_window_statistics(data['symbol'])
if stats and stats['window_size'] >= 30:
# 技術指標の計算(別途実装)
technical_indicators = self.calculate_technical_indicators(data['symbol'])
# 特徴量抽出
features = self.ml_model.models[0].extract_features(
stats,
technical_indicators
)
# 予測
prediction = self.ml_model.predict(features)
# アクション実行
await self.execute_trading_action(data['symbol'], prediction)
async def execute_trading_action(self, symbol, prediction):
"""予測に基づくアクション実行"""
threshold = 0.7
if prediction > threshold:
print(f"BUY signal for {symbol}: {prediction:.3f}")
# 買い注文のロジック
elif prediction < -threshold:
print(f"SELL signal for {symbol}: {prediction:.3f}")
# 売り注文のロジック
まとめ
リアルタイムストリーミングMLシステムの実装には、以下の要素が重要です:
- 低レイテンシー: WebSocketとメモリ内処理による高速化
- スケーラビリティ: Kafka/Redisによる水平スケーリング
- 信頼性: 自動リカバリーと品質モニタリング
- 適応性: オンライン学習による市場変化への対応
- 効率性: 並列処理とメモリ最適化
これらを組み合わせることで、24/7稼働する堅牢なシステムを構築できます。