fix(ml): add _model_swap_lock to prediction reads for thread-safe model access
In /predict and /predict/batch endpoints, grab the model reference under _model_swap_lock before running inference. Inference itself runs outside the lock (using a local variable) to avoid blocking model swaps during potentially slow computation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ff15adc847
commit
b9beea1574
2 changed files with 56 additions and 34 deletions
|
|
@ -43,7 +43,7 @@
|
||||||
|
|
||||||
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
|
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
|
||||||
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
|
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
|
||||||
- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
- [x] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||||
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
||||||
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||||
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
- [ ] 5.6 `[sonnet]` Implement real health checks: `SELECT 1` for PostgreSQL, MLflow API ping in `services/ml/app/main.py:396-409`
|
||||||
|
|
|
||||||
|
|
@ -642,38 +642,49 @@ async def predict(request: PredictRequest):
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail="No model available"
|
detail="No model available"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not request.candles:
|
if not request.candles:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Candles array cannot be empty"
|
detail="Candles array cannot be empty"
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.pipeline_config is None:
|
if state.pipeline_config is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Pipeline configuration not loaded"
|
detail="Pipeline configuration not loaded"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
|
logger.info(f"Predict request: {request.pair or 'unknown'} {request.timeframe or 'unknown'}, {len(request.candles)} candles")
|
||||||
|
|
||||||
|
# Grab model reference under lock to prevent reading a partially-swapped model
|
||||||
|
with _model_swap_lock:
|
||||||
|
current_model = state.model
|
||||||
|
current_model_info = state.model_info
|
||||||
|
|
||||||
|
if current_model is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="No model available"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert candles to list of dicts
|
# Convert candles to list of dicts
|
||||||
candles_data = [candle.model_dump() for candle in request.candles]
|
candles_data = [candle.model_dump() for candle in request.candles]
|
||||||
|
|
||||||
# Preprocess candles (feature engineering + windowing)
|
# Preprocess candles (feature engineering + windowing)
|
||||||
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
X, window_times = preprocess_candles(candles_data, state.pipeline_config)
|
||||||
|
|
||||||
# Get predictions and probabilities
|
# Get predictions and probabilities (using local reference, outside lock)
|
||||||
if hasattr(state.model, 'predict_proba'):
|
if hasattr(current_model, 'predict_proba'):
|
||||||
y_pred = state.model.predict(X)
|
y_pred = current_model.predict(X)
|
||||||
y_proba = state.model.predict_proba(X)
|
y_proba = current_model.predict_proba(X)
|
||||||
|
|
||||||
# Get confidence as max probability
|
# Get confidence as max probability
|
||||||
confidences = np.max(y_proba, axis=1)
|
confidences = np.max(y_proba, axis=1)
|
||||||
else:
|
else:
|
||||||
# Fallback for models without predict_proba
|
# Fallback for models without predict_proba
|
||||||
y_pred = state.model.predict(X)
|
y_pred = current_model.predict(X)
|
||||||
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
confidences = np.ones(len(y_pred)) # Default confidence of 1.0
|
||||||
|
|
||||||
# Get label names (handle both string and int predictions)
|
# Get label names (handle both string and int predictions)
|
||||||
|
|
@ -697,14 +708,14 @@ async def predict(request: PredictRequest):
|
||||||
|
|
||||||
# Build model info for response
|
# Build model info for response
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=state.model_info["model_name"],
|
model_name=current_model_info["model_name"],
|
||||||
model_version=state.model_info.get("model_version"),
|
model_version=current_model_info.get("model_version"),
|
||||||
model_type=state.model_info["model_type"],
|
model_type=current_model_info["model_type"],
|
||||||
trained_at=state.model_info.get("trained_at"),
|
trained_at=current_model_info.get("trained_at"),
|
||||||
dataset_version=state.model_info.get("dataset_version"),
|
dataset_version=current_model_info.get("dataset_version"),
|
||||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Prediction complete: {len(predictions)} windows, "
|
f"Prediction complete: {len(predictions)} windows, "
|
||||||
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
f"{len(spans)} spans, {len([p for p in predictions if p.label != 'O'])} patterns"
|
||||||
|
|
@ -742,17 +753,28 @@ async def predict_batch(request: BatchPredictRequest):
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail="No model available"
|
detail="No model available"
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.pipeline_config is None:
|
if state.pipeline_config is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Pipeline configuration not loaded"
|
detail="Pipeline configuration not loaded"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Batch predict: {request.pair} {request.timeframe} "
|
f"Batch predict: {request.pair} {request.timeframe} "
|
||||||
f"from {request.start_date} to {request.end_date}"
|
f"from {request.start_date} to {request.end_date}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Grab model reference under lock to prevent reading a partially-swapped model
|
||||||
|
with _model_swap_lock:
|
||||||
|
current_model = state.model
|
||||||
|
current_model_info = state.model_info
|
||||||
|
|
||||||
|
if current_model is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="No model available"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load OHLCV data from raw data path
|
# Load OHLCV data from raw data path
|
||||||
|
|
@ -807,13 +829,13 @@ async def predict_batch(request: BatchPredictRequest):
|
||||||
# Preprocess (feature engineering + windowing)
|
# Preprocess (feature engineering + windowing)
|
||||||
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
X, window_times = preprocess_candles(batch_candles, state.pipeline_config)
|
||||||
|
|
||||||
# Predict
|
# Predict (using local reference, outside lock)
|
||||||
if hasattr(state.model, 'predict_proba'):
|
if hasattr(current_model, 'predict_proba'):
|
||||||
y_pred = state.model.predict(X)
|
y_pred = current_model.predict(X)
|
||||||
y_proba = state.model.predict_proba(X)
|
y_proba = current_model.predict_proba(X)
|
||||||
confidences = np.max(y_proba, axis=1)
|
confidences = np.max(y_proba, axis=1)
|
||||||
else:
|
else:
|
||||||
y_pred = state.model.predict(X)
|
y_pred = current_model.predict(X)
|
||||||
confidences = np.ones(len(y_pred))
|
confidences = np.ones(len(y_pred))
|
||||||
|
|
||||||
# Get labels
|
# Get labels
|
||||||
|
|
@ -839,14 +861,14 @@ async def predict_batch(request: BatchPredictRequest):
|
||||||
|
|
||||||
# Build model info
|
# Build model info
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=state.model_info["model_name"],
|
model_name=current_model_info["model_name"],
|
||||||
model_version=state.model_info.get("model_version"),
|
model_version=current_model_info.get("model_version"),
|
||||||
model_type=state.model_info["model_type"],
|
model_type=current_model_info["model_type"],
|
||||||
trained_at=state.model_info.get("trained_at"),
|
trained_at=current_model_info.get("trained_at"),
|
||||||
dataset_version=state.model_info.get("dataset_version"),
|
dataset_version=current_model_info.get("dataset_version"),
|
||||||
feature_engineering_enabled=state.model_info["feature_engineering_enabled"]
|
feature_engineering_enabled=current_model_info["feature_engineering_enabled"]
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Batch prediction complete: {len(all_predictions)} candles, "
|
f"Batch prediction complete: {len(all_predictions)} candles, "
|
||||||
f"{len(all_spans)} spans"
|
f"{len(all_spans)} spans"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue