673 lines
15 KiB
Go
673 lines
15 KiB
Go
package webdevice
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type InterfaceInfo struct {
|
|
Name string `json:"name"`
|
|
IP string `json:"ip"`
|
|
Netmask string `json:"netmask"`
|
|
}
|
|
|
|
type TCPDevice struct {
|
|
IP string
|
|
Port int
|
|
}
|
|
|
|
type DeviceInfo struct {
|
|
IP string `json:"ip"`
|
|
Interface string `json:"interface"`
|
|
Port int `json:"port"`
|
|
TargetURL string `json:"target_url"`
|
|
ProxyURL string `json:"proxy_url"`
|
|
DirectURL string `json:"direct_url,omitempty"`
|
|
ForwardPort int `json:"forward_port,omitempty"`
|
|
}
|
|
|
|
type ScanResult struct {
|
|
Interfaces []InterfaceInfo `json:"interfaces"`
|
|
Devices []DeviceInfo `json:"devices"`
|
|
Count int `json:"count"`
|
|
Errors []string `json:"errors,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
}
|
|
|
|
type InterfaceGetter func() ([]InterfaceInfo, error)
|
|
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
|
|
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
|
|
type ProxyTargetResolver func(ip string) string
|
|
|
|
const (
|
|
webDeviceScanConcurrency = 128
|
|
webDeviceScanTimeout = 1500 * time.Millisecond
|
|
webDeviceScanAttempts = 2
|
|
webDeviceScanRetryDelay = 100 * time.Millisecond
|
|
maxWebDeviceScanAddrs = 256
|
|
)
|
|
|
|
type Service struct {
|
|
mu sync.RWMutex
|
|
allowed map[string]time.Time
|
|
forwarders map[string]*webDeviceForwarder
|
|
interfaceGetter InterfaceGetter
|
|
hostLANGetter InterfaceGetter
|
|
tcpScanner TCPScanner
|
|
newForwarder ForwarderFactory
|
|
proxyTarget ProxyTargetResolver
|
|
forwardTarget func(ip string) string
|
|
}
|
|
|
|
func NewService() *Service {
|
|
return &Service{
|
|
allowed: make(map[string]time.Time),
|
|
forwarders: make(map[string]*webDeviceForwarder),
|
|
interfaceGetter: defaultInterfaceGetter,
|
|
hostLANGetter: defaultHostLANInterfaceGetter,
|
|
tcpScanner: scanTCP,
|
|
newForwarder: newWebDeviceForwarder,
|
|
proxyTarget: defaultProxyTarget,
|
|
forwardTarget: defaultForwardTarget,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
|
interfaces, err := s.interfaceGetter()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, host := requestBase(r)
|
|
scanInterfaces, scanWarnings := s.scanInterfaces(interfaces, host)
|
|
|
|
if len(scanInterfaces) == 0 {
|
|
return &ScanResult{
|
|
Interfaces: []InterfaceInfo{},
|
|
Devices: []DeviceInfo{},
|
|
Message: "未找到有效的网卡",
|
|
Errors: scanWarnings,
|
|
}, nil
|
|
}
|
|
|
|
excludeIPs := make(map[string]bool)
|
|
for _, iface := range interfaces {
|
|
excludeIPs[iface.IP] = true
|
|
}
|
|
for _, iface := range scanInterfaces {
|
|
excludeIPs[iface.IP] = true
|
|
}
|
|
|
|
result := &ScanResult{
|
|
Interfaces: scanInterfaces,
|
|
Devices: []DeviceInfo{},
|
|
Errors: scanWarnings,
|
|
}
|
|
|
|
for _, iface := range scanInterfaces {
|
|
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
|
|
if scanErr != nil {
|
|
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
|
|
continue
|
|
}
|
|
|
|
for _, device := range devices {
|
|
if !IsPrivateIPv4Literal(device.IP) {
|
|
continue
|
|
}
|
|
|
|
s.allowIP(device.IP)
|
|
|
|
deviceInfo := DeviceInfo{
|
|
IP: device.IP,
|
|
Interface: iface.Name,
|
|
Port: device.Port,
|
|
TargetURL: fmt.Sprintf("http://%s/", device.IP),
|
|
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
|
|
}
|
|
result.Devices = append(result.Devices, deviceInfo)
|
|
}
|
|
}
|
|
|
|
sort.Slice(result.Devices, func(i, j int) bool {
|
|
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
|
|
})
|
|
result.Count = len(result.Devices)
|
|
return result, nil
|
|
}
|
|
|
|
func (s *Service) scanInterfaces(containerInterfaces []InterfaceInfo, requestHost string) ([]InterfaceInfo, []string) {
|
|
if s.hostLANGetter != nil {
|
|
hostInterfaces, err := s.hostLANGetter()
|
|
if err == nil && len(hostInterfaces) > 0 {
|
|
return hostInterfaces, nil
|
|
}
|
|
if err != nil {
|
|
return appendRequestHostInterface(containerInterfaces, requestHost), []string{
|
|
"宿主机局域网探测失败,已回退到请求地址网段: " + err.Error(),
|
|
}
|
|
}
|
|
}
|
|
return appendRequestHostInterface(containerInterfaces, requestHost), nil
|
|
}
|
|
|
|
func appendRequestHostInterface(interfaces []InterfaceInfo, host string) []InterfaceInfo {
|
|
if !IsPrivateIPv4Literal(host) {
|
|
return interfaces
|
|
}
|
|
for _, iface := range interfaces {
|
|
if iface.IP == host {
|
|
return interfaces
|
|
}
|
|
}
|
|
return append(interfaces, InterfaceInfo{
|
|
Name: "request-host",
|
|
IP: host,
|
|
Netmask: "255.255.255.0",
|
|
})
|
|
}
|
|
|
|
func (s *Service) allowIP(ip string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.allowed[ip] = time.Now()
|
|
}
|
|
|
|
func (s *Service) AllowIP(ip string) {
|
|
s.allowIP(ip)
|
|
}
|
|
|
|
func (s *Service) IsAllowed(ip string) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
_, ok := s.allowed[ip]
|
|
return ok
|
|
}
|
|
|
|
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
|
|
if getter != nil {
|
|
s.interfaceGetter = getter
|
|
}
|
|
}
|
|
|
|
func (s *Service) SetHostLANGetter(getter InterfaceGetter) {
|
|
if getter != nil {
|
|
s.hostLANGetter = getter
|
|
}
|
|
}
|
|
|
|
func (s *Service) SetTCPScanner(scanner TCPScanner) {
|
|
if scanner != nil {
|
|
s.tcpScanner = scanner
|
|
}
|
|
}
|
|
|
|
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
|
|
if factory != nil {
|
|
s.newForwarder = factory
|
|
}
|
|
}
|
|
|
|
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
|
|
if resolver != nil {
|
|
s.proxyTarget = resolver
|
|
}
|
|
}
|
|
|
|
func (s *Service) EnsureForwarder(ip string) (int, error) {
|
|
port, ok := WebDeviceForwardPort(ip)
|
|
if !ok {
|
|
return 0, fmt.Errorf("无效的IPv4地址")
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if forwarder, ok := s.forwarders[ip]; ok {
|
|
return forwarder.port, nil
|
|
}
|
|
|
|
targetAddress := defaultForwardTarget(ip)
|
|
if s.forwardTarget != nil {
|
|
targetAddress = s.forwardTarget(ip)
|
|
}
|
|
|
|
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
s.forwarders[ip] = forwarder
|
|
go forwarder.serve()
|
|
return port, nil
|
|
}
|
|
|
|
func (s *Service) ProxyTargetURL(ip string) string {
|
|
if s.proxyTarget == nil {
|
|
return defaultProxyTarget(ip)
|
|
}
|
|
return s.proxyTarget(ip)
|
|
}
|
|
|
|
func IsPrivateIPv4Literal(value string) bool {
|
|
ip := net.ParseIP(value)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
ip4 := ip.To4()
|
|
if ip4 == nil {
|
|
return false
|
|
}
|
|
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
|
|
}
|
|
|
|
func WebDeviceForwardPort(ip string) (int, bool) {
|
|
parsed := net.ParseIP(ip)
|
|
if parsed == nil {
|
|
return 0, false
|
|
}
|
|
ip4 := parsed.To4()
|
|
if ip4 == nil {
|
|
return 0, false
|
|
}
|
|
return 31000 + int(ip4[3]), true
|
|
}
|
|
|
|
func requestBase(r *http.Request) (string, string) {
|
|
scheme := r.Header.Get("X-Forwarded-Proto")
|
|
if scheme == "" {
|
|
scheme = "http"
|
|
if r.TLS != nil {
|
|
scheme = "https"
|
|
}
|
|
}
|
|
|
|
host := r.Header.Get("X-Forwarded-Host")
|
|
if host == "" {
|
|
host = r.Host
|
|
}
|
|
if hostname, _, err := net.SplitHostPort(host); err == nil {
|
|
host = hostname
|
|
}
|
|
return scheme, host
|
|
}
|
|
|
|
func buildDirectURL(scheme, host string, port int) string {
|
|
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
|
|
}
|
|
|
|
type webDeviceForwarder struct {
|
|
ip string
|
|
port int
|
|
targetAddress string
|
|
listener net.Listener
|
|
}
|
|
|
|
type WebDeviceForwarder = webDeviceForwarder
|
|
|
|
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
|
listener, err := net.Listen("tcp", listenAddress)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if port == 0 {
|
|
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
|
|
port = tcpAddr.Port
|
|
}
|
|
}
|
|
return &webDeviceForwarder{
|
|
ip: ip,
|
|
port: port,
|
|
targetAddress: targetAddress,
|
|
listener: listener,
|
|
}, nil
|
|
}
|
|
|
|
func (f *webDeviceForwarder) serve() {
|
|
if f == nil || f.listener == nil {
|
|
return
|
|
}
|
|
for {
|
|
clientConn, err := f.listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
go f.handle(clientConn)
|
|
}
|
|
}
|
|
|
|
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
|
|
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
|
|
if err != nil {
|
|
_ = clientConn.Close()
|
|
return
|
|
}
|
|
|
|
errCh := make(chan error, 2)
|
|
go func() {
|
|
_, err := ioCopy(targetConn, clientConn)
|
|
errCh <- err
|
|
}()
|
|
go func() {
|
|
_, err := ioCopy(clientConn, targetConn)
|
|
errCh <- err
|
|
}()
|
|
|
|
<-errCh
|
|
_ = clientConn.Close()
|
|
_ = targetConn.Close()
|
|
}
|
|
|
|
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
|
|
return copyConn(dst, src)
|
|
}
|
|
|
|
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
|
|
return io.Copy(dst, src)
|
|
}
|
|
|
|
func defaultProxyTarget(ip string) string {
|
|
return "http://" + net.JoinHostPort(ip, "80")
|
|
}
|
|
|
|
func defaultForwardTarget(ip string) string {
|
|
return net.JoinHostPort(ip, "80")
|
|
}
|
|
|
|
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
|
|
var interfaces []InterfaceInfo
|
|
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
|
|
}
|
|
|
|
for _, iface := range ifaces {
|
|
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
|
continue
|
|
}
|
|
if strings.Contains(iface.Name, "docker") ||
|
|
strings.Contains(iface.Name, "veth") ||
|
|
strings.Contains(iface.Name, "br-") ||
|
|
strings.Contains(iface.Name, "tun") ||
|
|
strings.Contains(iface.Name, "tap") ||
|
|
strings.HasPrefix(iface.Name, "lo") {
|
|
continue
|
|
}
|
|
|
|
addrs, err := iface.Addrs()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
ipNet, ok := addr.(*net.IPNet)
|
|
if !ok || ipNet.IP.To4() == nil {
|
|
continue
|
|
}
|
|
|
|
interfaces = append(interfaces, InterfaceInfo{
|
|
Name: iface.Name,
|
|
IP: ipNet.IP.String(),
|
|
Netmask: net.IP(ipNet.Mask).String(),
|
|
})
|
|
break
|
|
}
|
|
}
|
|
|
|
return interfaces, nil
|
|
}
|
|
|
|
func defaultHostLANInterfaceGetter() ([]InterfaceInfo, error) {
|
|
image := strings.TrimSpace(os.Getenv("MANAGED_PORTAL_HOST_SCAN_IMAGE"))
|
|
if image == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
cmd := exec.CommandContext(
|
|
ctx,
|
|
"docker",
|
|
"run",
|
|
"--rm",
|
|
"--network",
|
|
"host",
|
|
"--entrypoint",
|
|
"/sbin/ip",
|
|
image,
|
|
"-o",
|
|
"-4",
|
|
"addr",
|
|
"show",
|
|
"scope",
|
|
"global",
|
|
)
|
|
output, err := cmd.CombinedOutput()
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
return nil, fmt.Errorf("读取宿主机网卡超时")
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("读取宿主机网卡失败: %w: %s", err, strings.TrimSpace(string(output)))
|
|
}
|
|
return parseHostLANInterfaces(string(output)), nil
|
|
}
|
|
|
|
func parseHostLANInterfaces(output string) []InterfaceInfo {
|
|
var interfaces []InterfaceInfo
|
|
seen := make(map[string]bool)
|
|
for _, line := range strings.Split(output, "\n") {
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 4 {
|
|
continue
|
|
}
|
|
|
|
name := strings.TrimSuffix(fields[1], ":")
|
|
if ignoredInterfaceName(name) {
|
|
continue
|
|
}
|
|
|
|
inetIndex := -1
|
|
for index, field := range fields {
|
|
if field == "inet" {
|
|
inetIndex = index
|
|
break
|
|
}
|
|
}
|
|
if inetIndex == -1 || inetIndex+1 >= len(fields) {
|
|
continue
|
|
}
|
|
|
|
ip, ipNet, err := net.ParseCIDR(fields[inetIndex+1])
|
|
if err != nil || ip.To4() == nil || ipNet == nil {
|
|
continue
|
|
}
|
|
ipString := ip.String()
|
|
if !IsPrivateIPv4Literal(ipString) || seen[ipString] {
|
|
continue
|
|
}
|
|
ones, bits := ipNet.Mask.Size()
|
|
if bits != 32 {
|
|
continue
|
|
}
|
|
mask := net.IP(net.CIDRMask(ones, bits)).String()
|
|
interfaces = append(interfaces, InterfaceInfo{
|
|
Name: name,
|
|
IP: ipString,
|
|
Netmask: mask,
|
|
})
|
|
seen[ipString] = true
|
|
}
|
|
return interfaces
|
|
}
|
|
|
|
func ignoredInterfaceName(name string) bool {
|
|
lower := strings.ToLower(name)
|
|
if lower == "lo" ||
|
|
strings.Contains(lower, "docker") ||
|
|
strings.Contains(lower, "veth") ||
|
|
strings.Contains(lower, "br-") ||
|
|
strings.Contains(lower, "tun") ||
|
|
strings.Contains(lower, "tap") ||
|
|
strings.Contains(lower, "wg") ||
|
|
strings.Contains(lower, "tailscale") ||
|
|
strings.Contains(lower, "utun") ||
|
|
strings.Contains(lower, "ppp") ||
|
|
strings.Contains(lower, "zerotier") ||
|
|
strings.HasPrefix(lower, "zt") {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
|
ipRange, err := calculateIPRange(ip, netmask)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if port <= 0 || port > 65535 {
|
|
return nil, fmt.Errorf("无效的端口: %d", port)
|
|
}
|
|
|
|
var devices []TCPDevice
|
|
var mu sync.Mutex
|
|
var wg sync.WaitGroup
|
|
semaphore := make(chan struct{}, webDeviceScanConcurrency)
|
|
|
|
current := make(net.IP, len(ipRange.Start))
|
|
copy(current, ipRange.Start)
|
|
incrementIP(current)
|
|
|
|
for {
|
|
if current.To4().Equal(ipRange.End.To4()) {
|
|
break
|
|
}
|
|
|
|
currentIP := current.String()
|
|
if excludeIPs[currentIP] {
|
|
incrementIP(current)
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
semaphore <- struct{}{}
|
|
go func(targetIP string) {
|
|
defer wg.Done()
|
|
defer func() { <-semaphore }()
|
|
|
|
if scanTCPPortWithRetry(targetIP, port) {
|
|
mu.Lock()
|
|
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
|
|
mu.Unlock()
|
|
}
|
|
}(currentIP)
|
|
|
|
incrementIP(current)
|
|
}
|
|
|
|
wg.Wait()
|
|
sort.Slice(devices, func(i, j int) bool {
|
|
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
|
|
})
|
|
return devices, nil
|
|
}
|
|
|
|
type ipRange struct {
|
|
Start net.IP
|
|
End net.IP
|
|
}
|
|
|
|
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
|
|
parseIP := net.ParseIP(ip)
|
|
if parseIP == nil {
|
|
return nil, fmt.Errorf("无效的IP: %s", ip)
|
|
}
|
|
|
|
mask := net.IPMask(net.ParseIP(netmask).To4())
|
|
if mask == nil {
|
|
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
|
|
}
|
|
|
|
network := &net.IPNet{
|
|
IP: parseIP.Mask(mask),
|
|
Mask: mask,
|
|
}
|
|
broadcast := make(net.IP, len(network.IP))
|
|
copy(broadcast, network.IP)
|
|
for i := 0; i < len(mask); i++ {
|
|
broadcast[i] |= ^mask[i]
|
|
}
|
|
|
|
result := &ipRange{Start: network.IP.To4(), End: broadcast}
|
|
if ipToUint32(result.End)-ipToUint32(result.Start)+1 <= maxWebDeviceScanAddrs {
|
|
return result, nil
|
|
}
|
|
|
|
local24Mask := net.CIDRMask(24, 32)
|
|
local24Network := parseIP.To4().Mask(local24Mask)
|
|
local24Broadcast := make(net.IP, len(local24Network))
|
|
copy(local24Broadcast, local24Network)
|
|
for i := 0; i < len(local24Mask); i++ {
|
|
local24Broadcast[i] |= ^local24Mask[i]
|
|
}
|
|
return &ipRange{Start: local24Network, End: local24Broadcast}, nil
|
|
}
|
|
|
|
func scanTCPPortWithRetry(ip string, port int) bool {
|
|
for attempt := 0; attempt < webDeviceScanAttempts; attempt++ {
|
|
if scanTCPPort(ip, port, webDeviceScanTimeout) {
|
|
return true
|
|
}
|
|
if attempt < webDeviceScanAttempts-1 {
|
|
time.Sleep(webDeviceScanRetryDelay)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
|
|
if port <= 0 || port > 65535 {
|
|
return false
|
|
}
|
|
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
|
|
conn, err := net.DialTimeout("tcp", addr, timeout)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
_ = conn.Close()
|
|
return true
|
|
}
|
|
|
|
func incrementIP(ip net.IP) {
|
|
for i := len(ip) - 1; i >= 0; i-- {
|
|
ip[i]++
|
|
if ip[i] > 0 {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func ipv4ToUint32(value string) uint32 {
|
|
parsed := net.ParseIP(value)
|
|
if parsed == nil {
|
|
return 0
|
|
}
|
|
return ipToUint32(parsed)
|
|
}
|
|
|
|
func ipToUint32(parsed net.IP) uint32 {
|
|
ip := parsed.To4()
|
|
if ip == nil {
|
|
return 0
|
|
}
|
|
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
|
}
|