peripconn.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package fasthttp
  2. import (
  3. "net"
  4. "sync"
  5. )
  6. type perIPConnCounter struct {
  7. pool sync.Pool
  8. lock sync.Mutex
  9. m map[uint32]int
  10. }
  11. func (cc *perIPConnCounter) Register(ip uint32) int {
  12. cc.lock.Lock()
  13. if cc.m == nil {
  14. cc.m = make(map[uint32]int)
  15. }
  16. n := cc.m[ip] + 1
  17. cc.m[ip] = n
  18. cc.lock.Unlock()
  19. return n
  20. }
  21. func (cc *perIPConnCounter) Unregister(ip uint32) {
  22. cc.lock.Lock()
  23. defer cc.lock.Unlock()
  24. if cc.m == nil {
  25. // developer safeguard
  26. panic("BUG: perIPConnCounter.Register() wasn't called")
  27. }
  28. n := cc.m[ip] - 1
  29. if n < 0 {
  30. n = 0
  31. }
  32. cc.m[ip] = n
  33. }
  34. type perIPConn struct {
  35. net.Conn
  36. ip uint32
  37. perIPConnCounter *perIPConnCounter
  38. }
  39. func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
  40. v := counter.pool.Get()
  41. if v == nil {
  42. return &perIPConn{
  43. perIPConnCounter: counter,
  44. Conn: conn,
  45. ip: ip,
  46. }
  47. }
  48. c := v.(*perIPConn)
  49. c.Conn = conn
  50. c.ip = ip
  51. return c
  52. }
  53. func releasePerIPConn(c *perIPConn) {
  54. c.Conn = nil
  55. c.perIPConnCounter.pool.Put(c)
  56. }
  57. func (c *perIPConn) Close() error {
  58. err := c.Conn.Close()
  59. c.perIPConnCounter.Unregister(c.ip)
  60. releasePerIPConn(c)
  61. return err
  62. }
  63. func getUint32IP(c net.Conn) uint32 {
  64. return ip2uint32(getConnIP4(c))
  65. }
  66. func getConnIP4(c net.Conn) net.IP {
  67. addr := c.RemoteAddr()
  68. ipAddr, ok := addr.(*net.TCPAddr)
  69. if !ok {
  70. return net.IPv4zero
  71. }
  72. return ipAddr.IP.To4()
  73. }
  74. func ip2uint32(ip net.IP) uint32 {
  75. if len(ip) != 4 {
  76. return 0
  77. }
  78. return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
  79. }
  80. func uint322ip(ip uint32) net.IP {
  81. b := make([]byte, 4)
  82. b[0] = byte(ip >> 24)
  83. b[1] = byte(ip >> 16)
  84. b[2] = byte(ip >> 8)
  85. b[3] = byte(ip)
  86. return b
  87. }