fix: proxy web device resources transparently

This commit is contained in:
Yoilun
2026-05-15 11:12:27 +08:00
parent bd49486304
commit 7498960ba3
4 changed files with 146 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
package webdevice package webdevice
import ( import (
"context"
"errors" "errors"
"io" "io"
"log" "log"
@@ -18,6 +19,13 @@ var (
ErrInvalidProxyURL = errors.New("invalid proxy target") ErrInvalidProxyURL = errors.New("invalid proxy target")
) )
const (
webDeviceProxyConcurrency = 1
webDeviceProxyAttempts = 6
webDeviceProxyRetryDelay = 150 * time.Millisecond
webDeviceProxyQueueWait = 8 * time.Second
)
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error { func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
if !IsPrivateIPv4Literal(targetIP) { if !IsPrivateIPv4Literal(targetIP) {
return ErrInvalidTargetIP return ErrInvalidTargetIP
@@ -36,6 +44,7 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr
proxyPath = "/" proxyPath = "/"
} }
rawQuery := r.URL.RawQuery rawQuery := r.URL.RawQuery
upgradeRequest := isUpgradeRequest(r.Header)
proxy := &httputil.ReverseProxy{ proxy := &httputil.ReverseProxy{
Transport: retryTransport{ Transport: retryTransport{
@@ -43,8 +52,10 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr
DisableKeepAlives: true, DisableKeepAlives: true,
ResponseHeaderTimeout: 5 * time.Second, ResponseHeaderTimeout: 5 * time.Second,
}, },
attempts: 3, limiter: s.proxyLimiter(targetIP),
delay: 80 * time.Millisecond, attempts: webDeviceProxyAttempts,
delay: webDeviceProxyRetryDelay,
acquireTimeout: webDeviceProxyQueueWait,
}, },
Director: func(req *http.Request) { Director: func(req *http.Request) {
req.URL.Scheme = targetURL.Scheme req.URL.Scheme = targetURL.Scheme
@@ -54,7 +65,7 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr
req.URL.RawQuery = rawQuery req.URL.RawQuery = rawQuery
req.Host = targetURL.Host req.Host = targetURL.Host
req.Header = sanitizeProxyRequestHeader(req.Header, req.URL.Path) req.Header = sanitizeProxyRequestHeader(req.Header, req.URL.Path)
req.Close = true req.Close = !upgradeRequest
}, },
ModifyResponse: func(resp *http.Response) error { ModifyResponse: func(resp *http.Response) error {
proxyPrefix := "/proxy/web/" + targetIP proxyPrefix := "/proxy/web/" + targetIP
@@ -98,9 +109,11 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr
} }
type retryTransport struct { type retryTransport struct {
base http.RoundTripper base http.RoundTripper
attempts int limiter chan struct{}
delay time.Duration attempts int
delay time.Duration
acquireTimeout time.Duration
} }
func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -119,10 +132,14 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if attempt > 0 { if attempt > 0 {
nextReq = req.Clone(req.Context()) nextReq = req.Clone(req.Context())
nextReq.Header = req.Header.Clone() nextReq.Header = req.Header.Clone()
nextReq.Close = true nextReq.Close = req.Close
} }
if err := acquireProxySlot(req.Context(), t.limiter, t.acquireTimeout); err != nil {
return nil, err
}
resp, err := base.RoundTrip(nextReq) resp, err := base.RoundTrip(nextReq)
releaseProxySlot(t.limiter)
if err == nil { if err == nil {
return resp, nil return resp, nil
} }
@@ -130,11 +147,43 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 { if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 {
return nil, err return nil, err
} }
time.Sleep(t.delay) time.Sleep(t.delay * time.Duration(attempt+1))
} }
return nil, lastErr return nil, lastErr
} }
func acquireProxySlot(ctx context.Context, limiter chan struct{}, timeout time.Duration) error {
if limiter == nil {
return nil
}
if timeout <= 0 {
select {
case limiter <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case limiter <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return context.DeadlineExceeded
}
}
func releaseProxySlot(limiter chan struct{}) {
if limiter == nil {
return
}
<-limiter
}
func shouldRetryProxyRequest(req *http.Request, err error) bool { func shouldRetryProxyRequest(req *http.Request, err error) bool {
if req == nil || err == nil { if req == nil || err == nil {
return false return false
@@ -151,19 +200,25 @@ func shouldRetryProxyRequest(req *http.Request, err error) bool {
} }
func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.Header { func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.Header {
upgradeRequest := isUpgradeRequest(source)
header := source.Clone() header := source.Clone()
for key := range header { for key := range header {
if isProxyManagedHeader(key) { if isProxyManagedHeader(key, upgradeRequest) {
header.Del(key) header.Del(key)
} }
} }
header.Del("Accept-Encoding")
userAgent := strings.TrimSpace(source.Get("User-Agent")) userAgent := strings.TrimSpace(source.Get("User-Agent"))
if userAgent == "" { if userAgent == "" {
userAgent = "Mozilla/5.0" userAgent = "Mozilla/5.0"
} }
header.Set("User-Agent", userAgent) header.Set("User-Agent", userAgent)
header.Set("Connection", "close") if upgradeRequest {
header.Set("Connection", "Upgrade")
} else {
header.Set("Connection", "close")
}
if !isLoginPagePath(upstreamPath) { if !isLoginPagePath(upstreamPath) {
return header return header
@@ -173,13 +228,12 @@ func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.He
return header return header
} }
func isProxyManagedHeader(key string) bool { func isProxyManagedHeader(key string, upgradeRequest bool) bool {
switch http.CanonicalHeaderKey(key) { switch http.CanonicalHeaderKey(key) {
case "Connection", case "Connection",
"Proxy-Connection", "Proxy-Connection",
"Keep-Alive", "Keep-Alive",
"Transfer-Encoding", "Transfer-Encoding",
"Upgrade",
"Te", "Te",
"Trailer", "Trailer",
"Proxy-Authenticate", "Proxy-Authenticate",
@@ -190,11 +244,21 @@ func isProxyManagedHeader(key string) bool {
"X-Forwarded-Proto", "X-Forwarded-Proto",
"X-Real-Ip": "X-Real-Ip":
return true return true
case "Upgrade":
return !upgradeRequest
default: default:
return false return false
} }
} }
func isUpgradeRequest(header http.Header) bool {
if header == nil {
return false
}
connection := strings.ToLower(header.Get("Connection"))
return header.Get("Upgrade") != "" && strings.Contains(connection, "upgrade")
}
func isLoginPagePath(path string) bool { func isLoginPagePath(path string) bool {
path = strings.ToLower(path) path = strings.ToLower(path)
return strings.HasSuffix(path, "/doc/page/login.asp") || path == "/doc/page/login.asp" return strings.HasSuffix(path, "/doc/page/login.asp") || path == "/doc/page/login.asp"

View File

@@ -46,6 +46,24 @@ func TestRewriteText(t *testing.T) {
} }
} }
func TestRewriteTextRewritesScriptAbsoluteURLs(t *testing.T) {
t.Parallel()
targetURL, _ := url.Parse("http://192.168.1.124:80")
body := `window.location.href="/doc/page/login.asp";fetch("/ISAPI/System/time")`
got := RewriteText(body, "/proxy/web/192.168.1.124", targetURL, "application/javascript")
if !strings.Contains(got, `"/proxy/web/192.168.1.124/doc/page/login.asp"`) {
t.Fatalf("script missing rewritten login URL: %s", got)
}
if !strings.Contains(got, `"/proxy/web/192.168.1.124/ISAPI/System/time"`) {
t.Fatalf("script missing rewritten API URL: %s", got)
}
if strings.Contains(got, "data-web-proxy-runtime") {
t.Fatalf("script should not receive HTML runtime: %s", got)
}
}
func TestProxyHTTPRejectsUnscannedIP(t *testing.T) { func TestProxyHTTPRejectsUnscannedIP(t *testing.T) {
t.Parallel() t.Parallel()
@@ -137,6 +155,7 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) {
source.Set("Referer", "http://10.8.0.18:13000/proxy/web/192.168.0.108/") source.Set("Referer", "http://10.8.0.18:13000/proxy/web/192.168.0.108/")
source.Set("Sessiontag", "abc123") source.Set("Sessiontag", "abc123")
source.Set("If-Modified-Since", "0") source.Set("If-Modified-Since", "0")
source.Set("Accept-Encoding", "gzip, deflate")
source.Set("X-Forwarded-For", "10.8.0.1") source.Set("X-Forwarded-For", "10.8.0.1")
loginHeader := sanitizeProxyRequestHeader(source, "/doc/page/login.asp") loginHeader := sanitizeProxyRequestHeader(source, "/doc/page/login.asp")
@@ -149,6 +168,9 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) {
if got := loginHeader.Get("X-Forwarded-For"); got != "" { if got := loginHeader.Get("X-Forwarded-For"); got != "" {
t.Fatalf("login X-Forwarded-For = %q, want empty", got) t.Fatalf("login X-Forwarded-For = %q, want empty", got)
} }
if got := loginHeader.Get("Accept-Encoding"); got != "" {
t.Fatalf("login Accept-Encoding = %q, want empty", got)
}
if got := loginHeader.Get("Sessiontag"); got != "abc123" { if got := loginHeader.Get("Sessiontag"); got != "abc123" {
t.Fatalf("login Sessiontag = %q, want abc123", got) t.Fatalf("login Sessiontag = %q, want abc123", got)
} }
@@ -164,3 +186,27 @@ func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) {
t.Fatalf("api If-Modified-Since = %q, want 0", got) t.Fatalf("api If-Modified-Since = %q, want 0", got)
} }
} }
func TestSanitizeProxyRequestHeaderPreservesWebSocketUpgrade(t *testing.T) {
source := http.Header{}
source.Set("User-Agent", "browser")
source.Set("Connection", "keep-alive, Upgrade")
source.Set("Upgrade", "websocket")
source.Set("Sec-Websocket-Key", "abc")
source.Set("Sec-Websocket-Version", "13")
source.Set("Accept-Encoding", "gzip")
header := sanitizeProxyRequestHeader(source, "/webSocket")
if got := header.Get("Connection"); got != "Upgrade" {
t.Fatalf("Connection = %q, want Upgrade", got)
}
if got := header.Get("Upgrade"); got != "websocket" {
t.Fatalf("Upgrade = %q, want websocket", got)
}
if got := header.Get("Sec-Websocket-Key"); got != "abc" {
t.Fatalf("Sec-Websocket-Key = %q, want abc", got)
}
if got := header.Get("Accept-Encoding"); got != "" {
t.Fatalf("Accept-Encoding = %q, want empty", got)
}
}

View File

@@ -76,13 +76,16 @@ var (
func ShouldRewriteBody(contentType string) bool { func ShouldRewriteBody(contentType string) bool {
contentType = strings.ToLower(contentType) contentType = strings.ToLower(contentType)
return strings.Contains(contentType, "text/html") || return strings.Contains(contentType, "text/html") ||
strings.Contains(contentType, "text/css") strings.Contains(contentType, "text/css") ||
strings.Contains(contentType, "javascript")
} }
func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string { func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string {
contentType = strings.ToLower(contentType) contentType = strings.ToLower(contentType)
isHTML := strings.Contains(contentType, "text/html")
isScript := strings.Contains(contentType, "javascript")
if strings.Contains(contentType, "text/html") { if isHTML {
body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string { body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyQuotedAttrPattern.FindStringSubmatch(match) parts := webProxyQuotedAttrPattern.FindStringSubmatch(match)
if len(parts) != 4 { if len(parts) != 4 {
@@ -100,7 +103,9 @@ func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType strin
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL) rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
return strings.Replace(match, parts[2], rewritten, 1) return strings.Replace(match, parts[2], rewritten, 1)
}) })
}
if isHTML || isScript {
body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string { body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyQuotedURLPattern.FindStringSubmatch(match) parts := webProxyQuotedURLPattern.FindStringSubmatch(match)
if len(parts) != 3 { if len(parts) != 3 {
@@ -120,7 +125,7 @@ func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType strin
return "url(" + parts[1] + rewritten + parts[1] + ")" return "url(" + parts[1] + rewritten + parts[1] + ")"
}) })
if strings.Contains(contentType, "text/html") { if isHTML {
body = injectRuntime(body, proxyPrefix) body = injectRuntime(body, proxyPrefix)
} }

View File

@@ -61,6 +61,7 @@ type Service struct {
mu sync.RWMutex mu sync.RWMutex
allowed map[string]time.Time allowed map[string]time.Time
forwarders map[string]*webDeviceForwarder forwarders map[string]*webDeviceForwarder
proxyLimiters map[string]chan struct{}
interfaceGetter InterfaceGetter interfaceGetter InterfaceGetter
hostLANGetter InterfaceGetter hostLANGetter InterfaceGetter
tcpScanner TCPScanner tcpScanner TCPScanner
@@ -73,6 +74,7 @@ func NewService() *Service {
return &Service{ return &Service{
allowed: make(map[string]time.Time), allowed: make(map[string]time.Time),
forwarders: make(map[string]*webDeviceForwarder), forwarders: make(map[string]*webDeviceForwarder),
proxyLimiters: make(map[string]chan struct{}),
interfaceGetter: defaultInterfaceGetter, interfaceGetter: defaultInterfaceGetter,
hostLANGetter: defaultHostLANInterfaceGetter, hostLANGetter: defaultHostLANInterfaceGetter,
tcpScanner: scanTCP, tcpScanner: scanTCP,
@@ -256,6 +258,20 @@ func (s *Service) ProxyTargetURL(ip string) string {
return s.proxyTarget(ip) return s.proxyTarget(ip)
} }
func (s *Service) proxyLimiter(ip string) chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.proxyLimiters == nil {
s.proxyLimiters = make(map[string]chan struct{})
}
limiter := s.proxyLimiters[ip]
if limiter == nil {
limiter = make(chan struct{}, webDeviceProxyConcurrency)
s.proxyLimiters[ip] = limiter
}
return limiter
}
func IsPrivateIPv4Literal(value string) bool { func IsPrivateIPv4Literal(value string) bool {
ip := net.ParseIP(value) ip := net.ParseIP(value)
if ip == nil { if ip == nil {