Files

698 lines
16 KiB
Go

package webdevice
import (
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"sort"
"strconv"
"strings"
"sync"
"time"
)
type InterfaceInfo struct {
Name string `json:"name"`
IP string `json:"ip"`
Netmask string `json:"netmask"`
}
type TCPDevice struct {
IP string
Port int
}
type DeviceInfo struct {
IP string `json:"ip"`
Interface string `json:"interface"`
Port int `json:"port"`
TargetURL string `json:"target_url"`
ProxyURL string `json:"proxy_url"`
DirectURL string `json:"direct_url,omitempty"`
ForwardPort int `json:"forward_port,omitempty"`
}
type ScanResult struct {
Interfaces []InterfaceInfo `json:"interfaces"`
Devices []DeviceInfo `json:"devices"`
Count int `json:"count"`
Errors []string `json:"errors,omitempty"`
Message string `json:"message,omitempty"`
}
type InterfaceGetter func() ([]InterfaceInfo, error)
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
type ProxyTargetResolver func(ip string) string
const (
webDeviceScanConcurrency = 128
webDeviceScanTimeout = 1500 * time.Millisecond
webDeviceScanAttempts = 2
webDeviceScanRetryDelay = 100 * time.Millisecond
maxWebDeviceScanAddrs = 256
)
type Service struct {
mu sync.RWMutex
allowed map[string]time.Time
forwarders map[string]*webDeviceForwarder
proxyLimiters map[string]chan struct{}
interfaceGetter InterfaceGetter
hostLANGetter InterfaceGetter
tcpScanner TCPScanner
newForwarder ForwarderFactory
proxyTarget ProxyTargetResolver
forwardTarget func(ip string) string
}
func NewService() *Service {
return &Service{
allowed: make(map[string]time.Time),
forwarders: make(map[string]*webDeviceForwarder),
proxyLimiters: make(map[string]chan struct{}),
interfaceGetter: defaultInterfaceGetter,
hostLANGetter: defaultHostLANInterfaceGetter,
tcpScanner: scanTCP,
newForwarder: newWebDeviceForwarder,
proxyTarget: defaultProxyTarget,
forwardTarget: defaultForwardTarget,
}
}
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
interfaces, err := s.interfaceGetter()
if err != nil {
return nil, err
}
_, host := requestBase(r)
scanInterfaces, scanWarnings := s.scanInterfaces(interfaces, host)
if len(scanInterfaces) == 0 {
return &ScanResult{
Interfaces: []InterfaceInfo{},
Devices: []DeviceInfo{},
Message: "未找到有效的网卡",
Errors: scanWarnings,
}, nil
}
excludeIPs := make(map[string]bool)
for _, iface := range interfaces {
excludeIPs[iface.IP] = true
}
for _, iface := range scanInterfaces {
excludeIPs[iface.IP] = true
}
result := &ScanResult{
Interfaces: scanInterfaces,
Devices: []DeviceInfo{},
Errors: scanWarnings,
}
for _, iface := range scanInterfaces {
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
if scanErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
continue
}
for _, device := range devices {
if !IsPrivateIPv4Literal(device.IP) {
continue
}
s.allowIP(device.IP)
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
if forwardErr != nil {
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动%s透明转发失败: %v", iface.Name, device.IP, forwardErr))
continue
}
scheme, host := requestBase(r)
deviceInfo := DeviceInfo{
IP: device.IP,
Interface: iface.Name,
Port: device.Port,
TargetURL: fmt.Sprintf("http://%s/", device.IP),
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
DirectURL: buildDirectURL(scheme, host, forwardPort),
ForwardPort: forwardPort,
}
result.Devices = append(result.Devices, deviceInfo)
}
}
sort.Slice(result.Devices, func(i, j int) bool {
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
})
result.Count = len(result.Devices)
return result, nil
}
func (s *Service) scanInterfaces(containerInterfaces []InterfaceInfo, requestHost string) ([]InterfaceInfo, []string) {
if s.hostLANGetter != nil {
hostInterfaces, err := s.hostLANGetter()
if err == nil && len(hostInterfaces) > 0 {
return hostInterfaces, nil
}
if err != nil {
return appendRequestHostInterface(containerInterfaces, requestHost), []string{
"宿主机局域网探测失败,已回退到请求地址网段: " + err.Error(),
}
}
}
return appendRequestHostInterface(containerInterfaces, requestHost), nil
}
func appendRequestHostInterface(interfaces []InterfaceInfo, host string) []InterfaceInfo {
if !IsPrivateIPv4Literal(host) {
return interfaces
}
for _, iface := range interfaces {
if iface.IP == host {
return interfaces
}
}
return append(interfaces, InterfaceInfo{
Name: "request-host",
IP: host,
Netmask: "255.255.255.0",
})
}
func (s *Service) allowIP(ip string) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowed[ip] = time.Now()
}
func (s *Service) AllowIP(ip string) {
s.allowIP(ip)
}
func (s *Service) IsAllowed(ip string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.allowed[ip]
return ok
}
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
if getter != nil {
s.interfaceGetter = getter
}
}
func (s *Service) SetHostLANGetter(getter InterfaceGetter) {
if getter != nil {
s.hostLANGetter = getter
}
}
func (s *Service) SetTCPScanner(scanner TCPScanner) {
if scanner != nil {
s.tcpScanner = scanner
}
}
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
if factory != nil {
s.newForwarder = factory
}
}
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
if resolver != nil {
s.proxyTarget = resolver
}
}
func (s *Service) EnsureForwarder(ip string) (int, error) {
port, ok := WebDeviceForwardPort(ip)
if !ok {
return 0, fmt.Errorf("无效的IPv4地址")
}
s.mu.Lock()
defer s.mu.Unlock()
if forwarder, ok := s.forwarders[ip]; ok {
return forwarder.port, nil
}
targetAddress := defaultForwardTarget(ip)
if s.forwardTarget != nil {
targetAddress = s.forwardTarget(ip)
}
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
if err != nil {
return 0, err
}
s.forwarders[ip] = forwarder
go forwarder.serve()
return port, nil
}
func (s *Service) ProxyTargetURL(ip string) string {
if s.proxyTarget == nil {
return defaultProxyTarget(ip)
}
return s.proxyTarget(ip)
}
func (s *Service) proxyLimiter(ip string) chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.proxyLimiters == nil {
s.proxyLimiters = make(map[string]chan struct{})
}
limiter := s.proxyLimiters[ip]
if limiter == nil {
limiter = make(chan struct{}, webDeviceProxyConcurrency)
s.proxyLimiters[ip] = limiter
}
return limiter
}
func IsPrivateIPv4Literal(value string) bool {
ip := net.ParseIP(value)
if ip == nil {
return false
}
ip4 := ip.To4()
if ip4 == nil {
return false
}
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
}
func WebDeviceForwardPort(ip string) (int, bool) {
parsed := net.ParseIP(ip)
if parsed == nil {
return 0, false
}
ip4 := parsed.To4()
if ip4 == nil {
return 0, false
}
return 31000 + int(ip4[3]), true
}
func requestBase(r *http.Request) (string, string) {
scheme := r.Header.Get("X-Forwarded-Proto")
if scheme == "" {
scheme = "http"
if r.TLS != nil {
scheme = "https"
}
}
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
if hostname, _, err := net.SplitHostPort(host); err == nil {
host = hostname
}
return scheme, host
}
func buildDirectURL(scheme, host string, port int) string {
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
}
type webDeviceForwarder struct {
ip string
port int
targetAddress string
listener net.Listener
}
type WebDeviceForwarder = webDeviceForwarder
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
return nil, err
}
if port == 0 {
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
port = tcpAddr.Port
}
}
return &webDeviceForwarder{
ip: ip,
port: port,
targetAddress: targetAddress,
listener: listener,
}, nil
}
func (f *webDeviceForwarder) serve() {
if f == nil || f.listener == nil {
return
}
for {
clientConn, err := f.listener.Accept()
if err != nil {
return
}
go f.handle(clientConn)
}
}
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
if err != nil {
_ = clientConn.Close()
return
}
errCh := make(chan error, 2)
go func() {
_, err := ioCopy(targetConn, clientConn)
errCh <- err
}()
go func() {
_, err := ioCopy(clientConn, targetConn)
errCh <- err
}()
<-errCh
_ = clientConn.Close()
_ = targetConn.Close()
}
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
return copyConn(dst, src)
}
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
return io.Copy(dst, src)
}
func defaultProxyTarget(ip string) string {
return "http://" + net.JoinHostPort(ip, "80")
}
func defaultForwardTarget(ip string) string {
return net.JoinHostPort(ip, "80")
}
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
var interfaces []InterfaceInfo
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
if strings.Contains(iface.Name, "docker") ||
strings.Contains(iface.Name, "veth") ||
strings.Contains(iface.Name, "br-") ||
strings.Contains(iface.Name, "tun") ||
strings.Contains(iface.Name, "tap") ||
strings.HasPrefix(iface.Name, "lo") {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || ipNet.IP.To4() == nil {
continue
}
interfaces = append(interfaces, InterfaceInfo{
Name: iface.Name,
IP: ipNet.IP.String(),
Netmask: net.IP(ipNet.Mask).String(),
})
break
}
}
return interfaces, nil
}
func defaultHostLANInterfaceGetter() ([]InterfaceInfo, error) {
image := strings.TrimSpace(os.Getenv("MANAGED_PORTAL_HOST_SCAN_IMAGE"))
if image == "" {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(
ctx,
"docker",
"run",
"--rm",
"--network",
"host",
"--entrypoint",
"/sbin/ip",
image,
"-o",
"-4",
"addr",
"show",
"scope",
"global",
)
output, err := cmd.CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("读取宿主机网卡超时")
}
if err != nil {
return nil, fmt.Errorf("读取宿主机网卡失败: %w: %s", err, strings.TrimSpace(string(output)))
}
return parseHostLANInterfaces(string(output)), nil
}
func parseHostLANInterfaces(output string) []InterfaceInfo {
var interfaces []InterfaceInfo
seen := make(map[string]bool)
for _, line := range strings.Split(output, "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
name := strings.TrimSuffix(fields[1], ":")
if ignoredInterfaceName(name) {
continue
}
inetIndex := -1
for index, field := range fields {
if field == "inet" {
inetIndex = index
break
}
}
if inetIndex == -1 || inetIndex+1 >= len(fields) {
continue
}
ip, ipNet, err := net.ParseCIDR(fields[inetIndex+1])
if err != nil || ip.To4() == nil || ipNet == nil {
continue
}
ipString := ip.String()
if !IsPrivateIPv4Literal(ipString) || seen[ipString] {
continue
}
ones, bits := ipNet.Mask.Size()
if bits != 32 {
continue
}
mask := net.IP(net.CIDRMask(ones, bits)).String()
interfaces = append(interfaces, InterfaceInfo{
Name: name,
IP: ipString,
Netmask: mask,
})
seen[ipString] = true
}
return interfaces
}
func ignoredInterfaceName(name string) bool {
lower := strings.ToLower(name)
if lower == "lo" ||
strings.Contains(lower, "docker") ||
strings.Contains(lower, "veth") ||
strings.Contains(lower, "br-") ||
strings.Contains(lower, "tun") ||
strings.Contains(lower, "tap") ||
strings.Contains(lower, "wg") ||
strings.Contains(lower, "tailscale") ||
strings.Contains(lower, "utun") ||
strings.Contains(lower, "ppp") ||
strings.Contains(lower, "zerotier") ||
strings.HasPrefix(lower, "zt") {
return true
}
return false
}
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
ipRange, err := calculateIPRange(ip, netmask)
if err != nil {
return nil, err
}
if port <= 0 || port > 65535 {
return nil, fmt.Errorf("无效的端口: %d", port)
}
var devices []TCPDevice
var mu sync.Mutex
var wg sync.WaitGroup
semaphore := make(chan struct{}, webDeviceScanConcurrency)
current := make(net.IP, len(ipRange.Start))
copy(current, ipRange.Start)
incrementIP(current)
for {
if current.To4().Equal(ipRange.End.To4()) {
break
}
currentIP := current.String()
if excludeIPs[currentIP] {
incrementIP(current)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(targetIP string) {
defer wg.Done()
defer func() { <-semaphore }()
if scanTCPPortWithRetry(targetIP, port) {
mu.Lock()
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
mu.Unlock()
}
}(currentIP)
incrementIP(current)
}
wg.Wait()
sort.Slice(devices, func(i, j int) bool {
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
})
return devices, nil
}
type ipRange struct {
Start net.IP
End net.IP
}
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
parseIP := net.ParseIP(ip)
if parseIP == nil {
return nil, fmt.Errorf("无效的IP: %s", ip)
}
mask := net.IPMask(net.ParseIP(netmask).To4())
if mask == nil {
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
}
network := &net.IPNet{
IP: parseIP.Mask(mask),
Mask: mask,
}
broadcast := make(net.IP, len(network.IP))
copy(broadcast, network.IP)
for i := 0; i < len(mask); i++ {
broadcast[i] |= ^mask[i]
}
result := &ipRange{Start: network.IP.To4(), End: broadcast}
if ipToUint32(result.End)-ipToUint32(result.Start)+1 <= maxWebDeviceScanAddrs {
return result, nil
}
local24Mask := net.CIDRMask(24, 32)
local24Network := parseIP.To4().Mask(local24Mask)
local24Broadcast := make(net.IP, len(local24Network))
copy(local24Broadcast, local24Network)
for i := 0; i < len(local24Mask); i++ {
local24Broadcast[i] |= ^local24Mask[i]
}
return &ipRange{Start: local24Network, End: local24Broadcast}, nil
}
func scanTCPPortWithRetry(ip string, port int) bool {
for attempt := 0; attempt < webDeviceScanAttempts; attempt++ {
if scanTCPPort(ip, port, webDeviceScanTimeout) {
return true
}
if attempt < webDeviceScanAttempts-1 {
time.Sleep(webDeviceScanRetryDelay)
}
}
return false
}
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
if port <= 0 || port > 65535 {
return false
}
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return false
}
_ = conn.Close()
return true
}
func incrementIP(ip net.IP) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] > 0 {
return
}
}
}
func ipv4ToUint32(value string) uint32 {
parsed := net.ParseIP(value)
if parsed == nil {
return 0
}
return ipToUint32(parsed)
}
func ipToUint32(parsed net.IP) uint32 {
ip := parsed.To4()
if ip == nil {
return 0
}
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}