diff --git a/internal/webdevice/proxy.go b/internal/webdevice/proxy.go index 53456d9..20981ed 100644 --- a/internal/webdevice/proxy.go +++ b/internal/webdevice/proxy.go @@ -1,6 +1,7 @@ package webdevice import ( + "context" "errors" "io" "log" @@ -18,6 +19,13 @@ var ( ErrInvalidProxyURL = errors.New("invalid proxy target") ) +const ( + webDeviceProxyConcurrency = 1 + webDeviceProxyAttempts = 6 + webDeviceProxyRetryDelay = 150 * time.Millisecond + webDeviceProxyQueueWait = 8 * time.Second +) + func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error { if !IsPrivateIPv4Literal(targetIP) { return ErrInvalidTargetIP @@ -36,6 +44,7 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr proxyPath = "/" } rawQuery := r.URL.RawQuery + upgradeRequest := isUpgradeRequest(r.Header) proxy := &httputil.ReverseProxy{ Transport: retryTransport{ @@ -43,8 +52,10 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr DisableKeepAlives: true, ResponseHeaderTimeout: 5 * time.Second, }, - attempts: 3, - delay: 80 * time.Millisecond, + limiter: s.proxyLimiter(targetIP), + attempts: webDeviceProxyAttempts, + delay: webDeviceProxyRetryDelay, + acquireTimeout: webDeviceProxyQueueWait, }, Director: func(req *http.Request) { req.URL.Scheme = targetURL.Scheme @@ -54,7 +65,7 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr req.URL.RawQuery = rawQuery req.Host = targetURL.Host req.Header = sanitizeProxyRequestHeader(req.Header, req.URL.Path) - req.Close = true + req.Close = !upgradeRequest }, ModifyResponse: func(resp *http.Response) error { proxyPrefix := "/proxy/web/" + targetIP @@ -98,9 +109,11 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr } type retryTransport struct { - base http.RoundTripper - attempts int - delay time.Duration + base http.RoundTripper + limiter chan struct{} + attempts int + delay time.Duration + acquireTimeout time.Duration } func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -119,10 +132,14 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { if attempt > 0 { nextReq = req.Clone(req.Context()) nextReq.Header = req.Header.Clone() - nextReq.Close = true + nextReq.Close = req.Close } + if err := acquireProxySlot(req.Context(), t.limiter, t.acquireTimeout); err != nil { + return nil, err + } resp, err := base.RoundTrip(nextReq) + releaseProxySlot(t.limiter) if err == nil { return resp, nil } @@ -130,11 +147,43 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 { return nil, err } - time.Sleep(t.delay) + time.Sleep(t.delay * time.Duration(attempt+1)) } return nil, lastErr } +func acquireProxySlot(ctx context.Context, limiter chan struct{}, timeout time.Duration) error { + if limiter == nil { + return nil + } + if timeout <= 0 { + select { + case limiter <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case limiter <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return context.DeadlineExceeded + } +} + +func releaseProxySlot(limiter chan struct{}) { + if limiter == nil { + return + } + <-limiter +} + func shouldRetryProxyRequest(req *http.Request, err error) bool { if req == nil || err == nil { return false @@ -151,19 +200,25 @@ func shouldRetryProxyRequest(req *http.Request, err error) bool { } func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.Header { + upgradeRequest := isUpgradeRequest(source) header := source.Clone() for key := range header { - if isProxyManagedHeader(key) { + if isProxyManagedHeader(key, upgradeRequest) { header.Del(key) } } + header.Del("Accept-Encoding") userAgent := strings.TrimSpace(source.Get("User-Agent")) if userAgent == "" { userAgent = "Mozilla/5.0" } header.Set("User-Agent", userAgent) - header.Set("Connection", "close") + if upgradeRequest { + header.Set("Connection", "Upgrade") + } else { + header.Set("Connection", "close") + } if !isLoginPagePath(upstreamPath) { return header @@ -173,13 +228,12 @@ func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.He return header } -func isProxyManagedHeader(key string) bool { +func isProxyManagedHeader(key string, upgradeRequest bool) bool { switch http.CanonicalHeaderKey(key) { case "Connection", "Proxy-Connection", "Keep-Alive", "Transfer-Encoding", - "Upgrade", "Te", "Trailer", "Proxy-Authenticate", @@ -190,11 +244,21 @@ func isProxyManagedHeader(key string) bool { "X-Forwarded-Proto", "X-Real-Ip": return true + case "Upgrade": + return !upgradeRequest default: return false } } +func isUpgradeRequest(header http.Header) bool { + if header == nil { + return false + } + connection := strings.ToLower(header.Get("Connection")) + return header.Get("Upgrade") != "" && strings.Contains(connection, "upgrade") +} + func isLoginPagePath(path string) bool { path = strings.ToLower(path) return strings.HasSuffix(path, "/doc/page/login.asp") || path == "/doc/page/login.asp" diff --git a/internal/webdevice/proxy_test.go b/internal/webdevice/proxy_test.go index f8954e3..51c9239 100644 --- a/internal/webdevice/proxy_test.go +++ b/internal/webdevice/proxy_test.go @@ -46,6 +46,24 @@ func TestRewriteText(t *testing.T) { } } +func TestRewriteTextRewritesScriptAbsoluteURLs(t *testing.T) { + t.Parallel() + + targetURL, _ := url.Parse("http://192.168.1.124:80") + body := `window.location.href="/doc/page/login.asp";fetch("/ISAPI/System/time")` + got := RewriteText(body, "/proxy/web/192.168.1.124", targetURL, "application/javascript") + + if !strings.Contains(got, `"/proxy/web/192.168.1.124/doc/page/login.asp"`) { + t.Fatalf("script missing rewritten login URL: %s", got) + } + if !strings.Contains(got, `"/proxy/web/192.168.1.124/ISAPI/System/time"`) { + t.Fatalf("script missing rewritten API URL: %s", got) + } + if strings.Contains(got, "data-web-proxy-runtime") { + t.Fatalf("script should not receive HTML runtime: %s", got) + } +} + func TestProxyHTTPRejectsUnscannedIP(t *testing.T) { t.Parallel() @@ -137,6 +155,7 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) { source.Set("Referer", "http://10.8.0.18:13000/proxy/web/192.168.0.108/") source.Set("Sessiontag", "abc123") source.Set("If-Modified-Since", "0") + source.Set("Accept-Encoding", "gzip, deflate") source.Set("X-Forwarded-For", "10.8.0.1") loginHeader := sanitizeProxyRequestHeader(source, "/doc/page/login.asp") @@ -149,6 +168,9 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) { if got := loginHeader.Get("X-Forwarded-For"); got != "" { t.Fatalf("login X-Forwarded-For = %q, want empty", got) } + if got := loginHeader.Get("Accept-Encoding"); got != "" { + t.Fatalf("login Accept-Encoding = %q, want empty", got) + } if got := loginHeader.Get("Sessiontag"); got != "abc123" { t.Fatalf("login Sessiontag = %q, want abc123", got) } @@ -164,3 +186,27 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) { t.Fatalf("api If-Modified-Since = %q, want 0", got) } } + +func TestSanitizeProxyRequestHeaderPreservesWebSocketUpgrade(t *testing.T) { + source := http.Header{} + source.Set("User-Agent", "browser") + source.Set("Connection", "keep-alive, Upgrade") + source.Set("Upgrade", "websocket") + source.Set("Sec-Websocket-Key", "abc") + source.Set("Sec-Websocket-Version", "13") + source.Set("Accept-Encoding", "gzip") + + header := sanitizeProxyRequestHeader(source, "/webSocket") + if got := header.Get("Connection"); got != "Upgrade" { + t.Fatalf("Connection = %q, want Upgrade", got) + } + if got := header.Get("Upgrade"); got != "websocket" { + t.Fatalf("Upgrade = %q, want websocket", got) + } + if got := header.Get("Sec-Websocket-Key"); got != "abc" { + t.Fatalf("Sec-Websocket-Key = %q, want abc", got) + } + if got := header.Get("Accept-Encoding"); got != "" { + t.Fatalf("Accept-Encoding = %q, want empty", got) + } +} diff --git a/internal/webdevice/rewrite.go b/internal/webdevice/rewrite.go index 95fabb2..a3aafbf 100644 --- a/internal/webdevice/rewrite.go +++ b/internal/webdevice/rewrite.go @@ -76,13 +76,16 @@ var ( func ShouldRewriteBody(contentType string) bool { contentType = strings.ToLower(contentType) return strings.Contains(contentType, "text/html") || - strings.Contains(contentType, "text/css") + strings.Contains(contentType, "text/css") || + strings.Contains(contentType, "javascript") } func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string { contentType = strings.ToLower(contentType) + isHTML := strings.Contains(contentType, "text/html") + isScript := strings.Contains(contentType, "javascript") - if strings.Contains(contentType, "text/html") { + if isHTML { body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string { parts := webProxyQuotedAttrPattern.FindStringSubmatch(match) if len(parts) != 4 { @@ -100,7 +103,9 @@ func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType strin rewritten := rewriteURL(parts[2], proxyPrefix, targetURL) return strings.Replace(match, parts[2], rewritten, 1) }) + } + if isHTML || isScript { body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string { parts := webProxyQuotedURLPattern.FindStringSubmatch(match) if len(parts) != 3 { @@ -120,7 +125,7 @@ func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType strin return "url(" + parts[1] + rewritten + parts[1] + ")" }) - if strings.Contains(contentType, "text/html") { + if isHTML { body = injectRuntime(body, proxyPrefix) } diff --git a/internal/webdevice/service.go b/internal/webdevice/service.go index 6428a01..331f960 100644 --- a/internal/webdevice/service.go +++ b/internal/webdevice/service.go @@ -61,6 +61,7 @@ type Service struct { mu sync.RWMutex allowed map[string]time.Time forwarders map[string]*webDeviceForwarder + proxyLimiters map[string]chan struct{} interfaceGetter InterfaceGetter hostLANGetter InterfaceGetter tcpScanner TCPScanner @@ -73,6 +74,7 @@ func NewService() *Service { return &Service{ allowed: make(map[string]time.Time), forwarders: make(map[string]*webDeviceForwarder), + proxyLimiters: make(map[string]chan struct{}), interfaceGetter: defaultInterfaceGetter, hostLANGetter: defaultHostLANInterfaceGetter, tcpScanner: scanTCP, @@ -256,6 +258,20 @@ func (s *Service) ProxyTargetURL(ip string) string { return s.proxyTarget(ip) } +func (s *Service) proxyLimiter(ip string) chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + if s.proxyLimiters == nil { + s.proxyLimiters = make(map[string]chan struct{}) + } + limiter := s.proxyLimiters[ip] + if limiter == nil { + limiter = make(chan struct{}, webDeviceProxyConcurrency) + s.proxyLimiters[ip] = limiter + } + return limiter +} + func IsPrivateIPv4Literal(value string) bool { ip := net.ParseIP(value) if ip == nil {