Fix MITRE extraction to use actual S1 API structure + use generatedAlerts for firing status

MITRE fix:
- S1 platform-rules API returns rule["mitre"] = [{tactic, techniques:[{id,title}]}]
  not the flat field names we were checking — updated _extract_mitre to handle
  this as the primary path, keeping flat field fallback for STAR rules
- generatedAlerts field on each platform rule stored in raw JSON during import

Firing status fix:
- sync-rule-firing now reads generatedAlerts from ParsedRule.raw as fast path
  (instant, no SDL PowerQuery needed) since it's returned directly by the
  platform-rules API on every library sync
- SDL PowerQuery retained as fallback for rules imported from detections.json

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mick
2026-05-22 10:42:48 -04:00
parent 7922de315e
commit 7b4eceefb8
+76 -56
View File
@@ -14,11 +14,36 @@ router = APIRouter()
def _extract_mitre(rule: dict) -> tuple[list[str], list[dict]]: def _extract_mitre(rule: dict) -> tuple[list[str], list[dict]]:
"""Extract (tactics, techniques) from a raw S1 rule dict. """Extract (tactics, techniques) from a raw S1 rule dict.
Handles multiple field name conventions across S1 API versions."""
Primary format (platform-rules API):
rule["mitre"] = [
{"tactic": "Execution", "techniques": [{"id": "T1204", "title": "User Execution"}]},
...
]
Falls back to flat field names used by older API versions / STAR rules.
"""
tactics: list[str] = [] tactics: list[str] = []
techniques: list[dict] = [] techniques: list[dict] = []
for key in ("tactic", "tactics", "mitreTactic", "mitreTactics", "attack.tactic"): # ── Primary: structured mitre array (platform-rules API) ──────────────────
mitre_list = rule.get("mitre")
if isinstance(mitre_list, list):
for item in mitre_list:
if not isinstance(item, dict):
continue
tac = item.get("tactic")
if isinstance(tac, str) and tac.strip():
tactics.append(tac.strip())
for tech in item.get("techniques", []):
if isinstance(tech, dict):
tid = str(tech.get("id", "") or "").strip()
tname = str(tech.get("title") or tech.get("name") or tid).strip()
if tid or tname:
techniques.append({"id": tid, "name": tname})
# ── Fallback: flat field names (STAR rules / older API versions) ──────────
if not tactics:
for key in ("tactic", "tactics", "mitreTactic", "mitreTactics"):
val = rule.get(key) val = rule.get(key)
if isinstance(val, str) and val: if isinstance(val, str) and val:
tactics.extend(v.strip() for v in val.split(",") if v.strip()) tactics.extend(v.strip() for v in val.split(",") if v.strip())
@@ -31,33 +56,20 @@ def _extract_mitre(rule: dict) -> tuple[list[str], list[dict]]:
if n: if n:
tactics.append(n.strip()) tactics.append(n.strip())
for key in ("technique", "techniques", "mitreTechnique", "mitreTechniques", if not techniques:
"attack.technique", "mitreAttack"): for key in ("technique", "techniques", "mitreTechnique", "mitreTechniques", "mitreAttack"):
val = rule.get(key) val = rule.get(key)
if isinstance(val, str) and val: if isinstance(val, list):
for part in val.split(","):
part = part.strip()
if not part:
continue
if part.startswith("T") and len(part) >= 2 and part[1:5].replace(".", "").isdigit():
tid, _, tname = part.partition(" - ")
techniques.append({"id": tid.strip(), "name": tname.strip() or tid.strip()})
else:
techniques.append({"id": "", "name": part})
elif isinstance(val, list):
for v in val: for v in val:
if isinstance(v, str) and v.strip(): if isinstance(v, str) and v.strip():
part = v.strip() techniques.append({"id": v.strip(), "name": v.strip()})
if part.startswith("T") and len(part) >= 5 and part[1:5].replace(".", "").isdigit():
techniques.append({"id": part, "name": part})
else:
techniques.append({"id": "", "name": part})
elif isinstance(v, dict): elif isinstance(v, dict):
tid = v.get("id") or v.get("techniqueId") or v.get("technique_id") or "" tid = str(v.get("id") or v.get("techniqueId") or "").strip()
tname = v.get("name") or v.get("technique") or v.get("techniqueName") or tid tname = str(v.get("name") or v.get("title") or v.get("technique") or tid).strip()
if tid or tname: if tid or tname:
techniques.append({"id": str(tid).strip(), "name": str(tname).strip()}) techniques.append({"id": tid, "name": tname})
# Deduplicate
seen_ids: set = set() seen_ids: set = set()
unique_techniques = [] unique_techniques = []
for t in techniques: for t in techniques:
@@ -152,6 +164,8 @@ def _import_from_api_rules(db, rules: list) -> int:
sources = rule.get("sources") or [] sources = rule.get("sources") or []
tactics, techniques = _extract_mitre(rule) tactics, techniques = _extract_mitre(rule)
# generatedAlerts is returned directly by the platform-rules API
generated_alerts = rule.get("generatedAlerts")
db.add(ParsedRule( db.add(ParsedRule(
rule_id=rule_id, rule_id=rule_id,
name=rule.get("name", "unnamed"), name=rule.get("name", "unnamed"),
@@ -161,6 +175,7 @@ def _import_from_api_rules(db, rules: list) -> int:
"data_sources": sources, "data_sources": sources,
"tactics": tactics, "tactics": tactics,
"techniques": techniques, "techniques": techniques,
"generated_alerts": generated_alerts,
}), }),
)) ))
loaded += 1 loaded += 1
@@ -941,10 +956,31 @@ def get_mitre_coverage(db: Session = Depends(get_db)):
@router.post("/sync-rule-firing") @router.post("/sync-rule-firing")
async def sync_rule_firing(period_days: int = 30, db: Session = Depends(get_db)): async def sync_rule_firing(period_days: int = 30, db: Session = Depends(get_db)):
"""Query SDL for alert/threat counts by rule name over the last N days. """Populate rule firing cache from the generatedAlerts field stored during
Tries multiple field name patterns until one returns results. the last Detection Library sync (platform-rules API). This is instant and
Caches results in rule_firing_cache table.""" requires no SDL PowerQuery. Falls back to SDL PowerQuery if the stored data
from datetime import datetime, timedelta is missing (e.g. rules were imported from the detections.json file fallback)."""
from datetime import datetime
checked_at = datetime.utcnow()
result_rows = []
source = "api"
# ── Fast path: use generatedAlerts stored in ParsedRule.raw ───────────────
rules = db.query(ParsedRule).filter_by(rule_type="library").all()
for rule in rules:
try:
raw_data = json.loads(rule.raw) if rule.raw else {}
except Exception:
raw_data = {}
ga = raw_data.get("generated_alerts")
if ga is not None: # present means rule was imported from the live API
result_rows.append({"rule_name": rule.name, "alerts": int(ga)})
# ── Fallback: SDL PowerQuery (rules imported from detections.json) ─────────
if not result_rows:
source = "powerquery"
from datetime import timedelta
now = datetime.utcnow() now = datetime.utcnow()
from_dt = (now - timedelta(days=period_days)).strftime("%Y-%m-%dT%H:%M:%S.000Z") from_dt = (now - timedelta(days=period_days)).strftime("%Y-%m-%dT%H:%M:%S.000Z")
to_dt = now.strftime("%Y-%m-%dT%H:%M:%S.000Z") to_dt = now.strftime("%Y-%m-%dT%H:%M:%S.000Z")
@@ -952,50 +988,32 @@ async def sync_rule_firing(period_days: int = 30, db: Session = Depends(get_db))
FIRING_QUERIES = [ FIRING_QUERIES = [
("| filter ruleName != '' | group alerts=count() by ruleName | sort -alerts | limit 2000", "ruleName"), ("| filter ruleName != '' | group alerts=count() by ruleName | sort -alerts | limit 2000", "ruleName"),
("| filter threatInfo.detectionEngineRule.name != '' | group alerts=count() by threatInfo.detectionEngineRule.name | sort -alerts | limit 2000", "threatInfo.detectionEngineRule.name"), ("| filter threatInfo.detectionEngineRule.name != '' | group alerts=count() by threatInfo.detectionEngineRule.name | sort -alerts | limit 2000", "threatInfo.detectionEngineRule.name"),
("| filter alert.ruleName != '' | group alerts=count() by alert.ruleName | sort -alerts | limit 2000", "alert.ruleName"),
] ]
result_rows = []
query_used = None
errors = []
for query, name_field in FIRING_QUERIES: for query, name_field in FIRING_QUERIES:
try: try:
result = await s1_client.run_powerquery(query, from_dt, to_dt, max_count=10_000_000) result = await s1_client.run_powerquery(query, from_dt, to_dt, max_count=10_000_000)
err = result.get("error") if isinstance(result, dict) else None rows = result.get("events", []) if isinstance(result, dict) else []
if err:
errors.append(f"{name_field}: {err}")
continue
rows = result.get("events", [])
if rows: if rows:
# Remap the name field to a standard key result_rows = [
result_rows = [{"rule_name": r.get(name_field, r.get("ruleName", "")), "alerts": r.get("alerts", 0)} for r in rows] {"rule_name": r.get(name_field, ""), "alerts": r.get("alerts", 0)}
result_rows = [r for r in result_rows if r["rule_name"]] for r in rows if r.get(name_field)
]
if result_rows: if result_rows:
query_used = query
break break
except Exception as e: except Exception:
errors.append(f"{name_field}: {e}")
continue continue
if not result_rows: if not result_rows:
return { return {
"synced": 0, "synced": 0,
"period_days": period_days,
"rules_with_alerts": 0, "rules_with_alerts": 0,
"query_used": None, "source": source,
"message": "No alert data found. Errors: " + "; ".join(errors) if errors else "No alert data found — SDL may not have alert events in this time window.", "message": "No alert data found. Run Sync Detection Library first to import generatedAlerts from the S1 API.",
} }
# Upsert into cache # Upsert into cache
checked_at = datetime.utcnow() db.query(RuleFiringCache).delete()
for row in result_rows: for row in result_rows:
existing = db.query(RuleFiringCache).filter_by(rule_name=row["rule_name"]).first()
if existing:
existing.alert_count = row["alerts"]
existing.period_days = period_days
existing.checked_at = checked_at
else:
db.add(RuleFiringCache( db.add(RuleFiringCache(
rule_name=row["rule_name"], rule_name=row["rule_name"],
alert_count=row["alerts"], alert_count=row["alerts"],
@@ -1004,11 +1022,13 @@ async def sync_rule_firing(period_days: int = 30, db: Session = Depends(get_db))
)) ))
db.commit() db.commit()
fired = sum(1 for r in result_rows if r["alerts"] > 0)
return { return {
"synced": len(result_rows), "synced": len(result_rows),
"rules_with_alerts": fired,
"rules_never_fired": len(result_rows) - fired,
"source": source,
"period_days": period_days, "period_days": period_days,
"rules_with_alerts": len(result_rows),
"query_used": query_used,
} }