TCPServer.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. package server
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/binary"
  6. "errors"
  7. "file-manger-server/util"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "os"
  13. "os/signal"
  14. "sync"
  15. "sync/atomic"
  16. "syscall"
  17. "time"
  18. )
  19. // ConnPool 改进后的连接池结构
  20. type ConnPool struct {
  21. connections sync.Map // 使用sync.Map替代原生map
  22. maxConns int32 // 最大连接数
  23. currentConns atomic.Int32 // 当前连接数(原子操作)
  24. ctx context.Context // 上下文
  25. cancel context.CancelFunc // 用于取消上下文
  26. }
  27. type ConnInfo struct {
  28. conn net.Conn // 原始连接
  29. lastActive atomic.Int64 // 最后活跃时间(原子操作)
  30. heartbeatTTL time.Duration // 心跳TTL
  31. closed atomic.Bool // 关闭状态标记
  32. bufReader *bufio.Reader // 带缓冲的读取器
  33. }
  34. const (
  35. MaxConnections = 1000 // 最大连接数
  36. HeartbeatInterval = 30 * time.Second // 心跳间隔
  37. ReadTimeout = 2 * time.Minute // 读取超时时间
  38. )
  39. func NewConnPool(maxConns int) *ConnPool {
  40. ctx, cancel := context.WithCancel(context.Background())
  41. return &ConnPool{
  42. maxConns: int32(maxConns),
  43. ctx: ctx,
  44. cancel: cancel,
  45. }
  46. }
  47. // Add 添加连接到池(带数量检查)
  48. func (p *ConnPool) Add(conn net.Conn) error {
  49. if p.currentConns.Load() >= p.maxConns {
  50. return errors.New("connection limit reached")
  51. }
  52. fmt.Println(fmt.Sprintf("Ip \"%s\" connect.", conn.RemoteAddr()))
  53. ci := &ConnInfo{
  54. conn: conn,
  55. heartbeatTTL: HeartbeatInterval * 3,
  56. bufReader: bufio.NewReaderSize(conn, 4096),
  57. }
  58. ci.lastActive.Store(time.Now().UnixNano())
  59. p.connections.Store(conn, ci)
  60. p.currentConns.Add(1)
  61. return nil
  62. }
  63. // Remove 移除连接
  64. func (p *ConnPool) Remove(conn net.Conn) {
  65. if ci, loaded := p.connections.LoadAndDelete(conn); loaded {
  66. ci.(*ConnInfo).Close()
  67. p.currentConns.Add(-1)
  68. }
  69. }
  70. // Close 关闭连接
  71. func (ci *ConnInfo) Close() {
  72. if ci.closed.CompareAndSwap(false, true) {
  73. ci.conn.Close()
  74. }
  75. }
  76. // StartCleaner 启动自动清理协程
  77. func (p *ConnPool) StartCleaner() {
  78. go func() {
  79. ticker := time.NewTicker(1 * time.Minute)
  80. defer ticker.Stop()
  81. for {
  82. select {
  83. case <-ticker.C:
  84. p.cleanup()
  85. case <-p.ctx.Done():
  86. return
  87. }
  88. }
  89. }()
  90. }
  91. // 清理失效连接
  92. func (p *ConnPool) cleanup() {
  93. p.connections.Range(func(key, value any) bool {
  94. ci := value.(*ConnInfo)
  95. if time.Since(time.Unix(0, ci.lastActive.Load())) > ci.heartbeatTTL {
  96. p.Remove(ci.conn)
  97. }
  98. return true
  99. })
  100. }
  101. // Shutdown 优雅关闭
  102. func (p *ConnPool) Shutdown() {
  103. p.cancel()
  104. // 设置关闭超时为60秒
  105. ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
  106. defer cancel()
  107. done := make(chan struct{})
  108. go func() {
  109. p.connections.Range(func(key, value any) bool {
  110. ci := value.(*ConnInfo)
  111. ci.Close()
  112. p.connections.Delete(key)
  113. return true
  114. })
  115. close(done)
  116. }()
  117. select {
  118. case <-done:
  119. case <-ctx.Done():
  120. log.Println("强制关闭剩余连接")
  121. }
  122. }
  123. var Pool *ConnPool
  124. func Run() {
  125. Pool = NewConnPool(MaxConnections)
  126. Pool.StartCleaner()
  127. listener, err := net.Listen("tcp", "0.0.0.0:8080")
  128. if err != nil {
  129. log.Fatal(err)
  130. }
  131. // 信号处理
  132. sigCh := make(chan os.Signal, 1)
  133. signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
  134. // 修改信号处理goroutine
  135. go func() {
  136. <-sigCh
  137. log.Println("Shutting down server...")
  138. // 先关闭listener停止接受新连接
  139. listener.Close()
  140. // 然后关闭连接池
  141. Pool.Shutdown()
  142. // 添加退出等待逻辑
  143. time.Sleep(1 * time.Second) // 等待处理中的请求完成
  144. os.Exit(0) // 确保程序退出
  145. }()
  146. for {
  147. conn, err := listener.Accept()
  148. if err != nil {
  149. if errors.Is(err, net.ErrClosed) {
  150. break // 正常关闭
  151. }
  152. log.Printf("Accept error: %v", err)
  153. continue
  154. }
  155. if err := Pool.Add(conn); err != nil {
  156. conn.Close()
  157. log.Printf("Reject connection: %v", err)
  158. continue
  159. }
  160. go handleConnection(Pool, conn)
  161. }
  162. }
  163. func handleConnection(pool *ConnPool, conn net.Conn) {
  164. defer func() {
  165. if r := recover(); r != nil {
  166. log.Printf("Connection panic: %v", r)
  167. }
  168. pool.Remove(conn)
  169. }()
  170. // 获取连接信息
  171. ci, ok := pool.connections.Load(conn)
  172. if !ok {
  173. return
  174. }
  175. connInfo := ci.(*ConnInfo)
  176. // 心跳检测
  177. ctx, cancel := context.WithCancel(context.Background())
  178. defer cancel()
  179. go heartbeat(ctx, connInfo)
  180. // 消息处理循环
  181. for {
  182. if connInfo.closed.Load() {
  183. ctx.Done()
  184. return
  185. }
  186. // 设置读取超时
  187. connInfo.conn.SetReadDeadline(time.Now().Add(ReadTimeout))
  188. // 使用长度前缀协议处理粘包
  189. message, err := readMessage(connInfo.bufReader)
  190. if err != nil {
  191. if !errors.Is(err, io.EOF) {
  192. log.Printf("Read error: %v", err)
  193. }
  194. return
  195. }
  196. //处理业务
  197. HandlerReceiveFile(connInfo, message)
  198. // 更新活跃时间
  199. connInfo.lastActive.Store(time.Now().UnixNano())
  200. // 处理业务逻辑
  201. if err := processMessage(connInfo, message); err != nil {
  202. log.Printf("Process message error: %v", err)
  203. return
  204. }
  205. }
  206. }
  207. // 自定义错误类型
  208. var (
  209. ErrInvalidStartByte = errors.New("invalid start byte")
  210. ErrChecksumMismatch = errors.New("checksum mismatch")
  211. ErrMessageTooLarge = errors.New("message exceeds size limit")
  212. )
  213. func readMessage(r *bufio.Reader) ([]byte, error) {
  214. // 读取起始字节 (0x2a)
  215. startByte, err := r.ReadByte()
  216. if err != nil {
  217. return nil, fmt.Errorf("read start byte failed: %w", err)
  218. }
  219. if startByte != 0x2a {
  220. return nil, fmt.Errorf("%w: received 0x%02x", ErrInvalidStartByte, startByte)
  221. }
  222. // 读取长度字段 (4字节大端序)
  223. lengthBuf := make([]byte, 4)
  224. if _, err := io.ReadFull(r, lengthBuf); err != nil {
  225. return nil, fmt.Errorf("read length failed: %w", err)
  226. }
  227. dataLength := binary.BigEndian.Uint32(lengthBuf)
  228. // 验证数据长度 (10MB限制)
  229. if dataLength > 10*1024*1024 {
  230. return nil, fmt.Errorf("%w: %d bytes", ErrMessageTooLarge, dataLength)
  231. }
  232. // 读取数据体
  233. dataBody := make([]byte, dataLength)
  234. if _, err := io.ReadFull(r, dataBody); err != nil {
  235. return nil, fmt.Errorf("read data body failed: %w", err)
  236. }
  237. // 读取校验和 (2字节大端序)
  238. checksumBuf := make([]byte, 2)
  239. if _, err := io.ReadFull(r, checksumBuf); err != nil {
  240. return nil, fmt.Errorf("read checksum failed: %w", err)
  241. }
  242. receivedChecksum := binary.BigEndian.Uint16(checksumBuf)
  243. // 计算校验和 (CRC16-CCITT)
  244. calculatedChecksum := util.Crc16CCITT(dataBody)
  245. if calculatedChecksum != receivedChecksum {
  246. return nil, fmt.Errorf("%w: expected %04X, got %04X",
  247. ErrChecksumMismatch, calculatedChecksum, receivedChecksum)
  248. }
  249. return dataBody, nil
  250. }
  251. // 心跳检测
  252. func heartbeat(ctx context.Context, ci *ConnInfo) {
  253. ticker := time.NewTicker(HeartbeatInterval)
  254. defer ticker.Stop()
  255. for {
  256. select {
  257. case <-ticker.C:
  258. if ci.closed.Load() {
  259. return
  260. }
  261. // 发送心跳包
  262. if _, err := ci.conn.Write([]byte("PING")); err != nil {
  263. ci.Close()
  264. return
  265. }
  266. case <-ctx.Done():
  267. return
  268. }
  269. }
  270. }
  271. // 示例消息处理
  272. func processMessage(info *ConnInfo, msg []byte) error {
  273. // 实现业务逻辑
  274. fmt.Printf("Processing message: %s\n", util.ToHexBytes(msg))
  275. info.conn.Write(msg)
  276. return nil
  277. }