feat: add SHA256 model integrity check before joblib.load()
Add verify_model_checksum() that validates model files against a models/checksums.sha256 manifest before loading. Fails open when manifest is missing or file not listed (backward compat), raises HTTP 500 on hash mismatch. Created empty manifest placeholder. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
b7f9b2e04d
commit
ff15adc847
3 changed files with 71 additions and 2 deletions
0
models/checksums.sha256
Normal file
0
models/checksums.sha256
Normal file
|
|
@ -42,7 +42,7 @@
|
|||
## 5. ML Service Hardening (Python)
|
||||
|
||||
- [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py`
|
||||
- [ ] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
|
||||
- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266`
|
||||
- [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access
|
||||
- [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py`
|
||||
- [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py`
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ FastAPI inference service for candlestick pattern prediction.
|
|||
Provides REST API endpoints for model serving, health checks, and prediction.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
|
@ -258,10 +259,75 @@ def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str,
|
|||
raise
|
||||
|
||||
|
||||
def verify_model_checksum(model_path: Path) -> None:
|
||||
"""
|
||||
Verify model file integrity against SHA256 checksum manifest.
|
||||
|
||||
Looks for checksums in models/checksums.sha256. If the manifest exists and
|
||||
contains an entry for the model file, verifies the SHA256 hash matches.
|
||||
Raises HTTPException on mismatch. Logs a warning and continues if the
|
||||
manifest is missing or the file is not listed (fail-open for backward
|
||||
compatibility).
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file to verify.
|
||||
|
||||
Raises:
|
||||
HTTPException: If checksum verification fails (hash mismatch).
|
||||
"""
|
||||
manifest_path = Path("models/checksums.sha256")
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not manifest_path.exists():
|
||||
logger.warning("Model checksum manifest not found at %s — skipping integrity check", manifest_path)
|
||||
return
|
||||
|
||||
# Parse manifest: each line is "<sha256hash> <filename>"
|
||||
checksums: Dict[str, str] = {}
|
||||
try:
|
||||
for line in manifest_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
parts = line.split(None, 1)
|
||||
if len(parts) == 2:
|
||||
checksums[parts[1].strip()] = parts[0].strip()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read checksum manifest: %s — skipping integrity check", e)
|
||||
return
|
||||
|
||||
filename = model_path.name
|
||||
if filename not in checksums:
|
||||
logger.warning("No checksum entry for '%s' in manifest — skipping integrity check", filename)
|
||||
return
|
||||
|
||||
# Compute SHA256 of the actual file
|
||||
sha256 = hashlib.sha256()
|
||||
try:
|
||||
with open(model_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
except Exception as e:
|
||||
logger.error("Failed to compute SHA256 for '%s': %s", model_path, e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
actual_hash = sha256.hexdigest()
|
||||
expected_hash = checksums[filename]
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
logger.error(
|
||||
"Model integrity check FAILED for '%s': expected %s, got %s",
|
||||
filename, expected_hash, actual_hash
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
logger.info("Model integrity check passed for '%s'", filename)
|
||||
|
||||
|
||||
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
||||
"""
|
||||
Load model from local file using joblib.
|
||||
|
||||
|
||||
Args:
|
||||
model_path: Path to .pkl model file
|
||||
|
||||
|
|
@ -278,6 +344,9 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
|
|||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Verify model integrity before loading
|
||||
verify_model_checksum(model_path)
|
||||
|
||||
try:
|
||||
# Load model with joblib
|
||||
model_data = joblib.load(model_path)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue