| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- package server
- import (
- "bufio"
- "context"
- "encoding/binary"
- "errors"
- "file-manger-server/util"
- "fmt"
- "io"
- "log"
- "net"
- "os"
- "os/signal"
- "sync"
- "sync/atomic"
- "syscall"
- "time"
- )
- // ConnPool 改进后的连接池结构
- type ConnPool struct {
- connections sync.Map // 使用sync.Map替代原生map
- maxConns int32 // 最大连接数
- currentConns atomic.Int32 // 当前连接数(原子操作)
- ctx context.Context // 上下文
- cancel context.CancelFunc // 用于取消上下文
- }
- type ConnInfo struct {
- conn net.Conn // 原始连接
- lastActive atomic.Int64 // 最后活跃时间(原子操作)
- heartbeatTTL time.Duration // 心跳TTL
- closed atomic.Bool // 关闭状态标记
- bufReader *bufio.Reader // 带缓冲的读取器
- }
- const (
- MaxConnections = 1000 // 最大连接数
- HeartbeatInterval = 30 * time.Second // 心跳间隔
- ReadTimeout = 2 * time.Minute // 读取超时时间
- )
- func NewConnPool(maxConns int) *ConnPool {
- ctx, cancel := context.WithCancel(context.Background())
- return &ConnPool{
- maxConns: int32(maxConns),
- ctx: ctx,
- cancel: cancel,
- }
- }
- // Add 添加连接到池(带数量检查)
- func (p *ConnPool) Add(conn net.Conn) error {
- if p.currentConns.Load() >= p.maxConns {
- return errors.New("connection limit reached")
- }
- ci := &ConnInfo{
- conn: conn,
- heartbeatTTL: HeartbeatInterval * 3,
- bufReader: bufio.NewReaderSize(conn, 4096),
- }
- ci.lastActive.Store(time.Now().UnixNano())
- p.connections.Store(conn, ci)
- p.currentConns.Add(1)
- return nil
- }
- // Remove 移除连接
- func (p *ConnPool) Remove(conn net.Conn) {
- if ci, loaded := p.connections.LoadAndDelete(conn); loaded {
- ci.(*ConnInfo).Close()
- p.currentConns.Add(-1)
- }
- }
- // Close 关闭连接
- func (ci *ConnInfo) Close() {
- if ci.closed.CompareAndSwap(false, true) {
- ci.conn.Close()
- }
- }
- // StartCleaner 启动自动清理协程
- func (p *ConnPool) StartCleaner() {
- go func() {
- ticker := time.NewTicker(1 * time.Minute)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- p.cleanup()
- case <-p.ctx.Done():
- return
- }
- }
- }()
- }
- // 清理失效连接
- func (p *ConnPool) cleanup() {
- p.connections.Range(func(key, value any) bool {
- ci := value.(*ConnInfo)
- if time.Since(time.Unix(0, ci.lastActive.Load())) > ci.heartbeatTTL {
- p.Remove(ci.conn)
- }
- return true
- })
- }
- // Shutdown 优雅关闭
- func (p *ConnPool) Shutdown() {
- p.cancel()
- p.connections.Range(func(key, value any) bool {
- ci := value.(*ConnInfo)
- ci.Close()
- p.connections.Delete(key)
- return true
- })
- }
- func Run() {
- pool := NewConnPool(MaxConnections)
- pool.StartCleaner()
- listener, err := net.Listen("tcp", ":8080")
- if err != nil {
- log.Fatal(err)
- }
- // 信号处理
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
- go func() {
- <-sigCh
- log.Println("Shutting down server...")
- listener.Close()
- pool.Shutdown()
- }()
- for {
- conn, err := listener.Accept()
- if err != nil {
- if errors.Is(err, net.ErrClosed) {
- break // 正常关闭
- }
- log.Printf("Accept error: %v", err)
- continue
- }
- if err := pool.Add(conn); err != nil {
- conn.Close()
- log.Printf("Reject connection: %v", err)
- continue
- }
- go handleConnection(pool, conn)
- }
- }
- func handleConnection(pool *ConnPool, conn net.Conn) {
- defer func() {
- if r := recover(); r != nil {
- log.Printf("Connection panic: %v", r)
- }
- pool.Remove(conn)
- }()
- // 获取连接信息
- ci, ok := pool.connections.Load(conn)
- if !ok {
- return
- }
- connInfo := ci.(*ConnInfo)
- // 心跳检测
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
- go heartbeat(ctx, connInfo)
- // 消息处理循环
- for {
- // 设置读取超时
- connInfo.conn.SetReadDeadline(time.Now().Add(ReadTimeout))
- // 使用长度前缀协议处理粘包
- message, err := readMessage(connInfo.bufReader)
- if err != nil {
- if !errors.Is(err, io.EOF) {
- log.Printf("Read error: %v", err)
- }
- return
- }
- // 更新活跃时间
- connInfo.lastActive.Store(time.Now().UnixNano())
- // 处理业务逻辑
- if err := processMessage(connInfo, message); err != nil {
- log.Printf("Process message error: %v", err)
- return
- }
- }
- }
- // 自定义错误类型
- var (
- ErrInvalidStartByte = errors.New("invalid start byte")
- ErrChecksumMismatch = errors.New("checksum mismatch")
- ErrMessageTooLarge = errors.New("message exceeds size limit")
- )
- func readMessage(r *bufio.Reader) ([]byte, error) {
- // 读取起始字节 (0x2a)
- startByte, err := r.ReadByte()
- if err != nil {
- return nil, fmt.Errorf("read start byte failed: %w", err)
- }
- if startByte != 0x2a {
- return nil, fmt.Errorf("%w: received 0x%02x", ErrInvalidStartByte, startByte)
- }
- // 读取长度字段 (4字节大端序)
- lengthBuf := make([]byte, 4)
- if _, err := io.ReadFull(r, lengthBuf); err != nil {
- return nil, fmt.Errorf("read length failed: %w", err)
- }
- dataLength := binary.BigEndian.Uint32(lengthBuf)
- // 验证数据长度 (10MB限制)
- if dataLength > 10*1024*1024 {
- return nil, fmt.Errorf("%w: %d bytes", ErrMessageTooLarge, dataLength)
- }
- // 读取数据体
- dataBody := make([]byte, dataLength)
- if _, err := io.ReadFull(r, dataBody); err != nil {
- return nil, fmt.Errorf("read data body failed: %w", err)
- }
- // 读取校验和 (2字节大端序)
- checksumBuf := make([]byte, 2)
- if _, err := io.ReadFull(r, checksumBuf); err != nil {
- return nil, fmt.Errorf("read checksum failed: %w", err)
- }
- receivedChecksum := binary.BigEndian.Uint16(checksumBuf)
- // 计算校验和 (CRC16-CCITT)
- calculatedChecksum := util.Crc16CCITT(dataBody)
- if calculatedChecksum != receivedChecksum {
- return nil, fmt.Errorf("%w: expected %04X, got %04X",
- ErrChecksumMismatch, calculatedChecksum, receivedChecksum)
- }
- return dataBody, nil
- }
- // 心跳检测
- func heartbeat(ctx context.Context, ci *ConnInfo) {
- ticker := time.NewTicker(HeartbeatInterval)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- if ci.closed.Load() {
- return
- }
- // 发送心跳包
- if _, err := ci.conn.Write([]byte("PING")); err != nil {
- ci.Close()
- return
- }
- case <-ctx.Done():
- return
- }
- }
- }
- // 示例消息处理
- func processMessage(info *ConnInfo, msg []byte) error {
- // 实现业务逻辑
- fmt.Printf("Processing message: %s\n", util.ToHexBytes(msg))
- info.conn.Write(msg)
- return nil
- }
|