feat: add SHA256 model integrity check before joblib.load()

Add verify_model_checksum() that validates model files against a
models/checksums.sha256 manifest before loading. Fails open when
manifest is missing or file not listed (backward compat), raises
HTTP 500 on hash mismatch. Created empty manifest placeholder.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marko Djordjevic 2026-02-18 11:25:14 +01:00
parent b7f9b2e04d
commit ff15adc847
3 changed files with 71 additions and 2 deletions

View file

@ -4,6 +4,7 @@ FastAPI inference service for candlestick pattern prediction.
Provides REST API endpoints for model serving, health checks, and prediction.
"""
import hashlib
import logging
import os
import re
@ -258,10 +259,75 @@ def load_model_from_mlflow(model_name: str, stage: str) -> tuple[Any, Dict[str,
raise
def verify_model_checksum(model_path: Path) -> None:
"""
Verify model file integrity against SHA256 checksum manifest.
Looks for checksums in models/checksums.sha256. If the manifest exists and
contains an entry for the model file, verifies the SHA256 hash matches.
Raises HTTPException on mismatch. Logs a warning and continues if the
manifest is missing or the file is not listed (fail-open for backward
compatibility).
Args:
model_path: Path to the model file to verify.
Raises:
HTTPException: If checksum verification fails (hash mismatch).
"""
manifest_path = Path("models/checksums.sha256")
model_path = Path(model_path)
if not manifest_path.exists():
logger.warning("Model checksum manifest not found at %s — skipping integrity check", manifest_path)
return
# Parse manifest: each line is "<sha256hash> <filename>"
checksums: Dict[str, str] = {}
try:
for line in manifest_path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split(None, 1)
if len(parts) == 2:
checksums[parts[1].strip()] = parts[0].strip()
except Exception as e:
logger.warning("Failed to read checksum manifest: %s — skipping integrity check", e)
return
filename = model_path.name
if filename not in checksums:
logger.warning("No checksum entry for '%s' in manifest — skipping integrity check", filename)
return
# Compute SHA256 of the actual file
sha256 = hashlib.sha256()
try:
with open(model_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
except Exception as e:
logger.error("Failed to compute SHA256 for '%s': %s", model_path, e)
raise HTTPException(status_code=500, detail="Internal server error")
actual_hash = sha256.hexdigest()
expected_hash = checksums[filename]
if actual_hash != expected_hash:
logger.error(
"Model integrity check FAILED for '%s': expected %s, got %s",
filename, expected_hash, actual_hash
)
raise HTTPException(status_code=500, detail="Internal server error")
logger.info("Model integrity check passed for '%s'", filename)
def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
"""
Load model from local file using joblib.
Args:
model_path: Path to .pkl model file
@ -278,6 +344,9 @@ def load_model_from_local(model_path: str) -> tuple[Any, Dict[str, Any]]:
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
# Verify model integrity before loading
verify_model_checksum(model_path)
try:
# Load model with joblib
model_data = joblib.load(model_path)