Files
managed-portal/internal/webdevice/service.go
2026-04-27 10:04:36 +08:00

496 lines
11 KiB
Go

package webdevice
import (
"fmt"
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
)
type InterfaceInfo struct {
Name string `json:"name"`
IP string `json:"ip"`
Netmask string `json:"netmask"`
}
type TCPDevice struct {
IP string
Port int
}
type DeviceInfo struct {
IP string `json:"ip"`
Interface string `json:"interface"`
Port int `json:"port"`
TargetURL string `json:"target_url"`
ProxyURL string `json:"proxy_url"`
DirectURL string `json:"direct_url,omitempty"`
ForwardPort int `json:"forward_port,omitempty"`
}
type ScanResult struct {
Interfaces []InterfaceInfo `json:"interfaces"`
Devices []DeviceInfo `json:"devices"`
Count int `json:"count"`
Errors []string `json:"errors,omitempty"`
Message string `json:"message,omitempty"`
}
type InterfaceGetter func() ([]InterfaceInfo, error)
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
type ProxyTargetResolver func(ip string) string
type Service struct {
mu sync.RWMutex
allowed map[string]time.Time
forwarders map[string]*webDeviceForwarder
interfaceGetter InterfaceGetter
tcpScanner TCPScanner
newForwarder ForwarderFactory
proxyTarget ProxyTargetResolver
forwardTarget func(ip string) string
}
func NewService() *Service {
return &Service{
allowed: make(map[string]time.Time),
forwarders: make(map[string]*webDeviceForwarder),
interfaceGetter: defaultInterfaceGetter,
tcpScanner: scanTCP,
newForwarder: newWebDeviceForwarder,
proxyTarget: defaultProxyTarget,
forwardTarget: defaultForwardTarget,
}
}
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
interfaces, err := s.interfaceGetter()
if err != nil {
return nil, err
}
if len(interfaces) == 0 {
return &ScanResult{
Interfaces: []InterfaceInfo{},
Devices: []DeviceInfo{},
Message: "未找到有效的网卡",
}, nil
}
excludeIPs := make(map[string]bool)
for _, iface := range interfaces {
excludeIPs[iface.IP] = true
}
result := &ScanResult{
Interfaces: interfaces,
Devices: []DeviceInfo{},
Errors: []string{},
}
scheme, host := requestBase(r)
for _, iface := range interfaces {
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
if scanErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
continue
}
for _, device := range devices {
if !IsPrivateIPv4Literal(device.IP) {
continue
}
s.allowIP(device.IP)
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
if forwardErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动网页直连转发失败: %v", device.IP, forwardErr))
}
deviceInfo := DeviceInfo{
IP: device.IP,
Interface: iface.Name,
Port: device.Port,
TargetURL: fmt.Sprintf("http://%s/", device.IP),
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
}
if forwardErr == nil {
deviceInfo.ForwardPort = forwardPort
deviceInfo.DirectURL = buildDirectURL(scheme, host, forwardPort)
}
result.Devices = append(result.Devices, deviceInfo)
}
}
sort.Slice(result.Devices, func(i, j int) bool {
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
})
result.Count = len(result.Devices)
return result, nil
}
func (s *Service) allowIP(ip string) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowed[ip] = time.Now()
}
func (s *Service) AllowIP(ip string) {
s.allowIP(ip)
}
func (s *Service) IsAllowed(ip string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.allowed[ip]
return ok
}
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
if getter != nil {
s.interfaceGetter = getter
}
}
func (s *Service) SetTCPScanner(scanner TCPScanner) {
if scanner != nil {
s.tcpScanner = scanner
}
}
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
if factory != nil {
s.newForwarder = factory
}
}
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
if resolver != nil {
s.proxyTarget = resolver
}
}
func (s *Service) EnsureForwarder(ip string) (int, error) {
port, ok := WebDeviceForwardPort(ip)
if !ok {
return 0, fmt.Errorf("无效的IPv4地址")
}
s.mu.Lock()
defer s.mu.Unlock()
if forwarder, ok := s.forwarders[ip]; ok {
return forwarder.port, nil
}
targetAddress := defaultForwardTarget(ip)
if s.forwardTarget != nil {
targetAddress = s.forwardTarget(ip)
}
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
if err != nil {
return 0, err
}
s.forwarders[ip] = forwarder
go forwarder.serve()
return port, nil
}
func (s *Service) ProxyTargetURL(ip string) string {
if s.proxyTarget == nil {
return defaultProxyTarget(ip)
}
return s.proxyTarget(ip)
}
func IsPrivateIPv4Literal(value string) bool {
ip := net.ParseIP(value)
if ip == nil {
return false
}
ip4 := ip.To4()
if ip4 == nil {
return false
}
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
}
func WebDeviceForwardPort(ip string) (int, bool) {
parsed := net.ParseIP(ip)
if parsed == nil {
return 0, false
}
ip4 := parsed.To4()
if ip4 == nil {
return 0, false
}
return 31000 + int(ip4[3]), true
}
func requestBase(r *http.Request) (string, string) {
scheme := r.Header.Get("X-Forwarded-Proto")
if scheme == "" {
scheme = "http"
if r.TLS != nil {
scheme = "https"
}
}
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
if hostname, _, err := net.SplitHostPort(host); err == nil {
host = hostname
}
return scheme, host
}
func buildDirectURL(scheme, host string, port int) string {
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
}
type webDeviceForwarder struct {
ip string
port int
targetAddress string
listener net.Listener
}
type WebDeviceForwarder = webDeviceForwarder
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
return nil, err
}
if port == 0 {
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
port = tcpAddr.Port
}
}
return &webDeviceForwarder{
ip: ip,
port: port,
targetAddress: targetAddress,
listener: listener,
}, nil
}
func (f *webDeviceForwarder) serve() {
if f == nil || f.listener == nil {
return
}
for {
clientConn, err := f.listener.Accept()
if err != nil {
return
}
go f.handle(clientConn)
}
}
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
if err != nil {
_ = clientConn.Close()
return
}
errCh := make(chan error, 2)
go func() {
_, err := ioCopy(targetConn, clientConn)
errCh <- err
}()
go func() {
_, err := ioCopy(clientConn, targetConn)
errCh <- err
}()
<-errCh
_ = clientConn.Close()
_ = targetConn.Close()
}
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
return copyConn(dst, src)
}
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
return io.Copy(dst, src)
}
func defaultProxyTarget(ip string) string {
return "http://" + net.JoinHostPort(ip, "80")
}
func defaultForwardTarget(ip string) string {
return net.JoinHostPort(ip, "80")
}
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
var interfaces []InterfaceInfo
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
if strings.Contains(iface.Name, "docker") ||
strings.Contains(iface.Name, "veth") ||
strings.Contains(iface.Name, "br-") ||
strings.Contains(iface.Name, "tun") ||
strings.Contains(iface.Name, "tap") ||
strings.HasPrefix(iface.Name, "lo") {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || ipNet.IP.To4() == nil {
continue
}
interfaces = append(interfaces, InterfaceInfo{
Name: iface.Name,
IP: ipNet.IP.String(),
Netmask: net.IP(ipNet.Mask).String(),
})
break
}
}
return interfaces, nil
}
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
ipRange, err := calculateIPRange(ip, netmask)
if err != nil {
return nil, err
}
if port <= 0 || port > 65535 {
return nil, fmt.Errorf("无效的端口: %d", port)
}
var devices []TCPDevice
var mu sync.Mutex
var wg sync.WaitGroup
semaphore := make(chan struct{}, 20)
timeout := 2 * time.Second
current := make(net.IP, len(ipRange.Start))
copy(current, ipRange.Start)
incrementIP(current)
for {
if current.To4().Equal(ipRange.End.To4()) {
break
}
currentIP := current.String()
if excludeIPs[currentIP] {
incrementIP(current)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(targetIP string) {
defer wg.Done()
defer func() { <-semaphore }()
if scanTCPPort(targetIP, port, timeout) {
mu.Lock()
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
mu.Unlock()
}
}(currentIP)
incrementIP(current)
}
wg.Wait()
sort.Slice(devices, func(i, j int) bool {
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
})
return devices, nil
}
type ipRange struct {
Start net.IP
End net.IP
}
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
parseIP := net.ParseIP(ip)
if parseIP == nil {
return nil, fmt.Errorf("无效的IP: %s", ip)
}
mask := net.IPMask(net.ParseIP(netmask).To4())
if mask == nil {
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
}
network := &net.IPNet{
IP: parseIP.Mask(mask),
Mask: mask,
}
broadcast := make(net.IP, len(network.IP))
copy(broadcast, network.IP)
for i := 0; i < len(mask); i++ {
broadcast[i] |= ^mask[i]
}
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
}
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
if port <= 0 || port > 65535 {
return false
}
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return false
}
_ = conn.Close()
return true
}
func incrementIP(ip net.IP) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] > 0 {
return
}
}
}
func ipv4ToUint32(value string) uint32 {
parsed := net.ParseIP(value)
if parsed == nil {
return 0
}
ip := parsed.To4()
if ip == nil {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}