fix(ml): add _model_swap_lock to prediction reads for thread-safe model access

In /predict and /predict/batch endpoints, grab the model reference under
_model_swap_lock before running inference. Inference itself runs outside
the lock (using a local variable) to avoid blocking model swaps during
potentially slow computation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:26:33 +01:00
parent ff15adc847
commit b9beea1574
2 changed files with 56 additions and 34 deletions

View file

@ -43,7 +43,7 @@
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`

View file

@ -642,38 +642,49 @@ async def predict(request: PredictRequest):
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if not request.candles:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Candles array cannot be empty"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try:
# Convert candles to list of dicts
candles_data = [candle.model_dump() for candle in request.candles]
# Preprocess candles (feature engineering + windowing)
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
# Get predictions and probabilities
if hasattr(state.model, 'predict_proba'):
y_pred = state.model.predict(X)
y_proba = state.model.predict_proba(X)
# Get predictions and probabilities (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'):
y_pred = current_model.predict(X)
y_proba = current_model.predict_proba(X)
# Get confidence as max probability
confidences = np.max(y_proba, axis=1)
else:
# Fallback for models without predict_proba
y_pred = state.model.predict(X)
y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
# Get label names (handle both string and int predictions)
@ -697,14 +708,14 @@ async def predict(request: PredictRequest):
# Build model info for response
model_info = ModelInfo(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
model_name=current_model_info["model_name"],
model_version=current_model_info.get("model_version"),
model_type=current_model_info["model_type"],
trained_at=current_model_info.get("trained_at"),
dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
)
logger.info(
f"Prediction complete: {len(predictions)} windows, "
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
@ -742,17 +753,28 @@ async def predict_batch(request: BatchPredictRequest):
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
if state.pipeline_config is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Pipeline configuration not loaded"
)
logger.info(
f"Batch predict: {request.pair} {request.timeframe} "
f"from {request.start_date} to {request.end_date}"
)
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try:
# Load OHLCV data from raw data path
@ -807,13 +829,13 @@ async def predict_batch(request: BatchPredictRequest):
# Preprocess (feature engineering + windowing)
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
# Predict
if hasattr(state.model, 'predict_proba'):
y_pred = state.model.predict(X)
y_proba = state.model.predict_proba(X)
# Predict (using local reference, outside lock)
if hasattr(current_model, 'predict_proba'):
y_pred = current_model.predict(X)
y_proba = current_model.predict_proba(X)
confidences = np.max(y_proba, axis=1)
else:
y_pred = state.model.predict(X)
y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred))
# Get labels
@ -839,14 +861,14 @@ async def predict_batch(request: BatchPredictRequest):
# Build model info
model_info = ModelInfo(
model_name=state.model_info["model_name"],
model_version=state.model_info.get("model_version"),
model_type=state.model_info["model_type"],
trained_at=state.model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
model_name=current_model_info["model_name"],
model_version=current_model_info.get("model_version"),
model_type=current_model_info["model_type"],
trained_at=current_model_info.get("trained_at"),
dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
)
logger.info(
f"Batch prediction complete: {len(all_predictions)} candles, "
f"{len(all_spans)} spans"