Deploy managed portal with Docker
This commit is contained in:
@@ -18,13 +18,19 @@ type HTTPDoer interface {
|
||||
|
||||
type RemoteClient struct {
|
||||
httpClient HTTPDoer
|
||||
attempts int
|
||||
retryDelay time.Duration
|
||||
}
|
||||
|
||||
func NewRemoteClient(client HTTPDoer) *RemoteClient {
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: 5 * time.Second}
|
||||
}
|
||||
return &RemoteClient{httpClient: client}
|
||||
return &RemoteClient{
|
||||
httpClient: client,
|
||||
attempts: 5,
|
||||
retryDelay: 200 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteClient) GetConfig(ctx context.Context, service Service) (map[string]any, error) {
|
||||
@@ -36,17 +42,18 @@ func (c *RemoteClient) GetConfig(ctx context.Context, service Service) (map[stri
|
||||
}
|
||||
|
||||
func (c *RemoteClient) UpdateRTSP(ctx context.Context, service Service, rtsp string) (map[string]any, error) {
|
||||
body := strings.NewReader(fmt.Sprintf(`{"rtsp_url":%q}`, rtsp))
|
||||
req, err := c.newRequest(ctx, http.MethodPut, service, "/api/manage/config", body)
|
||||
resp, err := c.doRequest(ctx, func() (*http.Request, error) {
|
||||
body := strings.NewReader(fmt.Sprintf(`{"rtsp_url":%q}`, rtsp))
|
||||
req, err := c.newRequest(ctx, http.MethodPut, service, "/api/manage/config", body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -55,7 +62,7 @@ func (c *RemoteClient) UpdateRTSP(ctx context.Context, service Service, rtsp str
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, fmt.Errorf("decode response %s: %w", req.URL.String(), err)
|
||||
return nil, fmt.Errorf("decode response %s: %w", responseURL(resp, service.APIBaseURL+"/api/manage/config"), err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
@@ -93,14 +100,13 @@ func (c *RemoteClient) PreviewFile(ctx context.Context, service Service, path st
|
||||
func (c *RemoteClient) Download(ctx context.Context, service Service, path string) (*http.Response, error) {
|
||||
query := url.Values{}
|
||||
query.Set("path", path)
|
||||
req, err := c.newRequest(ctx, http.MethodGet, service, "/api/manage/files/download?"+query.Encode(), nil)
|
||||
|
||||
resp, err := c.doRequest(ctx, func() (*http.Request, error) {
|
||||
return c.newRequest(ctx, http.MethodGet, service, "/api/manage/files/download?"+query.Encode(), nil)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
return nil, decodeAPIError(resp)
|
||||
@@ -109,26 +115,79 @@ func (c *RemoteClient) Download(ctx context.Context, service Service, path strin
|
||||
}
|
||||
|
||||
func (c *RemoteClient) getJSON(ctx context.Context, service Service, endpoint string, target any) error {
|
||||
req, err := c.newRequest(ctx, http.MethodGet, service, endpoint, nil)
|
||||
resp, err := c.doRequest(ctx, func() (*http.Request, error) {
|
||||
return c.newRequest(ctx, http.MethodGet, service, endpoint, nil)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request %s %s: %w", req.Method, req.URL.String(), err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return decodeAPIError(resp)
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(target); err != nil {
|
||||
return fmt.Errorf("decode response %s: %w", req.URL.String(), err)
|
||||
return fmt.Errorf("decode response %s: %w", responseURL(resp, strings.TrimRight(service.APIBaseURL, "/")+endpoint), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseURL(resp *http.Response, fallback string) string {
|
||||
if resp != nil && resp.Request != nil && resp.Request.URL != nil {
|
||||
return resp.Request.URL.String()
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (c *RemoteClient) doRequest(ctx context.Context, newReq func() (*http.Request, error)) (*http.Response, error) {
|
||||
attempts := c.attempts
|
||||
if attempts <= 0 {
|
||||
attempts = 1
|
||||
}
|
||||
delay := c.retryDelay
|
||||
if delay <= 0 {
|
||||
delay = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
var lastReq *http.Request
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
req, err := newReq()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastReq = req
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
if attempt == attempts {
|
||||
break
|
||||
}
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
delay *= 2
|
||||
}
|
||||
|
||||
if lastReq != nil {
|
||||
return nil, fmt.Errorf("request %s %s: %w", lastReq.Method, lastReq.URL.String(), lastErr)
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, delay time.Duration) error {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteClient) newRequest(ctx context.Context, method string, service Service, endpoint string, body io.Reader) (*http.Request, error) {
|
||||
base := strings.TrimRight(service.APIBaseURL, "/")
|
||||
req, err := http.NewRequestWithContext(ctx, method, base+endpoint, body)
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type roundTripFunc func(req *http.Request) (*http.Response, error)
|
||||
@@ -139,3 +141,37 @@ func TestRemoteClientRoundTrip(t *testing.T) {
|
||||
t.Fatalf("Content-Disposition = %q", resp.Header.Get("Content-Disposition"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteClientRetriesTransientRequestErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attempts := 0
|
||||
client := NewRemoteClient(&http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return nil, errors.New("connect: connection refused")
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(`{"config_path":"/srv/store/config/local.yaml"}`)),
|
||||
}, nil
|
||||
})})
|
||||
client.retryDelay = time.Millisecond
|
||||
|
||||
service := Service{
|
||||
ID: "store_dwell_alert",
|
||||
APIBaseURL: "http://managed.invalid/store",
|
||||
}
|
||||
|
||||
configPayload, err := client.GetConfig(context.Background(), service)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConfig() error = %v", err)
|
||||
}
|
||||
if attempts != 3 {
|
||||
t.Fatalf("attempts = %d", attempts)
|
||||
}
|
||||
if got := configPayload["config_path"]; got != "/srv/store/config/local.yaml" {
|
||||
t.Fatalf("config_path = %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,14 @@ type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) (
|
||||
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
|
||||
type ProxyTargetResolver func(ip string) string
|
||||
|
||||
const (
|
||||
webDeviceScanConcurrency = 128
|
||||
webDeviceScanTimeout = 1500 * time.Millisecond
|
||||
webDeviceScanAttempts = 2
|
||||
webDeviceScanRetryDelay = 100 * time.Millisecond
|
||||
maxWebDeviceScanAddrs = 256
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
mu sync.RWMutex
|
||||
allowed map[string]time.Time
|
||||
@@ -74,6 +82,8 @@ func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scheme, host := requestBase(r)
|
||||
interfaces = appendRequestHostInterface(interfaces, host)
|
||||
|
||||
if len(interfaces) == 0 {
|
||||
return &ScanResult{
|
||||
@@ -94,7 +104,6 @@ func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
scheme, host := requestBase(r)
|
||||
for _, iface := range interfaces {
|
||||
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
|
||||
if scanErr != nil {
|
||||
@@ -135,6 +144,22 @@ func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func appendRequestHostInterface(interfaces []InterfaceInfo, host string) []InterfaceInfo {
|
||||
if !IsPrivateIPv4Literal(host) {
|
||||
return interfaces
|
||||
}
|
||||
for _, iface := range interfaces {
|
||||
if iface.IP == host {
|
||||
return interfaces
|
||||
}
|
||||
}
|
||||
return append(interfaces, InterfaceInfo{
|
||||
Name: "request-host",
|
||||
IP: host,
|
||||
Netmask: "255.255.255.0",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) allowIP(ip string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -390,8 +415,7 @@ func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([
|
||||
var devices []TCPDevice
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 20)
|
||||
timeout := 2 * time.Second
|
||||
semaphore := make(chan struct{}, webDeviceScanConcurrency)
|
||||
|
||||
current := make(net.IP, len(ipRange.Start))
|
||||
copy(current, ipRange.Start)
|
||||
@@ -414,7 +438,7 @@ func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
if scanTCPPort(targetIP, port, timeout) {
|
||||
if scanTCPPortWithRetry(targetIP, port) {
|
||||
mu.Lock()
|
||||
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
|
||||
mu.Unlock()
|
||||
@@ -457,7 +481,31 @@ func calculateIPRange(ip string, netmask string) (*ipRange, error) {
|
||||
broadcast[i] |= ^mask[i]
|
||||
}
|
||||
|
||||
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
|
||||
result := &ipRange{Start: network.IP.To4(), End: broadcast}
|
||||
if ipToUint32(result.End)-ipToUint32(result.Start)+1 <= maxWebDeviceScanAddrs {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
local24Mask := net.CIDRMask(24, 32)
|
||||
local24Network := parseIP.To4().Mask(local24Mask)
|
||||
local24Broadcast := make(net.IP, len(local24Network))
|
||||
copy(local24Broadcast, local24Network)
|
||||
for i := 0; i < len(local24Mask); i++ {
|
||||
local24Broadcast[i] |= ^local24Mask[i]
|
||||
}
|
||||
return &ipRange{Start: local24Network, End: local24Broadcast}, nil
|
||||
}
|
||||
|
||||
func scanTCPPortWithRetry(ip string, port int) bool {
|
||||
for attempt := 0; attempt < webDeviceScanAttempts; attempt++ {
|
||||
if scanTCPPort(ip, port, webDeviceScanTimeout) {
|
||||
return true
|
||||
}
|
||||
if attempt < webDeviceScanAttempts-1 {
|
||||
time.Sleep(webDeviceScanRetryDelay)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
|
||||
@@ -487,6 +535,10 @@ func ipv4ToUint32(value string) uint32 {
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
return ipToUint32(parsed)
|
||||
}
|
||||
|
||||
func ipToUint32(parsed net.IP) uint32 {
|
||||
ip := parsed.To4()
|
||||
if ip == nil {
|
||||
return 0
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
@@ -71,3 +72,73 @@ func TestScanBuildsDirectURLAndAllowList(t *testing.T) {
|
||||
t.Fatalf("DirectURL = %q", result.Devices[0].DirectURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanIncludesPrivateRequestHostSubnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService()
|
||||
svc.interfaceGetter = func() ([]InterfaceInfo, error) {
|
||||
return []InterfaceInfo{{
|
||||
Name: "eth0",
|
||||
IP: "172.18.0.2",
|
||||
Netmask: "255.255.0.0",
|
||||
}}, nil
|
||||
}
|
||||
svc.tcpScanner = func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
if ip == "192.168.5.189" {
|
||||
if netmask != "255.255.255.0" {
|
||||
t.Fatalf("request host netmask = %q, want 255.255.255.0", netmask)
|
||||
}
|
||||
if !excludeIPs["192.168.5.189"] {
|
||||
t.Fatal("expected request host IP to be excluded")
|
||||
}
|
||||
return []TCPDevice{{IP: "192.168.5.124", Port: 80}}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
svc.newForwarder = func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
return &webDeviceForwarder{ip: ip, port: port, targetAddress: targetAddress}, nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://192.168.5.189:13000/api/web-devices/scan", nil)
|
||||
result, err := svc.Scan(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan() error = %v", err)
|
||||
}
|
||||
if result.Count != 1 {
|
||||
t.Fatalf("result.Count = %d, want 1", result.Count)
|
||||
}
|
||||
if result.Devices[0].Interface != "request-host" {
|
||||
t.Fatalf("Interface = %q, want request-host", result.Devices[0].Interface)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateIPRangeCapsLargeSubnetToLocal24(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipRange, err := calculateIPRange("192.168.5.189", "255.255.0.0")
|
||||
if err != nil {
|
||||
t.Fatalf("calculateIPRange() error = %v", err)
|
||||
}
|
||||
if got := ipRange.Start.String(); got != "192.168.5.0" {
|
||||
t.Fatalf("Start = %q, want 192.168.5.0", got)
|
||||
}
|
||||
if got := ipRange.End.String(); got != "192.168.5.255" {
|
||||
t.Fatalf("End = %q, want 192.168.5.255", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateIPRangeKeepsSmallSubnet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipRange, err := calculateIPRange("192.168.5.189", "255.255.255.240")
|
||||
if err != nil {
|
||||
t.Fatalf("calculateIPRange() error = %v", err)
|
||||
}
|
||||
if got := ipRange.Start.String(); got != "192.168.5.176" {
|
||||
t.Fatalf("Start = %q, want 192.168.5.176", got)
|
||||
}
|
||||
if got := ipRange.End.String(); got != "192.168.5.191" {
|
||||
t.Fatalf("End = %q, want 192.168.5.191", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user