fix(ml): add _model_swap_lock to prediction reads for thread-safe model access
In /predict and /predict/batch endpoints, grab the model reference under _model_swap_lock before running inference. Inference itself runs outside the lock (using a local variable) to avoid blocking model swaps during potentially slow computation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ff15adc847
commit
b9beea1574
2 changed files with 56 additions and 34 deletions
|
|
@ -43,7 +43,7 @@
|
|||
|
||||
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
|
||||
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
|
||||
- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
||||
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||
|
|
|
|||
|
|
@ -642,38 +642,49 @@ async def predict(request: PredictRequest):
|
|||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
|
||||
if not request.candles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Candles array cannot be empty"
|
||||
)
|
||||
|
||||
|
||||
if state.pipeline_config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Pipeline configuration not loaded"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
|
||||
|
||||
|
||||
# Grab model reference under lock to prevent reading a partially-swapped model
|
||||
with _model_swap_lock:
|
||||
current_model = state.model
|
||||
current_model_info = state.model_info
|
||||
|
||||
if current_model is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert candles to list of dicts
|
||||
candles_data = [candle.model_dump() for candle in request.candles]
|
||||
|
||||
|
||||
# Preprocess candles (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
||||
|
||||
# Get predictions and probabilities
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
y_pred = state.model.predict(X)
|
||||
y_proba = state.model.predict_proba(X)
|
||||
|
||||
|
||||
# Get predictions and probabilities (using local reference, outside lock)
|
||||
if hasattr(current_model, 'predict_proba'):
|
||||
y_pred = current_model.predict(X)
|
||||
y_proba = current_model.predict_proba(X)
|
||||
|
||||
# Get confidence as max probability
|
||||
confidences = np.max(y_proba, axis=1)
|
||||
else:
|
||||
# Fallback for models without predict_proba
|
||||
y_pred = state.model.predict(X)
|
||||
y_pred = current_model.predict(X)
|
||||
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
||||
|
||||
# Get label names (handle both string and int predictions)
|
||||
|
|
@ -697,14 +708,14 @@ async def predict(request: PredictRequest):
|
|||
|
||||
# Build model info for response
|
||||
model_info = ModelInfo(
|
||||
model_name=state.model_info["model_name"],
|
||||
model_version=state.model_info.get("model_version"),
|
||||
model_type=state.model_info["model_type"],
|
||||
trained_at=state.model_info.get("trained_at"),
|
||||
dataset_version=state.model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||||
model_name=current_model_info["model_name"],
|
||||
model_version=current_model_info.get("model_version"),
|
||||
model_type=current_model_info["model_type"],
|
||||
trained_at=current_model_info.get("trained_at"),
|
||||
dataset_version=current_model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Prediction complete: {len(predictions)} windows, "
|
||||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||||
|
|
@ -742,17 +753,28 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
|
||||
if state.pipeline_config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Pipeline configuration not loaded"
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Batch predict: {request.pair} {request.timeframe} "
|
||||
f"from {request.start_date} to {request.end_date}"
|
||||
)
|
||||
|
||||
# Grab model reference under lock to prevent reading a partially-swapped model
|
||||
with _model_swap_lock:
|
||||
current_model = state.model
|
||||
current_model_info = state.model_info
|
||||
|
||||
if current_model is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="No model available"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load OHLCV data from raw data path
|
||||
|
|
@ -807,13 +829,13 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
# Preprocess (feature engineering + windowing)
|
||||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
||||
|
||||
# Predict
|
||||
if hasattr(state.model, 'predict_proba'):
|
||||
y_pred = state.model.predict(X)
|
||||
y_proba = state.model.predict_proba(X)
|
||||
# Predict (using local reference, outside lock)
|
||||
if hasattr(current_model, 'predict_proba'):
|
||||
y_pred = current_model.predict(X)
|
||||
y_proba = current_model.predict_proba(X)
|
||||
confidences = np.max(y_proba, axis=1)
|
||||
else:
|
||||
y_pred = state.model.predict(X)
|
||||
y_pred = current_model.predict(X)
|
||||
confidences = np.ones(len(y_pred))
|
||||
|
||||
# Get labels
|
||||
|
|
@ -839,14 +861,14 @@ async def predict_batch(request: BatchPredictRequest):
|
|||
|
||||
# Build model info
|
||||
model_info = ModelInfo(
|
||||
model_name=state.model_info["model_name"],
|
||||
model_version=state.model_info.get("model_version"),
|
||||
model_type=state.model_info["model_type"],
|
||||
trained_at=state.model_info.get("trained_at"),
|
||||
dataset_version=state.model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
||||
model_name=current_model_info["model_name"],
|
||||
model_version=current_model_info.get("model_version"),
|
||||
model_type=current_model_info["model_type"],
|
||||
trained_at=current_model_info.get("trained_at"),
|
||||
dataset_version=current_model_info.get("dataset_version"),
|
||||
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Batch prediction complete: {len(all_predictions)} candles, "
|
||||
f"{len(all_spans)} spans"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue