Fix MITRE extraction to use actual S1 API structure + use generatedAlerts for firing status

MITRE fix:
- S1 platform-rules API returns rule["mitre"] = [{tactic, techniques:[{id,title}]}]
  not the flat field names we were checking — updated _extract_mitre to handle
  this as the primary path, keeping flat field fallback for STAR rules
- generatedAlerts field on each platform rule stored in raw JSON during import

Firing status fix:
- sync-rule-firing now reads generatedAlerts from ParsedRule.raw as fast path
  (instant, no SDL PowerQuery needed) since it's returned directly by the
  platform-rules API on every library sync
- SDL PowerQuery retained as fallback for rules imported from detections.json

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mick
2026-05-22 10:42:48 -04:00
parent 7922de315e
commit 7b4eceefb8
+108 -88
View File
@@ -14,50 +14,62 @@ router = APIRouter()
def _extract_mitre(rule: dict) -> tuple[list[str], list[dict]]:
"""Extract (tactics, techniques) from a raw S1 rule dict.
Handles multiple field name conventions across S1 API versions."""
Primary format (platform-rules API):
rule["mitre"] = [
{"tactic": "Execution", "techniques": [{"id": "T1204", "title": "User Execution"}]},
...
]
Falls back to flat field names used by older API versions / STAR rules.
"""
tactics: list[str] = []
techniques: list[dict] = []
for key in ("tactic", "tactics", "mitreTactic", "mitreTactics", "attack.tactic"):
val = rule.get(key)
if isinstance(val, str) and val:
tactics.extend(v.strip() for v in val.split(",") if v.strip())
elif isinstance(val, list):
for v in val:
if isinstance(v, str) and v:
tactics.append(v.strip())
elif isinstance(v, dict):
n = v.get("name") or v.get("tactic") or ""
if n:
tactics.append(n.strip())
for key in ("technique", "techniques", "mitreTechnique", "mitreTechniques",
"attack.technique", "mitreAttack"):
val = rule.get(key)
if isinstance(val, str) and val:
for part in val.split(","):
part = part.strip()
if not part:
continue
if part.startswith("T") and len(part) >= 2 and part[1:5].replace(".", "").isdigit():
tid, _, tname = part.partition(" - ")
techniques.append({"id": tid.strip(), "name": tname.strip() or tid.strip()})
else:
techniques.append({"id": "", "name": part})
elif isinstance(val, list):
for v in val:
if isinstance(v, str) and v.strip():
part = v.strip()
if part.startswith("T") and len(part) >= 5 and part[1:5].replace(".", "").isdigit():
techniques.append({"id": part, "name": part})
else:
techniques.append({"id": "", "name": part})
elif isinstance(v, dict):
tid = v.get("id") or v.get("techniqueId") or v.get("technique_id") or ""
tname = v.get("name") or v.get("technique") or v.get("techniqueName") or tid
# ── Primary: structured mitre array (platform-rules API) ──────────────────
mitre_list = rule.get("mitre")
if isinstance(mitre_list, list):
for item in mitre_list:
if not isinstance(item, dict):
continue
tac = item.get("tactic")
if isinstance(tac, str) and tac.strip():
tactics.append(tac.strip())
for tech in item.get("techniques", []):
if isinstance(tech, dict):
tid = str(tech.get("id", "") or "").strip()
tname = str(tech.get("title") or tech.get("name") or tid).strip()
if tid or tname:
techniques.append({"id": str(tid).strip(), "name": str(tname).strip()})
techniques.append({"id": tid, "name": tname})
# ── Fallback: flat field names (STAR rules / older API versions) ──────────
if not tactics:
for key in ("tactic", "tactics", "mitreTactic", "mitreTactics"):
val = rule.get(key)
if isinstance(val, str) and val:
tactics.extend(v.strip() for v in val.split(",") if v.strip())
elif isinstance(val, list):
for v in val:
if isinstance(v, str) and v:
tactics.append(v.strip())
elif isinstance(v, dict):
n = v.get("name") or v.get("tactic") or ""
if n:
tactics.append(n.strip())
if not techniques:
for key in ("technique", "techniques", "mitreTechnique", "mitreTechniques", "mitreAttack"):
val = rule.get(key)
if isinstance(val, list):
for v in val:
if isinstance(v, str) and v.strip():
techniques.append({"id": v.strip(), "name": v.strip()})
elif isinstance(v, dict):
tid = str(v.get("id") or v.get("techniqueId") or "").strip()
tname = str(v.get("name") or v.get("title") or v.get("technique") or tid).strip()
if tid or tname:
techniques.append({"id": tid, "name": tname})
# Deduplicate
seen_ids: set = set()
unique_techniques = []
for t in techniques:
@@ -152,6 +164,8 @@ def _import_from_api_rules(db, rules: list) -> int:
sources = rule.get("sources") or []
tactics, techniques = _extract_mitre(rule)
# generatedAlerts is returned directly by the platform-rules API
generated_alerts = rule.get("generatedAlerts")
db.add(ParsedRule(
rule_id=rule_id,
name=rule.get("name", "unnamed"),
@@ -161,6 +175,7 @@ def _import_from_api_rules(db, rules: list) -> int:
"data_sources": sources,
"tactics": tactics,
"techniques": techniques,
"generated_alerts": generated_alerts,
}),
))
loaded += 1
@@ -941,74 +956,79 @@ def get_mitre_coverage(db: Session = Depends(get_db)):
@router.post("/sync-rule-firing")
async def sync_rule_firing(period_days: int = 30, db: Session = Depends(get_db)):
"""Query SDL for alert/threat counts by rule name over the last N days.
Tries multiple field name patterns until one returns results.
Caches results in rule_firing_cache table."""
from datetime import datetime, timedelta
now = datetime.utcnow()
from_dt = (now - timedelta(days=period_days)).strftime("%Y-%m-%dT%H:%M:%S.000Z")
to_dt = now.strftime("%Y-%m-%dT%H:%M:%S.000Z")
FIRING_QUERIES = [
("| filter ruleName != '' | group alerts=count() by ruleName | sort -alerts | limit 2000", "ruleName"),
("| filter threatInfo.detectionEngineRule.name != '' | group alerts=count() by threatInfo.detectionEngineRule.name | sort -alerts | limit 2000", "threatInfo.detectionEngineRule.name"),
("| filter alert.ruleName != '' | group alerts=count() by alert.ruleName | sort -alerts | limit 2000", "alert.ruleName"),
]
"""Populate rule firing cache from the generatedAlerts field stored during
the last Detection Library sync (platform-rules API). This is instant and
requires no SDL PowerQuery. Falls back to SDL PowerQuery if the stored data
is missing (e.g. rules were imported from the detections.json file fallback)."""
from datetime import datetime
checked_at = datetime.utcnow()
result_rows = []
query_used = None
errors = []
source = "api"
for query, name_field in FIRING_QUERIES:
# ── Fast path: use generatedAlerts stored in ParsedRule.raw ───────────────
rules = db.query(ParsedRule).filter_by(rule_type="library").all()
for rule in rules:
try:
result = await s1_client.run_powerquery(query, from_dt, to_dt, max_count=10_000_000)
err = result.get("error") if isinstance(result, dict) else None
if err:
errors.append(f"{name_field}: {err}")
raw_data = json.loads(rule.raw) if rule.raw else {}
except Exception:
raw_data = {}
ga = raw_data.get("generated_alerts")
if ga is not None: # present means rule was imported from the live API
result_rows.append({"rule_name": rule.name, "alerts": int(ga)})
# ── Fallback: SDL PowerQuery (rules imported from detections.json) ─────────
if not result_rows:
source = "powerquery"
from datetime import timedelta
now = datetime.utcnow()
from_dt = (now - timedelta(days=period_days)).strftime("%Y-%m-%dT%H:%M:%S.000Z")
to_dt = now.strftime("%Y-%m-%dT%H:%M:%S.000Z")
FIRING_QUERIES = [
("| filter ruleName != '' | group alerts=count() by ruleName | sort -alerts | limit 2000", "ruleName"),
("| filter threatInfo.detectionEngineRule.name != '' | group alerts=count() by threatInfo.detectionEngineRule.name | sort -alerts | limit 2000", "threatInfo.detectionEngineRule.name"),
]
for query, name_field in FIRING_QUERIES:
try:
result = await s1_client.run_powerquery(query, from_dt, to_dt, max_count=10_000_000)
rows = result.get("events", []) if isinstance(result, dict) else []
if rows:
result_rows = [
{"rule_name": r.get(name_field, ""), "alerts": r.get("alerts", 0)}
for r in rows if r.get(name_field)
]
if result_rows:
break
except Exception:
continue
rows = result.get("events", [])
if rows:
# Remap the name field to a standard key
result_rows = [{"rule_name": r.get(name_field, r.get("ruleName", "")), "alerts": r.get("alerts", 0)} for r in rows]
result_rows = [r for r in result_rows if r["rule_name"]]
if result_rows:
query_used = query
break
except Exception as e:
errors.append(f"{name_field}: {e}")
continue
if not result_rows:
return {
"synced": 0,
"period_days": period_days,
"rules_with_alerts": 0,
"query_used": None,
"message": "No alert data found. Errors: " + "; ".join(errors) if errors else "No alert data found — SDL may not have alert events in this time window.",
"source": source,
"message": "No alert data found. Run Sync Detection Library first to import generatedAlerts from the S1 API.",
}
# Upsert into cache
checked_at = datetime.utcnow()
db.query(RuleFiringCache).delete()
for row in result_rows:
existing = db.query(RuleFiringCache).filter_by(rule_name=row["rule_name"]).first()
if existing:
existing.alert_count = row["alerts"]
existing.period_days = period_days
existing.checked_at = checked_at
else:
db.add(RuleFiringCache(
rule_name=row["rule_name"],
alert_count=row["alerts"],
period_days=period_days,
checked_at=checked_at,
))
db.add(RuleFiringCache(
rule_name=row["rule_name"],
alert_count=row["alerts"],
period_days=period_days,
checked_at=checked_at,
))
db.commit()
fired = sum(1 for r in result_rows if r["alerts"] > 0)
return {
"synced": len(result_rows),
"rules_with_alerts": fired,
"rules_never_fired": len(result_rows) - fired,
"source": source,
"period_days": period_days,
"rules_with_alerts": len(result_rows),
"query_used": query_used,
}