feat: initialize managed portal
This commit is contained in:
101
internal/webdevice/proxy.go
Normal file
101
internal/webdevice/proxy.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidTargetIP = errors.New("invalid target ip")
|
||||
ErrTargetNotAllowed = errors.New("target ip not allowed")
|
||||
ErrInvalidProxyURL = errors.New("invalid proxy target")
|
||||
)
|
||||
|
||||
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
|
||||
if !IsPrivateIPv4Literal(targetIP) {
|
||||
return ErrInvalidTargetIP
|
||||
}
|
||||
if !s.IsAllowed(targetIP) {
|
||||
return ErrTargetNotAllowed
|
||||
}
|
||||
|
||||
rawTarget := s.ProxyTargetURL(targetIP)
|
||||
targetURL, err := url.Parse(rawTarget)
|
||||
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
|
||||
return ErrInvalidProxyURL
|
||||
}
|
||||
|
||||
if proxyPath == "" {
|
||||
proxyPath = "/"
|
||||
}
|
||||
rawQuery := r.URL.RawQuery
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = JoinProxyTargetPath(targetURL.Path, proxyPath)
|
||||
req.URL.RawPath = ""
|
||||
req.URL.RawQuery = rawQuery
|
||||
req.Host = targetURL.Host
|
||||
req.Header.Del("Accept-Encoding")
|
||||
},
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
proxyPrefix := "/proxy/web/" + targetIP
|
||||
if location := resp.Header.Get("Location"); location != "" {
|
||||
resp.Header.Set("Location", RewriteLocation(targetIP, targetURL, location))
|
||||
}
|
||||
if cookies := resp.Header.Values("Set-Cookie"); len(cookies) > 0 {
|
||||
resp.Header.Del("Set-Cookie")
|
||||
for _, cookie := range cookies {
|
||||
resp.Header.Add("Set-Cookie", RewriteSetCookie(cookie, proxyPrefix))
|
||||
}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if ShouldRewriteBody(contentType) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
rewritten := RewriteText(string(body), proxyPrefix, targetURL, contentType)
|
||||
rewrittenBytes := []byte(rewritten)
|
||||
resp.Body = io.NopCloser(strings.NewReader(rewritten))
|
||||
resp.ContentLength = int64(len(rewrittenBytes))
|
||||
resp.Header.Set("Content-Length", strconv.Itoa(len(rewrittenBytes)))
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-MD5")
|
||||
resp.Header.Del("Etag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "代理访问失败: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(closeNotifyWriter{ResponseWriter: w}, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
type closeNotifyWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w closeNotifyWriter) CloseNotify() <-chan bool {
|
||||
ch := make(chan bool, 1)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (w closeNotifyWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
93
internal/webdevice/proxy_test.go
Normal file
93
internal/webdevice/proxy_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteLocation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetURL, _ := url.Parse("http://192.168.1.124:80")
|
||||
got := RewriteLocation("192.168.1.124", targetURL, "http://192.168.1.124/ISAPI/Security")
|
||||
if got != "/proxy/web/192.168.1.124/ISAPI/Security" {
|
||||
t.Fatalf("RewriteLocation() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteSetCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := RewriteSetCookie("SID=1; Path=/; Domain=192.168.1.124; HttpOnly", "/proxy/web/192.168.1.124")
|
||||
if strings.Contains(strings.ToLower(got), "domain=") {
|
||||
t.Fatalf("RewriteSetCookie() kept domain: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Path=/proxy/web/192.168.1.124/") {
|
||||
t.Fatalf("RewriteSetCookie() path = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
targetURL, _ := url.Parse("http://192.168.1.124:80")
|
||||
body := `<html><head></head><body><img src="/doc/logo.png"><a href="http://192.168.1.124/ISAPI/x">x</a></body></html>`
|
||||
got := RewriteText(body, "/proxy/web/192.168.1.124", targetURL, "text/html")
|
||||
|
||||
if !strings.Contains(got, `/proxy/web/192.168.1.124/doc/logo.png`) {
|
||||
t.Fatalf("rewritten body missing proxied relative URL: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, `data-web-proxy-runtime`) {
|
||||
t.Fatalf("rewritten body missing runtime injection: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPRejectsUnscannedIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/")
|
||||
if err != ErrTargetNotAllowed {
|
||||
t.Fatalf("ProxyHTTP() error = %v, want ErrTargetNotAllowed", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHTTPServesAllowedTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", "http://192.168.1.124/ISAPI/test")
|
||||
w.Header().Add("Set-Cookie", "SID=1; Path=/")
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte(`<html><head></head><body><img src="/doc/logo.png"></body></html>`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
svc := NewService()
|
||||
svc.allowIP("192.168.1.124")
|
||||
svc.proxyTarget = func(ip string) string {
|
||||
return upstream.URL
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
if err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/"); err != nil {
|
||||
t.Fatalf("ProxyHTTP() error = %v", err)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d body = %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := rec.Header().Get("Location"); got != "http://192.168.1.124/ISAPI/test" {
|
||||
t.Fatalf("Location = %q", got)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "/proxy/web/192.168.1.124/doc/logo.png") {
|
||||
t.Fatalf("body = %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
240
internal/webdevice/rewrite.go
Normal file
240
internal/webdevice/rewrite.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func JoinProxyTargetPath(basePath, requestPath string) string {
|
||||
if requestPath == "" {
|
||||
requestPath = "/"
|
||||
}
|
||||
if basePath == "" || basePath == "/" {
|
||||
return requestPath
|
||||
}
|
||||
if strings.HasSuffix(basePath, "/") && strings.HasPrefix(requestPath, "/") {
|
||||
return basePath + strings.TrimPrefix(requestPath, "/")
|
||||
}
|
||||
if !strings.HasSuffix(basePath, "/") && !strings.HasPrefix(requestPath, "/") {
|
||||
return basePath + "/" + requestPath
|
||||
}
|
||||
return basePath + requestPath
|
||||
}
|
||||
|
||||
func RewriteLocation(targetIP string, targetURL *url.URL, location string) string {
|
||||
locationURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return location
|
||||
}
|
||||
|
||||
proxyPrefix := "/proxy/web/" + targetIP
|
||||
if locationURL.Host == "" && strings.HasPrefix(location, "/") {
|
||||
return proxyPrefix + location
|
||||
}
|
||||
|
||||
if locationURL.Host != "" {
|
||||
locationHost := locationURL.Hostname()
|
||||
locationPort := locationURL.Port()
|
||||
if locationPort == "" && (locationURL.Scheme == "" || locationURL.Scheme == "http") {
|
||||
locationPort = "80"
|
||||
}
|
||||
|
||||
targetHost := targetURL.Hostname()
|
||||
targetPort := targetURL.Port()
|
||||
if targetPort == "" && targetURL.Scheme == "http" {
|
||||
targetPort = "80"
|
||||
}
|
||||
|
||||
if locationHost != targetHost || locationPort != targetPort {
|
||||
return location
|
||||
}
|
||||
|
||||
rewrittenPath := locationURL.EscapedPath()
|
||||
if rewrittenPath == "" {
|
||||
rewrittenPath = "/"
|
||||
}
|
||||
if locationURL.RawQuery != "" {
|
||||
rewrittenPath += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewrittenPath += "#" + locationURL.Fragment
|
||||
}
|
||||
return proxyPrefix + rewrittenPath
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
var (
|
||||
webProxyQuotedAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*(['"])([^'"]*)['"]`)
|
||||
webProxyBareAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*([^'">\s][^>\s]*)`)
|
||||
webProxyCSSURLPattern = regexp.MustCompile(`(?i)url\(\s*(['"]?)([^'"\)\s]+)['"]?\s*\)`)
|
||||
webProxyQuotedURLPattern = regexp.MustCompile(`(['"])(/[^'"<>\s\\)]*)['"]`)
|
||||
)
|
||||
|
||||
func ShouldRewriteBody(contentType string) bool {
|
||||
contentType = strings.ToLower(contentType)
|
||||
return strings.Contains(contentType, "text/html") ||
|
||||
strings.Contains(contentType, "text/css")
|
||||
}
|
||||
|
||||
func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string {
|
||||
contentType = strings.ToLower(contentType)
|
||||
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyQuotedAttrPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 4 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[3], proxyPrefix, targetURL)
|
||||
return strings.Replace(match, parts[2]+parts[3]+parts[2], parts[2]+rewritten+parts[2], 1)
|
||||
})
|
||||
|
||||
body = webProxyBareAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyBareAttrPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return strings.Replace(match, parts[2], rewritten, 1)
|
||||
})
|
||||
|
||||
body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyQuotedURLPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return parts[1] + rewritten + parts[1]
|
||||
})
|
||||
}
|
||||
|
||||
body = webProxyCSSURLPattern.ReplaceAllStringFunc(body, func(match string) string {
|
||||
parts := webProxyCSSURLPattern.FindStringSubmatch(match)
|
||||
if len(parts) != 3 {
|
||||
return match
|
||||
}
|
||||
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
|
||||
return "url(" + parts[1] + rewritten + parts[1] + ")"
|
||||
})
|
||||
|
||||
if strings.Contains(contentType, "text/html") {
|
||||
body = injectRuntime(body, proxyPrefix)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func injectRuntime(body, proxyPrefix string) string {
|
||||
if strings.Contains(body, "data-web-proxy-runtime") {
|
||||
return body
|
||||
}
|
||||
|
||||
script := `<script data-web-proxy-runtime>(function(){var p="` + proxyPrefix + `";var d=["/ISAPI","/SDK","/PSIA","/doc","/webSocket"];function q(x){if(x.indexOf(p+"/")===0||x.indexOf("/proxy/web/")===0){return x}for(var i=0;i<d.length;i++){if(x===d[i]||x.indexOf(d[i]+"/")===0||x.indexOf(d[i]+"?")===0){return p+x}}return x}function r(u){if(typeof u!=="string"){return u}if(u.charAt(0)==="/"&&u.indexOf("//")!==0){return q(u)}try{var a=new URL(u,window.location.href);if(a.origin===window.location.origin){var x=a.pathname+a.search+a.hash;var y=q(x);if(y!==x){return y}}}catch(e){}return u}if(window.XMLHttpRequest){var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){arguments[1]=r(u);return o.apply(this,arguments)}}if(window.fetch){var f=window.fetch;window.fetch=function(i,n){if(typeof i==="string"){i=r(i)}else if(i&&i.url){i=new Request(r(i.url),i)}return f.call(this,i,n)}}function a(e){if(!e||!e.getAttribute){return}["src","href","action","data-src","data-href"].forEach(function(k){var v=e.getAttribute(k);if(v){var nv=r(v);if(nv!==v){e.setAttribute(k,nv)}}})}if(window.MutationObserver){new MutationObserver(function(ms){ms.forEach(function(m){if(m.type==="attributes"){a(m.target)}else{Array.prototype.forEach.call(m.addedNodes,function(n){a(n);if(n&&n.querySelectorAll){Array.prototype.forEach.call(n.querySelectorAll("[src],[href],[action],[data-src],[data-href]"),a)}})}})}).observe(document.documentElement,{childList:true,subtree:true,attributes:true,attributeFilter:["src","href","action","data-src","data-href"]})}})();</script>`
|
||||
|
||||
lower := strings.ToLower(body)
|
||||
if idx := strings.Index(lower, "</head>"); idx >= 0 {
|
||||
return body[:idx] + script + body[idx:]
|
||||
}
|
||||
if idx := strings.Index(lower, "<body"); idx >= 0 {
|
||||
if end := strings.Index(body[idx:], ">"); end >= 0 {
|
||||
insertAt := idx + end + 1
|
||||
return body[:insertAt] + script + body[insertAt:]
|
||||
}
|
||||
}
|
||||
return script + body
|
||||
}
|
||||
|
||||
func rewriteURL(rawURL, proxyPrefix string, targetURL *url.URL) string {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" ||
|
||||
strings.HasPrefix(rawURL, "#") ||
|
||||
strings.HasPrefix(rawURL, "//") ||
|
||||
strings.HasPrefix(rawURL, proxyPrefix+"/") ||
|
||||
strings.HasPrefix(rawURL, "/proxy/web/") {
|
||||
return rawURL
|
||||
}
|
||||
|
||||
lower := strings.ToLower(rawURL)
|
||||
for _, prefix := range []string{"data:", "blob:", "mailto:", "tel:", "javascript:"} {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return rawURL
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(rawURL, "/") {
|
||||
return proxyPrefix + rawURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return rawURL
|
||||
}
|
||||
if targetURL == nil || !sameUpstreamHost(parsed, targetURL) {
|
||||
return rawURL
|
||||
}
|
||||
|
||||
rewrittenPath := parsed.EscapedPath()
|
||||
if rewrittenPath == "" {
|
||||
rewrittenPath = "/"
|
||||
}
|
||||
if parsed.RawQuery != "" {
|
||||
rewrittenPath += "?" + parsed.RawQuery
|
||||
}
|
||||
if parsed.Fragment != "" {
|
||||
rewrittenPath += "#" + parsed.Fragment
|
||||
}
|
||||
return proxyPrefix + rewrittenPath
|
||||
}
|
||||
|
||||
func sameUpstreamHost(left, right *url.URL) bool {
|
||||
leftPort := left.Port()
|
||||
if leftPort == "" && (left.Scheme == "" || left.Scheme == "http") {
|
||||
leftPort = "80"
|
||||
}
|
||||
if leftPort == "" && left.Scheme == "https" {
|
||||
leftPort = "443"
|
||||
}
|
||||
|
||||
rightPort := right.Port()
|
||||
if rightPort == "" && (right.Scheme == "" || right.Scheme == "http") {
|
||||
rightPort = "80"
|
||||
}
|
||||
if rightPort == "" && right.Scheme == "https" {
|
||||
rightPort = "443"
|
||||
}
|
||||
|
||||
return strings.EqualFold(left.Hostname(), right.Hostname()) && leftPort == rightPort
|
||||
}
|
||||
|
||||
func RewriteSetCookie(cookie, proxyPrefix string) string {
|
||||
parts := strings.Split(cookie, ";")
|
||||
if len(parts) == 0 {
|
||||
return cookie
|
||||
}
|
||||
|
||||
rewritten := []string{strings.TrimSpace(parts[0])}
|
||||
hasPath := false
|
||||
for _, part := range parts[1:] {
|
||||
attr := strings.TrimSpace(part)
|
||||
if attr == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(attr)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "domain="):
|
||||
continue
|
||||
case strings.HasPrefix(lower, "path="):
|
||||
hasPath = true
|
||||
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
|
||||
default:
|
||||
rewritten = append(rewritten, attr)
|
||||
}
|
||||
}
|
||||
if !hasPath {
|
||||
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
|
||||
}
|
||||
return strings.Join(rewritten, "; ")
|
||||
}
|
||||
495
internal/webdevice/service.go
Normal file
495
internal/webdevice/service.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type InterfaceInfo struct {
|
||||
Name string `json:"name"`
|
||||
IP string `json:"ip"`
|
||||
Netmask string `json:"netmask"`
|
||||
}
|
||||
|
||||
type TCPDevice struct {
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Interface string `json:"interface"`
|
||||
Port int `json:"port"`
|
||||
TargetURL string `json:"target_url"`
|
||||
ProxyURL string `json:"proxy_url"`
|
||||
DirectURL string `json:"direct_url,omitempty"`
|
||||
ForwardPort int `json:"forward_port,omitempty"`
|
||||
}
|
||||
|
||||
type ScanResult struct {
|
||||
Interfaces []InterfaceInfo `json:"interfaces"`
|
||||
Devices []DeviceInfo `json:"devices"`
|
||||
Count int `json:"count"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type InterfaceGetter func() ([]InterfaceInfo, error)
|
||||
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
|
||||
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
|
||||
type ProxyTargetResolver func(ip string) string
|
||||
|
||||
type Service struct {
|
||||
mu sync.RWMutex
|
||||
allowed map[string]time.Time
|
||||
forwarders map[string]*webDeviceForwarder
|
||||
interfaceGetter InterfaceGetter
|
||||
tcpScanner TCPScanner
|
||||
newForwarder ForwarderFactory
|
||||
proxyTarget ProxyTargetResolver
|
||||
forwardTarget func(ip string) string
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
allowed: make(map[string]time.Time),
|
||||
forwarders: make(map[string]*webDeviceForwarder),
|
||||
interfaceGetter: defaultInterfaceGetter,
|
||||
tcpScanner: scanTCP,
|
||||
newForwarder: newWebDeviceForwarder,
|
||||
proxyTarget: defaultProxyTarget,
|
||||
forwardTarget: defaultForwardTarget,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
interfaces, err := s.interfaceGetter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(interfaces) == 0 {
|
||||
return &ScanResult{
|
||||
Interfaces: []InterfaceInfo{},
|
||||
Devices: []DeviceInfo{},
|
||||
Message: "未找到有效的网卡",
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeIPs := make(map[string]bool)
|
||||
for _, iface := range interfaces {
|
||||
excludeIPs[iface.IP] = true
|
||||
}
|
||||
|
||||
result := &ScanResult{
|
||||
Interfaces: interfaces,
|
||||
Devices: []DeviceInfo{},
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
scheme, host := requestBase(r)
|
||||
for _, iface := range interfaces {
|
||||
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
|
||||
if scanErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
if !IsPrivateIPv4Literal(device.IP) {
|
||||
continue
|
||||
}
|
||||
|
||||
s.allowIP(device.IP)
|
||||
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
|
||||
if forwardErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动网页直连转发失败: %v", device.IP, forwardErr))
|
||||
}
|
||||
|
||||
deviceInfo := DeviceInfo{
|
||||
IP: device.IP,
|
||||
Interface: iface.Name,
|
||||
Port: device.Port,
|
||||
TargetURL: fmt.Sprintf("http://%s/", device.IP),
|
||||
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
|
||||
}
|
||||
if forwardErr == nil {
|
||||
deviceInfo.ForwardPort = forwardPort
|
||||
deviceInfo.DirectURL = buildDirectURL(scheme, host, forwardPort)
|
||||
}
|
||||
result.Devices = append(result.Devices, deviceInfo)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result.Devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
|
||||
})
|
||||
result.Count = len(result.Devices)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) allowIP(ip string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowed[ip] = time.Now()
|
||||
}
|
||||
|
||||
func (s *Service) AllowIP(ip string) {
|
||||
s.allowIP(ip)
|
||||
}
|
||||
|
||||
func (s *Service) IsAllowed(ip string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.allowed[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
|
||||
if getter != nil {
|
||||
s.interfaceGetter = getter
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetTCPScanner(scanner TCPScanner) {
|
||||
if scanner != nil {
|
||||
s.tcpScanner = scanner
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
|
||||
if factory != nil {
|
||||
s.newForwarder = factory
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
|
||||
if resolver != nil {
|
||||
s.proxyTarget = resolver
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) EnsureForwarder(ip string) (int, error) {
|
||||
port, ok := WebDeviceForwardPort(ip)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("无效的IPv4地址")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if forwarder, ok := s.forwarders[ip]; ok {
|
||||
return forwarder.port, nil
|
||||
}
|
||||
|
||||
targetAddress := defaultForwardTarget(ip)
|
||||
if s.forwardTarget != nil {
|
||||
targetAddress = s.forwardTarget(ip)
|
||||
}
|
||||
|
||||
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.forwarders[ip] = forwarder
|
||||
go forwarder.serve()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func (s *Service) ProxyTargetURL(ip string) string {
|
||||
if s.proxyTarget == nil {
|
||||
return defaultProxyTarget(ip)
|
||||
}
|
||||
return s.proxyTarget(ip)
|
||||
}
|
||||
|
||||
func IsPrivateIPv4Literal(value string) bool {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
|
||||
}
|
||||
|
||||
func WebDeviceForwardPort(ip string) (int, bool) {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return 0, false
|
||||
}
|
||||
ip4 := parsed.To4()
|
||||
if ip4 == nil {
|
||||
return 0, false
|
||||
}
|
||||
return 31000 + int(ip4[3]), true
|
||||
}
|
||||
|
||||
func requestBase(r *http.Request) (string, string) {
|
||||
scheme := r.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
if hostname, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = hostname
|
||||
}
|
||||
return scheme, host
|
||||
}
|
||||
|
||||
func buildDirectURL(scheme, host string, port int) string {
|
||||
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
|
||||
}
|
||||
|
||||
type webDeviceForwarder struct {
|
||||
ip string
|
||||
port int
|
||||
targetAddress string
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
type WebDeviceForwarder = webDeviceForwarder
|
||||
|
||||
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port == 0 {
|
||||
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||
port = tcpAddr.Port
|
||||
}
|
||||
}
|
||||
return &webDeviceForwarder{
|
||||
ip: ip,
|
||||
port: port,
|
||||
targetAddress: targetAddress,
|
||||
listener: listener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) serve() {
|
||||
if f == nil || f.listener == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
clientConn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go f.handle(clientConn)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
|
||||
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := ioCopy(targetConn, clientConn)
|
||||
errCh <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := ioCopy(clientConn, targetConn)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
<-errCh
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return copyConn(dst, src)
|
||||
}
|
||||
|
||||
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return io.Copy(dst, src)
|
||||
}
|
||||
|
||||
func defaultProxyTarget(ip string) string {
|
||||
return "http://" + net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultForwardTarget(ip string) string {
|
||||
return net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
|
||||
var interfaces []InterfaceInfo
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(iface.Name, "docker") ||
|
||||
strings.Contains(iface.Name, "veth") ||
|
||||
strings.Contains(iface.Name, "br-") ||
|
||||
strings.Contains(iface.Name, "tun") ||
|
||||
strings.Contains(iface.Name, "tap") ||
|
||||
strings.HasPrefix(iface.Name, "lo") {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipNet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
interfaces = append(interfaces, InterfaceInfo{
|
||||
Name: iface.Name,
|
||||
IP: ipNet.IP.String(),
|
||||
Netmask: net.IP(ipNet.Mask).String(),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
ipRange, err := calculateIPRange(ip, netmask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port <= 0 || port > 65535 {
|
||||
return nil, fmt.Errorf("无效的端口: %d", port)
|
||||
}
|
||||
|
||||
var devices []TCPDevice
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 20)
|
||||
timeout := 2 * time.Second
|
||||
|
||||
current := make(net.IP, len(ipRange.Start))
|
||||
copy(current, ipRange.Start)
|
||||
incrementIP(current)
|
||||
|
||||
for {
|
||||
if current.To4().Equal(ipRange.End.To4()) {
|
||||
break
|
||||
}
|
||||
|
||||
currentIP := current.String()
|
||||
if excludeIPs[currentIP] {
|
||||
incrementIP(current)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(targetIP string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
if scanTCPPort(targetIP, port, timeout) {
|
||||
mu.Lock()
|
||||
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
|
||||
mu.Unlock()
|
||||
}
|
||||
}(currentIP)
|
||||
|
||||
incrementIP(current)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
sort.Slice(devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
|
||||
})
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
type ipRange struct {
|
||||
Start net.IP
|
||||
End net.IP
|
||||
}
|
||||
|
||||
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
|
||||
parseIP := net.ParseIP(ip)
|
||||
if parseIP == nil {
|
||||
return nil, fmt.Errorf("无效的IP: %s", ip)
|
||||
}
|
||||
|
||||
mask := net.IPMask(net.ParseIP(netmask).To4())
|
||||
if mask == nil {
|
||||
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
|
||||
}
|
||||
|
||||
network := &net.IPNet{
|
||||
IP: parseIP.Mask(mask),
|
||||
Mask: mask,
|
||||
}
|
||||
broadcast := make(net.IP, len(network.IP))
|
||||
copy(broadcast, network.IP)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
broadcast[i] |= ^mask[i]
|
||||
}
|
||||
|
||||
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
|
||||
}
|
||||
|
||||
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
|
||||
if port <= 0 || port > 65535 {
|
||||
return false
|
||||
}
|
||||
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
|
||||
conn, err := net.DialTimeout("tcp", addr, timeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func incrementIP(ip net.IP) {
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
ip[i]++
|
||||
if ip[i] > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ipv4ToUint32(value string) uint32 {
|
||||
parsed := net.ParseIP(value)
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
ip := parsed.To4()
|
||||
if ip == nil {
|
||||
return 0
|
||||
}
|
||||
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
||||
}
|
||||
73
internal/webdevice/service_test.go
Normal file
73
internal/webdevice/service_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsPrivateIPv4Literal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]bool{
|
||||
"192.168.1.10": true,
|
||||
"10.0.0.8": true,
|
||||
"172.16.5.2": true,
|
||||
"127.0.0.1": false,
|
||||
"8.8.8.8": false,
|
||||
"0.0.0.0": false,
|
||||
"::1": false,
|
||||
"bad-ip": false,
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
if got := IsPrivateIPv4Literal(input); got != want {
|
||||
t.Fatalf("IsPrivateIPv4Literal(%q) = %v, want %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebDeviceForwardPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
port, ok := WebDeviceForwardPort("192.168.1.124")
|
||||
if !ok {
|
||||
t.Fatal("WebDeviceForwardPort() ok = false")
|
||||
}
|
||||
if port != 31124 {
|
||||
t.Fatalf("port = %d, want 31124", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanBuildsDirectURLAndAllowList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewService()
|
||||
svc.interfaceGetter = func() ([]InterfaceInfo, error) {
|
||||
return []InterfaceInfo{{
|
||||
Name: "eth0",
|
||||
IP: "10.8.0.14",
|
||||
Netmask: "255.255.255.0",
|
||||
}}, nil
|
||||
}
|
||||
svc.tcpScanner = func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
return []TCPDevice{{IP: "192.168.1.124", Port: 80}}, nil
|
||||
}
|
||||
svc.newForwarder = func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
return &webDeviceForwarder{ip: ip, port: port, targetAddress: targetAddress}, nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "http://10.8.0.14:13000/api/web-devices/scan", nil)
|
||||
result, err := svc.Scan(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Scan() error = %v", err)
|
||||
}
|
||||
if result.Count != 1 {
|
||||
t.Fatalf("result.Count = %d", result.Count)
|
||||
}
|
||||
if !svc.IsAllowed("192.168.1.124") {
|
||||
t.Fatal("expected IP to be allowed after scan")
|
||||
}
|
||||
if result.Devices[0].DirectURL != "http://10.8.0.14:31124/" {
|
||||
t.Fatalf("DirectURL = %q", result.Devices[0].DirectURL)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user