package server import ( "bufio" "context" "encoding/binary" "errors" "file-manger-server/util" "fmt" "io" "log" "net" "os" "os/signal" "sync" "sync/atomic" "syscall" "time" ) // ConnPool 改进后的连接池结构 type ConnPool struct { connections sync.Map // 使用sync.Map替代原生map maxConns int32 // 最大连接数 currentConns atomic.Int32 // 当前连接数(原子操作) ctx context.Context // 上下文 cancel context.CancelFunc // 用于取消上下文 } type ConnInfo struct { conn net.Conn // 原始连接 lastActive atomic.Int64 // 最后活跃时间(原子操作) heartbeatTTL time.Duration // 心跳TTL closed atomic.Bool // 关闭状态标记 bufReader *bufio.Reader // 带缓冲的读取器 } const ( MaxConnections = 1000 // 最大连接数 HeartbeatInterval = 30 * time.Second // 心跳间隔 ReadTimeout = 2 * time.Minute // 读取超时时间 ) func NewConnPool(maxConns int) *ConnPool { ctx, cancel := context.WithCancel(context.Background()) return &ConnPool{ maxConns: int32(maxConns), ctx: ctx, cancel: cancel, } } // Add 添加连接到池(带数量检查) func (p *ConnPool) Add(conn net.Conn) error { if p.currentConns.Load() >= p.maxConns { return errors.New("connection limit reached") } fmt.Println(fmt.Sprintf("Ip \"%s\" connect.", conn.RemoteAddr())) ci := &ConnInfo{ conn: conn, heartbeatTTL: HeartbeatInterval * 3, bufReader: bufio.NewReaderSize(conn, 4096), } ci.lastActive.Store(time.Now().UnixNano()) p.connections.Store(conn, ci) p.currentConns.Add(1) return nil } // Remove 移除连接 func (p *ConnPool) Remove(conn net.Conn) { if ci, loaded := p.connections.LoadAndDelete(conn); loaded { ci.(*ConnInfo).Close() p.currentConns.Add(-1) } } // Close 关闭连接 func (ci *ConnInfo) Close() { if ci.closed.CompareAndSwap(false, true) { ci.conn.Close() } } // StartCleaner 启动自动清理协程 func (p *ConnPool) StartCleaner() { go func() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: p.cleanup() case <-p.ctx.Done(): return } } }() } // 清理失效连接 func (p *ConnPool) cleanup() { p.connections.Range(func(key, value any) bool { ci := value.(*ConnInfo) if time.Since(time.Unix(0, ci.lastActive.Load())) > ci.heartbeatTTL { p.Remove(ci.conn) } return true }) } // Shutdown 优雅关闭 func (p *ConnPool) Shutdown() { p.cancel() // 设置关闭超时为60秒 ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() done := make(chan struct{}) go func() { p.connections.Range(func(key, value any) bool { ci := value.(*ConnInfo) ci.Close() p.connections.Delete(key) return true }) close(done) }() select { case <-done: case <-ctx.Done(): log.Println("强制关闭剩余连接") } } var Pool *ConnPool func Run() { Pool = NewConnPool(MaxConnections) Pool.StartCleaner() listener, err := net.Listen("tcp", "0.0.0.0:8080") if err != nil { log.Fatal(err) } // 信号处理 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) // 修改信号处理goroutine go func() { <-sigCh log.Println("Shutting down server...") // 先关闭listener停止接受新连接 listener.Close() // 然后关闭连接池 Pool.Shutdown() // 添加退出等待逻辑 time.Sleep(1 * time.Second) // 等待处理中的请求完成 os.Exit(0) // 确保程序退出 }() for { conn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { break // 正常关闭 } log.Printf("Accept error: %v", err) continue } if err := Pool.Add(conn); err != nil { conn.Close() log.Printf("Reject connection: %v", err) continue } go handleConnection(Pool, conn) } } func handleConnection(pool *ConnPool, conn net.Conn) { defer func() { if r := recover(); r != nil { log.Printf("Connection panic: %v", r) } pool.Remove(conn) }() // 获取连接信息 ci, ok := pool.connections.Load(conn) if !ok { return } connInfo := ci.(*ConnInfo) // 心跳检测 ctx, cancel := context.WithCancel(context.Background()) defer cancel() go heartbeat(ctx, connInfo) // 消息处理循环 for { if connInfo.closed.Load() { ctx.Done() return } // 设置读取超时 connInfo.conn.SetReadDeadline(time.Now().Add(ReadTimeout)) // 使用长度前缀协议处理粘包 message, err := readMessage(connInfo.bufReader) if err != nil { if !errors.Is(err, io.EOF) { log.Printf("Read error: %v", err) } return } //处理业务 HandlerReceiveFile(connInfo, message) // 更新活跃时间 connInfo.lastActive.Store(time.Now().UnixNano()) // 处理业务逻辑 if err := processMessage(connInfo, message); err != nil { log.Printf("Process message error: %v", err) return } } } // 自定义错误类型 var ( ErrInvalidStartByte = errors.New("invalid start byte") ErrChecksumMismatch = errors.New("checksum mismatch") ErrMessageTooLarge = errors.New("message exceeds size limit") ) func readMessage(r *bufio.Reader) ([]byte, error) { // 读取起始字节 (0x2a) startByte, err := r.ReadByte() if err != nil { return nil, fmt.Errorf("read start byte failed: %w", err) } if startByte != 0x2a { return nil, fmt.Errorf("%w: received 0x%02x", ErrInvalidStartByte, startByte) } // 读取长度字段 (4字节大端序) lengthBuf := make([]byte, 4) if _, err := io.ReadFull(r, lengthBuf); err != nil { return nil, fmt.Errorf("read length failed: %w", err) } dataLength := binary.BigEndian.Uint32(lengthBuf) // 验证数据长度 (10MB限制) if dataLength > 10*1024*1024 { return nil, fmt.Errorf("%w: %d bytes", ErrMessageTooLarge, dataLength) } // 读取数据体 dataBody := make([]byte, dataLength) if _, err := io.ReadFull(r, dataBody); err != nil { return nil, fmt.Errorf("read data body failed: %w", err) } // 读取校验和 (2字节大端序) checksumBuf := make([]byte, 2) if _, err := io.ReadFull(r, checksumBuf); err != nil { return nil, fmt.Errorf("read checksum failed: %w", err) } receivedChecksum := binary.BigEndian.Uint16(checksumBuf) // 计算校验和 (CRC16-CCITT) calculatedChecksum := util.Crc16CCITT(dataBody) if calculatedChecksum != receivedChecksum { return nil, fmt.Errorf("%w: expected %04X, got %04X", ErrChecksumMismatch, calculatedChecksum, receivedChecksum) } return dataBody, nil } // 心跳检测 func heartbeat(ctx context.Context, ci *ConnInfo) { ticker := time.NewTicker(HeartbeatInterval) defer ticker.Stop() for { select { case <-ticker.C: if ci.closed.Load() { return } // 发送心跳包 if _, err := ci.conn.Write([]byte("PING")); err != nil { ci.Close() return } case <-ctx.Done(): return } } } // 示例消息处理 func processMessage(info *ConnInfo, msg []byte) error { // 实现业务逻辑 fmt.Printf("Processing message: %s\n", util.ToHexBytes(msg)) info.conn.Write(msg) return nil }