Files
video-ai-analysis/video_ai_analysis_poc/config.py
2026-06-17 11:33:54 +08:00

279 lines
9.0 KiB
Python

from __future__ import annotations
import ast
from pathlib import Path
from typing import Any
from .paths import resolve_path, validate_output_dir
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent.parent / "config" / "local_batch.yaml"
def load_config(
config_path: str | Path = DEFAULT_CONFIG_PATH,
*,
input_dir: str | Path | None = None,
output_dir: str | Path | None = None,
) -> dict[str, Any]:
path = Path(config_path).expanduser().resolve(strict=False)
raw_config = _parse_simple_yaml(path)
config = _with_defaults(raw_config)
base_dir = path.parent.parent if path.parent.name == "config" else path.parent
if input_dir is not None:
config["input"]["dir"] = str(input_dir)
if output_dir is not None:
config["output"]["dir"] = str(output_dir)
config["input"]["dir"] = str(resolve_path(config["input"]["dir"], base_dir=base_dir))
config["output"]["dir"] = str(
resolve_path(config["output"]["dir"], base_dir=base_dir)
)
validate_output_dir(config["input"]["dir"], config["output"]["dir"])
extensions = config["input"].get("extensions", [])
config["input"]["extensions"] = _normalize_extensions(extensions)
config["input"]["recursive"] = bool(config["input"].get("recursive", True))
config.setdefault("ffprobe", {})
config["ffprobe"]["timeout_seconds"] = int(
config["ffprobe"].get("timeout_seconds", 30)
)
return config
def _with_defaults(config: dict[str, Any]) -> dict[str, Any]:
merged: dict[str, Any] = {
"input": {
"dir": "./videos",
"recursive": True,
"extensions": [".mp4", ".mov", ".mkv", ".avi", ".flv", ".ts", ".m4v"],
},
"output": {
"dir": "./outputs/local-batch",
"overwrite": False,
"resume": True,
"keep_frames": True,
},
"source": {"mode": "local"},
"hik_cloud": {
"api_base_url": "https://api2.hik-cloud.com",
"download_path": "/v1/carrier/cstorage/open/play/download",
"access_token": None,
"access_token_env": "HIK_CLOUD_ACCESS_TOKEN",
"devices": [],
"time_ranges": [],
"chunk_seconds": 600,
"timeout_seconds": 60,
"download_timeout_seconds": 600,
},
"ffprobe": {"timeout_seconds": 30},
"ffmpeg": {
"prefer_nvdec": True,
"allow_cpu_fallback": False,
"hwaccel": "cuda",
"codec_decoders": {"h264": "h264_cuvid", "hevc": "hevc_cuvid"},
"frame_fps": 1,
"frame_width": 640,
"jpeg_quality": 4,
"timeout_seconds_per_video": 3600,
},
"clip": {
"length_seconds": 10,
"stride_seconds": 10,
"frames_per_clip": 8,
"min_frames_per_clip": 4,
},
"vlm": {
"api_base_url": "http://localhost:8679",
"chat_completions_path": "/v1/chat/completions",
"model": "memai-zhengxin-v3-20260413",
"timeout_seconds": 120,
"max_tokens": 512,
"temperature": 0,
"batch_size": 1,
"image_transport": "data_uri",
"retries": 1,
},
"prompt": {
"system": "You are a store video analysis assistant. Return strict JSON only.",
"user": "Analyze this clip. Return events and screen_time. If no event, return events: [].",
},
"schema": {
"version": "local-batch-v1",
"event_types": [
"customer_enter",
"customer_leave",
"queue_detected",
"staff_absent",
"staff_present",
"area_crowded",
"abnormal_behavior",
"unknown",
],
"require_strict_json": True,
"parse_retry": 1,
"merge_gap_seconds": 30,
},
"runtime": {"timezone": "Asia/Shanghai", "log_level": "INFO"},
}
for section, values in config.items():
if isinstance(values, dict) and isinstance(merged.get(section), dict):
merged[section].update(values)
else:
merged[section] = values
return merged
def _normalize_extensions(extensions: list[str]) -> list[str]:
normalized = []
for extension in extensions:
value = str(extension).lower()
if not value.startswith("."):
value = f".{value}"
normalized.append(value)
return normalized
def _parse_simple_yaml(path: Path) -> dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"config file not found: {path}")
root: dict[str, Any] = {}
stack: list[tuple[int, dict[str, Any] | list[Any]]] = [(-1, root)]
lines = path.read_text(encoding="utf-8").splitlines()
index = 0
while index < len(lines):
raw_line = lines[index].rstrip()
stripped = raw_line.strip()
if not stripped or raw_line.lstrip().startswith("#"):
index += 1
continue
indent = len(raw_line) - len(raw_line.lstrip(" "))
while indent <= stack[-1][0]:
stack.pop()
parent = stack[-1][1]
if stripped.startswith("- "):
if not isinstance(parent, list):
raise ValueError(f"list item without list parent: {raw_line}")
item = stripped[2:].strip()
if ":" in item:
key, value = item.split(":", 1)
mapping: dict[str, Any] = {}
parent.append(mapping)
key = key.strip()
value = value.strip()
if not value:
next_stripped = _next_stripped(lines, index)
child: dict[str, Any] | list[Any]
child = [] if next_stripped and next_stripped.startswith("- ") else {}
mapping[key] = child
stack.append((indent, mapping))
stack.append((indent + 2, child))
else:
mapping[key] = _parse_scalar(value)
stack.append((indent, mapping))
else:
parent.append(_parse_scalar(item))
index += 1
continue
if not isinstance(parent, dict):
raise ValueError(f"mapping entry inside list is not supported: {raw_line}")
if ":" not in stripped:
raise ValueError(f"unsupported config line: {raw_line}")
key, value = stripped.split(":", 1)
key = key.strip()
value = value.strip()
if _is_block_scalar(value):
parent[key], index = _parse_block_scalar(lines, index, indent, value)
continue
if not value:
next_stripped = _next_stripped(lines, index)
child: dict[str, Any] | list[Any]
child = [] if next_stripped and next_stripped.startswith("- ") else {}
parent[key] = child
stack.append((indent, child))
else:
parent[key] = _parse_scalar(value)
index += 1
return root
def _next_stripped(lines: list[str], current_index: int) -> str | None:
for raw_line in lines[current_index + 1 :]:
stripped = raw_line.strip()
if stripped and not raw_line.lstrip().startswith("#"):
return stripped
return None
def _is_block_scalar(value: str) -> bool:
return value in {">", ">-", "|", "|-"}
def _parse_block_scalar(
lines: list[str],
start_index: int,
parent_indent: int,
marker: str,
) -> tuple[str, int]:
content_lines: list[str] = []
content_indent: int | None = None
index = start_index + 1
while index < len(lines):
raw_line = lines[index].rstrip()
stripped = raw_line.strip()
if not stripped:
content_lines.append("")
index += 1
continue
indent = len(raw_line) - len(raw_line.lstrip(" "))
if indent <= parent_indent:
break
if content_indent is None:
content_indent = indent
content_lines.append(raw_line[content_indent:])
index += 1
if marker.endswith("-"):
while content_lines and content_lines[-1] == "":
content_lines.pop()
return "\n".join(content_lines), index
def _parse_scalar(value: str) -> Any:
lower = value.lower()
if lower == "true":
return True
if lower == "false":
return False
if lower in {"null", "none"}:
return None
if value.startswith("[") and value.endswith("]"):
parsed = ast.literal_eval(value)
if not isinstance(parsed, list):
raise ValueError(f"expected list value: {value}")
return parsed
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
return ast.literal_eval(value)
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
return value