feat: initialize managed portal
This commit is contained in:
29
internal/config/config.go
Normal file
29
internal/config/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package config
|
||||
|
||||
import "os"
|
||||
|
||||
type Config struct {
|
||||
HTTPAddr string
|
||||
WebDistDir string
|
||||
RegistryPath string
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
cfg := &Config{
|
||||
HTTPAddr: ":8080",
|
||||
WebDistDir: "web/dist",
|
||||
RegistryPath: "managed_services.yaml",
|
||||
}
|
||||
|
||||
if value := os.Getenv("MANAGED_PORTAL_HTTP_ADDR"); value != "" {
|
||||
cfg.HTTPAddr = value
|
||||
}
|
||||
if value := os.Getenv("MANAGED_PORTAL_WEB_DIST_DIR"); value != "" {
|
||||
cfg.WebDistDir = value
|
||||
}
|
||||
if value := os.Getenv("MANAGED_PORTAL_REGISTRY_PATH"); value != "" {
|
||||
cfg.RegistryPath = value
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
98
internal/managed/docker_runtime.go
Normal file
98
internal/managed/docker_runtime.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DockerRuntime interface {
|
||||
GetContainerStatus(containerName string) (string, error)
|
||||
RestartContainer(containerName string) error
|
||||
}
|
||||
|
||||
type CommandRunner func(ctx context.Context, name string, args ...string) ([]byte, error)
|
||||
|
||||
type DockerController struct {
|
||||
runner CommandRunner
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewDockerController() *DockerController {
|
||||
return &DockerController{
|
||||
runner: func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return exec.CommandContext(ctx, name, args...).CombinedOutput()
|
||||
},
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDockerControllerWithRunner(runner CommandRunner) *DockerController {
|
||||
return &DockerController{
|
||||
runner: runner,
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeContainerStatus(raw string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(raw)) {
|
||||
case "running":
|
||||
return "running"
|
||||
case "created", "exited", "paused":
|
||||
return "stopped"
|
||||
case "dead", "removing", "restarting":
|
||||
return "failed"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DockerController) GetContainerStatus(containerName string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := c.runner(ctx, "docker", "inspect", "--format", "{{.State.Status}}", containerName)
|
||||
status := NormalizeContainerStatus(string(output))
|
||||
if status != "unknown" {
|
||||
return status, nil
|
||||
}
|
||||
if err != nil {
|
||||
return status, fmt.Errorf("docker inspect %s: %w", containerName, err)
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (c *DockerController) RestartContainer(containerName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
output, err := c.runner(ctx, "docker", "restart", containerName)
|
||||
if err != nil {
|
||||
trimmed := strings.TrimSpace(string(output))
|
||||
if trimmed != "" {
|
||||
return fmt.Errorf("docker restart %s: %w: %s", containerName, err, trimmed)
|
||||
}
|
||||
return fmt.Errorf("docker restart %s: %w", containerName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsDockerUnavailable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var execErr *exec.Error
|
||||
if errors.As(err, &execErr) {
|
||||
return execErr.Err == exec.ErrNotFound
|
||||
}
|
||||
|
||||
message := err.Error()
|
||||
return errors.Is(err, exec.ErrNotFound) ||
|
||||
strings.Contains(message, `exec: "docker": executable file not found`) ||
|
||||
strings.Contains(message, "failed to connect to the docker API") ||
|
||||
strings.Contains(message, "docker.sock")
|
||||
}
|
||||
88
internal/managed/docker_runtime_test.go
Normal file
88
internal/managed/docker_runtime_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeContainerStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"running": "running",
|
||||
"created": "stopped",
|
||||
" exited \n": "stopped",
|
||||
"paused": "stopped",
|
||||
"dead": "failed",
|
||||
"removing": "failed",
|
||||
"restarting": "failed",
|
||||
"unknown": "unknown",
|
||||
"": "unknown",
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
if got := NormalizeContainerStatus(input); got != want {
|
||||
t.Fatalf("NormalizeContainerStatus(%q) = %q, want %q", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDockerControllerGetContainerStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
controller := NewDockerControllerWithRunner(func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
if name != "docker" {
|
||||
t.Fatalf("name = %q", name)
|
||||
}
|
||||
return []byte("running\n"), nil
|
||||
})
|
||||
|
||||
status, err := controller.GetContainerStatus("store-dwell-alert")
|
||||
if err != nil {
|
||||
t.Fatalf("GetContainerStatus() error = %v", err)
|
||||
}
|
||||
if status != "running" {
|
||||
t.Fatalf("status = %q", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDockerControllerRestartContainer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
called := false
|
||||
controller := NewDockerControllerWithRunner(func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
called = true
|
||||
return []byte("store-dwell-alert"), nil
|
||||
})
|
||||
|
||||
if err := controller.RestartContainer("store-dwell-alert"); err != nil {
|
||||
t.Fatalf("RestartContainer() error = %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("runner was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDockerControllerRestartContainerIncludesOutputOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
controller := NewDockerControllerWithRunner(func(ctx context.Context, name string, args ...string) ([]byte, error) {
|
||||
return []byte("permission denied"), errors.New("exit status 1")
|
||||
})
|
||||
|
||||
err := controller.RestartContainer("store-dwell-alert")
|
||||
if err == nil || !strings.Contains(err.Error(), "permission denied") {
|
||||
t.Fatalf("RestartContainer() error = %v, want output included", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDockerUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !IsDockerUnavailable(&exec.Error{Name: "docker", Err: exec.ErrNotFound}) {
|
||||
t.Fatal("IsDockerUnavailable() = false, want true")
|
||||
}
|
||||
}
|
||||
209
internal/managed/manager.go
Normal file
209
internal/managed/manager.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
registry *Registry
|
||||
docker DockerRuntime
|
||||
remote *RemoteClient
|
||||
}
|
||||
|
||||
type ServiceState struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"display_name"`
|
||||
ProjectType string `json:"project_type"`
|
||||
ProjectRoot string `json:"project_root"`
|
||||
ContainerName string `json:"container_name"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ServiceName string `json:"service_name"`
|
||||
ConfigPath string `json:"config_path"`
|
||||
RTSPField string `json:"rtsp_field"`
|
||||
ResultType string `json:"result_type"`
|
||||
ResultPaths map[string]string `json:"result_paths"`
|
||||
Status string `json:"status"`
|
||||
RTSP string `json:"rtsp,omitempty"`
|
||||
Summary *ResultSummary `json:"summary,omitempty"`
|
||||
ResultFiles []ResultFile `json:"result_files,omitempty"`
|
||||
ConfigError string `json:"config_error,omitempty"`
|
||||
ResultError string `json:"result_error,omitempty"`
|
||||
ServiceError string `json:"service_error,omitempty"`
|
||||
}
|
||||
|
||||
func NewManager(registry *Registry, docker DockerRuntime, remote *RemoteClient) *Manager {
|
||||
if registry == nil {
|
||||
registry = EmptyRegistry()
|
||||
}
|
||||
if docker == nil {
|
||||
docker = NewDockerController()
|
||||
}
|
||||
if remote == nil {
|
||||
remote = NewRemoteClient(nil)
|
||||
}
|
||||
return &Manager{
|
||||
registry: registry,
|
||||
docker: docker,
|
||||
remote: remote,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) List() []*ServiceState {
|
||||
states := make([]*ServiceState, 0, len(m.registry.Services))
|
||||
for _, service := range m.registry.Services {
|
||||
states = append(states, m.snapshot(service, false))
|
||||
}
|
||||
return states
|
||||
}
|
||||
|
||||
func (m *Manager) Detail(id string) (*ServiceState, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.snapshot(service, true), nil
|
||||
}
|
||||
|
||||
func (m *Manager) Summary(id string) (*ResultSummary, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.remote.GetSummary(context.Background(), service)
|
||||
}
|
||||
|
||||
func (m *Manager) Files(id string) ([]ResultFile, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.remote.GetFiles(context.Background(), service)
|
||||
}
|
||||
|
||||
func (m *Manager) PreviewFile(id, path string, lines int) (*FilePreview, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.remote.PreviewFile(context.Background(), service, path, lines)
|
||||
}
|
||||
|
||||
func (m *Manager) Download(ctx context.Context, id, path string) (*http.Response, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.remote.Download(ctx, service, path)
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateRTSP(id, rtsp string) (*ServiceState, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(rtsp) == "" {
|
||||
return nil, fmt.Errorf("rtsp url is required")
|
||||
}
|
||||
if _, err := m.remote.UpdateRTSP(context.Background(), service, rtsp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.snapshot(service, true), nil
|
||||
}
|
||||
|
||||
func (m *Manager) Restart(id string) (*ServiceState, error) {
|
||||
service, err := m.lookup(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.docker.RestartContainer(service.ContainerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.snapshot(service, true), nil
|
||||
}
|
||||
|
||||
func (m *Manager) lookup(id string) (Service, error) {
|
||||
service, ok := m.registry.Get(id)
|
||||
if !ok {
|
||||
return Service{}, fmt.Errorf("%w: %s", ErrServiceNotFound, id)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *Manager) snapshot(service Service, includeFiles bool) *ServiceState {
|
||||
state := &ServiceState{
|
||||
ID: service.ID,
|
||||
DisplayName: service.DisplayName,
|
||||
ProjectType: service.ProjectType,
|
||||
ProjectRoot: service.ProjectRoot,
|
||||
ContainerName: service.ContainerName,
|
||||
APIBaseURL: service.APIBaseURL,
|
||||
ServiceName: service.ServiceName,
|
||||
ConfigPath: service.ConfigPath,
|
||||
RTSPField: service.RTSPField,
|
||||
ResultType: service.ResultType,
|
||||
ResultPaths: service.ResultPaths,
|
||||
Status: "unknown",
|
||||
}
|
||||
|
||||
if payload, err := m.remote.GetConfig(context.Background(), service); err == nil {
|
||||
state.RTSP = extractRTSP(payload)
|
||||
state.ConfigPath = extractConfigPath(payload)
|
||||
} else {
|
||||
state.ConfigError = err.Error()
|
||||
}
|
||||
|
||||
if status, err := m.docker.GetContainerStatus(service.ContainerName); err == nil {
|
||||
state.Status = status
|
||||
} else {
|
||||
state.Status = status
|
||||
if !IsDockerUnavailable(err) {
|
||||
state.ServiceError = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
if summary, err := m.remote.GetSummary(context.Background(), service); err == nil {
|
||||
state.Summary = summary
|
||||
} else {
|
||||
state.ResultError = err.Error()
|
||||
}
|
||||
|
||||
if includeFiles {
|
||||
if files, err := m.remote.GetFiles(context.Background(), service); err == nil {
|
||||
state.ResultFiles = files
|
||||
} else if state.ResultError == "" {
|
||||
state.ResultError = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func extractRTSP(payload map[string]any) string {
|
||||
if payload == nil {
|
||||
return ""
|
||||
}
|
||||
if stream, ok := payload["stream"].(map[string]any); ok {
|
||||
if rtsp, ok := stream["rtsp_url"].(string); ok {
|
||||
return strings.TrimSpace(rtsp)
|
||||
}
|
||||
}
|
||||
if runtime, ok := payload["runtime"].(map[string]any); ok {
|
||||
if rtsp, ok := runtime["rtsp_url"].(string); ok {
|
||||
return strings.TrimSpace(rtsp)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractConfigPath(payload map[string]any) string {
|
||||
if payload == nil {
|
||||
return ""
|
||||
}
|
||||
if configPath, ok := payload["config_path"].(string); ok {
|
||||
return strings.TrimSpace(configPath)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
194
internal/managed/manager_test.go
Normal file
194
internal/managed/manager_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeDockerRuntime struct {
|
||||
statusByContainer map[string]string
|
||||
restarted []string
|
||||
}
|
||||
|
||||
func (f *fakeDockerRuntime) GetContainerStatus(containerName string) (string, error) {
|
||||
if status, ok := f.statusByContainer[containerName]; ok {
|
||||
return status, nil
|
||||
}
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerRuntime) RestartContainer(containerName string) error {
|
||||
f.restarted = append(f.restarted, containerName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagedDockerAndRemoteAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
storeConfig := map[string]any{
|
||||
"config_path": "/srv/store/config/local.yaml",
|
||||
"stream": map[string]any{
|
||||
"rtsp_url": "rtsp://store-old/stream",
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
response := func(status int, body any) (*http.Response, error) {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader(data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/config":
|
||||
return response(http.StatusOK, storeConfig)
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/store/api/manage/config":
|
||||
var payload map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode update payload: %v", err)
|
||||
}
|
||||
storeConfig["stream"].(map[string]any)["rtsp_url"] = payload["rtsp_url"]
|
||||
return response(http.StatusOK, storeConfig)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/summary":
|
||||
return response(http.StatusOK, ResultSummary{
|
||||
ResultType: "store_dwell_alert",
|
||||
Headline: "Latest report shows 1 active customers, longest dwell 900s",
|
||||
LastResultTime: "2026-04-16T10:00:00+08:00",
|
||||
Metrics: map[string]any{
|
||||
"active_customer_count": 1,
|
||||
"longest_dwell_seconds": 900,
|
||||
},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files":
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"files": []ResultFile{{
|
||||
Path: "logs/events.jsonl",
|
||||
Name: "events.jsonl",
|
||||
}},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/preview":
|
||||
return response(http.StatusOK, FilePreview{
|
||||
Path: "logs/events.jsonl",
|
||||
Lines: []string{"line1", "line2"},
|
||||
Count: 2,
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/download":
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Disposition": []string{`attachment; filename="events.jsonl"`},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("downloaded")),
|
||||
}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected child request: %s %s", r.Method, r.URL.String())
|
||||
return nil, nil
|
||||
}
|
||||
})}
|
||||
|
||||
registry := &Registry{
|
||||
Services: []Service{{
|
||||
ID: "store_dwell_alert",
|
||||
DisplayName: "Store Dwell Alert",
|
||||
ProjectType: "store_dwell_alert",
|
||||
ProjectRoot: "/srv/store",
|
||||
ContainerName: "store-dwell-alert",
|
||||
ServiceName: "store-dwell-alert",
|
||||
APIBaseURL: "http://managed.invalid/store",
|
||||
ResultType: "store_dwell_alert",
|
||||
}},
|
||||
}
|
||||
|
||||
docker := &fakeDockerRuntime{
|
||||
statusByContainer: map[string]string{
|
||||
"store-dwell-alert": "running",
|
||||
},
|
||||
}
|
||||
manager := NewManager(registry, docker, NewRemoteClient(client))
|
||||
|
||||
states := manager.List()
|
||||
if len(states) != 1 {
|
||||
t.Fatalf("len(List()) = %d, want 1", len(states))
|
||||
}
|
||||
if states[0].Status != "running" {
|
||||
t.Fatalf("List()[0].Status = %q", states[0].Status)
|
||||
}
|
||||
|
||||
state, err := manager.Detail("store_dwell_alert")
|
||||
if err != nil {
|
||||
t.Fatalf("Detail() error = %v", err)
|
||||
}
|
||||
if state.Status != "running" {
|
||||
t.Fatalf("state.Status = %q", state.Status)
|
||||
}
|
||||
if state.RTSP != "rtsp://store-old/stream" {
|
||||
t.Fatalf("state.RTSP = %q", state.RTSP)
|
||||
}
|
||||
if state.ConfigPath != "/srv/store/config/local.yaml" {
|
||||
t.Fatalf("state.ConfigPath = %q", state.ConfigPath)
|
||||
}
|
||||
if state.Summary == nil || state.Summary.Metrics["longest_dwell_seconds"] != float64(900) {
|
||||
t.Fatalf("unexpected summary: %#v", state.Summary)
|
||||
}
|
||||
if len(state.ResultFiles) != 1 {
|
||||
t.Fatalf("len(ResultFiles) = %d", len(state.ResultFiles))
|
||||
}
|
||||
|
||||
if _, err := manager.UpdateRTSP("store_dwell_alert", "rtsp://store-new/stream"); err != nil {
|
||||
t.Fatalf("UpdateRTSP() error = %v", err)
|
||||
}
|
||||
if got := storeConfig["stream"].(map[string]any)["rtsp_url"]; got != "rtsp://store-new/stream" {
|
||||
t.Fatalf("updated rtsp = %#v", got)
|
||||
}
|
||||
|
||||
summary, err := manager.Summary("store_dwell_alert")
|
||||
if err != nil {
|
||||
t.Fatalf("Summary() error = %v", err)
|
||||
}
|
||||
if summary.Headline == "" {
|
||||
t.Fatalf("Summary().Headline is empty")
|
||||
}
|
||||
|
||||
files, err := manager.Files("store_dwell_alert")
|
||||
if err != nil {
|
||||
t.Fatalf("Files() error = %v", err)
|
||||
}
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("len(Files()) = %d", len(files))
|
||||
}
|
||||
|
||||
preview, err := manager.PreviewFile("store_dwell_alert", "logs/events.jsonl", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewFile() error = %v", err)
|
||||
}
|
||||
if preview.Count != 2 {
|
||||
t.Fatalf("preview.Count = %d", preview.Count)
|
||||
}
|
||||
|
||||
resp, err := manager.Download(context.Background(), "store_dwell_alert", "logs/events.jsonl")
|
||||
if err != nil {
|
||||
t.Fatalf("Download() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !strings.Contains(resp.Header.Get("Content-Disposition"), "events.jsonl") {
|
||||
t.Fatalf("Content-Disposition = %q", resp.Header.Get("Content-Disposition"))
|
||||
}
|
||||
|
||||
if _, err := manager.Restart("store_dwell_alert"); err != nil {
|
||||
t.Fatalf("Restart() error = %v", err)
|
||||
}
|
||||
if len(docker.restarted) != 1 || docker.restarted[0] != "store-dwell-alert" {
|
||||
t.Fatalf("restarted = %#v", docker.restarted)
|
||||
}
|
||||
}
|
||||
134
internal/managed/registry.go
Normal file
134
internal/managed/registry.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var ErrServiceNotFound = errors.New("managed service not found")
|
||||
|
||||
type Registry struct {
|
||||
Services []Service `yaml:"services"`
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
DisplayName string `yaml:"display_name" json:"display_name"`
|
||||
ProjectType string `yaml:"project_type" json:"project_type"`
|
||||
ProjectRoot string `yaml:"project_root" json:"project_root"`
|
||||
ContainerName string `yaml:"container_name" json:"container_name"`
|
||||
APIBaseURL string `yaml:"api_base_url" json:"api_base_url"`
|
||||
ServiceName string `yaml:"service_name" json:"service_name"`
|
||||
ConfigPath string `yaml:"config_path" json:"config_path"`
|
||||
RTSPField string `yaml:"rtsp_field" json:"rtsp_field"`
|
||||
ResultType string `yaml:"result_type" json:"result_type"`
|
||||
ResultPaths map[string]string `yaml:"result_paths" json:"result_paths"`
|
||||
}
|
||||
|
||||
func EmptyRegistry() *Registry {
|
||||
return &Registry{Services: []Service{}}
|
||||
}
|
||||
|
||||
func LoadRegistry(path string) (*Registry, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read managed registry: %w", err)
|
||||
}
|
||||
|
||||
var registry Registry
|
||||
if err := yaml.Unmarshal(data, ®istry); err != nil {
|
||||
return nil, fmt.Errorf("parse managed registry: %w", err)
|
||||
}
|
||||
|
||||
baseDir := filepath.Dir(path)
|
||||
ids := make(map[string]struct{}, len(registry.Services))
|
||||
for i := range registry.Services {
|
||||
svc := ®istry.Services[i]
|
||||
if err := normalizeService(baseDir, svc); err != nil {
|
||||
return nil, fmt.Errorf("service[%d]: %w", i, err)
|
||||
}
|
||||
if _, exists := ids[svc.ID]; exists {
|
||||
return nil, fmt.Errorf("duplicate service id %q", svc.ID)
|
||||
}
|
||||
ids[svc.ID] = struct{}{}
|
||||
}
|
||||
|
||||
if registry.Services == nil {
|
||||
registry.Services = []Service{}
|
||||
}
|
||||
|
||||
return ®istry, nil
|
||||
}
|
||||
|
||||
func (r *Registry) Get(id string) (Service, bool) {
|
||||
for _, svc := range r.Services {
|
||||
if svc.ID == id {
|
||||
return svc, true
|
||||
}
|
||||
}
|
||||
return Service{}, false
|
||||
}
|
||||
|
||||
func normalizeService(baseDir string, svc *Service) error {
|
||||
svc.ID = strings.TrimSpace(svc.ID)
|
||||
svc.DisplayName = strings.TrimSpace(svc.DisplayName)
|
||||
svc.ProjectType = strings.TrimSpace(svc.ProjectType)
|
||||
svc.ContainerName = strings.TrimSpace(svc.ContainerName)
|
||||
svc.APIBaseURL = strings.TrimSpace(svc.APIBaseURL)
|
||||
svc.ServiceName = strings.TrimSpace(svc.ServiceName)
|
||||
svc.ConfigPath = strings.TrimSpace(svc.ConfigPath)
|
||||
svc.RTSPField = strings.TrimSpace(svc.RTSPField)
|
||||
svc.ResultType = strings.TrimSpace(svc.ResultType)
|
||||
|
||||
if svc.ID == "" {
|
||||
return errors.New("id is required")
|
||||
}
|
||||
if svc.DisplayName == "" {
|
||||
return errors.New("display_name is required")
|
||||
}
|
||||
if svc.ProjectType == "" {
|
||||
return errors.New("project_type is required")
|
||||
}
|
||||
if svc.ContainerName == "" {
|
||||
return errors.New("container_name is required")
|
||||
}
|
||||
if svc.APIBaseURL == "" {
|
||||
return errors.New("api_base_url is required")
|
||||
}
|
||||
if svc.ResultType == "" {
|
||||
return errors.New("result_type is required")
|
||||
}
|
||||
|
||||
projectRoot := strings.TrimSpace(svc.ProjectRoot)
|
||||
if projectRoot == "" {
|
||||
return errors.New("project_root is required")
|
||||
}
|
||||
svc.ProjectRoot = resolvePath(baseDir, projectRoot)
|
||||
if svc.ServiceName == "" {
|
||||
svc.ServiceName = svc.ContainerName
|
||||
}
|
||||
if svc.ConfigPath != "" {
|
||||
svc.ConfigPath = resolvePath(baseDir, svc.ConfigPath)
|
||||
}
|
||||
|
||||
if svc.ResultPaths == nil {
|
||||
svc.ResultPaths = map[string]string{}
|
||||
}
|
||||
for key, path := range svc.ResultPaths {
|
||||
svc.ResultPaths[key] = resolvePath(baseDir, strings.TrimSpace(path))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolvePath(baseDir, path string) string {
|
||||
if filepath.IsAbs(path) {
|
||||
return filepath.Clean(path)
|
||||
}
|
||||
return filepath.Clean(filepath.Join(baseDir, path))
|
||||
}
|
||||
101
internal/managed/registry_test.go
Normal file
101
internal/managed/registry_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadRegistry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
registryPath := filepath.Join(root, "managed_services.yaml")
|
||||
writeFile(t, registryPath, `
|
||||
services:
|
||||
- id: store_dwell_alert
|
||||
display_name: Store Dwell Alert
|
||||
project_type: store_dwell_alert
|
||||
project_root: ./store_dwell_alert
|
||||
container_name: store-dwell-alert
|
||||
api_base_url: http://store-dwell-alert:18081
|
||||
config_path: ./configs/store.yaml
|
||||
result_type: store_dwell_alert
|
||||
- id: people_flow_project
|
||||
display_name: People Flow Project
|
||||
project_type: people_flow_project
|
||||
project_root: ./people_flow_project
|
||||
container_name: people-flow-project
|
||||
api_base_url: http://people-flow-project:18082
|
||||
result_type: people_flow_project
|
||||
`)
|
||||
|
||||
registry, err := LoadRegistry(registryPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
if len(registry.Services) != 2 {
|
||||
t.Fatalf("len(Services) = %d, want 2", len(registry.Services))
|
||||
}
|
||||
|
||||
store := registry.Services[0]
|
||||
if store.ID != "store_dwell_alert" {
|
||||
t.Fatalf("store.ID = %q", store.ID)
|
||||
}
|
||||
if store.ProjectRoot != filepath.Join(root, "store_dwell_alert") {
|
||||
t.Fatalf("store.ProjectRoot = %q", store.ProjectRoot)
|
||||
}
|
||||
if store.ConfigPath != filepath.Join(root, "configs", "store.yaml") {
|
||||
t.Fatalf("store.ConfigPath = %q", store.ConfigPath)
|
||||
}
|
||||
if store.ServiceName != "store-dwell-alert" {
|
||||
t.Fatalf("store.ServiceName = %q", store.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRegistryRejectsDuplicateIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
registryPath := filepath.Join(root, "managed_services.yaml")
|
||||
writeFile(t, registryPath, `
|
||||
services:
|
||||
- id: repeated
|
||||
display_name: One
|
||||
project_type: store
|
||||
project_root: ./one
|
||||
container_name: one
|
||||
api_base_url: http://one
|
||||
result_type: store
|
||||
- id: repeated
|
||||
display_name: Two
|
||||
project_type: people
|
||||
project_root: ./two
|
||||
container_name: two
|
||||
api_base_url: http://two
|
||||
result_type: people
|
||||
`)
|
||||
|
||||
_, err := LoadRegistry(registryPath)
|
||||
if err == nil || !strings.Contains(err.Error(), `duplicate service id "repeated"`) {
|
||||
t.Fatalf("LoadRegistry() error = %v, want duplicate id", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRegistryRejectsMissingRequiredFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root := t.TempDir()
|
||||
registryPath := filepath.Join(root, "managed_services.yaml")
|
||||
writeFile(t, registryPath, `
|
||||
services:
|
||||
- id: missing_fields
|
||||
display_name: Missing Fields
|
||||
`)
|
||||
|
||||
_, err := LoadRegistry(registryPath)
|
||||
if err == nil || !strings.Contains(err.Error(), "project_type is required") {
|
||||
t.Fatalf("LoadRegistry() error = %v, want missing field error", err)
|
||||
}
|
||||
}
|
||||
154
internal/managed/remote_client.go
Normal file
154
internal/managed/remote_client.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HTTPDoer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type RemoteClient struct {
|
||||
httpClient HTTPDoer
|
||||
}
|
||||
|
||||
func NewRemoteClient(client HTTPDoer) *RemoteClient {
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: 5 * time.Second}
|
||||
}
|
||||
return &RemoteClient{httpClient: client}
|
||||
}
|
||||
|
||||
func (c *RemoteClient) GetConfig(ctx context.Context, service Service) (map[string]any, error) {
|
||||
var payload map[string]any
|
||||
if err := c.getJSON(ctx, service, "/api/manage/config", &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) UpdateRTSP(ctx context.Context, service Service, rtsp string) (map[string]any, error) {
|
||||
body := strings.NewReader(fmt.Sprintf(`{"rtsp_url":%q}`, rtsp))
|
||||
req, err := c.newRequest(ctx, http.MethodPut, service, "/api/manage/config", body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, decodeAPIError(resp)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, fmt.Errorf("decode response %s: %w", req.URL.String(), err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) GetSummary(ctx context.Context, service Service) (*ResultSummary, error) {
|
||||
var summary ResultSummary
|
||||
if err := c.getJSON(ctx, service, "/api/manage/summary", &summary); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) GetFiles(ctx context.Context, service Service) ([]ResultFile, error) {
|
||||
var payload struct {
|
||||
Files []ResultFile `json:"files"`
|
||||
}
|
||||
if err := c.getJSON(ctx, service, "/api/manage/files", &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload.Files, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) PreviewFile(ctx context.Context, service Service, path string, lines int) (*FilePreview, error) {
|
||||
query := url.Values{}
|
||||
query.Set("path", path)
|
||||
query.Set("lines", fmt.Sprintf("%d", lines))
|
||||
|
||||
var preview FilePreview
|
||||
if err := c.getJSON(ctx, service, "/api/manage/files/preview?"+query.Encode(), &preview); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &preview, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Download(ctx context.Context, service Service, path string) (*http.Response, error) {
|
||||
query := url.Values{}
|
||||
query.Set("path", path)
|
||||
req, err := c.newRequest(ctx, http.MethodGet, service, "/api/manage/files/download?"+query.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
return nil, decodeAPIError(resp)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) getJSON(ctx context.Context, service Service, endpoint string, target any) error {
|
||||
req, err := c.newRequest(ctx, http.MethodGet, service, endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return decodeAPIError(resp)
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response %s: %w", req.URL.String(), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) newRequest(ctx context.Context, method string, service Service, endpoint string, body io.Reader) (*http.Request, error) {
|
||||
base := strings.TrimRight(service.APIBaseURL, "/")
|
||||
req, err := http.NewRequestWithContext(ctx, method, base+endpoint, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build request for %s%s: %w", base, endpoint, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeAPIError(resp *http.Response) error {
|
||||
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(data, &payload); err == nil {
|
||||
if message, ok := payload["error"].(string); ok && strings.TrimSpace(message) != "" {
|
||||
return errors.New(message)
|
||||
}
|
||||
}
|
||||
message := strings.TrimSpace(string(data))
|
||||
if message == "" {
|
||||
message = resp.Status
|
||||
}
|
||||
return errors.New(message)
|
||||
}
|
||||
141
internal/managed/remote_client_test.go
Normal file
141
internal/managed/remote_client_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func TestRemoteClientRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
storeConfig := map[string]any{
|
||||
"config_path": "/srv/store/config/local.yaml",
|
||||
"stream": map[string]any{
|
||||
"rtsp_url": "rtsp://store-old/stream",
|
||||
},
|
||||
}
|
||||
|
||||
client := NewRemoteClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
response := func(status int, body any) (*http.Response, error) {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader(data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/config":
|
||||
return response(http.StatusOK, storeConfig)
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/store/api/manage/config":
|
||||
var payload map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode update payload: %v", err)
|
||||
}
|
||||
storeConfig["stream"].(map[string]any)["rtsp_url"] = payload["rtsp_url"]
|
||||
return response(http.StatusOK, storeConfig)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/summary":
|
||||
return response(http.StatusOK, ResultSummary{
|
||||
ResultType: "store_dwell_alert",
|
||||
Headline: "Latest report shows 1 active customers, longest dwell 900s",
|
||||
LastResultTime: "2026-04-16T10:00:00+08:00",
|
||||
Metrics: map[string]any{
|
||||
"active_customer_count": 1,
|
||||
"longest_dwell_seconds": 900,
|
||||
},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files":
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"files": []ResultFile{{
|
||||
Path: "logs/events.jsonl",
|
||||
Name: "events.jsonl",
|
||||
}},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/preview":
|
||||
return response(http.StatusOK, FilePreview{
|
||||
Path: "logs/events.jsonl",
|
||||
Lines: []string{"line1", "line2"},
|
||||
Count: 2,
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/download":
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Disposition": []string{`attachment; filename="events.jsonl"`},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("downloaded")),
|
||||
}, nil
|
||||
default:
|
||||
t.Fatalf("unexpected child request: %s %s", r.Method, r.URL.String())
|
||||
return nil, nil
|
||||
}
|
||||
})})
|
||||
|
||||
service := Service{
|
||||
ID: "store_dwell_alert",
|
||||
APIBaseURL: "http://managed.invalid/store",
|
||||
}
|
||||
|
||||
configPayload, err := client.GetConfig(context.Background(), service)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConfig() error = %v", err)
|
||||
}
|
||||
if got := configPayload["config_path"]; got != "/srv/store/config/local.yaml" {
|
||||
t.Fatalf("config_path = %#v", got)
|
||||
}
|
||||
|
||||
if _, err := client.UpdateRTSP(context.Background(), service, "rtsp://store-new/stream"); err != nil {
|
||||
t.Fatalf("UpdateRTSP() error = %v", err)
|
||||
}
|
||||
if got := storeConfig["stream"].(map[string]any)["rtsp_url"]; got != "rtsp://store-new/stream" {
|
||||
t.Fatalf("updated rtsp = %#v", got)
|
||||
}
|
||||
|
||||
summary, err := client.GetSummary(context.Background(), service)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSummary() error = %v", err)
|
||||
}
|
||||
if summary.ResultType != "store_dwell_alert" {
|
||||
t.Fatalf("summary.ResultType = %q", summary.ResultType)
|
||||
}
|
||||
|
||||
files, err := client.GetFiles(context.Background(), service)
|
||||
if err != nil {
|
||||
t.Fatalf("GetFiles() error = %v", err)
|
||||
}
|
||||
if len(files) != 1 || files[0].Path != "logs/events.jsonl" {
|
||||
t.Fatalf("files = %#v", files)
|
||||
}
|
||||
|
||||
preview, err := client.PreviewFile(context.Background(), service, "logs/events.jsonl", 2)
|
||||
if err != nil {
|
||||
t.Fatalf("PreviewFile() error = %v", err)
|
||||
}
|
||||
if preview.Count != 2 {
|
||||
t.Fatalf("preview.Count = %d", preview.Count)
|
||||
}
|
||||
|
||||
resp, err := client.Download(context.Background(), service, "logs/events.jsonl")
|
||||
if err != nil {
|
||||
t.Fatalf("Download() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !strings.Contains(resp.Header.Get("Content-Disposition"), "events.jsonl") {
|
||||
t.Fatalf("Content-Disposition = %q", resp.Header.Get("Content-Disposition"))
|
||||
}
|
||||
}
|
||||
19
internal/managed/test_helpers_test.go
Normal file
19
internal/managed/test_helpers_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package managed
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeFile(t *testing.T, path, content string) string {
|
||||
t.Helper()
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll(%q): %v", path, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("WriteFile(%q): %v", path, err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
41
internal/managed/types.go
Normal file
41
internal/managed/types.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package managed
|
||||
|
||||
type ResultSummary struct {
|
||||
ResultType string `json:"result_type"`
|
||||
Headline string `json:"headline"`
|
||||
LastResultTime string `json:"last_result_time,omitempty"`
|
||||
Metrics map[string]any `json:"metrics,omitempty"`
|
||||
}
|
||||
|
||||
type StoreDwellWindowStat struct {
|
||||
WindowStart string `json:"window_start"`
|
||||
WindowEnd string `json:"window_end"`
|
||||
ActiveCustomerCount int `json:"active_customer_count"`
|
||||
ActiveWaitSeconds []int `json:"active_wait_seconds"`
|
||||
ClosedWaitSeconds []int `json:"closed_wait_seconds"`
|
||||
MaxWaitSeconds int `json:"max_wait_seconds"`
|
||||
}
|
||||
|
||||
type PeopleFlowWindowStat struct {
|
||||
WindowStart string `json:"window_start"`
|
||||
WindowEnd string `json:"window_end"`
|
||||
TotalPeople int `json:"total_people"`
|
||||
AgeCounts map[string]int `json:"age_counts"`
|
||||
GenderCounts map[string]int `json:"gender_counts"`
|
||||
UnknownAttributes int `json:"unknown_attributes"`
|
||||
}
|
||||
|
||||
type ResultFile struct {
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
Label string `json:"label"`
|
||||
Kind string `json:"kind"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
}
|
||||
|
||||
type FilePreview struct {
|
||||
Path string `json:"path"`
|
||||
Lines []string `json:"lines"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
271
internal/server/managed_handlers_test.go
Normal file
271
internal/server/managed_handlers_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"managed-portal/internal/managed"
|
||||
)
|
||||
|
||||
type fakeDockerRuntime struct {
|
||||
statusByContainer map[string]string
|
||||
restarted []string
|
||||
}
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func (f *fakeDockerRuntime) GetContainerStatus(containerName string) (string, error) {
|
||||
if status, ok := f.statusByContainer[containerName]; ok {
|
||||
return status, nil
|
||||
}
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
func (f *fakeDockerRuntime) RestartContainer(containerName string) error {
|
||||
f.restarted = append(f.restarted, containerName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagedServicesHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
storeConfigPath := "/srv/store/config/local.yaml"
|
||||
storeRTSP := "rtsp://store-old/stream"
|
||||
|
||||
srv := New(nil)
|
||||
registry := &managed.Registry{
|
||||
Services: []managed.Service{{
|
||||
ID: "store_dwell_alert",
|
||||
DisplayName: "Store Dwell Alert",
|
||||
ProjectType: "store_dwell_alert",
|
||||
ProjectRoot: "/srv/store",
|
||||
ContainerName: "store-dwell-alert",
|
||||
ServiceName: "store-dwell-alert",
|
||||
APIBaseURL: "http://managed.invalid/store",
|
||||
ResultType: "store_dwell_alert",
|
||||
}, {
|
||||
ID: "people_flow_project",
|
||||
DisplayName: "People Flow Project",
|
||||
ProjectType: "people_flow_project",
|
||||
ProjectRoot: "/srv/people",
|
||||
ContainerName: "people-flow-project",
|
||||
ServiceName: "people-flow-project",
|
||||
APIBaseURL: "http://managed.invalid/people",
|
||||
ResultType: "people_flow_project",
|
||||
}},
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
response := func(status int, body any) (*http.Response, error) {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader(data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/config":
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"config_path": storeConfigPath,
|
||||
"stream": map[string]any{"rtsp_url": storeRTSP},
|
||||
})
|
||||
case r.Method == http.MethodPut && r.URL.Path == "/store/api/manage/config":
|
||||
var payload map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("decode store config update: %v", err)
|
||||
}
|
||||
storeRTSP = payload["rtsp_url"]
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"config_path": storeConfigPath,
|
||||
"stream": map[string]any{"rtsp_url": storeRTSP},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/summary":
|
||||
return response(http.StatusOK, managed.ResultSummary{
|
||||
ResultType: "store_dwell_alert",
|
||||
Headline: "Latest report shows 2 active customers, longest dwell 850s",
|
||||
LastResultTime: "2026-04-16T09:30:00+08:00",
|
||||
Metrics: map[string]any{
|
||||
"longest_dwell_seconds": 850,
|
||||
},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files":
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"files": []managed.ResultFile{{
|
||||
Path: "logs/events.jsonl",
|
||||
Name: "events.jsonl",
|
||||
}},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/preview":
|
||||
return response(http.StatusOK, managed.FilePreview{
|
||||
Path: "logs/events.jsonl",
|
||||
Lines: []string{"preview"},
|
||||
Count: 1,
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/store/api/manage/files/download":
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Disposition": []string{`attachment; filename="events.jsonl"`},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("store-download")),
|
||||
}, nil
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/people/api/manage/config":
|
||||
return response(http.StatusOK, map[string]any{
|
||||
"config_path": "/srv/people/config/local.yaml",
|
||||
"runtime": map[string]any{"rtsp_url": "rtsp://people-old/stream"},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/people/api/manage/summary":
|
||||
return response(http.StatusOK, managed.ResultSummary{
|
||||
ResultType: "people_flow_project",
|
||||
Headline: "Latest window counted 5 people",
|
||||
LastResultTime: "2026-04-16T09:00:00+08:00",
|
||||
Metrics: map[string]any{
|
||||
"recent_window_stats": []map[string]any{{"total_people": 5}},
|
||||
},
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/people/api/manage/files":
|
||||
return response(http.StatusOK, map[string]any{"files": []managed.ResultFile{}})
|
||||
default:
|
||||
t.Fatalf("unexpected child request: %s %s", r.Method, r.URL.String())
|
||||
return nil, nil
|
||||
}
|
||||
})}
|
||||
|
||||
docker := &fakeDockerRuntime{
|
||||
statusByContainer: map[string]string{
|
||||
"store-dwell-alert": "running",
|
||||
"people-flow-project": "stopped",
|
||||
},
|
||||
}
|
||||
srv.managedManager = managed.NewManager(registry, docker, managed.NewRemoteClient(client))
|
||||
|
||||
t.Run("GET /api/managed-services", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Services []managed.ServiceState `json:"services"`
|
||||
}
|
||||
if err := json.Unmarshal(recorder.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(payload.Services) != 2 {
|
||||
t.Fatalf("len(services) = %d", len(payload.Services))
|
||||
}
|
||||
if payload.Services[0].RTSP == "" {
|
||||
t.Fatalf("expected RTSP in list response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET /api/managed-services/:id", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services/store_dwell_alert", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "longest_dwell_seconds") {
|
||||
t.Fatalf("detail response missing summary metrics: %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PUT /api/managed-services/:id/config", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
body := bytes.NewBufferString(`{"rtsp_url":"rtsp://store-new/stream"}`)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/managed-services/store_dwell_alert/config", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if storeRTSP != "rtsp://store-new/stream" {
|
||||
t.Fatalf("storeRTSP = %q", storeRTSP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("POST /api/managed-services/:id/restart", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/managed-services/people_flow_project/restart", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if len(docker.restarted) != 1 || docker.restarted[0] != "people-flow-project" {
|
||||
t.Fatalf("restarted = %#v", docker.restarted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET /api/managed-services/:id/results/summary", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services/store_dwell_alert/results/summary", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "active customers") {
|
||||
t.Fatalf("summary response = %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET /api/managed-services/:id/results/files", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services/store_dwell_alert/results/files", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "events.jsonl") {
|
||||
t.Fatalf("files response missing expected file: %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET /api/managed-services/:id/results/preview", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services/store_dwell_alert/results/preview?path=logs/events.jsonl&lines=1", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "preview") {
|
||||
t.Fatalf("preview response = %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GET /api/managed-services/:id/results/download", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/managed-services/store_dwell_alert/results/download?path=logs/events.jsonl", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if got := recorder.Body.String(); got != "store-download" {
|
||||
t.Fatalf("download body = %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
223
internal/server/server.go
Normal file
223
internal/server/server.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"managed-portal/internal/config"
|
||||
"managed-portal/internal/managed"
|
||||
"managed-portal/internal/webdevice"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
engine *gin.Engine
|
||||
managedManager *managed.Manager
|
||||
webDeviceSvc *webdevice.Service
|
||||
}
|
||||
|
||||
func New(cfg *config.Config) *Server {
|
||||
if cfg == nil {
|
||||
cfg = config.Load()
|
||||
}
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Logger(), gin.Recovery())
|
||||
engine.Use(cors.Default())
|
||||
|
||||
srv := &Server{
|
||||
cfg: cfg,
|
||||
engine: engine,
|
||||
}
|
||||
srv.managedManager = managed.NewManager(loadRegistry(cfg.RegistryPath), nil, nil)
|
||||
srv.webDeviceSvc = webdevice.NewService()
|
||||
srv.registerRoutes()
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *Server) registerRoutes() {
|
||||
api := s.engine.Group("/api")
|
||||
api.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
api.GET("/managed-services", s.listManagedServices)
|
||||
api.GET("/managed-services/:id", s.getManagedService)
|
||||
api.PUT("/managed-services/:id/config", s.updateManagedServiceConfig)
|
||||
api.POST("/managed-services/:id/restart", s.restartManagedService)
|
||||
api.GET("/managed-services/:id/results/summary", s.getManagedServiceSummary)
|
||||
api.GET("/managed-services/:id/results/files", s.listManagedServiceFiles)
|
||||
api.GET("/managed-services/:id/results/preview", s.previewManagedServiceFile)
|
||||
api.GET("/managed-services/:id/results/download", s.downloadManagedServiceFile)
|
||||
api.GET("/web-devices/scan", s.scanWebDevices)
|
||||
s.engine.Any("/proxy/web/:ip/*proxyPath", s.proxyWebDevice)
|
||||
}
|
||||
|
||||
func (s *Server) Engine() *gin.Engine {
|
||||
return s.engine
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
return s.engine.Run(s.cfg.HTTPAddr)
|
||||
}
|
||||
|
||||
func loadRegistry(path string) *managed.Registry {
|
||||
registry, err := managed.LoadRegistry(path)
|
||||
if err != nil {
|
||||
return managed.EmptyRegistry()
|
||||
}
|
||||
return registry
|
||||
}
|
||||
|
||||
func (s *Server) listManagedServices(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"services": s.managedManager.List(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) getManagedService(c *gin.Context) {
|
||||
service, err := s.managedManager.Detail(c.Param("id"))
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"service": service})
|
||||
}
|
||||
|
||||
func (s *Server) updateManagedServiceConfig(c *gin.Context) {
|
||||
var req struct {
|
||||
RTSPURL string `json:"rtsp_url"`
|
||||
RTSP string `json:"rtsp"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rtsp := strings.TrimSpace(req.RTSPURL)
|
||||
if rtsp == "" {
|
||||
rtsp = strings.TrimSpace(req.RTSP)
|
||||
}
|
||||
|
||||
service, err := s.managedManager.UpdateRTSP(c.Param("id"), rtsp)
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"service": service})
|
||||
}
|
||||
|
||||
func (s *Server) restartManagedService(c *gin.Context) {
|
||||
service, err := s.managedManager.Restart(c.Param("id"))
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"service": service})
|
||||
}
|
||||
|
||||
func (s *Server) getManagedServiceSummary(c *gin.Context) {
|
||||
summary, err := s.managedManager.Summary(c.Param("id"))
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"summary": summary})
|
||||
}
|
||||
|
||||
func (s *Server) listManagedServiceFiles(c *gin.Context) {
|
||||
files, err := s.managedManager.Files(c.Param("id"))
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
func (s *Server) previewManagedServiceFile(c *gin.Context) {
|
||||
lines := 2000
|
||||
if raw := strings.TrimSpace(c.Query("lines")); raw != "" {
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil || parsed <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "lines 参数必须为正整数"})
|
||||
return
|
||||
}
|
||||
if parsed > 2000 {
|
||||
parsed = 2000
|
||||
}
|
||||
lines = parsed
|
||||
}
|
||||
|
||||
preview, err := s.managedManager.PreviewFile(c.Param("id"), c.Query("path"), lines)
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, preview)
|
||||
}
|
||||
|
||||
func (s *Server) downloadManagedServiceFile(c *gin.Context) {
|
||||
resp, err := s.managedManager.Download(c.Request.Context(), c.Param("id"), c.Query("path"))
|
||||
if err != nil {
|
||||
s.handleManagedError(c, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if contentType := resp.Header.Get("Content-Type"); contentType != "" {
|
||||
c.Header("Content-Type", contentType)
|
||||
}
|
||||
if contentDisposition := resp.Header.Get("Content-Disposition"); contentDisposition != "" {
|
||||
c.Header("Content-Disposition", contentDisposition)
|
||||
}
|
||||
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取文件失败"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleManagedError(c *gin.Context, err error) {
|
||||
switch {
|
||||
case errors.Is(err, managed.ErrServiceNotFound):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "被管理服务不存在"})
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
case strings.Contains(err.Error(), "rtsp url"):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case strings.Contains(err.Error(), "invalid file path"):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) scanWebDevices(c *gin.Context) {
|
||||
result, err := s.webDeviceSvc.Scan(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取网卡信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (s *Server) proxyWebDevice(c *gin.Context) {
|
||||
err := s.webDeviceSvc.ProxyHTTP(c.Writer, c.Request, c.Param("ip"), c.Param("proxyPath"))
|
||||
switch {
|
||||
case err == nil:
|
||||
return
|
||||
case errors.Is(err, webdevice.ErrInvalidTargetIP):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "仅支持内网IPv4地址"})
|
||||
case errors.Is(err, webdevice.ErrTargetNotAllowed):
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "目标IP未在扫描结果中,请先扫描网页设备"})
|
||||
case errors.Is(err, webdevice.ErrInvalidProxyURL):
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "代理目标无效"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
23
internal/server/server_test.go
Normal file
23
internal/server/server_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
srv := New(nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
srv.Engine().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec.Body.String(); !strings.Contains(body, `"status":"ok"`) {
|
||||
t.Fatalf("unexpected body: %s", body)
|
||||
}
|
||||
}
|
||||
83
internal/server/webdevice_handlers_test.go
Normal file
83
internal/server/webdevice_handlers_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"managed-portal/internal/webdevice"
|
||||
)
|
||||
|
||||
func TestWebDeviceHandlers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GET /api/web-devices/scan", func(t *testing.T) {
|
||||
srv := New(nil)
|
||||
svc := webdevice.NewService()
|
||||
svc.SetInterfaceGetter(func() ([]webdevice.InterfaceInfo, error) {
|
||||
return []webdevice.InterfaceInfo{{
|
||||
Name: "eth0",
|
||||
IP: "10.8.0.14",
|
||||
Netmask: "255.255.255.0",
|
||||
}}, nil
|
||||
})
|
||||
svc.SetTCPScanner(func(ip, netmask string, port int, excludeIPs map[string]bool) ([]webdevice.TCPDevice, error) {
|
||||
return []webdevice.TCPDevice{{IP: "192.168.1.124", Port: 80}}, nil
|
||||
})
|
||||
svc.SetForwarderFactory(func(ip string, port int, listenAddress, targetAddress string) (*webdevice.WebDeviceForwarder, error) {
|
||||
return nil, nil
|
||||
})
|
||||
srv.webDeviceSvc = svc
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://10.8.0.14:13000/api/web-devices/scan", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "192.168.1.124") {
|
||||
t.Fatalf("scan response = %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ANY /proxy/web/:ip/*proxyPath rejects unscanned IP", func(t *testing.T) {
|
||||
srv := New(nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ANY /proxy/web/:ip/*proxyPath proxies allowed IP", func(t *testing.T) {
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte(`<html><head></head><body><img src="/doc/logo.png"></body></html>`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
srv := New(nil)
|
||||
svc := webdevice.NewService()
|
||||
svc.AllowIP("192.168.1.124")
|
||||
svc.SetProxyTargetResolver(func(ip string) string {
|
||||
return upstream.URL
|
||||
})
|
||||
srv.webDeviceSvc = svc
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
srv.engine.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
if !strings.Contains(recorder.Body.String(), "/proxy/web/192.168.1.124/doc/logo.png") {
|
||||
t.Fatalf("proxy response = %s", recorder.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
101
internal/webdevice/proxy.go
Normal file
101
internal/webdevice/proxy.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidTargetIP = errors.New("invalid target ip")
|
||||
ErrTargetNotAllowed = errors.New("target ip not allowed")
|
||||
ErrInvalidProxyURL = errors.New("invalid proxy target")
|
||||
)
|
||||
|
||||
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
|
||||
if !IsPrivateIPv4Literal(targetIP) {
|
||||
return ErrInvalidTargetIP
|
||||
}
|
||||
if !s.IsAllowed(targetIP) {
|
||||
return ErrTargetNotAllowed
|
||||
}
|
||||
|
||||
rawTarget := s.ProxyTargetURL(targetIP)
|
||||
targetURL, err := url.Parse(rawTarget)
|
||||
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
|
||||
return ErrInvalidProxyURL
|
||||
}
|
||||
|
||||
if proxyPath == "" {
|
||||
proxyPath = "/"
|
||||
}
|
||||
rawQuery := r.URL.RawQuery
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = JoinProxyTargetPath(targetURL.Path, proxyPath)
|
||||
req.URL.RawPath = ""
|
||||
req.URL.RawQuery = rawQuery
|
||||
req.Host = targetURL.Host
|
||||
req.Header.Del("Accept-Encoding")
|
||||
},
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
proxyPrefix := "/proxy/web/" + targetIP
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
resp.Header.Set("Location", RewriteLocation(targetIP, targetURL, location))
|
||||
}
|
||||
if cookies := resp.Header.Values("Set-Cookie"); len(cookies) > 0 {
|
||||
resp.Header.Del("Set-Cookie")
|
||||
for _, cookie := range cookies {
|
||||
resp.Header.Add("Set-Cookie", RewriteSetCookie(cookie, proxyPrefix))
|
||||
}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if ShouldRewriteBody(contentType) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
rewritten := RewriteText(string(body), proxyPrefix, targetURL, contentType)
|
||||
rewrittenBytes := []byte(rewritten)
|
||||
resp.Body = io.NopCloser(strings.NewReader(rewritten))
|
||||
resp.ContentLength = int64(len(rewrittenBytes))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(rewrittenBytes)))
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-MD5")
|
||||
resp.Header.Del("Etag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "代理访问失败: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(closeNotifyWriter{ResponseWriter: w}, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
type closeNotifyWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w closeNotifyWriter) CloseNotify() <-chan bool {
|
||||
ch := make(chan bool, 1)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w closeNotifyWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
93
internal/webdevice/proxy_test.go
Normal file
93
internal/webdevice/proxy_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteLocation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetURL, _ := url.Parse("http://192.168.1.124:80")
|
||||
got := RewriteLocation("192.168.1.124", targetURL, "http://192.168.1.124/ISAPI/Security")
|
||||
if got != "/proxy/web/192.168.1.124/ISAPI/Security" {
|
||||
t.Fatalf("RewriteLocation() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteSetCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := RewriteSetCookie("SID=1; Path=/; Domain=192.168.1.124; HttpOnly", "/proxy/web/192.168.1.124")
|
||||
if strings.Contains(strings.ToLower(got), "domain=") {
|
||||
t.Fatalf("RewriteSetCookie() kept domain: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Path=/proxy/web/192.168.1.124/") {
|
||||
t.Fatalf("RewriteSetCookie() path = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetURL, _ := url.Parse("http://192.168.1.124:80")
|
||||
body := `<html><head></head><body><img src="/doc/logo.png"><a href="http://192.168.1.124/ISAPI/x">x</a></body></html>`
|
||||
got := RewriteText(body, "/proxy/web/192.168.1.124", targetURL, "text/html")
|
||||
|
||||
if !strings.Contains(got, `/proxy/web/192.168.1.124/doc/logo.png`) {
|
||||
t.Fatalf("rewritten body missing proxied relative URL: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `data-web-proxy-runtime`) {
|
||||
t.Fatalf("rewritten body missing runtime injection: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPRejectsUnscannedIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/")
|
||||
if err != ErrTargetNotAllowed {
|
||||
t.Fatalf("ProxyHTTP() error = %v, want ErrTargetNotAllowed", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPServesAllowedTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", "http://192.168.1.124/ISAPI/test")
|
||||
w.Header().Add("Set-Cookie", "SID=1; Path=/")
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte(`<html><head></head><body><img src="/doc/logo.png"></body></html>`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
svc := NewService()
|
||||
svc.allowIP("192.168.1.124")
|
||||
svc.proxyTarget = func(ip string) string {
|
||||
return upstream.URL
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
if err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/"); err != nil {
|
||||
t.Fatalf("ProxyHTTP() error = %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := rec.Header().Get("Location"); got != "http://192.168.1.124/ISAPI/test" {
|
||||
t.Fatalf("Location = %q", got)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "/proxy/web/192.168.1.124/doc/logo.png") {
|
||||
t.Fatalf("body = %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
240
internal/webdevice/rewrite.go
Normal file
240
internal/webdevice/rewrite.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func JoinProxyTargetPath(basePath, requestPath string) string {
|
||||
if requestPath == "" {
|
||||
requestPath = "/"
|
||||
}
|
||||
if basePath == "" || basePath == "/" {
|
||||
return requestPath
|
||||
}
|
||||
if strings.HasSuffix(basePath, "/") && strings.HasPrefix(requestPath, "/") {
|
||||
return basePath + strings.TrimPrefix(requestPath, "/")
|
||||
}
|
||||
if !strings.HasSuffix(basePath, "/") && !strings.HasPrefix(requestPath, "/") {
|
||||
return basePath + "/" + requestPath
|
||||
}
|
||||
return basePath + requestPath
|
||||
}
|
||||
|
||||
func RewriteLocation(targetIP string, targetURL *url.URL, location string) string {
|
||||
locationURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return location
|
||||
}
|
||||
|
||||
proxyPrefix := "/proxy/web/" + targetIP
|
||||
if locationURL.Host == "" && strings.HasPrefix(location, "/") {
|
||||
return proxyPrefix + location
|
||||
}
|
||||
|
||||
if locationURL.Host != "" {
|
||||
locationHost := locationURL.Hostname()
|
||||
locationPort := locationURL.Port()
|
||||
if locationPort == "" && (locationURL.Scheme == "" || locationURL.Scheme == "http") {
|
||||
locationPort = "80"
|
||||
}
|
||||
|
||||
targetHost := targetURL.Hostname()
|
||||
targetPort := targetURL.Port()
|
||||
if targetPort == "" && targetURL.Scheme == "http" {
|
||||
targetPort = "80"
|
||||
}
|
||||
|
||||
if locationHost != targetHost || locationPort != targetPort {
|
||||
return location
|
||||
}
|
||||
|
||||
rewrittenPath := locationURL.EscapedPath()
|
||||
if rewrittenPath == "" {
|
||||
rewrittenPath = "/"
|
||||
}
|
||||
if locationURL.RawQuery != "" {
|
||||
rewrittenPath += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewrittenPath += "#" + locationURL.Fragment
|
||||
}
|
||||
return proxyPrefix + rewrittenPath
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
var (
|
||||
webProxyQuotedAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*(['"])([^'"]*)['"]`)
|
||||
webProxyBareAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*([^'">\s][^>\s]*)`)
|
||||
webProxyCSSURLPattern = regexp.MustCompile(`(?i)url\(\s*(['"]?)([^'"\)\s]+)['"]?\s*\)`)
|
||||
webProxyQuotedURLPattern = regexp.MustCompile(`(['"])(/[^'"<>\s\\)]*)['"]`)
|
||||
)
|
||||
|
||||
func ShouldRewriteBody(contentType string) bool {
|
||||
contentType = strings.ToLower(contentType)
|
||||
return strings.Contains(contentType, "text/html") ||
|
||||
strings.Contains(contentType, "text/css")
|
||||
}
|
||||
|
||||
func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string {
|
||||
contentType = strings.ToLower(contentType)
|
||||
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyQuotedAttrPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 4 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[3], proxyPrefix, targetURL)
|
||||
return strings.Replace(match, parts[2]+parts[3]+parts[2], parts[2]+rewritten+parts[2], 1)
|
||||
})
|
||||
|
||||
body = webProxyBareAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyBareAttrPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return strings.Replace(match, parts[2], rewritten, 1)
|
||||
})
|
||||
|
||||
body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyQuotedURLPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return parts[1] + rewritten + parts[1]
|
||||
})
|
||||
}
|
||||
|
||||
body = webProxyCSSURLPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyCSSURLPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return "url(" + parts[1] + rewritten + parts[1] + ")"
|
||||
})
|
||||
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
body = injectRuntime(body, proxyPrefix)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func injectRuntime(body, proxyPrefix string) string {
|
||||
if strings.Contains(body, "data-web-proxy-runtime") {
|
||||
return body
|
||||
}
|
||||
|
||||
script := `<script data-web-proxy-runtime>(function(){var p="` + proxyPrefix + `";var d=["/ISAPI","/SDK","/PSIA","/doc","/webSocket"];function q(x){if(x.indexOf(p+"/")===0||x.indexOf("/proxy/web/")===0){return x}for(var i=0;i<d.length;i++){if(x===d[i]||x.indexOf(d[i]+"/")===0||x.indexOf(d[i]+"?")===0){return p+x}}return x}function r(u){if(typeof u!=="string"){return u}if(u.charAt(0)==="/"&&u.indexOf("//")!==0){return q(u)}try{var a=new URL(u,window.location.href);if(a.origin===window.location.origin){var x=a.pathname+a.search+a.hash;var y=q(x);if(y!==x){return y}}}catch(e){}return u}if(window.XMLHttpRequest){var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){arguments[1]=r(u);return o.apply(this,arguments)}}if(window.fetch){var f=window.fetch;window.fetch=function(i,n){if(typeof i==="string"){i=r(i)}else if(i&&i.url){i=new Request(r(i.url),i)}return f.call(this,i,n)}}function a(e){if(!e||!e.getAttribute){return}["src","href","action","data-src","data-href"].forEach(function(k){var v=e.getAttribute(k);if(v){var nv=r(v);if(nv!==v){e.setAttribute(k,nv)}}})}if(window.MutationObserver){new MutationObserver(function(ms){ms.forEach(function(m){if(m.type==="attributes"){a(m.target)}else{Array.prototype.forEach.call(m.addedNodes,function(n){a(n);if(n&&n.querySelectorAll){Array.prototype.forEach.call(n.querySelectorAll("[src],[href],[action],[data-src],[data-href]"),a)}})}})}).observe(document.documentElement,{childList:true,subtree:true,attributes:true,attributeFilter:["src","href","action","data-src","data-href"]})}})();</script>`
|
||||
|
||||
lower := strings.ToLower(body)
|
||||
if idx := strings.Index(lower, "</head>"); idx >= 0 {
|
||||
return body[:idx] + script + body[idx:]
|
||||
}
|
||||
if idx := strings.Index(lower, "<body"); idx >= 0 {
|
||||
if end := strings.Index(body[idx:], ">"); end >= 0 {
|
||||
insertAt := idx + end + 1
|
||||
return body[:insertAt] + script + body[insertAt:]
|
||||
}
|
||||
}
|
||||
return script + body
|
||||
}
|
||||
|
||||
func rewriteURL(rawURL, proxyPrefix string, targetURL *url.URL) string {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" ||
|
||||
strings.HasPrefix(rawURL, "#") ||
|
||||
strings.HasPrefix(rawURL, "//") ||
|
||||
strings.HasPrefix(rawURL, proxyPrefix+"/") ||
|
||||
strings.HasPrefix(rawURL, "/proxy/web/") {
|
||||
return rawURL
|
||||
}
|
||||
|
||||
lower := strings.ToLower(rawURL)
|
||||
for _, prefix := range []string{"data:", "blob:", "mailto:", "tel:", "javascript:"} {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return rawURL
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(rawURL, "/") {
|
||||
return proxyPrefix + rawURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return rawURL
|
||||
}
|
||||
if targetURL == nil || !sameUpstreamHost(parsed, targetURL) {
|
||||
return rawURL
|
||||
}
|
||||
|
||||
rewrittenPath := parsed.EscapedPath()
|
||||
if rewrittenPath == "" {
|
||||
rewrittenPath = "/"
|
||||
}
|
||||
if parsed.RawQuery != "" {
|
||||
rewrittenPath += "?" + parsed.RawQuery
|
||||
}
|
||||
if parsed.Fragment != "" {
|
||||
rewrittenPath += "#" + parsed.Fragment
|
||||
}
|
||||
return proxyPrefix + rewrittenPath
|
||||
}
|
||||
|
||||
func sameUpstreamHost(left, right *url.URL) bool {
|
||||
leftPort := left.Port()
|
||||
if leftPort == "" && (left.Scheme == "" || left.Scheme == "http") {
|
||||
leftPort = "80"
|
||||
}
|
||||
if leftPort == "" && left.Scheme == "https" {
|
||||
leftPort = "443"
|
||||
}
|
||||
|
||||
rightPort := right.Port()
|
||||
if rightPort == "" && (right.Scheme == "" || right.Scheme == "http") {
|
||||
rightPort = "80"
|
||||
}
|
||||
if rightPort == "" && right.Scheme == "https" {
|
||||
rightPort = "443"
|
||||
}
|
||||
|
||||
return strings.EqualFold(left.Hostname(), right.Hostname()) && leftPort == rightPort
|
||||
}
|
||||
|
||||
func RewriteSetCookie(cookie, proxyPrefix string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
if len(parts) == 0 {
|
||||
return cookie
|
||||
}
|
||||
|
||||
rewritten := []string{strings.TrimSpace(parts[0])}
|
||||
hasPath := false
|
||||
for _, part := range parts[1:] {
|
||||
attr := strings.TrimSpace(part)
|
||||
if attr == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(attr)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "domain="):
|
||||
continue
|
||||
case strings.HasPrefix(lower, "path="):
|
||||
hasPath = true
|
||||
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
|
||||
default:
|
||||
rewritten = append(rewritten, attr)
|
||||
}
|
||||
}
|
||||
if !hasPath {
|
||||
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
|
||||
}
|
||||
return strings.Join(rewritten, "; ")
|
||||
}
|
||||
495
internal/webdevice/service.go
Normal file
495
internal/webdevice/service.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type InterfaceInfo struct {
|
||||
Name string `json:"name"`
|
||||
IP string `json:"ip"`
|
||||
Netmask string `json:"netmask"`
|
||||
}
|
||||
|
||||
type TCPDevice struct {
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Interface string `json:"interface"`
|
||||
Port int `json:"port"`
|
||||
TargetURL string `json:"target_url"`
|
||||
ProxyURL string `json:"proxy_url"`
|
||||
DirectURL string `json:"direct_url,omitempty"`
|
||||
ForwardPort int `json:"forward_port,omitempty"`
|
||||
}
|
||||
|
||||
type ScanResult struct {
|
||||
Interfaces []InterfaceInfo `json:"interfaces"`
|
||||
Devices []DeviceInfo `json:"devices"`
|
||||
Count int `json:"count"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type InterfaceGetter func() ([]InterfaceInfo, error)
|
||||
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
|
||||
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
|
||||
type ProxyTargetResolver func(ip string) string
|
||||
|
||||
type Service struct {
|
||||
mu sync.RWMutex
|
||||
allowed map[string]time.Time
|
||||
forwarders map[string]*webDeviceForwarder
|
||||
interfaceGetter InterfaceGetter
|
||||
tcpScanner TCPScanner
|
||||
newForwarder ForwarderFactory
|
||||
proxyTarget ProxyTargetResolver
|
||||
forwardTarget func(ip string) string
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
allowed: make(map[string]time.Time),
|
||||
forwarders: make(map[string]*webDeviceForwarder),
|
||||
interfaceGetter: defaultInterfaceGetter,
|
||||
tcpScanner: scanTCP,
|
||||
newForwarder: newWebDeviceForwarder,
|
||||
proxyTarget: defaultProxyTarget,
|
||||
forwardTarget: defaultForwardTarget,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
interfaces, err := s.interfaceGetter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(interfaces) == 0 {
|
||||
return &ScanResult{
|
||||
Interfaces: []InterfaceInfo{},
|
||||
Devices: []DeviceInfo{},
|
||||
Message: "未找到有效的网卡",
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeIPs := make(map[string]bool)
|
||||
for _, iface := range interfaces {
|
||||
excludeIPs[iface.IP] = true
|
||||
}
|
||||
|
||||
result := &ScanResult{
|
||||
Interfaces: interfaces,
|
||||
Devices: []DeviceInfo{},
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
scheme, host := requestBase(r)
|
||||
for _, iface := range interfaces {
|
||||
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
|
||||
if scanErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
if !IsPrivateIPv4Literal(device.IP) {
|
||||
continue
|
||||
}
|
||||
|
||||
s.allowIP(device.IP)
|
||||
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
|
||||
if forwardErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动网页直连转发失败: %v", device.IP, forwardErr))
|
||||
}
|
||||
|
||||
deviceInfo := DeviceInfo{
|
||||
IP: device.IP,
|
||||
Interface: iface.Name,
|
||||
Port: device.Port,
|
||||
TargetURL: fmt.Sprintf("http://%s/", device.IP),
|
||||
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
|
||||
}
|
||||
if forwardErr == nil {
|
||||
deviceInfo.ForwardPort = forwardPort
|
||||
deviceInfo.DirectURL = buildDirectURL(scheme, host, forwardPort)
|
||||
}
|
||||
result.Devices = append(result.Devices, deviceInfo)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result.Devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
|
||||
})
|
||||
result.Count = len(result.Devices)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) allowIP(ip string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowed[ip] = time.Now()
|
||||
}
|
||||
|
||||
func (s *Service) AllowIP(ip string) {
|
||||
s.allowIP(ip)
|
||||
}
|
||||
|
||||
func (s *Service) IsAllowed(ip string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.allowed[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
|
||||
if getter != nil {
|
||||
s.interfaceGetter = getter
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetTCPScanner(scanner TCPScanner) {
|
||||
if scanner != nil {
|
||||
s.tcpScanner = scanner
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
|
||||
if factory != nil {
|
||||
s.newForwarder = factory
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
|
||||
if resolver != nil {
|
||||
s.proxyTarget = resolver
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) EnsureForwarder(ip string) (int, error) {
|
||||
port, ok := WebDeviceForwardPort(ip)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("无效的IPv4地址")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if forwarder, ok := s.forwarders[ip]; ok {
|
||||
return forwarder.port, nil
|
||||
}
|
||||
|
||||
targetAddress := defaultForwardTarget(ip)
|
||||
if s.forwardTarget != nil {
|
||||
targetAddress = s.forwardTarget(ip)
|
||||
}
|
||||
|
||||
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.forwarders[ip] = forwarder
|
||||
go forwarder.serve()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func (s *Service) ProxyTargetURL(ip string) string {
|
||||
if s.proxyTarget == nil {
|
||||
return defaultProxyTarget(ip)
|
||||
}
|
||||
return s.proxyTarget(ip)
|
||||
}
|
||||
|
||||
func IsPrivateIPv4Literal(value string) bool {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
|
||||
}
|
||||
|
||||
func WebDeviceForwardPort(ip string) (int, bool) {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return 0, false
|
||||
}
|
||||
ip4 := parsed.To4()
|
||||
if ip4 == nil {
|
||||
return 0, false
|
||||
}
|
||||
return 31000 + int(ip4[3]), true
|
||||
}
|
||||
|
||||
func requestBase(r *http.Request) (string, string) {
|
||||
scheme := r.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
if hostname, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = hostname
|
||||
}
|
||||
return scheme, host
|
||||
}
|
||||
|
||||
func buildDirectURL(scheme, host string, port int) string {
|
||||
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
|
||||
}
|
||||
|
||||
type webDeviceForwarder struct {
|
||||
ip string
|
||||
port int
|
||||
targetAddress string
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
type WebDeviceForwarder = webDeviceForwarder
|
||||
|
||||
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port == 0 {
|
||||
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||
port = tcpAddr.Port
|
||||
}
|
||||
}
|
||||
return &webDeviceForwarder{
|
||||
ip: ip,
|
||||
port: port,
|
||||
targetAddress: targetAddress,
|
||||
listener: listener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) serve() {
|
||||
if f == nil || f.listener == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
clientConn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go f.handle(clientConn)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
|
||||
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := ioCopy(targetConn, clientConn)
|
||||
errCh <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := ioCopy(clientConn, targetConn)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
<-errCh
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return copyConn(dst, src)
|
||||
}
|
||||
|
||||
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return io.Copy(dst, src)
|
||||
}
|
||||
|
||||
func defaultProxyTarget(ip string) string {
|
||||
return "http://" + net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultForwardTarget(ip string) string {
|
||||
return net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
|
||||
var interfaces []InterfaceInfo
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(iface.Name, "docker") ||
|
||||
strings.Contains(iface.Name, "veth") ||
|
||||
strings.Contains(iface.Name, "br-") ||
|
||||
strings.Contains(iface.Name, "tun") ||
|
||||
strings.Contains(iface.Name, "tap") ||
|
||||
strings.HasPrefix(iface.Name, "lo") {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipNet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
interfaces = append(interfaces, InterfaceInfo{
|
||||
Name: iface.Name,
|
||||
IP: ipNet.IP.String(),
|
||||
Netmask: net.IP(ipNet.Mask).String(),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
ipRange, err := calculateIPRange(ip, netmask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port <= 0 || port > 65535 {
|
||||
return nil, fmt.Errorf("无效的端口: %d", port)
|
||||
}
|
||||
|
||||
var devices []TCPDevice
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 20)
|
||||
timeout := 2 * time.Second
|
||||
|
||||
current := make(net.IP, len(ipRange.Start))
|
||||
copy(current, ipRange.Start)
|
||||
incrementIP(current)
|
||||
|
||||
for {
|
||||
if current.To4().Equal(ipRange.End.To4()) {
|
||||
break
|
||||
}
|
||||
|
||||
currentIP := current.String()
|
||||
if excludeIPs[currentIP] {
|
||||
incrementIP(current)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(targetIP string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
if scanTCPPort(targetIP, port, timeout) {
|
||||
mu.Lock()
|
||||
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
|
||||
mu.Unlock()
|
||||
}
|
||||
}(currentIP)
|
||||
|
||||
incrementIP(current)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
sort.Slice(devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
|
||||
})
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
type ipRange struct {
|
||||
Start net.IP
|
||||
End net.IP
|
||||
}
|
||||
|
||||
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
|
||||
parseIP := net.ParseIP(ip)
|
||||
if parseIP == nil {
|
||||
return nil, fmt.Errorf("无效的IP: %s", ip)
|
||||
}
|
||||
|
||||
mask := net.IPMask(net.ParseIP(netmask).To4())
|
||||
if mask == nil {
|
||||
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
|
||||
}
|
||||
|
||||
network := &net.IPNet{
|
||||
IP: parseIP.Mask(mask),
|
||||
Mask: mask,
|
||||
}
|
||||
broadcast := make(net.IP, len(network.IP))
|
||||
copy(broadcast, network.IP)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
broadcast[i] |= ^mask[i]
|
||||
}
|
||||
|
||||
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
|
||||
}
|
||||
|
||||
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
|
||||
if port <= 0 || port > 65535 {
|
||||
return false
|
||||
}
|
||||
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
|
||||
conn, err := net.DialTimeout("tcp", addr, timeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func incrementIP(ip net.IP) {
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
ip[i]++
|
||||
if ip[i] > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ipv4ToUint32(value string) uint32 {
|
||||
parsed := net.ParseIP(value)
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
ip := parsed.To4()
|
||||
if ip == nil {
|
||||
return 0
|
||||
}
|
||||
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
||||
}
|
||||
73
internal/webdevice/service_test.go
Normal file
73
internal/webdevice/service_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsPrivateIPv4Literal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]bool{
|
||||
"192.168.1.10": true,
|
||||
"10.0.0.8": true,
|
||||
"172.16.5.2": true,
|
||||
"127.0.0.1": false,
|
||||
"8.8.8.8": false,
|
||||
"0.0.0.0": false,
|
||||
"::1": false,
|
||||
"bad-ip": false,
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
if got := IsPrivateIPv4Literal(input); got != want {
|
||||
t.Fatalf("IsPrivateIPv4Literal(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebDeviceForwardPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port, ok := WebDeviceForwardPort("192.168.1.124")
|
||||
if !ok {
|
||||
t.Fatal("WebDeviceForwardPort() ok = false")
|
||||
}
|
||||
if port != 31124 {
|
||||
t.Fatalf("port = %d, want 31124", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanBuildsDirectURLAndAllowList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService()
|
||||
svc.interfaceGetter = func() ([]InterfaceInfo, error) {
|
||||
return []InterfaceInfo{{
|
||||
Name: "eth0",
|
||||
IP: "10.8.0.14",
|
||||
Netmask: "255.255.255.0",
|
||||
}}, nil
|
||||
}
|
||||
svc.tcpScanner = func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
return []TCPDevice{{IP: "192.168.1.124", Port: 80}}, nil
|
||||
}
|
||||
svc.newForwarder = func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
return &webDeviceForwarder{ip: ip, port: port, targetAddress: targetAddress}, nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://10.8.0.14:13000/api/web-devices/scan", nil)
|
||||
result, err := svc.Scan(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan() error = %v", err)
|
||||
}
|
||||
if result.Count != 1 {
|
||||
t.Fatalf("result.Count = %d", result.Count)
|
||||
}
|
||||
if !svc.IsAllowed("192.168.1.124") {
|
||||
t.Fatal("expected IP to be allowed after scan")
|
||||
}
|
||||
if result.Devices[0].DirectURL != "http://10.8.0.14:31124/" {
|
||||
t.Fatalf("DirectURL = %q", result.Devices[0].DirectURL)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user