streaming.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package fasthttp
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "sync"
  7. "github.com/valyala/bytebufferpool"
  8. )
  9. type headerInterface interface {
  10. ContentLength() int
  11. ReadTrailer(r *bufio.Reader) error
  12. }
  13. type requestStream struct {
  14. header headerInterface
  15. prefetchedBytes *bytes.Reader
  16. reader *bufio.Reader
  17. totalBytesRead int
  18. chunkLeft int
  19. }
  20. func (rs *requestStream) Read(p []byte) (int, error) {
  21. var (
  22. n int
  23. err error
  24. )
  25. if rs.header.ContentLength() == -1 {
  26. if rs.chunkLeft == 0 {
  27. chunkSize, err := parseChunkSize(rs.reader)
  28. if err != nil {
  29. return 0, err
  30. }
  31. if chunkSize == 0 {
  32. err = rs.header.ReadTrailer(rs.reader)
  33. if err != nil && err != io.EOF {
  34. return 0, err
  35. }
  36. return 0, io.EOF
  37. }
  38. rs.chunkLeft = chunkSize
  39. }
  40. bytesToRead := len(p)
  41. if rs.chunkLeft < len(p) {
  42. bytesToRead = rs.chunkLeft
  43. }
  44. n, err = rs.reader.Read(p[:bytesToRead])
  45. rs.totalBytesRead += n
  46. rs.chunkLeft -= n
  47. if err == io.EOF {
  48. err = io.ErrUnexpectedEOF
  49. }
  50. if err == nil && rs.chunkLeft == 0 {
  51. err = readCrLf(rs.reader)
  52. }
  53. return n, err
  54. }
  55. if rs.totalBytesRead == rs.header.ContentLength() {
  56. return 0, io.EOF
  57. }
  58. prefetchedSize := int(rs.prefetchedBytes.Size())
  59. if prefetchedSize > rs.totalBytesRead {
  60. left := prefetchedSize - rs.totalBytesRead
  61. if len(p) > left {
  62. p = p[:left]
  63. }
  64. n, err := rs.prefetchedBytes.Read(p)
  65. rs.totalBytesRead += n
  66. if n == rs.header.ContentLength() {
  67. return n, io.EOF
  68. }
  69. return n, err
  70. }
  71. left := rs.header.ContentLength() - rs.totalBytesRead
  72. if len(p) > left {
  73. p = p[:left]
  74. }
  75. n, err = rs.reader.Read(p)
  76. rs.totalBytesRead += n
  77. if err != nil {
  78. return n, err
  79. }
  80. if rs.totalBytesRead == rs.header.ContentLength() {
  81. err = io.EOF
  82. }
  83. return n, err
  84. }
  85. func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h headerInterface) *requestStream {
  86. rs := requestStreamPool.Get().(*requestStream)
  87. rs.prefetchedBytes = bytes.NewReader(b.B)
  88. rs.reader = r
  89. rs.header = h
  90. return rs
  91. }
  92. func releaseRequestStream(rs *requestStream) {
  93. rs.prefetchedBytes = nil
  94. rs.totalBytesRead = 0
  95. rs.chunkLeft = 0
  96. rs.reader = nil
  97. rs.header = nil
  98. requestStreamPool.Put(rs)
  99. }
  100. var requestStreamPool = sync.Pool{
  101. New: func() any {
  102. return &requestStream{}
  103. },
  104. }