feat: initialize managed portal

This commit is contained in:
Yoilun
2026-04-27 10:04:36 +08:00
commit d4e351df71
145 changed files with 13425 additions and 0 deletions

101
internal/webdevice/proxy.go Normal file
View File

@@ -0,0 +1,101 @@
package webdevice
import (
"errors"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
)
var (
ErrInvalidTargetIP = errors.New("invalid target ip")
ErrTargetNotAllowed = errors.New("target ip not allowed")
ErrInvalidProxyURL = errors.New("invalid proxy target")
)
func (s *Service) ProxyHTTP(w http.ResponseWriter, r *http.Request, targetIP, proxyPath string) error {
if !IsPrivateIPv4Literal(targetIP) {
return ErrInvalidTargetIP
}
if !s.IsAllowed(targetIP) {
return ErrTargetNotAllowed
}
rawTarget := s.ProxyTargetURL(targetIP)
targetURL, err := url.Parse(rawTarget)
if err != nil || targetURL.Scheme == "" || targetURL.Host == "" {
return ErrInvalidProxyURL
}
if proxyPath == "" {
proxyPath = "/"
}
rawQuery := r.URL.RawQuery
proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = JoinProxyTargetPath(targetURL.Path, proxyPath)
req.URL.RawPath = ""
req.URL.RawQuery = rawQuery
req.Host = targetURL.Host
req.Header.Del("Accept-Encoding")
},
ModifyResponse: func(resp *http.Response) error {
proxyPrefix := "/proxy/web/" + targetIP
if location := resp.Header.Get("Location"); location != "" {
resp.Header.Set("Location", RewriteLocation(targetIP, targetURL, location))
}
if cookies := resp.Header.Values("Set-Cookie"); len(cookies) > 0 {
resp.Header.Del("Set-Cookie")
for _, cookie := range cookies {
resp.Header.Add("Set-Cookie", RewriteSetCookie(cookie, proxyPrefix))
}
}
contentType := resp.Header.Get("Content-Type")
if ShouldRewriteBody(contentType) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
_ = resp.Body.Close()
rewritten := RewriteText(string(body), proxyPrefix, targetURL, contentType)
rewrittenBytes := []byte(rewritten)
resp.Body = io.NopCloser(strings.NewReader(rewritten))
resp.ContentLength = int64(len(rewrittenBytes))
resp.Header.Set("Content-Length", strconv.Itoa(len(rewrittenBytes)))
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-MD5")
resp.Header.Del("Etag")
}
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "代理访问失败: "+err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(closeNotifyWriter{ResponseWriter: w}, r)
return nil
}
type closeNotifyWriter struct {
http.ResponseWriter
}
func (w closeNotifyWriter) CloseNotify() <-chan bool {
ch := make(chan bool, 1)
return ch
}
func (w closeNotifyWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,93 @@
package webdevice
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestRewriteLocation(t *testing.T) {
t.Parallel()
targetURL, _ := url.Parse("http://192.168.1.124:80")
got := RewriteLocation("192.168.1.124", targetURL, "http://192.168.1.124/ISAPI/Security")
if got != "/proxy/web/192.168.1.124/ISAPI/Security" {
t.Fatalf("RewriteLocation() = %q", got)
}
}
func TestRewriteSetCookie(t *testing.T) {
t.Parallel()
got := RewriteSetCookie("SID=1; Path=/; Domain=192.168.1.124; HttpOnly", "/proxy/web/192.168.1.124")
if strings.Contains(strings.ToLower(got), "domain=") {
t.Fatalf("RewriteSetCookie() kept domain: %q", got)
}
if !strings.Contains(got, "Path=/proxy/web/192.168.1.124/") {
t.Fatalf("RewriteSetCookie() path = %q", got)
}
}
func TestRewriteText(t *testing.T) {
t.Parallel()
targetURL, _ := url.Parse("http://192.168.1.124:80")
body := `<html><head></head><body><img src="/doc/logo.png"><a href="http://192.168.1.124/ISAPI/x">x</a></body></html>`
got := RewriteText(body, "/proxy/web/192.168.1.124", targetURL, "text/html")
if !strings.Contains(got, `/proxy/web/192.168.1.124/doc/logo.png`) {
t.Fatalf("rewritten body missing proxied relative URL: %s", got)
}
if !strings.Contains(got, `data-web-proxy-runtime`) {
t.Fatalf("rewritten body missing runtime injection: %s", got)
}
}
func TestProxyHTTPRejectsUnscannedIP(t *testing.T) {
t.Parallel()
svc := NewService()
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
rec := httptest.NewRecorder()
err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/")
if err != ErrTargetNotAllowed {
t.Fatalf("ProxyHTTP() error = %v, want ErrTargetNotAllowed", err)
}
}
func TestProxyHTTPServesAllowedTarget(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", "http://192.168.1.124/ISAPI/test")
w.Header().Add("Set-Cookie", "SID=1; Path=/")
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte(`<html><head></head><body><img src="/doc/logo.png"></body></html>`))
}))
defer upstream.Close()
svc := NewService()
svc.allowIP("192.168.1.124")
svc.proxyTarget = func(ip string) string {
return upstream.URL
}
req := httptest.NewRequest(http.MethodGet, "http://portal/proxy/web/192.168.1.124/", nil)
rec := httptest.NewRecorder()
if err := svc.ProxyHTTP(rec, req, "192.168.1.124", "/"); err != nil {
t.Fatalf("ProxyHTTP() error = %v", err)
}
if rec.Code != http.StatusOK {
t.Fatalf("status = %d body = %s", rec.Code, rec.Body.String())
}
if got := rec.Header().Get("Location"); got != "http://192.168.1.124/ISAPI/test" {
t.Fatalf("Location = %q", got)
}
if !strings.Contains(rec.Body.String(), "/proxy/web/192.168.1.124/doc/logo.png") {
t.Fatalf("body = %s", rec.Body.String())
}
}

View File

@@ -0,0 +1,240 @@
package webdevice
import (
"net/url"
"regexp"
"strings"
)
func JoinProxyTargetPath(basePath, requestPath string) string {
if requestPath == "" {
requestPath = "/"
}
if basePath == "" || basePath == "/" {
return requestPath
}
if strings.HasSuffix(basePath, "/") && strings.HasPrefix(requestPath, "/") {
return basePath + strings.TrimPrefix(requestPath, "/")
}
if !strings.HasSuffix(basePath, "/") && !strings.HasPrefix(requestPath, "/") {
return basePath + "/" + requestPath
}
return basePath + requestPath
}
func RewriteLocation(targetIP string, targetURL *url.URL, location string) string {
locationURL, err := url.Parse(location)
if err != nil {
return location
}
proxyPrefix := "/proxy/web/" + targetIP
if locationURL.Host == "" && strings.HasPrefix(location, "/") {
return proxyPrefix + location
}
if locationURL.Host != "" {
locationHost := locationURL.Hostname()
locationPort := locationURL.Port()
if locationPort == "" && (locationURL.Scheme == "" || locationURL.Scheme == "http") {
locationPort = "80"
}
targetHost := targetURL.Hostname()
targetPort := targetURL.Port()
if targetPort == "" && targetURL.Scheme == "http" {
targetPort = "80"
}
if locationHost != targetHost || locationPort != targetPort {
return location
}
rewrittenPath := locationURL.EscapedPath()
if rewrittenPath == "" {
rewrittenPath = "/"
}
if locationURL.RawQuery != "" {
rewrittenPath += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewrittenPath += "#" + locationURL.Fragment
}
return proxyPrefix + rewrittenPath
}
return location
}
var (
webProxyQuotedAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*(['"])([^'"]*)['"]`)
webProxyBareAttrPattern = regexp.MustCompile(`(?i)\b(href|src|action|poster|data-src|data-href)\s*=\s*([^'">\s][^>\s]*)`)
webProxyCSSURLPattern = regexp.MustCompile(`(?i)url\(\s*(['"]?)([^'"\)\s]+)['"]?\s*\)`)
webProxyQuotedURLPattern = regexp.MustCompile(`(['"])(/[^'"<>\s\\)]*)['"]`)
)
func ShouldRewriteBody(contentType string) bool {
contentType = strings.ToLower(contentType)
return strings.Contains(contentType, "text/html") ||
strings.Contains(contentType, "text/css")
}
func RewriteText(body, proxyPrefix string, targetURL *url.URL, contentType string) string {
contentType = strings.ToLower(contentType)
if strings.Contains(contentType, "text/html") {
body = webProxyQuotedAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyQuotedAttrPattern.FindStringSubmatch(match)
if len(parts) != 4 {
return match
}
rewritten := rewriteURL(parts[3], proxyPrefix, targetURL)
return strings.Replace(match, parts[2]+parts[3]+parts[2], parts[2]+rewritten+parts[2], 1)
})
body = webProxyBareAttrPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyBareAttrPattern.FindStringSubmatch(match)
if len(parts) != 3 {
return match
}
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
return strings.Replace(match, parts[2], rewritten, 1)
})
body = webProxyQuotedURLPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyQuotedURLPattern.FindStringSubmatch(match)
if len(parts) != 3 {
return match
}
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
return parts[1] + rewritten + parts[1]
})
}
body = webProxyCSSURLPattern.ReplaceAllStringFunc(body, func(match string) string {
parts := webProxyCSSURLPattern.FindStringSubmatch(match)
if len(parts) != 3 {
return match
}
rewritten := rewriteURL(parts[2], proxyPrefix, targetURL)
return "url(" + parts[1] + rewritten + parts[1] + ")"
})
if strings.Contains(contentType, "text/html") {
body = injectRuntime(body, proxyPrefix)
}
return body
}
func injectRuntime(body, proxyPrefix string) string {
if strings.Contains(body, "data-web-proxy-runtime") {
return body
}
script := `<script data-web-proxy-runtime>(function(){var p="` + proxyPrefix + `";var d=["/ISAPI","/SDK","/PSIA","/doc","/webSocket"];function q(x){if(x.indexOf(p+"/")===0||x.indexOf("/proxy/web/")===0){return x}for(var i=0;i<d.length;i++){if(x===d[i]||x.indexOf(d[i]+"/")===0||x.indexOf(d[i]+"?")===0){return p+x}}return x}function r(u){if(typeof u!=="string"){return u}if(u.charAt(0)==="/"&&u.indexOf("//")!==0){return q(u)}try{var a=new URL(u,window.location.href);if(a.origin===window.location.origin){var x=a.pathname+a.search+a.hash;var y=q(x);if(y!==x){return y}}}catch(e){}return u}if(window.XMLHttpRequest){var o=XMLHttpRequest.prototype.open;XMLHttpRequest.prototype.open=function(m,u){arguments[1]=r(u);return o.apply(this,arguments)}}if(window.fetch){var f=window.fetch;window.fetch=function(i,n){if(typeof i==="string"){i=r(i)}else if(i&&i.url){i=new Request(r(i.url),i)}return f.call(this,i,n)}}function a(e){if(!e||!e.getAttribute){return}["src","href","action","data-src","data-href"].forEach(function(k){var v=e.getAttribute(k);if(v){var nv=r(v);if(nv!==v){e.setAttribute(k,nv)}}})}if(window.MutationObserver){new MutationObserver(function(ms){ms.forEach(function(m){if(m.type==="attributes"){a(m.target)}else{Array.prototype.forEach.call(m.addedNodes,function(n){a(n);if(n&&n.querySelectorAll){Array.prototype.forEach.call(n.querySelectorAll("[src],[href],[action],[data-src],[data-href]"),a)}})}})}).observe(document.documentElement,{childList:true,subtree:true,attributes:true,attributeFilter:["src","href","action","data-src","data-href"]})}})();</script>`
lower := strings.ToLower(body)
if idx := strings.Index(lower, "</head>"); idx >= 0 {
return body[:idx] + script + body[idx:]
}
if idx := strings.Index(lower, "<body"); idx >= 0 {
if end := strings.Index(body[idx:], ">"); end >= 0 {
insertAt := idx + end + 1
return body[:insertAt] + script + body[insertAt:]
}
}
return script + body
}
func rewriteURL(rawURL, proxyPrefix string, targetURL *url.URL) string {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" ||
strings.HasPrefix(rawURL, "#") ||
strings.HasPrefix(rawURL, "//") ||
strings.HasPrefix(rawURL, proxyPrefix+"/") ||
strings.HasPrefix(rawURL, "/proxy/web/") {
return rawURL
}
lower := strings.ToLower(rawURL)
for _, prefix := range []string{"data:", "blob:", "mailto:", "tel:", "javascript:"} {
if strings.HasPrefix(lower, prefix) {
return rawURL
}
}
if strings.HasPrefix(rawURL, "/") {
return proxyPrefix + rawURL
}
parsed, err := url.Parse(rawURL)
if err != nil || parsed.Host == "" {
return rawURL
}
if targetURL == nil || !sameUpstreamHost(parsed, targetURL) {
return rawURL
}
rewrittenPath := parsed.EscapedPath()
if rewrittenPath == "" {
rewrittenPath = "/"
}
if parsed.RawQuery != "" {
rewrittenPath += "?" + parsed.RawQuery
}
if parsed.Fragment != "" {
rewrittenPath += "#" + parsed.Fragment
}
return proxyPrefix + rewrittenPath
}
func sameUpstreamHost(left, right *url.URL) bool {
leftPort := left.Port()
if leftPort == "" && (left.Scheme == "" || left.Scheme == "http") {
leftPort = "80"
}
if leftPort == "" && left.Scheme == "https" {
leftPort = "443"
}
rightPort := right.Port()
if rightPort == "" && (right.Scheme == "" || right.Scheme == "http") {
rightPort = "80"
}
if rightPort == "" && right.Scheme == "https" {
rightPort = "443"
}
return strings.EqualFold(left.Hostname(), right.Hostname()) && leftPort == rightPort
}
func RewriteSetCookie(cookie, proxyPrefix string) string {
parts := strings.Split(cookie, ";")
if len(parts) == 0 {
return cookie
}
rewritten := []string{strings.TrimSpace(parts[0])}
hasPath := false
for _, part := range parts[1:] {
attr := strings.TrimSpace(part)
if attr == "" {
continue
}
lower := strings.ToLower(attr)
switch {
case strings.HasPrefix(lower, "domain="):
continue
case strings.HasPrefix(lower, "path="):
hasPath = true
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
default:
rewritten = append(rewritten, attr)
}
}
if !hasPath {
rewritten = append(rewritten, "Path="+proxyPrefix+"/")
}
return strings.Join(rewritten, "; ")
}

View File

@@ -0,0 +1,495 @@
package webdevice
import (
"fmt"
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
)
type InterfaceInfo struct {
Name string `json:"name"`
IP string `json:"ip"`
Netmask string `json:"netmask"`
}
type TCPDevice struct {
IP string
Port int
}
type DeviceInfo struct {
IP string `json:"ip"`
Interface string `json:"interface"`
Port int `json:"port"`
TargetURL string `json:"target_url"`
ProxyURL string `json:"proxy_url"`
DirectURL string `json:"direct_url,omitempty"`
ForwardPort int `json:"forward_port,omitempty"`
}
type ScanResult struct {
Interfaces []InterfaceInfo `json:"interfaces"`
Devices []DeviceInfo `json:"devices"`
Count int `json:"count"`
Errors []string `json:"errors,omitempty"`
Message string `json:"message,omitempty"`
}
type InterfaceGetter func() ([]InterfaceInfo, error)
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
type ProxyTargetResolver func(ip string) string
type Service struct {
mu sync.RWMutex
allowed map[string]time.Time
forwarders map[string]*webDeviceForwarder
interfaceGetter InterfaceGetter
tcpScanner TCPScanner
newForwarder ForwarderFactory
proxyTarget ProxyTargetResolver
forwardTarget func(ip string) string
}
func NewService() *Service {
return &Service{
allowed: make(map[string]time.Time),
forwarders: make(map[string]*webDeviceForwarder),
interfaceGetter: defaultInterfaceGetter,
tcpScanner: scanTCP,
newForwarder: newWebDeviceForwarder,
proxyTarget: defaultProxyTarget,
forwardTarget: defaultForwardTarget,
}
}
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
interfaces, err := s.interfaceGetter()
if err != nil {
return nil, err
}
if len(interfaces) == 0 {
return &ScanResult{
Interfaces: []InterfaceInfo{},
Devices: []DeviceInfo{},
Message: "未找到有效的网卡",
}, nil
}
excludeIPs := make(map[string]bool)
for _, iface := range interfaces {
excludeIPs[iface.IP] = true
}
result := &ScanResult{
Interfaces: interfaces,
Devices: []DeviceInfo{},
Errors: []string{},
}
scheme, host := requestBase(r)
for _, iface := range interfaces {
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
if scanErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
continue
}
for _, device := range devices {
if !IsPrivateIPv4Literal(device.IP) {
continue
}
s.allowIP(device.IP)
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
if forwardErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动网页直连转发失败: %v", device.IP, forwardErr))
}
deviceInfo := DeviceInfo{
IP: device.IP,
Interface: iface.Name,
Port: device.Port,
TargetURL: fmt.Sprintf("http://%s/", device.IP),
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
}
if forwardErr == nil {
deviceInfo.ForwardPort = forwardPort
deviceInfo.DirectURL = buildDirectURL(scheme, host, forwardPort)
}
result.Devices = append(result.Devices, deviceInfo)
}
}
sort.Slice(result.Devices, func(i, j int) bool {
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
})
result.Count = len(result.Devices)
return result, nil
}
func (s *Service) allowIP(ip string) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowed[ip] = time.Now()
}
func (s *Service) AllowIP(ip string) {
s.allowIP(ip)
}
func (s *Service) IsAllowed(ip string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.allowed[ip]
return ok
}
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
if getter != nil {
s.interfaceGetter = getter
}
}
func (s *Service) SetTCPScanner(scanner TCPScanner) {
if scanner != nil {
s.tcpScanner = scanner
}
}
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
if factory != nil {
s.newForwarder = factory
}
}
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
if resolver != nil {
s.proxyTarget = resolver
}
}
func (s *Service) EnsureForwarder(ip string) (int, error) {
port, ok := WebDeviceForwardPort(ip)
if !ok {
return 0, fmt.Errorf("无效的IPv4地址")
}
s.mu.Lock()
defer s.mu.Unlock()
if forwarder, ok := s.forwarders[ip]; ok {
return forwarder.port, nil
}
targetAddress := defaultForwardTarget(ip)
if s.forwardTarget != nil {
targetAddress = s.forwardTarget(ip)
}
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
if err != nil {
return 0, err
}
s.forwarders[ip] = forwarder
go forwarder.serve()
return port, nil
}
func (s *Service) ProxyTargetURL(ip string) string {
if s.proxyTarget == nil {
return defaultProxyTarget(ip)
}
return s.proxyTarget(ip)
}
func IsPrivateIPv4Literal(value string) bool {
ip := net.ParseIP(value)
if ip == nil {
return false
}
ip4 := ip.To4()
if ip4 == nil {
return false
}
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
}
func WebDeviceForwardPort(ip string) (int, bool) {
parsed := net.ParseIP(ip)
if parsed == nil {
return 0, false
}
ip4 := parsed.To4()
if ip4 == nil {
return 0, false
}
return 31000 + int(ip4[3]), true
}
func requestBase(r *http.Request) (string, string) {
scheme := r.Header.Get("X-Forwarded-Proto")
if scheme == "" {
scheme = "http"
if r.TLS != nil {
scheme = "https"
}
}
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
if hostname, _, err := net.SplitHostPort(host); err == nil {
host = hostname
}
return scheme, host
}
func buildDirectURL(scheme, host string, port int) string {
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
}
type webDeviceForwarder struct {
ip string
port int
targetAddress string
listener net.Listener
}
type WebDeviceForwarder = webDeviceForwarder
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
return nil, err
}
if port == 0 {
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
port = tcpAddr.Port
}
}
return &webDeviceForwarder{
ip: ip,
port: port,
targetAddress: targetAddress,
listener: listener,
}, nil
}
func (f *webDeviceForwarder) serve() {
if f == nil || f.listener == nil {
return
}
for {
clientConn, err := f.listener.Accept()
if err != nil {
return
}
go f.handle(clientConn)
}
}
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
if err != nil {
_ = clientConn.Close()
return
}
errCh := make(chan error, 2)
go func() {
_, err := ioCopy(targetConn, clientConn)
errCh <- err
}()
go func() {
_, err := ioCopy(clientConn, targetConn)
errCh <- err
}()
<-errCh
_ = clientConn.Close()
_ = targetConn.Close()
}
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
return copyConn(dst, src)
}
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
return io.Copy(dst, src)
}
func defaultProxyTarget(ip string) string {
return "http://" + net.JoinHostPort(ip, "80")
}
func defaultForwardTarget(ip string) string {
return net.JoinHostPort(ip, "80")
}
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
var interfaces []InterfaceInfo
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
if strings.Contains(iface.Name, "docker") ||
strings.Contains(iface.Name, "veth") ||
strings.Contains(iface.Name, "br-") ||
strings.Contains(iface.Name, "tun") ||
strings.Contains(iface.Name, "tap") ||
strings.HasPrefix(iface.Name, "lo") {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || ipNet.IP.To4() == nil {
continue
}
interfaces = append(interfaces, InterfaceInfo{
Name: iface.Name,
IP: ipNet.IP.String(),
Netmask: net.IP(ipNet.Mask).String(),
})
break
}
}
return interfaces, nil
}
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
ipRange, err := calculateIPRange(ip, netmask)
if err != nil {
return nil, err
}
if port <= 0 || port > 65535 {
return nil, fmt.Errorf("无效的端口: %d", port)
}
var devices []TCPDevice
var mu sync.Mutex
var wg sync.WaitGroup
semaphore := make(chan struct{}, 20)
timeout := 2 * time.Second
current := make(net.IP, len(ipRange.Start))
copy(current, ipRange.Start)
incrementIP(current)
for {
if current.To4().Equal(ipRange.End.To4()) {
break
}
currentIP := current.String()
if excludeIPs[currentIP] {
incrementIP(current)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(targetIP string) {
defer wg.Done()
defer func() { <-semaphore }()
if scanTCPPort(targetIP, port, timeout) {
mu.Lock()
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
mu.Unlock()
}
}(currentIP)
incrementIP(current)
}
wg.Wait()
sort.Slice(devices, func(i, j int) bool {
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
})
return devices, nil
}
type ipRange struct {
Start net.IP
End net.IP
}
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
parseIP := net.ParseIP(ip)
if parseIP == nil {
return nil, fmt.Errorf("无效的IP: %s", ip)
}
mask := net.IPMask(net.ParseIP(netmask).To4())
if mask == nil {
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
}
network := &net.IPNet{
IP: parseIP.Mask(mask),
Mask: mask,
}
broadcast := make(net.IP, len(network.IP))
copy(broadcast, network.IP)
for i := 0; i < len(mask); i++ {
broadcast[i] |= ^mask[i]
}
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
}
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
if port <= 0 || port > 65535 {
return false
}
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return false
}
_ = conn.Close()
return true
}
func incrementIP(ip net.IP) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] > 0 {
return
}
}
}
func ipv4ToUint32(value string) uint32 {
parsed := net.ParseIP(value)
if parsed == nil {
return 0
}
ip := parsed.To4()
if ip == nil {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}

View File

@@ -0,0 +1,73 @@
package webdevice
import (
"net/http/httptest"
"testing"
)
func TestIsPrivateIPv4Literal(t *testing.T) {
t.Parallel()
cases := map[string]bool{
"192.168.1.10": true,
"10.0.0.8": true,
"172.16.5.2": true,
"127.0.0.1": false,
"8.8.8.8": false,
"0.0.0.0": false,
"::1": false,
"bad-ip": false,
}
for input, want := range cases {
if got := IsPrivateIPv4Literal(input); got != want {
t.Fatalf("IsPrivateIPv4Literal(%q) = %v, want %v", input, got, want)
}
}
}
func TestWebDeviceForwardPort(t *testing.T) {
t.Parallel()
port, ok := WebDeviceForwardPort("192.168.1.124")
if !ok {
t.Fatal("WebDeviceForwardPort() ok = false")
}
if port != 31124 {
t.Fatalf("port = %d, want 31124", port)
}
}
func TestScanBuildsDirectURLAndAllowList(t *testing.T) {
t.Parallel()
svc := NewService()
svc.interfaceGetter = func() ([]InterfaceInfo, error) {
return []InterfaceInfo{{
Name: "eth0",
IP: "10.8.0.14",
Netmask: "255.255.255.0",
}}, nil
}
svc.tcpScanner = func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
return []TCPDevice{{IP: "192.168.1.124", Port: 80}}, nil
}
svc.newForwarder = func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
return &webDeviceForwarder{ip: ip, port: port, targetAddress: targetAddress}, nil
}
req := httptest.NewRequest("GET", "http://10.8.0.14:13000/api/web-devices/scan", nil)
result, err := svc.Scan(req)
if err != nil {
t.Fatalf("Scan() error = %v", err)
}
if result.Count != 1 {
t.Fatalf("result.Count = %d", result.Count)
}
if !svc.IsAllowed("192.168.1.124") {
t.Fatal("expected IP to be allowed after scan")
}
if result.Devices[0].DirectURL != "http://10.8.0.14:31124/" {
t.Fatalf("DirectURL = %q", result.Devices[0].DirectURL)
}
}