TCPServer.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. package server
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/binary"
  6. "errors"
  7. "file-manger-server/util"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "os"
  13. "os/signal"
  14. "sync"
  15. "sync/atomic"
  16. "syscall"
  17. "time"
  18. )
  19. // ConnPool 改进后的连接池结构
  20. type ConnPool struct {
  21. connections sync.Map // 使用sync.Map替代原生map
  22. maxConns int32 // 最大连接数
  23. currentConns atomic.Int32 // 当前连接数(原子操作)
  24. ctx context.Context // 上下文
  25. cancel context.CancelFunc // 用于取消上下文
  26. }
  27. type ConnInfo struct {
  28. conn net.Conn // 原始连接
  29. lastActive atomic.Int64 // 最后活跃时间(原子操作)
  30. heartbeatTTL time.Duration // 心跳TTL
  31. closed atomic.Bool // 关闭状态标记
  32. bufReader *bufio.Reader // 带缓冲的读取器
  33. }
  34. const (
  35. MaxConnections = 1000 // 最大连接数
  36. HeartbeatInterval = 30 * time.Second // 心跳间隔
  37. ReadTimeout = 2 * time.Minute // 读取超时时间
  38. )
  39. func NewConnPool(maxConns int) *ConnPool {
  40. ctx, cancel := context.WithCancel(context.Background())
  41. return &ConnPool{
  42. maxConns: int32(maxConns),
  43. ctx: ctx,
  44. cancel: cancel,
  45. }
  46. }
  47. // Add 添加连接到池(带数量检查)
  48. func (p *ConnPool) Add(conn net.Conn) error {
  49. if p.currentConns.Load() >= p.maxConns {
  50. return errors.New("connection limit reached")
  51. }
  52. ci := &ConnInfo{
  53. conn: conn,
  54. heartbeatTTL: HeartbeatInterval * 3,
  55. bufReader: bufio.NewReaderSize(conn, 4096),
  56. }
  57. ci.lastActive.Store(time.Now().UnixNano())
  58. p.connections.Store(conn, ci)
  59. p.currentConns.Add(1)
  60. return nil
  61. }
  62. // Remove 移除连接
  63. func (p *ConnPool) Remove(conn net.Conn) {
  64. if ci, loaded := p.connections.LoadAndDelete(conn); loaded {
  65. ci.(*ConnInfo).Close()
  66. p.currentConns.Add(-1)
  67. }
  68. }
  69. // Close 关闭连接
  70. func (ci *ConnInfo) Close() {
  71. if ci.closed.CompareAndSwap(false, true) {
  72. ci.conn.Close()
  73. }
  74. }
  75. // StartCleaner 启动自动清理协程
  76. func (p *ConnPool) StartCleaner() {
  77. go func() {
  78. ticker := time.NewTicker(1 * time.Minute)
  79. defer ticker.Stop()
  80. for {
  81. select {
  82. case <-ticker.C:
  83. p.cleanup()
  84. case <-p.ctx.Done():
  85. return
  86. }
  87. }
  88. }()
  89. }
  90. // 清理失效连接
  91. func (p *ConnPool) cleanup() {
  92. p.connections.Range(func(key, value any) bool {
  93. ci := value.(*ConnInfo)
  94. if time.Since(time.Unix(0, ci.lastActive.Load())) > ci.heartbeatTTL {
  95. p.Remove(ci.conn)
  96. }
  97. return true
  98. })
  99. }
  100. // Shutdown 优雅关闭
  101. func (p *ConnPool) Shutdown() {
  102. p.cancel()
  103. p.connections.Range(func(key, value any) bool {
  104. ci := value.(*ConnInfo)
  105. ci.Close()
  106. p.connections.Delete(key)
  107. return true
  108. })
  109. }
  110. func Run() {
  111. pool := NewConnPool(MaxConnections)
  112. pool.StartCleaner()
  113. listener, err := net.Listen("tcp", ":8080")
  114. if err != nil {
  115. log.Fatal(err)
  116. }
  117. // 信号处理
  118. sigCh := make(chan os.Signal, 1)
  119. signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
  120. go func() {
  121. <-sigCh
  122. log.Println("Shutting down server...")
  123. listener.Close()
  124. pool.Shutdown()
  125. }()
  126. for {
  127. conn, err := listener.Accept()
  128. if err != nil {
  129. if errors.Is(err, net.ErrClosed) {
  130. break // 正常关闭
  131. }
  132. log.Printf("Accept error: %v", err)
  133. continue
  134. }
  135. if err := pool.Add(conn); err != nil {
  136. conn.Close()
  137. log.Printf("Reject connection: %v", err)
  138. continue
  139. }
  140. go handleConnection(pool, conn)
  141. }
  142. }
  143. func handleConnection(pool *ConnPool, conn net.Conn) {
  144. defer func() {
  145. if r := recover(); r != nil {
  146. log.Printf("Connection panic: %v", r)
  147. }
  148. pool.Remove(conn)
  149. }()
  150. // 获取连接信息
  151. ci, ok := pool.connections.Load(conn)
  152. if !ok {
  153. return
  154. }
  155. connInfo := ci.(*ConnInfo)
  156. // 心跳检测
  157. ctx, cancel := context.WithCancel(context.Background())
  158. defer cancel()
  159. go heartbeat(ctx, connInfo)
  160. // 消息处理循环
  161. for {
  162. // 设置读取超时
  163. connInfo.conn.SetReadDeadline(time.Now().Add(ReadTimeout))
  164. // 使用长度前缀协议处理粘包
  165. message, err := readMessage(connInfo.bufReader)
  166. if err != nil {
  167. if !errors.Is(err, io.EOF) {
  168. log.Printf("Read error: %v", err)
  169. }
  170. return
  171. }
  172. // 更新活跃时间
  173. connInfo.lastActive.Store(time.Now().UnixNano())
  174. // 处理业务逻辑
  175. if err := processMessage(connInfo, message); err != nil {
  176. log.Printf("Process message error: %v", err)
  177. return
  178. }
  179. }
  180. }
  181. // 自定义错误类型
  182. var (
  183. ErrInvalidStartByte = errors.New("invalid start byte")
  184. ErrChecksumMismatch = errors.New("checksum mismatch")
  185. ErrMessageTooLarge = errors.New("message exceeds size limit")
  186. )
  187. func readMessage(r *bufio.Reader) ([]byte, error) {
  188. // 读取起始字节 (0x2a)
  189. startByte, err := r.ReadByte()
  190. if err != nil {
  191. return nil, fmt.Errorf("read start byte failed: %w", err)
  192. }
  193. if startByte != 0x2a {
  194. return nil, fmt.Errorf("%w: received 0x%02x", ErrInvalidStartByte, startByte)
  195. }
  196. // 读取长度字段 (4字节大端序)
  197. lengthBuf := make([]byte, 4)
  198. if _, err := io.ReadFull(r, lengthBuf); err != nil {
  199. return nil, fmt.Errorf("read length failed: %w", err)
  200. }
  201. dataLength := binary.BigEndian.Uint32(lengthBuf)
  202. // 验证数据长度 (10MB限制)
  203. if dataLength > 10*1024*1024 {
  204. return nil, fmt.Errorf("%w: %d bytes", ErrMessageTooLarge, dataLength)
  205. }
  206. // 读取数据体
  207. dataBody := make([]byte, dataLength)
  208. if _, err := io.ReadFull(r, dataBody); err != nil {
  209. return nil, fmt.Errorf("read data body failed: %w", err)
  210. }
  211. // 读取校验和 (2字节大端序)
  212. checksumBuf := make([]byte, 2)
  213. if _, err := io.ReadFull(r, checksumBuf); err != nil {
  214. return nil, fmt.Errorf("read checksum failed: %w", err)
  215. }
  216. receivedChecksum := binary.BigEndian.Uint16(checksumBuf)
  217. // 计算校验和 (CRC16-CCITT)
  218. calculatedChecksum := util.Crc16CCITT(dataBody)
  219. if calculatedChecksum != receivedChecksum {
  220. return nil, fmt.Errorf("%w: expected %04X, got %04X",
  221. ErrChecksumMismatch, calculatedChecksum, receivedChecksum)
  222. }
  223. return dataBody, nil
  224. }
  225. // 心跳检测
  226. func heartbeat(ctx context.Context, ci *ConnInfo) {
  227. ticker := time.NewTicker(HeartbeatInterval)
  228. defer ticker.Stop()
  229. for {
  230. select {
  231. case <-ticker.C:
  232. if ci.closed.Load() {
  233. return
  234. }
  235. // 发送心跳包
  236. if _, err := ci.conn.Write([]byte("PING")); err != nil {
  237. ci.Close()
  238. return
  239. }
  240. case <-ctx.Done():
  241. return
  242. }
  243. }
  244. }
  245. // 示例消息处理
  246. func processMessage(info *ConnInfo, msg []byte) error {
  247. // 实现业务逻辑
  248. fmt.Printf("Processing message: %s\n", util.ToHexBytes(msg))
  249. info.conn.Write(msg)
  250. return nil
  251. }