from __future__ import annotations import ast from pathlib import Path from typing import Any from .paths import resolve_path, validate_output_dir DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent.parent / "config" / "local_batch.yaml" def load_config( config_path: str | Path = DEFAULT_CONFIG_PATH, *, input_dir: str | Path | None = None, output_dir: str | Path | None = None, ) -> dict[str, Any]: path = Path(config_path).expanduser().resolve(strict=False) raw_config = _parse_simple_yaml(path) config = _with_defaults(raw_config) base_dir = path.parent.parent if path.parent.name == "config" else path.parent if input_dir is not None: config["input"]["dir"] = str(input_dir) if output_dir is not None: config["output"]["dir"] = str(output_dir) config["input"]["dir"] = str(resolve_path(config["input"]["dir"], base_dir=base_dir)) config["output"]["dir"] = str( resolve_path(config["output"]["dir"], base_dir=base_dir) ) validate_output_dir(config["input"]["dir"], config["output"]["dir"]) extensions = config["input"].get("extensions", []) config["input"]["extensions"] = _normalize_extensions(extensions) config["input"]["recursive"] = bool(config["input"].get("recursive", True)) config.setdefault("ffprobe", {}) config["ffprobe"]["timeout_seconds"] = int( config["ffprobe"].get("timeout_seconds", 30) ) return config def _with_defaults(config: dict[str, Any]) -> dict[str, Any]: merged: dict[str, Any] = { "input": { "dir": "./videos", "recursive": True, "extensions": [".mp4", ".mov", ".mkv", ".avi", ".flv", ".ts", ".m4v"], }, "output": { "dir": "./outputs/local-batch", "overwrite": False, "resume": True, "keep_frames": True, }, "source": {"mode": "local"}, "hik_cloud": { "api_base_url": "https://api2.hik-cloud.com", "download_path": "/v1/carrier/cstorage/open/play/download", "access_token": None, "access_token_env": "HIK_CLOUD_ACCESS_TOKEN", "devices": [], "time_ranges": [], "chunk_seconds": 600, "timeout_seconds": 60, "download_timeout_seconds": 600, }, "ffprobe": {"timeout_seconds": 30}, "ffmpeg": { "prefer_nvdec": True, "allow_cpu_fallback": False, "hwaccel": "cuda", "codec_decoders": {"h264": "h264_cuvid", "hevc": "hevc_cuvid"}, "frame_fps": 1, "frame_width": 640, "jpeg_quality": 4, "timeout_seconds_per_video": 3600, }, "clip": { "length_seconds": 10, "stride_seconds": 10, "frames_per_clip": 8, "min_frames_per_clip": 4, }, "vlm": { "api_base_url": "http://localhost:8679", "chat_completions_path": "/v1/chat/completions", "model": "memai-zhengxin-v3-20260413", "timeout_seconds": 120, "max_tokens": 512, "temperature": 0, "batch_size": 1, "image_transport": "data_uri", "retries": 1, }, "prompt": { "system": "You are a store video analysis assistant. Return strict JSON only.", "user": "Analyze this clip. Return events and screen_time. If no event, return events: [].", }, "schema": { "version": "local-batch-v1", "event_types": [ "customer_enter", "customer_leave", "queue_detected", "staff_absent", "staff_present", "area_crowded", "abnormal_behavior", "unknown", ], "require_strict_json": True, "parse_retry": 1, "merge_gap_seconds": 30, }, "runtime": {"timezone": "Asia/Shanghai", "log_level": "INFO"}, } for section, values in config.items(): if isinstance(values, dict) and isinstance(merged.get(section), dict): merged[section].update(values) else: merged[section] = values return merged def _normalize_extensions(extensions: list[str]) -> list[str]: normalized = [] for extension in extensions: value = str(extension).lower() if not value.startswith("."): value = f".{value}" normalized.append(value) return normalized def _parse_simple_yaml(path: Path) -> dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"config file not found: {path}") root: dict[str, Any] = {} stack: list[tuple[int, dict[str, Any] | list[Any]]] = [(-1, root)] lines = path.read_text(encoding="utf-8").splitlines() index = 0 while index < len(lines): raw_line = lines[index].rstrip() stripped = raw_line.strip() if not stripped or raw_line.lstrip().startswith("#"): index += 1 continue indent = len(raw_line) - len(raw_line.lstrip(" ")) while indent <= stack[-1][0]: stack.pop() parent = stack[-1][1] if stripped.startswith("- "): if not isinstance(parent, list): raise ValueError(f"list item without list parent: {raw_line}") item = stripped[2:].strip() if ":" in item: key, value = item.split(":", 1) mapping: dict[str, Any] = {} parent.append(mapping) key = key.strip() value = value.strip() if not value: next_stripped = _next_stripped(lines, index) child: dict[str, Any] | list[Any] child = [] if next_stripped and next_stripped.startswith("- ") else {} mapping[key] = child stack.append((indent, mapping)) stack.append((indent + 2, child)) else: mapping[key] = _parse_scalar(value) stack.append((indent, mapping)) else: parent.append(_parse_scalar(item)) index += 1 continue if not isinstance(parent, dict): raise ValueError(f"mapping entry inside list is not supported: {raw_line}") if ":" not in stripped: raise ValueError(f"unsupported config line: {raw_line}") key, value = stripped.split(":", 1) key = key.strip() value = value.strip() if _is_block_scalar(value): parent[key], index = _parse_block_scalar(lines, index, indent, value) continue if not value: next_stripped = _next_stripped(lines, index) child: dict[str, Any] | list[Any] child = [] if next_stripped and next_stripped.startswith("- ") else {} parent[key] = child stack.append((indent, child)) else: parent[key] = _parse_scalar(value) index += 1 return root def _next_stripped(lines: list[str], current_index: int) -> str | None: for raw_line in lines[current_index + 1 :]: stripped = raw_line.strip() if stripped and not raw_line.lstrip().startswith("#"): return stripped return None def _is_block_scalar(value: str) -> bool: return value in {">", ">-", "|", "|-"} def _parse_block_scalar( lines: list[str], start_index: int, parent_indent: int, marker: str, ) -> tuple[str, int]: content_lines: list[str] = [] content_indent: int | None = None index = start_index + 1 while index < len(lines): raw_line = lines[index].rstrip() stripped = raw_line.strip() if not stripped: content_lines.append("") index += 1 continue indent = len(raw_line) - len(raw_line.lstrip(" ")) if indent <= parent_indent: break if content_indent is None: content_indent = indent content_lines.append(raw_line[content_indent:]) index += 1 if marker.endswith("-"): while content_lines and content_lines[-1] == "": content_lines.pop() return "\n".join(content_lines), index def _parse_scalar(value: str) -> Any: lower = value.lower() if lower == "true": return True if lower == "false": return False if lower in {"null", "none"}: return None if value.startswith("[") and value.endswith("]"): parsed = ast.literal_eval(value) if not isinstance(parsed, list): raise ValueError(f"expected list value: {value}") return parsed if (value.startswith('"') and value.endswith('"')) or ( value.startswith("'") and value.endswith("'") ): return ast.literal_eval(value) try: return int(value) except ValueError: pass try: return float(value) except ValueError: return value