281 lines
6.7 KiB
Go
281 lines
6.7 KiB
Go
package webdevice
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidTargetIP = errors.New("invalid target ip")
|
|
ErrTargetNotAllowed = errors.New("target ip not allowed")
|
|
ErrInvalidProxyURL = errors.New("invalid proxy target")
|
|
)
|
|
|
|
const (
|
|
webDeviceProxyConcurrency = 1
|
|
webDeviceProxyAttempts = 6
|
|
webDeviceProxyRetryDelay = 150 * time.Millisecond
|
|
webDeviceProxyQueueWait = 8 * time.Second
|
|
)
|
|
|
|
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
|
|
if !IsPrivateIPv4Literal(targetIP) {
|
|
return ErrInvalidTargetIP
|
|
}
|
|
if !s.IsAllowed(targetIP) {
|
|
return ErrTargetNotAllowed
|
|
}
|
|
|
|
rawTarget := s.ProxyTargetURL(targetIP)
|
|
targetURL, err := url.Parse(rawTarget)
|
|
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
|
|
return ErrInvalidProxyURL
|
|
}
|
|
|
|
if proxyPath == "" {
|
|
proxyPath = "/"
|
|
}
|
|
rawQuery := r.URL.RawQuery
|
|
upgradeRequest := isUpgradeRequest(r.Header)
|
|
|
|
proxy := &httputil.ReverseProxy{
|
|
Transport: retryTransport{
|
|
base: &http.Transport{
|
|
DisableKeepAlives: true,
|
|
ResponseHeaderTimeout: 5 * time.Second,
|
|
},
|
|
limiter: s.proxyLimiter(targetIP),
|
|
attempts: webDeviceProxyAttempts,
|
|
delay: webDeviceProxyRetryDelay,
|
|
acquireTimeout: webDeviceProxyQueueWait,
|
|
},
|
|
Director: func(req *http.Request) {
|
|
req.URL.Scheme = targetURL.Scheme
|
|
req.URL.Host = targetURL.Host
|
|
req.URL.Path = JoinProxyTargetPath(targetURL.Path, proxyPath)
|
|
req.URL.RawPath = ""
|
|
req.URL.RawQuery = rawQuery
|
|
req.Host = targetURL.Host
|
|
req.Header = sanitizeProxyRequestHeader(req.Header, req.URL.Path)
|
|
req.Close = !upgradeRequest
|
|
},
|
|
ModifyResponse: func(resp *http.Response) error {
|
|
proxyPrefix := "/proxy/web/" + targetIP
|
|
if location := resp.Header.Get("Location"); location != "" {
|
|
resp.Header.Set("Location", RewriteLocation(targetIP, targetURL, location))
|
|
}
|
|
if cookies := resp.Header.Values("Set-Cookie"); len(cookies) > 0 {
|
|
resp.Header.Del("Set-Cookie")
|
|
for _, cookie := range cookies {
|
|
resp.Header.Add("Set-Cookie", RewriteSetCookie(cookie, proxyPrefix))
|
|
}
|
|
}
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if ShouldRewriteBody(contentType) {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
rewritten := RewriteText(string(body), proxyPrefix, targetURL, contentType)
|
|
rewrittenBytes := []byte(rewritten)
|
|
resp.Body = io.NopCloser(strings.NewReader(rewritten))
|
|
resp.ContentLength = int64(len(rewrittenBytes))
|
|
resp.Header.Set("Content-Length", strconv.Itoa(len(rewrittenBytes)))
|
|
resp.Header.Del("Content-Encoding")
|
|
resp.Header.Del("Content-MD5")
|
|
resp.Header.Del("Etag")
|
|
}
|
|
return nil
|
|
},
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
log.Printf("web device proxy failed target=%s path=%s error=%v", targetURL.String(), r.URL.RequestURI(), err)
|
|
http.Error(w, "代理访问失败: "+err.Error(), http.StatusBadGateway)
|
|
},
|
|
}
|
|
|
|
proxy.ServeHTTP(closeNotifyWriter{ResponseWriter: w}, r)
|
|
return nil
|
|
}
|
|
|
|
type retryTransport struct {
|
|
base http.RoundTripper
|
|
limiter chan struct{}
|
|
attempts int
|
|
delay time.Duration
|
|
acquireTimeout time.Duration
|
|
}
|
|
|
|
func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
base := t.base
|
|
if base == nil {
|
|
base = http.DefaultTransport
|
|
}
|
|
attempts := t.attempts
|
|
if attempts < 1 {
|
|
attempts = 1
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < attempts; attempt++ {
|
|
nextReq := req
|
|
if attempt > 0 {
|
|
nextReq = req.Clone(req.Context())
|
|
nextReq.Header = req.Header.Clone()
|
|
nextReq.Close = req.Close
|
|
}
|
|
|
|
if err := acquireProxySlot(req.Context(), t.limiter, t.acquireTimeout); err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := base.RoundTrip(nextReq)
|
|
releaseProxySlot(t.limiter)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = err
|
|
if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 {
|
|
return nil, err
|
|
}
|
|
time.Sleep(t.delay * time.Duration(attempt+1))
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
func acquireProxySlot(ctx context.Context, limiter chan struct{}, timeout time.Duration) error {
|
|
if limiter == nil {
|
|
return nil
|
|
}
|
|
if timeout <= 0 {
|
|
select {
|
|
case limiter <- struct{}{}:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
select {
|
|
case limiter <- struct{}{}:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
return context.DeadlineExceeded
|
|
}
|
|
}
|
|
|
|
func releaseProxySlot(limiter chan struct{}) {
|
|
if limiter == nil {
|
|
return
|
|
}
|
|
<-limiter
|
|
}
|
|
|
|
func shouldRetryProxyRequest(req *http.Request, err error) bool {
|
|
if req == nil || err == nil {
|
|
return false
|
|
}
|
|
switch req.Method {
|
|
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
|
default:
|
|
return false
|
|
}
|
|
message := strings.ToLower(err.Error())
|
|
return strings.Contains(message, "eof") ||
|
|
strings.Contains(message, "connection reset") ||
|
|
strings.Contains(message, "broken pipe")
|
|
}
|
|
|
|
func sanitizeProxyRequestHeader(source http.Header, upstreamPath string) http.Header {
|
|
upgradeRequest := isUpgradeRequest(source)
|
|
header := source.Clone()
|
|
for key := range header {
|
|
if isProxyManagedHeader(key, upgradeRequest) {
|
|
header.Del(key)
|
|
}
|
|
}
|
|
header.Del("Accept-Encoding")
|
|
|
|
userAgent := strings.TrimSpace(source.Get("User-Agent"))
|
|
if userAgent == "" {
|
|
userAgent = "Mozilla/5.0"
|
|
}
|
|
header.Set("User-Agent", userAgent)
|
|
if upgradeRequest {
|
|
header.Set("Connection", "Upgrade")
|
|
} else {
|
|
header.Set("Connection", "close")
|
|
}
|
|
|
|
if !isLoginPagePath(upstreamPath) {
|
|
return header
|
|
}
|
|
header.Del("Cookie")
|
|
header.Del("Referer")
|
|
return header
|
|
}
|
|
|
|
func isProxyManagedHeader(key string, upgradeRequest bool) bool {
|
|
switch http.CanonicalHeaderKey(key) {
|
|
case "Connection",
|
|
"Proxy-Connection",
|
|
"Keep-Alive",
|
|
"Transfer-Encoding",
|
|
"Te",
|
|
"Trailer",
|
|
"Proxy-Authenticate",
|
|
"Proxy-Authorization",
|
|
"Forwarded",
|
|
"X-Forwarded-For",
|
|
"X-Forwarded-Host",
|
|
"X-Forwarded-Proto",
|
|
"X-Real-Ip":
|
|
return true
|
|
case "Upgrade":
|
|
return !upgradeRequest
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isUpgradeRequest(header http.Header) bool {
|
|
if header == nil {
|
|
return false
|
|
}
|
|
connection := strings.ToLower(header.Get("Connection"))
|
|
return header.Get("Upgrade") != "" && strings.Contains(connection, "upgrade")
|
|
}
|
|
|
|
func isLoginPagePath(path string) bool {
|
|
path = strings.ToLower(path)
|
|
return strings.HasSuffix(path, "/doc/page/login.asp") || path == "/doc/page/login.asp"
|
|
}
|
|
|
|
type closeNotifyWriter struct {
|
|
http.ResponseWriter
|
|
}
|
|
|
|
func (w closeNotifyWriter) CloseNotify() <-chan bool {
|
|
ch := make(chan bool, 1)
|
|
return ch
|
|
}
|
|
|
|
func (w closeNotifyWriter) Flush() {
|
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|