diff --git a/openspec/changes/code-review-fix/tasks.md b/openspec/changes/code-review-fix/tasks.md index f803286..d8bbe7c 100644 --- a/openspec/changes/code-review-fix/tasks.md +++ b/openspec/changes/code-review-fix/tasks.md @@ -43,7 +43,7 @@ - [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py` - [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266` -- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access +- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access - [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py` - [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py` - [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409` diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 6d635bb..60cef4a 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -642,38 +642,49 @@ async def predict(request: PredictRequest): status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) - + if not request.candles: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Candles array cannot be empty" ) - + if state.pipeline_config is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Pipeline configuration not loaded" ) - + logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles") - + + # Grab model reference under lock to prevent reading a partially-swapped model + with _model_swap_lock: + current_model = state.model + current_model_info = state.model_info + + if current_model is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) + try: # Convert candles to list of dicts candles_data = [candle.model_dump() for candle in request.candles] - + # Preprocess candles (feature engineering + windowing) X, window_times = preprocess_candles(candles_data, state.pipeline_config) - - # Get predictions and probabilities - if hasattr(state.model, 'predict_proba'): - y_pred = state.model.predict(X) - y_proba = state.model.predict_proba(X) - + + # Get predictions and probabilities (using local reference, outside lock) + if hasattr(current_model, 'predict_proba'): + y_pred = current_model.predict(X) + y_proba = current_model.predict_proba(X) + # Get confidence as max probability confidences = np.max(y_proba, axis=1) else: # Fallback for models without predict_proba - y_pred = state.model.predict(X) + y_pred = current_model.predict(X) confidences = np.ones(len(y_pred)) # Default confidence of 1.0 # Get label names (handle both string and int predictions) @@ -697,14 +708,14 @@ async def predict(request: PredictRequest): # Build model info for response model_info = ModelInfo( - model_name=state.model_info["model_name"], - model_version=state.model_info.get("model_version"), - model_type=state.model_info["model_type"], - trained_at=state.model_info.get("trained_at"), - dataset_version=state.model_info.get("dataset_version"), - feature_engineering_enabled=state.model_info["feature_engineering_enabled"] + model_name=current_model_info["model_name"], + model_version=current_model_info.get("model_version"), + model_type=current_model_info["model_type"], + trained_at=current_model_info.get("trained_at"), + dataset_version=current_model_info.get("dataset_version"), + feature_engineering_enabled=current_model_info["feature_engineering_enabled"] ) - + logger.info( f"Prediction complete: {len(predictions)} windows, " f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns" @@ -742,17 +753,28 @@ async def predict_batch(request: BatchPredictRequest): status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="No model available" ) - + if state.pipeline_config is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Pipeline configuration not loaded" ) - + logger.info( f"Batch predict: {request.pair} {request.timeframe} " f"from {request.start_date} to {request.end_date}" ) + + # Grab model reference under lock to prevent reading a partially-swapped model + with _model_swap_lock: + current_model = state.model + current_model_info = state.model_info + + if current_model is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="No model available" + ) try: # Load OHLCV data from raw data path @@ -807,13 +829,13 @@ async def predict_batch(request: BatchPredictRequest): # Preprocess (feature engineering + windowing) X, window_times = preprocess_candles(batch_candles, state.pipeline_config) - # Predict - if hasattr(state.model, 'predict_proba'): - y_pred = state.model.predict(X) - y_proba = state.model.predict_proba(X) + # Predict (using local reference, outside lock) + if hasattr(current_model, 'predict_proba'): + y_pred = current_model.predict(X) + y_proba = current_model.predict_proba(X) confidences = np.max(y_proba, axis=1) else: - y_pred = state.model.predict(X) + y_pred = current_model.predict(X) confidences = np.ones(len(y_pred)) # Get labels @@ -839,14 +861,14 @@ async def predict_batch(request: BatchPredictRequest): # Build model info model_info = ModelInfo( - model_name=state.model_info["model_name"], - model_version=state.model_info.get("model_version"), - model_type=state.model_info["model_type"], - trained_at=state.model_info.get("trained_at"), - dataset_version=state.model_info.get("dataset_version"), - feature_engineering_enabled=state.model_info["feature_engineering_enabled"] + model_name=current_model_info["model_name"], + model_version=current_model_info.get("model_version"), + model_type=current_model_info["model_type"], + trained_at=current_model_info.get("trained_at"), + dataset_version=current_model_info.get("dataset_version"), + feature_engineering_enabled=current_model_info["feature_engineering_enabled"] ) - + logger.info( f"Batch prediction complete: {len(all_predictions)} candles, " f"{len(all_spans)} spans"