Files
managed-portal/internal/webdevice/proxy.go
2026-05-15 11:18:54 +08:00

309 lines
7.3 KiB
Go

package webdevice
import (
"context"
"errors"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
var (
ErrInvalidTargetIP = errors.New("invalid target ip")
ErrTargetNotAllowed = errors.New("target ip not allowed")
ErrInvalidProxyURL = errors.New("invalid proxy target")
)
const (
webDeviceProxyConcurrency = 1
webDeviceProxyAttempts = 6
webDeviceProxyRetryDelay = 150 * time.Millisecond
webDeviceProxyQueueWait = 8 * time.Second
)
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
if !IsPrivateIPv4Literal(targetIP) {
return ErrInvalidTargetIP
}
if !s.IsAllowed(targetIP) {
return ErrTargetNotAllowed
}
rawTarget := s.ProxyTargetURL(targetIP)
targetURL, err := url.Parse(rawTarget)
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
return ErrInvalidProxyURL
}
if proxyPath == "" {
proxyPath = "/"
}
rawQuery := r.URL.RawQuery
upgradeRequest := isUpgradeRequest(r.Header)
proxy := &httputil.ReverseProxy{
Transport: retryTransport{
base: &http.Transport{
DisableKeepAlives: true,
DisableCompression: true,
ResponseHeaderTimeout: 5 * time.Second,
},
limiter: s.proxyLimiter(targetIP),
attempts: webDeviceProxyAttempts,
delay: webDeviceProxyRetryDelay,
acquireTimeout: webDeviceProxyQueueWait,
},
Director: func(req *http.Request) {
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = JoinProxyTargetPath(targetURL.Path, proxyPath)
req.URL.RawPath = ""
req.URL.RawQuery = rawQuery
req.Host = targetURL.Host
req.Header = sanitizeProxyRequestHeader(req.Header, req.URL.Path)
req.Close = !upgradeRequest
},
ModifyResponse: func(resp *http.Response) error {
proxyPrefix := "/proxy/web/" + targetIP
if location := resp.Header.Get("Location"); location != "" {
resp.Header.Set("Location", RewriteLocation(targetIP, targetURL, location))
}
if cookies := resp.Header.Values("Set-Cookie"); len(cookies) > 0 {
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
resp.Header.Add("Set-Cookie", RewriteSetCookie(cookie, proxyPrefix))
}
}
contentType := resp.Header.Get("Content-Type")
if ShouldRewriteBody(contentType) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
_ = resp.Body.Close()
rewritten := RewriteText(string(body), proxyPrefix, targetURL, contentType)
rewrittenBytes := []byte(rewritten)
resp.Body = io.NopCloser(strings.NewReader(rewritten))
resp.ContentLength = int64(len(rewrittenBytes))
resp.Header.Set("Content-Length", strconv.Itoa(len(rewrittenBytes)))
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-MD5")
resp.Header.Del("Etag")
}
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("web device proxy failed target=%s path=%s error=%v", targetURL.String(), r.URL.RequestURI(), err)
http.Error(w, "代理访问失败: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(closeNotifyWriter{ResponseWriter: w}, r)
return nil
}
type retryTransport struct {
base http.RoundTripper
limiter chan struct{}
attempts int
delay time.Duration
acquireTimeout time.Duration
}
func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
base := t.base
if base == nil {
base = http.DefaultTransport
}
attempts := t.attempts
if attempts < 1 {
attempts = 1
}
var lastErr error
for attempt := 0; attempt < attempts; attempt++ {
nextReq := req
if attempt > 0 {
nextReq = req.Clone(req.Context())
nextReq.Header = req.Header.Clone()
nextReq.Close = req.Close
}
if err := acquireProxySlot(req.Context(), t.limiter, t.acquireTimeout); err != nil {
return nil, err
}
resp, err := base.RoundTrip(nextReq)
if err == nil {
if resp == nil || resp.Body == nil {
releaseProxySlot(t.limiter)
return resp, nil
}
resp.Body = &releaseOnCloseReadCloser{
ReadCloser: resp.Body,
release: func() {
releaseProxySlot(t.limiter)
},
}
return resp, nil
}
releaseProxySlot(t.limiter)
lastErr = err
if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 {
return nil, err
}
time.Sleep(t.delay * time.Duration(attempt+1))
}
return nil, lastErr
}
type releaseOnCloseReadCloser struct {
io.ReadCloser
once sync.Once
release func()
}
func (r *releaseOnCloseReadCloser) Close() error {
err := r.ReadCloser.Close()
r.once.Do(func() {
if r.release != nil {
r.release()
}
})
return err
}
func acquireProxySlot(ctx context.Context, limiter chan struct{}, timeout time.Duration) error {
if limiter == nil {
return nil
}
if timeout <= 0 {
select {
case limiter <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case limiter <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return context.DeadlineExceeded
}
}
func releaseProxySlot(limiter chan struct{}) {
if limiter == nil {
return
}
<-limiter
}
func shouldRetryProxyRequest(req *http.Request, err error) bool {
if req == nil || err == nil {
return false
}
switch req.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
default:
return false
}
message := strings.ToLower(err.Error())
return strings.Contains(message, "eof") ||
strings.Contains(message, "connection reset") ||
strings.Contains(message, "broken pipe")
}
func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.Header {
upgradeRequest := isUpgradeRequest(source)
header := source.Clone()
for key := range header {
if isProxyManagedHeader(key, upgradeRequest) {
header.Del(key)
}
}
header.Del("Accept-Encoding")
userAgent := strings.TrimSpace(source.Get("User-Agent"))
if userAgent == "" {
userAgent = "Mozilla/5.0"
}
header.Set("User-Agent", userAgent)
if upgradeRequest {
header.Set("Connection", "Upgrade")
} else {
header.Set("Connection", "close")
}
if !isLoginPagePath(upstreamPath) {
return header
}
header.Del("Cookie")
header.Del("Referer")
return header
}
func isProxyManagedHeader(key string, upgradeRequest bool) bool {
switch http.CanonicalHeaderKey(key) {
case "Connection",
"Proxy-Connection",
"Keep-Alive",
"Transfer-Encoding",
"Te",
"Trailer",
"Proxy-Authenticate",
"Proxy-Authorization",
"Forwarded",
"X-Forwarded-For",
"X-Forwarded-Host",
"X-Forwarded-Proto",
"X-Real-Ip":
return true
case "Upgrade":
return !upgradeRequest
default:
return false
}
}
func isUpgradeRequest(header http.Header) bool {
if header == nil {
return false
}
connection := strings.ToLower(header.Get("Connection"))
return header.Get("Upgrade") != "" && strings.Contains(connection, "upgrade")
}
func isLoginPagePath(path string) bool {
path = strings.ToLower(path)
return strings.HasSuffix(path, "/doc/page/login.asp") || path == "/doc/page/login.asp"
}
type closeNotifyWriter struct {
http.ResponseWriter
}
func (w closeNotifyWriter) CloseNotify() <-chan bool {
ch := make(chan bool, 1)
return ch
}
func (w closeNotifyWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}