fix(ml): add _model_swap_lock to prediction reads for thread-safe model access

In /predict and /predict/batch endpoints, grab the model reference under
_model_swap_lock before running inference. Inference itself runs outside
the lock (using a local variable) to avoid blocking model swaps during
potentially slow computation.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:26:33 +01:00
parent ff15adc847
commit b9beea1574
2 changed files with 56 additions and 34 deletions

View file

@ -43,7 +43,7 @@
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py` - [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266` - [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access - [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py` - [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py` - [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409` - [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`

View file

@ -657,6 +657,17 @@ async def predict(request: PredictRequest):
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles") logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try: try:
# Convert candles to list of dicts # Convert candles to list of dicts
candles_data = [candle.model_dump() for candle in request.candles] candles_data = [candle.model_dump() for candle in request.candles]
@ -664,16 +675,16 @@ async def predict(request: PredictRequest):
# Preprocess candles (feature engineering + windowing) # Preprocess candles (feature engineering + windowing)
X, window_times = preprocess_candles(candles_data, state.pipeline_config) X, window_times = preprocess_candles(candles_data, state.pipeline_config)
# Get predictions and probabilities # Get predictions and probabilities (using local reference, outside lock)
if hasattr(state.model, 'predict_proba'): if hasattr(current_model, 'predict_proba'):
y_pred = state.model.predict(X) y_pred = current_model.predict(X)
y_proba = state.model.predict_proba(X) y_proba = current_model.predict_proba(X)
# Get confidence as max probability # Get confidence as max probability
confidences = np.max(y_proba, axis=1) confidences = np.max(y_proba, axis=1)
else: else:
# Fallback for models without predict_proba # Fallback for models without predict_proba
y_pred = state.model.predict(X) y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred)) # Default confidence of 1.0 confidences = np.ones(len(y_pred)) # Default confidence of 1.0
# Get label names (handle both string and int predictions) # Get label names (handle both string and int predictions)
@ -697,12 +708,12 @@ async def predict(request: PredictRequest):
# Build model info for response # Build model info for response
model_info = ModelInfo( model_info = ModelInfo(
model_name=state.model_info["model_name"], model_name=current_model_info["model_name"],
model_version=state.model_info.get("model_version"), model_version=current_model_info.get("model_version"),
model_type=state.model_info["model_type"], model_type=current_model_info["model_type"],
trained_at=state.model_info.get("trained_at"), trained_at=current_model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"), dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"] feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
) )
logger.info( logger.info(
@ -754,6 +765,17 @@ async def predict_batch(request: BatchPredictRequest):
f"from {request.start_date} to {request.end_date}" f"from {request.start_date} to {request.end_date}"
) )
# Grab model reference under lock to prevent reading a partially-swapped model
with _model_swap_lock:
current_model = state.model
current_model_info = state.model_info
if current_model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="No model available"
)
try: try:
# Load OHLCV data from raw data path # Load OHLCV data from raw data path
raw_path = Path(state.pipeline_config.data.raw_path) raw_path = Path(state.pipeline_config.data.raw_path)
@ -807,13 +829,13 @@ async def predict_batch(request: BatchPredictRequest):
# Preprocess (feature engineering + windowing) # Preprocess (feature engineering + windowing)
X, window_times = preprocess_candles(batch_candles, state.pipeline_config) X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
# Predict # Predict (using local reference, outside lock)
if hasattr(state.model, 'predict_proba'): if hasattr(current_model, 'predict_proba'):
y_pred = state.model.predict(X) y_pred = current_model.predict(X)
y_proba = state.model.predict_proba(X) y_proba = current_model.predict_proba(X)
confidences = np.max(y_proba, axis=1) confidences = np.max(y_proba, axis=1)
else: else:
y_pred = state.model.predict(X) y_pred = current_model.predict(X)
confidences = np.ones(len(y_pred)) confidences = np.ones(len(y_pred))
# Get labels # Get labels
@ -839,12 +861,12 @@ async def predict_batch(request: BatchPredictRequest):
# Build model info # Build model info
model_info = ModelInfo( model_info = ModelInfo(
model_name=state.model_info["model_name"], model_name=current_model_info["model_name"],
model_version=state.model_info.get("model_version"), model_version=current_model_info.get("model_version"),
model_type=state.model_info["model_type"], model_type=current_model_info["model_type"],
trained_at=state.model_info.get("trained_at"), trained_at=current_model_info.get("trained_at"),
dataset_version=state.model_info.get("dataset_version"), dataset_version=current_model_info.get("dataset_version"),
feature_engineering_enabled=state.model_info["feature_engineering_enabled"] feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
) )
logger.info( logger.info(