Files
cold_display_guard/src/cold_display_guard/config.py

342 lines
12 KiB
Python

from __future__ import annotations
import tomllib
from copy import deepcopy
from pathlib import Path
from typing import Any
from cold_display_guard.models import DEFAULT_ZONE_IDS, EngineSettings
DEFAULT_CONFIG_PATH = Path("config/example.toml")
MAX_CUSTOM_FOOD_ZONES = 10
def load_settings(path: str | Path) -> EngineSettings:
data = load_config_document(path)
thresholds: dict[str, Any] = data.get("thresholds", {})
layout: dict[str, Any] = data.get("layout", {})
zone_ids = _zone_ids_from_layout(layout)
if not zone_ids:
zone_ids = DEFAULT_ZONE_IDS
return EngineSettings(
camera_id=str(data.get("camera_id", "cold_display_cam_01")),
max_dwell_seconds=int(thresholds.get("max_dwell_seconds", 10_800)),
trash_confirmation_seconds=int(thresholds.get("trash_confirmation_seconds", 120)),
zone_ids=zone_ids,
)
def load_config_document(path: str | Path) -> dict[str, Any]:
config_path = Path(path)
return tomllib.loads(config_path.read_text(encoding="utf-8"))
def save_config_document(path: str | Path, data: dict[str, Any]) -> None:
config_path = Path(path)
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(format_config_document(data), encoding="utf-8")
def resolve_config_path(path: str | Path | None = None) -> Path:
if path is None:
path = DEFAULT_CONFIG_PATH
return Path(path).expanduser().resolve()
def resolve_project_root(config_path: str | Path) -> Path:
path = Path(config_path).expanduser().resolve()
if path.parent.name == "config":
return path.parent.parent
return Path.cwd().resolve()
def merge_calibration(
data: dict[str, Any],
zones: list[dict[str, Any]],
trash_roi: list[list[float]] | None,
layout_update: dict[str, Any] | None = None,
) -> dict[str, Any]:
merged = deepcopy(data)
incoming_numeric_zone_ids = _incoming_numeric_zone_ids(layout_update)
valid_zones: dict[str, dict[str, Any]] = {}
for zone in zones:
zone_id = str(zone.get("id", "")).strip()
if zone_id.lower() == "trash":
continue
polygon = _normalize_points(zone.get("polygon", []))
if not zone_id or len(polygon) < 3:
continue
valid_zone: dict[str, Any] = {"id": zone_id, "polygon": polygon}
label = str(zone.get("label", "")).strip()
if zone_id.isdecimal():
valid_zone["label"] = f"区域 {int(zone_id)}"
elif label:
valid_zone["label"] = label
valid_zones[zone_id] = valid_zone
if valid_zones or incoming_numeric_zone_ids:
layout = merged.setdefault("layout", {})
existing_numeric_zone_ids = _existing_numeric_zone_ids(layout)
if incoming_numeric_zone_ids or existing_numeric_zone_ids or _is_numeric_zone_ids(valid_zones):
zone_order = _numeric_calibration_zone_order(
incoming_numeric_zone_ids,
existing_numeric_zone_ids,
valid_zones,
)
_validate_numeric_zone_ids(zone_order)
existing_by_id = {
str(zone.get("id", "")).strip(): zone
for zone in merged.get("zones", [])
if str(zone.get("id", "")).strip()
}
layout.pop("rows", None)
layout.pop("cols", None)
layout["zone_count"] = len(zone_order)
layout["zone_ids"] = zone_order
merged["zones"] = _ordered_normalized_zones(zone_order, valid_zones, existing_by_id)
else:
existing_by_id = {
str(zone.get("id", "")).strip(): zone
for zone in merged.get("zones", [])
if str(zone.get("id", "")).strip()
}
existing_by_id.update(valid_zones)
zone_order = [str(item) for item in layout.get("zone_ids", []) if str(item) in existing_by_id]
for zone_id in valid_zones:
if zone_id not in zone_order:
zone_order.append(zone_id)
if not zone_order:
zone_order = list(valid_zones)
layout["zone_ids"] = zone_order
merged["zones"] = [existing_by_id[zone_id] for zone_id in zone_order if zone_id in existing_by_id]
if trash_roi is not None:
normalized_roi = _normalize_points(trash_roi)
if len(normalized_roi) >= 3:
trash = merged.setdefault("trash", {})
trash["roi"] = normalized_roi
return merged
def format_config_document(data: dict[str, Any]) -> str:
lines: list[str] = []
lines.append(f'camera_id = "{_escape(str(data.get("camera_id", "cold_display_cam_01")))}"')
lines.append(f'timezone = "{_escape(str(data.get("timezone", "Asia/Shanghai")))}"')
lines.append("")
stream = data.get("stream", {})
lines.append("[stream]")
lines.append(f'rtsp_url = "{_escape(str(stream.get("rtsp_url", "")))}"')
lines.append("")
thresholds = data.get("thresholds", {})
lines.append("[thresholds]")
lines.append(f'max_dwell_seconds = {int(thresholds.get("max_dwell_seconds", 10_800))}')
lines.append(f'trash_confirmation_seconds = {int(thresholds.get("trash_confirmation_seconds", 120))}')
lines.append("")
runtime = data.get("runtime", {})
if runtime:
lines.append("[runtime]")
for key in sorted(runtime):
value = runtime[key]
if isinstance(value, bool):
lines.append(f"{key} = {str(value).lower()}")
elif isinstance(value, int | float):
lines.append(f"{key} = {value}")
else:
lines.append(f'{key} = "{_escape(str(value))}"')
lines.append("")
layout = data.get("layout", {})
zone_ids = list(_zone_ids_from_layout(layout))
if not zone_ids:
zone_ids = list(DEFAULT_ZONE_IDS)
numeric_layout = _is_numeric_zone_ids(zone_ids)
lines.append("[layout]")
if numeric_layout:
lines.append(f"zone_count = {len(zone_ids)}")
else:
rows = int(layout.get("rows", 2))
cols = int(layout.get("cols", 4))
lines.append(f"rows = {rows}")
lines.append(f"cols = {cols}")
lines.append(f"zone_ids = {_format_string_array(zone_ids)}")
lines.append("")
for zone in data.get("zones", []):
zone_id = str(zone.get("id", "")).strip()
polygon = _normalize_points(zone.get("polygon", []))
if not zone_id or len(polygon) < 3:
continue
lines.append("[[zones]]")
lines.append(f'id = "{_escape(zone_id)}"')
label = str(zone.get("label", "")).strip()
if label:
lines.append(f'label = "{_escape(label)}"')
lines.append(f"polygon = {_format_points(polygon)}")
lines.append("")
trash = data.get("trash", {})
roi = _normalize_points(trash.get("roi", []))
if len(roi) >= 3:
lines.append("[trash]")
lines.append(f"roi = {_format_points(roi)}")
lines.append("")
event_sink = data.get("event_sink", {})
lines.append("[event_sink]")
lines.append(f'path = "{_escape(str(event_sink.get("path", "logs/events.jsonl")))}"')
lines.append("")
return "\n".join(lines)
def _zone_ids_from_layout(layout: dict[str, Any]) -> tuple[str, ...]:
zone_ids = _coerce_zone_ids(layout.get("zone_ids"))
if zone_ids:
_validate_numeric_zone_ids(zone_ids)
_validate_zone_count_matches_ids(layout, zone_ids)
return tuple(zone_ids)
if "zone_count" in layout:
return tuple(_numeric_zone_ids_from_count(layout.get("zone_count")))
return _zone_ids_from_rows_cols(layout)
def _zone_ids_from_rows_cols(layout: dict[str, Any]) -> tuple[str, ...]:
rows = int(layout.get("rows", 0))
cols = int(layout.get("cols", 0))
if rows <= 0 or cols <= 0:
return ()
return tuple(f"r{row}c{col}" for row in range(1, rows + 1) for col in range(1, cols + 1))
def _incoming_numeric_zone_ids(layout_update: dict[str, Any] | None) -> list[str]:
if not isinstance(layout_update, dict):
return []
zone_ids = list(_zone_ids_from_layout(layout_update))
if not zone_ids:
return []
if not _is_numeric_zone_ids(zone_ids):
raise ValueError("calibration layout zone IDs must be numeric")
return zone_ids
def _existing_numeric_zone_ids(layout: dict[str, Any]) -> list[str]:
zone_ids = list(_zone_ids_from_layout(layout))
if not _is_numeric_zone_ids(zone_ids):
return []
return zone_ids
def _ordered_normalized_zones(
zone_order: list[str],
valid_zones: dict[str, dict[str, Any]],
existing_by_id: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
zones: list[dict[str, Any]] = []
for zone_id in zone_order:
zone = _normalized_zone(valid_zones.get(zone_id) or existing_by_id.get(zone_id))
if zone is not None:
zones.append(zone)
return zones
def _numeric_calibration_zone_order(
incoming_numeric_zone_ids: list[str],
existing_numeric_zone_ids: list[str],
valid_zones: dict[str, dict[str, Any]],
) -> list[str]:
if incoming_numeric_zone_ids:
return incoming_numeric_zone_ids
valid_zone_ids = sorted(valid_zones, key=int) if _is_numeric_zone_ids(valid_zones) else []
if existing_numeric_zone_ids and valid_zone_ids:
if set(valid_zone_ids).issubset(set(existing_numeric_zone_ids)):
return existing_numeric_zone_ids
return valid_zone_ids
return existing_numeric_zone_ids or valid_zone_ids
def _normalized_zone(zone: dict[str, Any] | None) -> dict[str, Any] | None:
if zone is None:
return None
zone_id = str(zone.get("id", "")).strip()
polygon = _normalize_points(zone.get("polygon", []))
if not zone_id or len(polygon) < 3:
return None
normalized: dict[str, Any] = {"id": zone_id, "polygon": polygon}
label = str(zone.get("label", "")).strip()
if zone_id.isdecimal():
normalized["label"] = f"区域 {int(zone_id)}"
elif label:
normalized["label"] = label
return normalized
def _coerce_zone_ids(value: Any) -> list[str]:
if not isinstance(value, list | tuple):
return []
return [str(item).strip() for item in value if str(item).strip()]
def _numeric_zone_ids_from_count(value: Any) -> list[str]:
count = int(value)
if count < 1 or count > MAX_CUSTOM_FOOD_ZONES:
raise ValueError(f"food zone count must be 1 to {MAX_CUSTOM_FOOD_ZONES}")
return [str(index) for index in range(1, count + 1)]
def _is_numeric_zone_ids(zone_ids: Any) -> bool:
return bool(zone_ids) and all(str(zone_id).isdecimal() for zone_id in zone_ids)
def _validate_numeric_zone_ids(zone_ids: list[str] | tuple[str, ...]) -> None:
numeric_ids = [zone_id for zone_id in zone_ids if zone_id.isdecimal()]
if not numeric_ids:
return
if len(numeric_ids) != len(zone_ids):
raise ValueError("numeric food zone IDs must not be mixed with legacy zone IDs")
if len(zone_ids) < 1 or len(zone_ids) > MAX_CUSTOM_FOOD_ZONES:
raise ValueError(f"food zone count must be 1 to {MAX_CUSTOM_FOOD_ZONES}")
expected = [str(index) for index in range(1, len(zone_ids) + 1)]
if list(zone_ids) != expected:
raise ValueError("numeric food zone IDs must be contiguous from 1")
def _validate_zone_count_matches_ids(layout: dict[str, Any], zone_ids: list[str]) -> None:
if "zone_count" not in layout:
return
count = int(layout["zone_count"])
if count < 1 or count > MAX_CUSTOM_FOOD_ZONES:
raise ValueError(f"food zone count must be 1 to {MAX_CUSTOM_FOOD_ZONES}")
if count != len(zone_ids):
raise ValueError("zone_count must match zone_ids length")
def _normalize_points(value: Any) -> list[list[float]]:
points: list[list[float]] = []
if not isinstance(value, list):
return points
for item in value:
if not isinstance(item, list | tuple) or len(item) != 2:
continue
x = min(1.0, max(0.0, float(item[0])))
y = min(1.0, max(0.0, float(item[1])))
points.append([round(x, 6), round(y, 6)])
return points
def _format_points(points: list[list[float]]) -> str:
return "[" + ", ".join(f"[{x:.6f}, {y:.6f}]" for x, y in points) + "]"
def _format_string_array(values: list[str]) -> str:
return "[" + ", ".join(f'"{_escape(value)}"' for value in values) + "]"
def _escape(value: str) -> str:
return value.replace("\\", "\\\\").replace('"', '\\"')