diff --git a/internal/webdevice/proxy.go b/internal/webdevice/proxy.go index 20981ed..e8bb291 100644 --- a/internal/webdevice/proxy.go +++ b/internal/webdevice/proxy.go @@ -10,6 +10,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" ) @@ -50,6 +51,7 @@ func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, pr Transport: retryTransport{ base: &http.Transport{ DisableKeepAlives: true, + DisableCompression: true, ResponseHeaderTimeout: 5 * time.Second, }, limiter: s.proxyLimiter(targetIP), @@ -139,10 +141,20 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } resp, err := base.RoundTrip(nextReq) - releaseProxySlot(t.limiter) if err == nil { + if resp == nil || resp.Body == nil { + releaseProxySlot(t.limiter) + return resp, nil + } + resp.Body = &releaseOnCloseReadCloser{ + ReadCloser: resp.Body, + release: func() { + releaseProxySlot(t.limiter) + }, + } return resp, nil } + releaseProxySlot(t.limiter) lastErr = err if !shouldRetryProxyRequest(req, err) || attempt == attempts-1 { return nil, err @@ -152,6 +164,22 @@ func (t retryTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, lastErr } +type releaseOnCloseReadCloser struct { + io.ReadCloser + once sync.Once + release func() +} + +func (r *releaseOnCloseReadCloser) Close() error { + err := r.ReadCloser.Close() + r.once.Do(func() { + if r.release != nil { + r.release() + } + }) + return err +} + func acquireProxySlot(ctx context.Context, limiter chan struct{}, timeout time.Duration) error { if limiter == nil { return nil diff --git a/internal/webdevice/proxy_test.go b/internal/webdevice/proxy_test.go index 51c9239..2369dfc 100644 --- a/internal/webdevice/proxy_test.go +++ b/internal/webdevice/proxy_test.go @@ -1,6 +1,7 @@ package webdevice import ( + "io" "net/http" "net/http/httptest" "net/url" @@ -148,6 +149,44 @@ func TestProxyHTTPClosesUpstreamConnection(t *testing.T) { } } +func TestRetryTransportHoldsProxySlotUntilBodyClose(t *testing.T) { + limiter := make(chan struct{}, 1) + transport := retryTransport{ + base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, + Request: req, + }, nil + }), + limiter: limiter, + attempts: 1, + acquireTimeout: time.Second, + } + + req := httptest.NewRequest(http.MethodGet, "http://portal/test", nil) + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip() error = %v", err) + } + if got := len(limiter); got != 1 { + t.Fatalf("limiter len after RoundTrip = %d, want 1", got) + } + if err := resp.Body.Close(); err != nil { + t.Fatalf("Body.Close() error = %v", err) + } + if got := len(limiter); got != 0 { + t.Fatalf("limiter len after Body.Close = %d, want 0", got) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + func TestSanitizeProxyRequestHeaderDropsLoginCookie(t *testing.T) { source := http.Header{} source.Set("User-Agent", "browser")