tcpdialer.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. package fasthttp
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "strconv"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. )
  11. // Dial dials the given TCP addr using tcp4.
  12. //
  13. // This function has the following additional features comparing to net.Dial:
  14. //
  15. // - It reduces load on DNS resolver by caching resolved TCP addressed
  16. // for DNSCacheDuration.
  17. // - It dials all the resolved TCP addresses in round-robin manner until
  18. // connection is established. This may be useful if certain addresses
  19. // are temporarily unreachable.
  20. // - It returns ErrDialTimeout if connection cannot be established during
  21. // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
  22. //
  23. // This dialer is intended for custom code wrapping before passing
  24. // to Client.Dial or HostClient.Dial.
  25. //
  26. // For instance, per-host counters and/or limits may be implemented
  27. // by such wrappers.
  28. //
  29. // The addr passed to the function must contain port. Example addr values:
  30. //
  31. // - foobar.baz:443
  32. // - foo.bar:80
  33. // - aaa.com:8080
  34. func Dial(addr string) (net.Conn, error) {
  35. return defaultDialer.Dial(addr)
  36. }
  37. // DialTimeout dials the given TCP addr using tcp4 using the given timeout.
  38. //
  39. // This function has the following additional features comparing to net.Dial:
  40. //
  41. // - It reduces load on DNS resolver by caching resolved TCP addressed
  42. // for DNSCacheDuration.
  43. // - It dials all the resolved TCP addresses in round-robin manner until
  44. // connection is established. This may be useful if certain addresses
  45. // are temporarily unreachable.
  46. //
  47. // This dialer is intended for custom code wrapping before passing
  48. // to Client.Dial or HostClient.Dial.
  49. //
  50. // For instance, per-host counters and/or limits may be implemented
  51. // by such wrappers.
  52. //
  53. // The addr passed to the function must contain port. Example addr values:
  54. //
  55. // - foobar.baz:443
  56. // - foo.bar:80
  57. // - aaa.com:8080
  58. func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  59. return defaultDialer.DialTimeout(addr, timeout)
  60. }
  61. // DialDualStack dials the given TCP addr using both tcp4 and tcp6.
  62. //
  63. // This function has the following additional features comparing to net.Dial:
  64. //
  65. // - It reduces load on DNS resolver by caching resolved TCP addressed
  66. // for DNSCacheDuration.
  67. // - It dials all the resolved TCP addresses in round-robin manner until
  68. // connection is established. This may be useful if certain addresses
  69. // are temporarily unreachable.
  70. // - It returns ErrDialTimeout if connection cannot be established during
  71. // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
  72. // timeout.
  73. //
  74. // This dialer is intended for custom code wrapping before passing
  75. // to Client.Dial or HostClient.Dial.
  76. //
  77. // For instance, per-host counters and/or limits may be implemented
  78. // by such wrappers.
  79. //
  80. // The addr passed to the function must contain port. Example addr values:
  81. //
  82. // - foobar.baz:443
  83. // - foo.bar:80
  84. // - aaa.com:8080
  85. func DialDualStack(addr string) (net.Conn, error) {
  86. return defaultDialer.DialDualStack(addr)
  87. }
  88. // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
  89. // using the given timeout.
  90. //
  91. // This function has the following additional features comparing to net.Dial:
  92. //
  93. // - It reduces load on DNS resolver by caching resolved TCP addressed
  94. // for DNSCacheDuration.
  95. // - It dials all the resolved TCP addresses in round-robin manner until
  96. // connection is established. This may be useful if certain addresses
  97. // are temporarily unreachable.
  98. //
  99. // This dialer is intended for custom code wrapping before passing
  100. // to Client.Dial or HostClient.Dial.
  101. //
  102. // For instance, per-host counters and/or limits may be implemented
  103. // by such wrappers.
  104. //
  105. // The addr passed to the function must contain port. Example addr values:
  106. //
  107. // - foobar.baz:443
  108. // - foo.bar:80
  109. // - aaa.com:8080
  110. func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  111. return defaultDialer.DialDualStackTimeout(addr, timeout)
  112. }
  113. var defaultDialer = &TCPDialer{Concurrency: 1000}
  114. // Resolver represents interface of the tcp resolver.
  115. type Resolver interface {
  116. LookupIPAddr(context.Context, string) (names []net.IPAddr, err error)
  117. }
  118. // TCPDialer contains options to control a group of Dial calls.
  119. type TCPDialer struct {
  120. // Concurrency controls the maximum number of concurrent Dials
  121. // that can be performed using this object.
  122. // Setting this to 0 means unlimited.
  123. //
  124. // WARNING: This can only be changed before the first Dial.
  125. // Changes made after the first Dial will not affect anything.
  126. Concurrency int
  127. // LocalAddr is the local address to use when dialing an
  128. // address.
  129. // If nil, a local address is automatically chosen.
  130. LocalAddr *net.TCPAddr
  131. // This may be used to override DNS resolving policy, like this:
  132. // var dialer = &fasthttp.TCPDialer{
  133. // Resolver: &net.Resolver{
  134. // PreferGo: true,
  135. // StrictErrors: false,
  136. // Dial: func (ctx context.Context, network, address string) (net.Conn, error) {
  137. // d := net.Dialer{}
  138. // return d.DialContext(ctx, "udp", "8.8.8.8:53")
  139. // },
  140. // },
  141. // }
  142. Resolver Resolver
  143. // DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration)
  144. DNSCacheDuration time.Duration
  145. tcpAddrsMap sync.Map
  146. concurrencyCh chan struct{}
  147. once sync.Once
  148. }
  149. // Dial dials the given TCP addr using tcp4.
  150. //
  151. // This function has the following additional features comparing to net.Dial:
  152. //
  153. // - It reduces load on DNS resolver by caching resolved TCP addressed
  154. // for DNSCacheDuration.
  155. // - It dials all the resolved TCP addresses in round-robin manner until
  156. // connection is established. This may be useful if certain addresses
  157. // are temporarily unreachable.
  158. // - It returns ErrDialTimeout if connection cannot be established during
  159. // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
  160. //
  161. // This dialer is intended for custom code wrapping before passing
  162. // to Client.Dial or HostClient.Dial.
  163. //
  164. // For instance, per-host counters and/or limits may be implemented
  165. // by such wrappers.
  166. //
  167. // The addr passed to the function must contain port. Example addr values:
  168. //
  169. // - foobar.baz:443
  170. // - foo.bar:80
  171. // - aaa.com:8080
  172. func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
  173. return d.dial(addr, false, DefaultDialTimeout)
  174. }
  175. // DialTimeout dials the given TCP addr using tcp4 using the given timeout.
  176. //
  177. // This function has the following additional features comparing to net.Dial:
  178. //
  179. // - It reduces load on DNS resolver by caching resolved TCP addressed
  180. // for DNSCacheDuration.
  181. // - It dials all the resolved TCP addresses in round-robin manner until
  182. // connection is established. This may be useful if certain addresses
  183. // are temporarily unreachable.
  184. //
  185. // This dialer is intended for custom code wrapping before passing
  186. // to Client.Dial or HostClient.Dial.
  187. //
  188. // For instance, per-host counters and/or limits may be implemented
  189. // by such wrappers.
  190. //
  191. // The addr passed to the function must contain port. Example addr values:
  192. //
  193. // - foobar.baz:443
  194. // - foo.bar:80
  195. // - aaa.com:8080
  196. func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  197. return d.dial(addr, false, timeout)
  198. }
  199. // DialDualStack dials the given TCP addr using both tcp4 and tcp6.
  200. //
  201. // This function has the following additional features comparing to net.Dial:
  202. //
  203. // - It reduces load on DNS resolver by caching resolved TCP addressed
  204. // for DNSCacheDuration.
  205. // - It dials all the resolved TCP addresses in round-robin manner until
  206. // connection is established. This may be useful if certain addresses
  207. // are temporarily unreachable.
  208. // - It returns ErrDialTimeout if connection cannot be established during
  209. // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
  210. // timeout.
  211. //
  212. // This dialer is intended for custom code wrapping before passing
  213. // to Client.Dial or HostClient.Dial.
  214. //
  215. // For instance, per-host counters and/or limits may be implemented
  216. // by such wrappers.
  217. //
  218. // The addr passed to the function must contain port. Example addr values:
  219. //
  220. // - foobar.baz:443
  221. // - foo.bar:80
  222. // - aaa.com:8080
  223. func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
  224. return d.dial(addr, true, DefaultDialTimeout)
  225. }
  226. // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
  227. // using the given timeout.
  228. //
  229. // This function has the following additional features comparing to net.Dial:
  230. //
  231. // - It reduces load on DNS resolver by caching resolved TCP addressed
  232. // for DNSCacheDuration.
  233. // - It dials all the resolved TCP addresses in round-robin manner until
  234. // connection is established. This may be useful if certain addresses
  235. // are temporarily unreachable.
  236. //
  237. // This dialer is intended for custom code wrapping before passing
  238. // to Client.Dial or HostClient.Dial.
  239. //
  240. // For instance, per-host counters and/or limits may be implemented
  241. // by such wrappers.
  242. //
  243. // The addr passed to the function must contain port. Example addr values:
  244. //
  245. // - foobar.baz:443
  246. // - foo.bar:80
  247. // - aaa.com:8080
  248. func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
  249. return d.dial(addr, true, timeout)
  250. }
  251. func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) {
  252. d.once.Do(func() {
  253. if d.Concurrency > 0 {
  254. d.concurrencyCh = make(chan struct{}, d.Concurrency)
  255. }
  256. if d.DNSCacheDuration == 0 {
  257. d.DNSCacheDuration = DefaultDNSCacheDuration
  258. }
  259. go d.tcpAddrsClean()
  260. })
  261. deadline := time.Now().Add(timeout)
  262. addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline)
  263. if err != nil {
  264. return nil, err
  265. }
  266. network := "tcp4"
  267. if dualStack {
  268. network = "tcp"
  269. }
  270. var conn net.Conn
  271. n := uint32(len(addrs))
  272. for n > 0 {
  273. conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh)
  274. if err == nil {
  275. return conn, nil
  276. }
  277. if err == ErrDialTimeout {
  278. return nil, err
  279. }
  280. idx++
  281. n--
  282. }
  283. return nil, err
  284. }
  285. func (d *TCPDialer) tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) {
  286. timeout := time.Until(deadline)
  287. if timeout <= 0 {
  288. return nil, ErrDialTimeout
  289. }
  290. if concurrencyCh != nil {
  291. select {
  292. case concurrencyCh <- struct{}{}:
  293. default:
  294. tc := AcquireTimer(timeout)
  295. isTimeout := false
  296. select {
  297. case concurrencyCh <- struct{}{}:
  298. case <-tc.C:
  299. isTimeout = true
  300. }
  301. ReleaseTimer(tc)
  302. if isTimeout {
  303. return nil, ErrDialTimeout
  304. }
  305. }
  306. defer func() { <-concurrencyCh }()
  307. }
  308. dialer := net.Dialer{}
  309. if d.LocalAddr != nil {
  310. dialer.LocalAddr = d.LocalAddr
  311. }
  312. ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)
  313. defer cancelCtx()
  314. conn, err := dialer.DialContext(ctx, network, addr.String())
  315. if err != nil && ctx.Err() == context.DeadlineExceeded {
  316. return nil, ErrDialTimeout
  317. }
  318. return conn, err
  319. }
  320. // ErrDialTimeout is returned when TCP dialing is timed out.
  321. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out")
  322. // DefaultDialTimeout is timeout used by Dial and DialDualStack
  323. // for establishing TCP connections.
  324. const DefaultDialTimeout = 3 * time.Second
  325. type tcpAddrEntry struct {
  326. addrs []net.TCPAddr
  327. addrsIdx uint32
  328. pending int32
  329. resolveTime time.Time
  330. }
  331. // DefaultDNSCacheDuration is the duration for caching resolved TCP addresses
  332. // by Dial* functions.
  333. const DefaultDNSCacheDuration = time.Minute
  334. func (d *TCPDialer) tcpAddrsClean() {
  335. expireDuration := 2 * d.DNSCacheDuration
  336. for {
  337. time.Sleep(time.Second)
  338. t := time.Now()
  339. d.tcpAddrsMap.Range(func(k, v interface{}) bool {
  340. if e, ok := v.(*tcpAddrEntry); ok && t.Sub(e.resolveTime) > expireDuration {
  341. d.tcpAddrsMap.Delete(k)
  342. }
  343. return true
  344. })
  345. }
  346. }
  347. func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool, deadline time.Time) ([]net.TCPAddr, uint32, error) {
  348. item, exist := d.tcpAddrsMap.Load(addr)
  349. e, ok := item.(*tcpAddrEntry)
  350. if exist && ok && e != nil && time.Since(e.resolveTime) > d.DNSCacheDuration {
  351. // Only let one goroutine re-resolve at a time.
  352. if atomic.SwapInt32(&e.pending, 1) == 0 {
  353. e = nil
  354. }
  355. }
  356. if e == nil {
  357. addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver, deadline)
  358. if err != nil {
  359. item, exist := d.tcpAddrsMap.Load(addr)
  360. e, ok = item.(*tcpAddrEntry)
  361. if exist && ok && e != nil {
  362. // Set pending to 0 so another goroutine can retry.
  363. atomic.StoreInt32(&e.pending, 0)
  364. }
  365. return nil, 0, err
  366. }
  367. e = &tcpAddrEntry{
  368. addrs: addrs,
  369. resolveTime: time.Now(),
  370. }
  371. d.tcpAddrsMap.Store(addr, e)
  372. }
  373. idx := atomic.AddUint32(&e.addrsIdx, 1)
  374. return e.addrs, idx, nil
  375. }
  376. func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver, deadline time.Time) ([]net.TCPAddr, error) {
  377. host, portS, err := net.SplitHostPort(addr)
  378. if err != nil {
  379. return nil, err
  380. }
  381. port, err := strconv.Atoi(portS)
  382. if err != nil {
  383. return nil, err
  384. }
  385. if resolver == nil {
  386. resolver = net.DefaultResolver
  387. }
  388. ctx, cancel := context.WithDeadline(context.Background(), deadline)
  389. defer cancel()
  390. ipaddrs, err := resolver.LookupIPAddr(ctx, host)
  391. if err != nil {
  392. return nil, err
  393. }
  394. n := len(ipaddrs)
  395. addrs := make([]net.TCPAddr, 0, n)
  396. for i := 0; i < n; i++ {
  397. ip := ipaddrs[i]
  398. if !dualStack && ip.IP.To4() == nil {
  399. continue
  400. }
  401. addrs = append(addrs, net.TCPAddr{
  402. IP: ip.IP,
  403. Port: port,
  404. Zone: ip.Zone,
  405. })
  406. }
  407. if len(addrs) == 0 {
  408. return nil, errNoDNSEntries
  409. }
  410. return addrs, nil
  411. }
  412. var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack")