diff --git a/models/checksums.sha256 b/models/checksums.sha256 new file mode 100644 index 0000000..e69de29 diff --git a/openspec/changes/code-review-fix/tasks.md b/openspec/changes/code-review-fix/tasks.md index 603a25d..f803286 100644 --- a/openspec/changes/code-review-fix/tasks.md +++ b/openspec/changes/code-review-fix/tasks.md @@ -42,7 +42,7 @@ ## 5. ML Service Hardening (Python) - [x] 5.1 `[sonnet]` Replace `error.message` / traceback details with generic `"Internal server error"` in FastAPI exception handlers at lines 640, 778, 1091, 1134, 1199, 1296 of `services/ml/app/main.py` -- [ ] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266` +- [x] 5.2 `[opus]` Add SHA256 model integrity check: create `models/checksums.sha256` manifest, verify hash before `joblib.load()` in `services/ml/app/main.py:266` - [ ] 5.3 `[sonnet]` Add `_model_swap_lock` to prediction reads (not just writes) in `services/ml/app/main.py` for thread-safe model access - [ ] 5.4 `[sonnet]` Add date range validation (max 1 year) to `POST /predict/batch` in `services/ml/app/main.py` - [ ] 5.5 `[sonnet]` Add candle time-sort validation/auto-sort to `POST /predict` in `services/ml/app/main.py` diff --git a/services/ml/app/main.py b/services/ml/app/main.py index 6915f1f..6d635bb 100644 --- a/services/ml/app/main.py +++ b/services/ml/app/main.py @@ -4,6 +4,7 @@ FastAPI inference service for candlestick pattern prediction. Provides REST API endpoints for model serving, health checks, and prediction. """ +import hashlib import logging import os import re @@ -258,10 +259,75 @@ def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str, raise +def verify_model_checksum(model_path: Path) -> None: + """ + Verify model file integrity against SHA256 checksum manifest. + + Looks for checksums in models/checksums.sha256. If the manifest exists and + contains an entry for the model file, verifies the SHA256 hash matches. + Raises HTTPException on mismatch. Logs a warning and continues if the + manifest is missing or the file is not listed (fail-open for backward + compatibility). + + Args: + model_path: Path to the model file to verify. + + Raises: + HTTPException: If checksum verification fails (hash mismatch). + """ + manifest_path = Path("models/checksums.sha256") + model_path = Path(model_path) + + if not manifest_path.exists(): + logger.warning("Model checksum manifest not found at %s — skipping integrity check", manifest_path) + return + + # Parse manifest: each line is " " + checksums: Dict[str, str] = {} + try: + for line in manifest_path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split(None, 1) + if len(parts) == 2: + checksums[parts[1].strip()] = parts[0].strip() + except Exception as e: + logger.warning("Failed to read checksum manifest: %s — skipping integrity check", e) + return + + filename = model_path.name + if filename not in checksums: + logger.warning("No checksum entry for '%s' in manifest — skipping integrity check", filename) + return + + # Compute SHA256 of the actual file + sha256 = hashlib.sha256() + try: + with open(model_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + except Exception as e: + logger.error("Failed to compute SHA256 for '%s': %s", model_path, e) + raise HTTPException(status_code=500, detail="Internal server error") + + actual_hash = sha256.hexdigest() + expected_hash = checksums[filename] + + if actual_hash != expected_hash: + logger.error( + "Model integrity check FAILED for '%s': expected %s, got %s", + filename, expected_hash, actual_hash + ) + raise HTTPException(status_code=500, detail="Internal server error") + + logger.info("Model integrity check passed for '%s'", filename) + + def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: """ Load model from local file using joblib. - + Args: model_path: Path to .pkl model file @@ -278,6 +344,9 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]: if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") + # Verify model integrity before loading + verify_model_checksum(model_path) + try: # Load model with joblib model_data = joblib.load(model_path)