cache.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. // Special thanks to @codemicro for moving this to fiber core
  2. // Original middleware: github.com/codemicro/fiber-cache
  3. package cache
  4. import (
  5. "strconv"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/gofiber/fiber/v2"
  11. "github.com/gofiber/fiber/v2/utils"
  12. )
  13. // timestampUpdatePeriod is the period which is used to check the cache expiration.
  14. // It should not be too long to provide more or less acceptable expiration error, and in the same
  15. // time it should not be too short to avoid overwhelming of the system
  16. const timestampUpdatePeriod = 300 * time.Millisecond
  17. // cache status
  18. // unreachable: when cache is bypass, or invalid
  19. // hit: cache is served
  20. // miss: do not have cache record
  21. const (
  22. cacheUnreachable = "unreachable"
  23. cacheHit = "hit"
  24. cacheMiss = "miss"
  25. )
  26. // directives
  27. const (
  28. noCache = "no-cache"
  29. noStore = "no-store"
  30. )
  31. var ignoreHeaders = map[string]interface{}{
  32. "Connection": nil,
  33. "Keep-Alive": nil,
  34. "Proxy-Authenticate": nil,
  35. "Proxy-Authorization": nil,
  36. "TE": nil,
  37. "Trailers": nil,
  38. "Transfer-Encoding": nil,
  39. "Upgrade": nil,
  40. "Content-Type": nil, // already stored explicitly by the cache manager
  41. "Content-Encoding": nil, // already stored explicitly by the cache manager
  42. }
  43. // New creates a new middleware handler
  44. func New(config ...Config) fiber.Handler {
  45. // Set default config
  46. cfg := configDefault(config...)
  47. // Nothing to cache
  48. if int(cfg.Expiration.Seconds()) < 0 {
  49. return func(c *fiber.Ctx) error {
  50. return c.Next()
  51. }
  52. }
  53. var (
  54. // Cache settings
  55. mux = &sync.RWMutex{}
  56. timestamp = uint64(time.Now().Unix())
  57. )
  58. // Create manager to simplify storage operations ( see manager.go )
  59. manager := newManager(cfg.Storage)
  60. // Create indexed heap for tracking expirations ( see heap.go )
  61. heap := &indexedHeap{}
  62. // count stored bytes (sizes of response bodies)
  63. var storedBytes uint
  64. // Update timestamp in the configured interval
  65. go func() {
  66. for {
  67. atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
  68. time.Sleep(timestampUpdatePeriod)
  69. }
  70. }()
  71. // Delete key from both manager and storage
  72. deleteKey := func(dkey string) {
  73. manager.del(dkey)
  74. // External storage saves body data with different key
  75. if cfg.Storage != nil {
  76. manager.del(dkey + "_body")
  77. }
  78. }
  79. // Return new handler
  80. return func(c *fiber.Ctx) error {
  81. // Refrain from caching
  82. if hasRequestDirective(c, noStore) {
  83. return c.Next()
  84. }
  85. // Only cache selected methods
  86. var isExists bool
  87. for _, method := range cfg.Methods {
  88. if c.Method() == method {
  89. isExists = true
  90. }
  91. }
  92. if !isExists {
  93. c.Set(cfg.CacheHeader, cacheUnreachable)
  94. return c.Next()
  95. }
  96. // Get key from request
  97. // TODO(allocation optimization): try to minimize the allocation from 2 to 1
  98. key := cfg.KeyGenerator(c) + "_" + c.Method()
  99. // Get entry from pool
  100. e := manager.get(key)
  101. // Lock entry
  102. mux.Lock()
  103. // Get timestamp
  104. ts := atomic.LoadUint64(&timestamp)
  105. // Check if entry is expired
  106. if e.exp != 0 && ts >= e.exp {
  107. deleteKey(key)
  108. if cfg.MaxBytes > 0 {
  109. _, size := heap.remove(e.heapidx)
  110. storedBytes -= size
  111. }
  112. } else if e.exp != 0 && !hasRequestDirective(c, noCache) {
  113. // Separate body value to avoid msgp serialization
  114. // We can store raw bytes with Storage 👍
  115. if cfg.Storage != nil {
  116. e.body = manager.getRaw(key + "_body")
  117. }
  118. // Set response headers from cache
  119. c.Response().SetBodyRaw(e.body)
  120. c.Response().SetStatusCode(e.status)
  121. c.Response().Header.SetContentTypeBytes(e.ctype)
  122. if len(e.cencoding) > 0 {
  123. c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
  124. }
  125. for k, v := range e.headers {
  126. c.Response().Header.SetBytesV(k, v)
  127. }
  128. // Set Cache-Control header if enabled
  129. if cfg.CacheControl {
  130. maxAge := strconv.FormatUint(e.exp-ts, 10)
  131. c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
  132. }
  133. c.Set(cfg.CacheHeader, cacheHit)
  134. mux.Unlock()
  135. // Return response
  136. return nil
  137. }
  138. // make sure we're not blocking concurrent requests - do unlock
  139. mux.Unlock()
  140. // Continue stack, return err to Fiber if exist
  141. if err := c.Next(); err != nil {
  142. return err
  143. }
  144. // lock entry back and unlock on finish
  145. mux.Lock()
  146. defer mux.Unlock()
  147. // Don't cache response if Next returns true
  148. if cfg.Next != nil && cfg.Next(c) {
  149. c.Set(cfg.CacheHeader, cacheUnreachable)
  150. return nil
  151. }
  152. // Don't try to cache if body won't fit into cache
  153. bodySize := uint(len(c.Response().Body()))
  154. if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
  155. c.Set(cfg.CacheHeader, cacheUnreachable)
  156. return nil
  157. }
  158. // Remove oldest to make room for new
  159. if cfg.MaxBytes > 0 {
  160. for storedBytes+bodySize > cfg.MaxBytes {
  161. key, size := heap.removeFirst()
  162. deleteKey(key)
  163. storedBytes -= size
  164. }
  165. }
  166. // Cache response
  167. e.body = utils.CopyBytes(c.Response().Body())
  168. e.status = c.Response().StatusCode()
  169. e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
  170. e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
  171. // Store all response headers
  172. // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
  173. if cfg.StoreResponseHeaders {
  174. e.headers = make(map[string][]byte)
  175. c.Response().Header.VisitAll(
  176. func(key, value []byte) {
  177. // create real copy
  178. keyS := string(key)
  179. if _, ok := ignoreHeaders[keyS]; !ok {
  180. e.headers[keyS] = utils.CopyBytes(value)
  181. }
  182. },
  183. )
  184. }
  185. // default cache expiration
  186. expiration := cfg.Expiration
  187. // Calculate expiration by response header or other setting
  188. if cfg.ExpirationGenerator != nil {
  189. expiration = cfg.ExpirationGenerator(c, &cfg)
  190. }
  191. e.exp = ts + uint64(expiration.Seconds())
  192. // Store entry in heap
  193. if cfg.MaxBytes > 0 {
  194. e.heapidx = heap.put(key, e.exp, bodySize)
  195. storedBytes += bodySize
  196. }
  197. // For external Storage we store raw body separated
  198. if cfg.Storage != nil {
  199. manager.setRaw(key+"_body", e.body, expiration)
  200. // avoid body msgp encoding
  201. e.body = nil
  202. manager.set(key, e, expiration)
  203. manager.release(e)
  204. } else {
  205. // Store entry in memory
  206. manager.set(key, e, expiration)
  207. }
  208. c.Set(cfg.CacheHeader, cacheMiss)
  209. // Finish response
  210. return nil
  211. }
  212. }
  213. // Check if request has directive
  214. func hasRequestDirective(c *fiber.Ctx, directive string) bool {
  215. return strings.Contains(c.Get(fiber.HeaderCacheControl), directive)
  216. }