package server import ( "bufio" "context" "encoding/binary" "errors" "file-manger-server/util" "fmt" "io" "log" "net" "os" "os/signal" "sync" "sync/atomic" "syscall" "time" ) // ConnPool 改进后的连接池结构 type ConnPool struct { connections sync.Map // 使用sync.Map替代原生map maxConns int32 // 最大连接数 currentConns atomic.Int32 // 当前连接数(原子操作) ctx context.Context // 上下文 cancel context.CancelFunc // 用于取消上下文 } type ConnInfo struct { conn net.Conn // 原始连接 lastActive atomic.Int64 // 最后活跃时间(原子操作) heartbeatTTL time.Duration // 心跳TTL closed atomic.Bool // 关闭状态标记 bufReader *bufio.Reader // 带缓冲的读取器 } const ( MaxConnections = 1000 // 最大连接数 HeartbeatInterval = 30 * time.Second // 心跳间隔 ReadTimeout = 2 * time.Minute // 读取超时时间 ) func NewConnPool(maxConns int) *ConnPool { ctx, cancel := context.WithCancel(context.Background()) return &ConnPool{ maxConns: int32(maxConns), ctx: ctx, cancel: cancel, } } // Add 添加连接到池(带数量检查) func (p *ConnPool) Add(conn net.Conn) error { if p.currentConns.Load() >= p.maxConns { return errors.New("connection limit reached") } ci := &ConnInfo{ conn: conn, heartbeatTTL: HeartbeatInterval * 3, bufReader: bufio.NewReaderSize(conn, 4096), } ci.lastActive.Store(time.Now().UnixNano()) p.connections.Store(conn, ci) p.currentConns.Add(1) return nil } // Remove 移除连接 func (p *ConnPool) Remove(conn net.Conn) { if ci, loaded := p.connections.LoadAndDelete(conn); loaded { ci.(*ConnInfo).Close() p.currentConns.Add(-1) } } // Close 关闭连接 func (ci *ConnInfo) Close() { if ci.closed.CompareAndSwap(false, true) { ci.conn.Close() } } // StartCleaner 启动自动清理协程 func (p *ConnPool) StartCleaner() { go func() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: p.cleanup() case <-p.ctx.Done(): return } } }() } // 清理失效连接 func (p *ConnPool) cleanup() { p.connections.Range(func(key, value any) bool { ci := value.(*ConnInfo) if time.Since(time.Unix(0, ci.lastActive.Load())) > ci.heartbeatTTL { p.Remove(ci.conn) } return true }) } // Shutdown 优雅关闭 func (p *ConnPool) Shutdown() { p.cancel() p.connections.Range(func(key, value any) bool { ci := value.(*ConnInfo) ci.Close() p.connections.Delete(key) return true }) } func Run() { pool := NewConnPool(MaxConnections) pool.StartCleaner() listener, err := net.Listen("tcp", ":8080") if err != nil { log.Fatal(err) } // 信号处理 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh log.Println("Shutting down server...") listener.Close() pool.Shutdown() }() for { conn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { break // 正常关闭 } log.Printf("Accept error: %v", err) continue } if err := pool.Add(conn); err != nil { conn.Close() log.Printf("Reject connection: %v", err) continue } go handleConnection(pool, conn) } } func handleConnection(pool *ConnPool, conn net.Conn) { defer func() { if r := recover(); r != nil { log.Printf("Connection panic: %v", r) } pool.Remove(conn) }() // 获取连接信息 ci, ok := pool.connections.Load(conn) if !ok { return } connInfo := ci.(*ConnInfo) // 心跳检测 ctx, cancel := context.WithCancel(context.Background()) defer cancel() go heartbeat(ctx, connInfo) // 消息处理循环 for { // 设置读取超时 connInfo.conn.SetReadDeadline(time.Now().Add(ReadTimeout)) // 使用长度前缀协议处理粘包 message, err := readMessage(connInfo.bufReader) if err != nil { if !errors.Is(err, io.EOF) { log.Printf("Read error: %v", err) } return } // 更新活跃时间 connInfo.lastActive.Store(time.Now().UnixNano()) // 处理业务逻辑 if err := processMessage(connInfo, message); err != nil { log.Printf("Process message error: %v", err) return } } } // 自定义错误类型 var ( ErrInvalidStartByte = errors.New("invalid start byte") ErrChecksumMismatch = errors.New("checksum mismatch") ErrMessageTooLarge = errors.New("message exceeds size limit") ) func readMessage(r *bufio.Reader) ([]byte, error) { // 读取起始字节 (0x2a) startByte, err := r.ReadByte() if err != nil { return nil, fmt.Errorf("read start byte failed: %w", err) } if startByte != 0x2a { return nil, fmt.Errorf("%w: received 0x%02x", ErrInvalidStartByte, startByte) } // 读取长度字段 (4字节大端序) lengthBuf := make([]byte, 4) if _, err := io.ReadFull(r, lengthBuf); err != nil { return nil, fmt.Errorf("read length failed: %w", err) } dataLength := binary.BigEndian.Uint32(lengthBuf) // 验证数据长度 (10MB限制) if dataLength > 10*1024*1024 { return nil, fmt.Errorf("%w: %d bytes", ErrMessageTooLarge, dataLength) } // 读取数据体 dataBody := make([]byte, dataLength) if _, err := io.ReadFull(r, dataBody); err != nil { return nil, fmt.Errorf("read data body failed: %w", err) } // 读取校验和 (2字节大端序) checksumBuf := make([]byte, 2) if _, err := io.ReadFull(r, checksumBuf); err != nil { return nil, fmt.Errorf("read checksum failed: %w", err) } receivedChecksum := binary.BigEndian.Uint16(checksumBuf) // 计算校验和 (CRC16-CCITT) calculatedChecksum := util.Crc16CCITT(dataBody) if calculatedChecksum != receivedChecksum { return nil, fmt.Errorf("%w: expected %04X, got %04X", ErrChecksumMismatch, calculatedChecksum, receivedChecksum) } return dataBody, nil } // 心跳检测 func heartbeat(ctx context.Context, ci *ConnInfo) { ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { select { case <-ticker.C: if ci.closed.Load() { return } // 发送心跳包 if _, err := ci.conn.Write([]byte("PING")); err != nil { ci.Close() return } case <-ctx.Done(): return } } } // 示例消息处理 func processMessage(info *ConnInfo, msg []byte) error { // 实现业务逻辑 fmt.Printf("Processing message: %s\n", util.ToHexBytes(msg)) info.conn.Write(msg) return nil }