feat: initialize managed portal
This commit is contained in:
495
internal/webdevice/service.go
Normal file
495
internal/webdevice/service.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package webdevice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type InterfaceInfo struct {
|
||||
Name string `json:"name"`
|
||||
IP string `json:"ip"`
|
||||
Netmask string `json:"netmask"`
|
||||
}
|
||||
|
||||
type TCPDevice struct {
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
type DeviceInfo struct {
|
||||
IP string `json:"ip"`
|
||||
Interface string `json:"interface"`
|
||||
Port int `json:"port"`
|
||||
TargetURL string `json:"target_url"`
|
||||
ProxyURL string `json:"proxy_url"`
|
||||
DirectURL string `json:"direct_url,omitempty"`
|
||||
ForwardPort int `json:"forward_port,omitempty"`
|
||||
}
|
||||
|
||||
type ScanResult struct {
|
||||
Interfaces []InterfaceInfo `json:"interfaces"`
|
||||
Devices []DeviceInfo `json:"devices"`
|
||||
Count int `json:"count"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type InterfaceGetter func() ([]InterfaceInfo, error)
|
||||
type TCPScanner func(ip, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error)
|
||||
type ForwarderFactory func(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error)
|
||||
type ProxyTargetResolver func(ip string) string
|
||||
|
||||
type Service struct {
|
||||
mu sync.RWMutex
|
||||
allowed map[string]time.Time
|
||||
forwarders map[string]*webDeviceForwarder
|
||||
interfaceGetter InterfaceGetter
|
||||
tcpScanner TCPScanner
|
||||
newForwarder ForwarderFactory
|
||||
proxyTarget ProxyTargetResolver
|
||||
forwardTarget func(ip string) string
|
||||
}
|
||||
|
||||
func NewService() *Service {
|
||||
return &Service{
|
||||
allowed: make(map[string]time.Time),
|
||||
forwarders: make(map[string]*webDeviceForwarder),
|
||||
interfaceGetter: defaultInterfaceGetter,
|
||||
tcpScanner: scanTCP,
|
||||
newForwarder: newWebDeviceForwarder,
|
||||
proxyTarget: defaultProxyTarget,
|
||||
forwardTarget: defaultForwardTarget,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Scan(r *http.Request) (*ScanResult, error) {
|
||||
interfaces, err := s.interfaceGetter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(interfaces) == 0 {
|
||||
return &ScanResult{
|
||||
Interfaces: []InterfaceInfo{},
|
||||
Devices: []DeviceInfo{},
|
||||
Message: "未找到有效的网卡",
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeIPs := make(map[string]bool)
|
||||
for _, iface := range interfaces {
|
||||
excludeIPs[iface.IP] = true
|
||||
}
|
||||
|
||||
result := &ScanResult{
|
||||
Interfaces: interfaces,
|
||||
Devices: []DeviceInfo{},
|
||||
Errors: []string{},
|
||||
}
|
||||
|
||||
scheme, host := requestBase(r)
|
||||
for _, iface := range interfaces {
|
||||
devices, scanErr := s.tcpScanner(iface.IP, iface.Netmask, 80, excludeIPs)
|
||||
if scanErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: %v", iface.Name, scanErr))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
if !IsPrivateIPv4Literal(device.IP) {
|
||||
continue
|
||||
}
|
||||
|
||||
s.allowIP(device.IP)
|
||||
forwardPort, forwardErr := s.EnsureForwarder(device.IP)
|
||||
if forwardErr != nil {
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("%s: 启动网页直连转发失败: %v", device.IP, forwardErr))
|
||||
}
|
||||
|
||||
deviceInfo := DeviceInfo{
|
||||
IP: device.IP,
|
||||
Interface: iface.Name,
|
||||
Port: device.Port,
|
||||
TargetURL: fmt.Sprintf("http://%s/", device.IP),
|
||||
ProxyURL: fmt.Sprintf("/proxy/web/%s/", device.IP),
|
||||
}
|
||||
if forwardErr == nil {
|
||||
deviceInfo.ForwardPort = forwardPort
|
||||
deviceInfo.DirectURL = buildDirectURL(scheme, host, forwardPort)
|
||||
}
|
||||
result.Devices = append(result.Devices, deviceInfo)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(result.Devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(result.Devices[i].IP) < ipv4ToUint32(result.Devices[j].IP)
|
||||
})
|
||||
result.Count = len(result.Devices)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) allowIP(ip string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowed[ip] = time.Now()
|
||||
}
|
||||
|
||||
func (s *Service) AllowIP(ip string) {
|
||||
s.allowIP(ip)
|
||||
}
|
||||
|
||||
func (s *Service) IsAllowed(ip string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.allowed[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *Service) SetInterfaceGetter(getter InterfaceGetter) {
|
||||
if getter != nil {
|
||||
s.interfaceGetter = getter
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetTCPScanner(scanner TCPScanner) {
|
||||
if scanner != nil {
|
||||
s.tcpScanner = scanner
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetForwarderFactory(factory ForwarderFactory) {
|
||||
if factory != nil {
|
||||
s.newForwarder = factory
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetProxyTargetResolver(resolver ProxyTargetResolver) {
|
||||
if resolver != nil {
|
||||
s.proxyTarget = resolver
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) EnsureForwarder(ip string) (int, error) {
|
||||
port, ok := WebDeviceForwardPort(ip)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("无效的IPv4地址")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if forwarder, ok := s.forwarders[ip]; ok {
|
||||
return forwarder.port, nil
|
||||
}
|
||||
|
||||
targetAddress := defaultForwardTarget(ip)
|
||||
if s.forwardTarget != nil {
|
||||
targetAddress = s.forwardTarget(ip)
|
||||
}
|
||||
|
||||
forwarder, err := s.newForwarder(ip, port, net.JoinHostPort("0.0.0.0", strconv.Itoa(port)), targetAddress)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.forwarders[ip] = forwarder
|
||||
go forwarder.serve()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func (s *Service) ProxyTargetURL(ip string) string {
|
||||
if s.proxyTarget == nil {
|
||||
return defaultProxyTarget(ip)
|
||||
}
|
||||
return s.proxyTarget(ip)
|
||||
}
|
||||
|
||||
func IsPrivateIPv4Literal(value string) bool {
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return false
|
||||
}
|
||||
return ip4.IsPrivate() && !ip4.IsLoopback() && !ip4.IsMulticast() && !ip4.IsUnspecified()
|
||||
}
|
||||
|
||||
func WebDeviceForwardPort(ip string) (int, bool) {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return 0, false
|
||||
}
|
||||
ip4 := parsed.To4()
|
||||
if ip4 == nil {
|
||||
return 0, false
|
||||
}
|
||||
return 31000 + int(ip4[3]), true
|
||||
}
|
||||
|
||||
func requestBase(r *http.Request) (string, string) {
|
||||
scheme := r.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
if hostname, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = hostname
|
||||
}
|
||||
return scheme, host
|
||||
}
|
||||
|
||||
func buildDirectURL(scheme, host string, port int) string {
|
||||
return scheme + "://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/"
|
||||
}
|
||||
|
||||
type webDeviceForwarder struct {
|
||||
ip string
|
||||
port int
|
||||
targetAddress string
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
type WebDeviceForwarder = webDeviceForwarder
|
||||
|
||||
func newWebDeviceForwarder(ip string, port int, listenAddress, targetAddress string) (*webDeviceForwarder, error) {
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port == 0 {
|
||||
if tcpAddr, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||
port = tcpAddr.Port
|
||||
}
|
||||
}
|
||||
return &webDeviceForwarder{
|
||||
ip: ip,
|
||||
port: port,
|
||||
targetAddress: targetAddress,
|
||||
listener: listener,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) serve() {
|
||||
if f == nil || f.listener == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
clientConn, err := f.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go f.handle(clientConn)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *webDeviceForwarder) handle(clientConn net.Conn) {
|
||||
targetConn, err := net.DialTimeout("tcp", f.targetAddress, 10*time.Second)
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := ioCopy(targetConn, clientConn)
|
||||
errCh <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := ioCopy(clientConn, targetConn)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
<-errCh
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}
|
||||
|
||||
var ioCopy = func(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return copyConn(dst, src)
|
||||
}
|
||||
|
||||
func copyConn(dst net.Conn, src net.Conn) (int64, error) {
|
||||
return io.Copy(dst, src)
|
||||
}
|
||||
|
||||
func defaultProxyTarget(ip string) string {
|
||||
return "http://" + net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultForwardTarget(ip string) string {
|
||||
return net.JoinHostPort(ip, "80")
|
||||
}
|
||||
|
||||
func defaultInterfaceGetter() ([]InterfaceInfo, error) {
|
||||
var interfaces []InterfaceInfo
|
||||
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取网卡列表失败: %w", err)
|
||||
}
|
||||
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(iface.Name, "docker") ||
|
||||
strings.Contains(iface.Name, "veth") ||
|
||||
strings.Contains(iface.Name, "br-") ||
|
||||
strings.Contains(iface.Name, "tun") ||
|
||||
strings.Contains(iface.Name, "tap") ||
|
||||
strings.HasPrefix(iface.Name, "lo") {
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipNet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
interfaces = append(interfaces, InterfaceInfo{
|
||||
Name: iface.Name,
|
||||
IP: ipNet.IP.String(),
|
||||
Netmask: net.IP(ipNet.Mask).String(),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func scanTCP(ip string, netmask string, port int, excludeIPs map[string]bool) ([]TCPDevice, error) {
|
||||
ipRange, err := calculateIPRange(ip, netmask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port <= 0 || port > 65535 {
|
||||
return nil, fmt.Errorf("无效的端口: %d", port)
|
||||
}
|
||||
|
||||
var devices []TCPDevice
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 20)
|
||||
timeout := 2 * time.Second
|
||||
|
||||
current := make(net.IP, len(ipRange.Start))
|
||||
copy(current, ipRange.Start)
|
||||
incrementIP(current)
|
||||
|
||||
for {
|
||||
if current.To4().Equal(ipRange.End.To4()) {
|
||||
break
|
||||
}
|
||||
|
||||
currentIP := current.String()
|
||||
if excludeIPs[currentIP] {
|
||||
incrementIP(current)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(targetIP string) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
if scanTCPPort(targetIP, port, timeout) {
|
||||
mu.Lock()
|
||||
devices = append(devices, TCPDevice{IP: targetIP, Port: port})
|
||||
mu.Unlock()
|
||||
}
|
||||
}(currentIP)
|
||||
|
||||
incrementIP(current)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
sort.Slice(devices, func(i, j int) bool {
|
||||
return ipv4ToUint32(devices[i].IP) < ipv4ToUint32(devices[j].IP)
|
||||
})
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
type ipRange struct {
|
||||
Start net.IP
|
||||
End net.IP
|
||||
}
|
||||
|
||||
func calculateIPRange(ip string, netmask string) (*ipRange, error) {
|
||||
parseIP := net.ParseIP(ip)
|
||||
if parseIP == nil {
|
||||
return nil, fmt.Errorf("无效的IP: %s", ip)
|
||||
}
|
||||
|
||||
mask := net.IPMask(net.ParseIP(netmask).To4())
|
||||
if mask == nil {
|
||||
return nil, fmt.Errorf("无效的子网掩码: %s", netmask)
|
||||
}
|
||||
|
||||
network := &net.IPNet{
|
||||
IP: parseIP.Mask(mask),
|
||||
Mask: mask,
|
||||
}
|
||||
broadcast := make(net.IP, len(network.IP))
|
||||
copy(broadcast, network.IP)
|
||||
for i := 0; i < len(mask); i++ {
|
||||
broadcast[i] |= ^mask[i]
|
||||
}
|
||||
|
||||
return &ipRange{Start: network.IP.To4(), End: broadcast}, nil
|
||||
}
|
||||
|
||||
func scanTCPPort(ip string, port int, timeout time.Duration) bool {
|
||||
if port <= 0 || port > 65535 {
|
||||
return false
|
||||
}
|
||||
addr := net.JoinHostPort(ip, fmt.Sprintf("%d", port))
|
||||
conn, err := net.DialTimeout("tcp", addr, timeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func incrementIP(ip net.IP) {
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
ip[i]++
|
||||
if ip[i] > 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ipv4ToUint32(value string) uint32 {
|
||||
parsed := net.ParseIP(value)
|
||||
if parsed == nil {
|
||||
return 0
|
||||
}
|
||||
ip := parsed.To4()
|
||||
if ip == nil {
|
||||
return 0
|
||||
}
|
||||
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
||||
}
|
||||
Reference in New Issue
Block a user