@@ -40,7 +40,9 @@ type WriteFlusher interface {
40
40
}
41
41
42
42
type CompressResponseWriter struct {
43
- http.ResponseWriter
43
+ Header * BufferedServerHeader
44
+ ControllerResponse * Response
45
+ OriginalWriter io.Writer
44
46
compressWriter WriteFlusher
45
47
compressionType string
46
48
headersWritten bool
@@ -49,23 +51,33 @@ type CompressResponseWriter struct {
49
51
closed bool
50
52
}
51
53
52
- // CompressFilter does compresssion of response body in gzip/deflate if
54
+ // CompressFilter does compression of response body in gzip/deflate if
53
55
// `results.compressed=true` in the app.conf
54
56
func CompressFilter (c * Controller , fc []Filter ) {
55
- fc [0 ](c , fc [1 :])
56
- if Config .BoolDefault ("results.compressed" , false ) {
57
+ if c .Response .ServerHeader != nil && Config .BoolDefault ("results.compressed" , false ) {
57
58
if c .Response .Status != http .StatusNoContent && c .Response .Status != http .StatusNotModified {
58
- writer := CompressResponseWriter {c .Response .Out , nil , "" , false , make (chan bool , 1 ), nil , false }
59
- writer .DetectCompressionType (c .Request , c .Response )
60
- w , ok := c .Response .Out .(http.CloseNotifier )
61
- if ok {
62
- writer .parentNotify = w .CloseNotify ()
59
+ if found , compressType , compressWriter := detectCompressionType (c .Request , c .Response ); found {
60
+ writer := CompressResponseWriter {
61
+ ControllerResponse :c .Response ,
62
+ OriginalWriter : c .Response .GetWriter (),
63
+ compressWriter : compressWriter ,
64
+ compressionType :compressType ,
65
+ headersWritten :false ,
66
+ closeNotify : make (chan bool , 1 ),
67
+ closed : false }
68
+ // Swap out the header with our own
69
+ writer .Header = NewBufferedServerHeader (c .Response .ServerHeader )
70
+ c .Response .ServerHeader = writer .Header
71
+ if w , ok := c .Response .GetWriter ().(http.CloseNotifier ); ok {
72
+ writer .parentNotify = w .CloseNotify ()
73
+ }
74
+ c .Response .SetWriter (& writer )
63
75
}
64
- c .Response .Out = & writer
65
76
} else {
66
77
TRACE .Printf ("Compression disabled for response status (%d)" , c .Response .Status )
67
78
}
68
79
}
80
+ fc [0 ](c , fc [1 :])
69
81
}
70
82
71
83
func (c CompressResponseWriter ) CloseNotify () <- chan bool {
@@ -77,16 +89,16 @@ func (c CompressResponseWriter) CloseNotify() <-chan bool {
77
89
78
90
func (c * CompressResponseWriter ) prepareHeaders () {
79
91
if c .compressionType != "" {
80
- responseMime := c .Header () .Get ("Content-Type" )
92
+ responseMime := c .Header .Get ("Content-Type" )
81
93
responseMime = strings .TrimSpace (strings .SplitN (responseMime , ";" , 2 )[0 ])
82
94
shouldEncode := false
83
95
84
- if c .Header () .Get ("Content-Encoding" ) == "" {
96
+ if c .Header .Get ("Content-Encoding" ) == "" {
85
97
for _ , compressableMime := range compressableMimes {
86
98
if responseMime == compressableMime {
87
99
shouldEncode = true
88
- c .Header () .Set ("Content-Encoding" , c .compressionType )
89
- c .Header () .Del ("Content-Length" )
100
+ c .Header .Set ("Content-Encoding" , c .compressionType )
101
+ c .Header .Del ("Content-Length" )
90
102
break
91
103
}
92
104
}
@@ -97,20 +109,26 @@ func (c *CompressResponseWriter) prepareHeaders() {
97
109
c .compressionType = ""
98
110
}
99
111
}
112
+ c .Header .Release ()
100
113
}
101
114
102
115
func (c * CompressResponseWriter ) WriteHeader (status int ) {
103
116
c .headersWritten = true
104
117
c .prepareHeaders ()
105
- c .ResponseWriter . WriteHeader (status )
118
+ c .Header . SetStatus (status )
106
119
}
107
120
108
121
func (c * CompressResponseWriter ) Close () error {
109
- if c .compressionType != "" {
110
- _ = c .compressWriter .Close ()
111
- }
112
- if w , ok := c .ResponseWriter .(io.Closer ); ok {
113
- _ = w .Close ()
122
+ if ! c .headersWritten {
123
+ c .prepareHeaders ()
124
+ }
125
+ if c .compressionType != "" {
126
+ c .Header .Del ("Content-Length" )
127
+ if err := c .compressWriter .Close (); err != nil {
128
+ // TODO When writing directly to stream, an error will be generated
129
+ ERROR .Println ("Error closing compress writer" ,c .compressionType , err )
130
+ }
131
+
114
132
}
115
133
// Non-blocking write to the closenotifier, if we for some reason should
116
134
// get called multiple times
@@ -123,6 +141,7 @@ func (c *CompressResponseWriter) Close() error {
123
141
}
124
142
125
143
func (c * CompressResponseWriter ) Write (b []byte ) (int , error ) {
144
+ println ("*** Write called" )
126
145
// Abort if parent has been closed
127
146
if c .parentNotify != nil {
128
147
select {
@@ -135,23 +154,22 @@ func (c *CompressResponseWriter) Write(b []byte) (int, error) {
135
154
if c .closed {
136
155
return 0 , io .ErrClosedPipe
137
156
}
157
+
138
158
if ! c .headersWritten {
139
159
c .prepareHeaders ()
140
160
c .headersWritten = true
141
161
}
142
-
143
162
if c .compressionType != "" {
144
163
return c .compressWriter .Write (b )
145
164
}
146
-
147
- return c .ResponseWriter .Write (b )
165
+ return c .OriginalWriter .Write (b )
148
166
}
149
167
150
- // DetectCompressionType method detects the comperssion type
168
+ // DetectCompressionType method detects the compression type
151
169
// from header "Accept-Encoding"
152
- func ( c * CompressResponseWriter ) DetectCompressionType ( req * Request , resp * Response ) {
170
+ func detectCompressionType ( req * Request , resp * Response ) ( found bool , compressionType string , compressionKind WriteFlusher ) {
153
171
if Config .BoolDefault ("results.compressed" , false ) {
154
- acceptedEncodings := strings .Split (req .Request . Header . Get ("Accept-Encoding" ), "," )
172
+ acceptedEncodings := strings .Split (req .HttpHeaderValue ("Accept-Encoding" ), "," )
155
173
156
174
largestQ := 0.0
157
175
chosenEncoding := len (compressionTypes )
@@ -216,13 +234,98 @@ func (c *CompressResponseWriter) DetectCompressionType(req *Request, resp *Respo
216
234
return
217
235
}
218
236
219
- c . compressionType = compressionTypes [chosenEncoding ]
237
+ compressionType = compressionTypes [chosenEncoding ]
220
238
221
- switch c . compressionType {
239
+ switch compressionType {
222
240
case "gzip" :
223
- c .compressWriter = gzip .NewWriter (resp .Out )
241
+ compressionKind = gzip .NewWriter (resp .GetWriter ())
242
+ found = true
224
243
case "deflate" :
225
- c .compressWriter = zlib .NewWriter (resp .Out )
244
+ compressionKind = zlib .NewWriter (resp .GetWriter ())
245
+ found = true
226
246
}
227
247
}
248
+ return
249
+ }
250
+
251
+
252
+ // This class will not send content out until the Released is called, from that point on it will act normally
253
+ // It implements all the ServerHeader
254
+ type BufferedServerHeader struct {
255
+ cookieList []string
256
+ headerMap map [string ][]string
257
+ status int
258
+ released bool
259
+ original ServerHeader
260
+ }
261
+ func NewBufferedServerHeader (o ServerHeader ) * BufferedServerHeader {
262
+ return & BufferedServerHeader {original :o ,headerMap :map [string ][]string {}}
263
+ }
264
+ func (bsh * BufferedServerHeader ) SetCookie (cookie string ) {
265
+ if bsh .released {
266
+ bsh .original .SetCookie (cookie )
267
+ } else {
268
+ bsh .cookieList = append (bsh .cookieList ,cookie )
269
+ }
270
+ }
271
+ func (bsh * BufferedServerHeader ) GetCookie (key string ) (value ServerCookie , err error ) {
272
+ return bsh .original .GetCookie (key )
273
+ }
274
+ func (bsh * BufferedServerHeader ) Set (key string , value string ){
275
+ if bsh .released {
276
+ bsh .original .Set (key ,value )
277
+ } else {
278
+ bsh .headerMap [key ]= []string {value }
279
+ }
280
+ }
281
+ func (bsh * BufferedServerHeader ) Add (key string , value string ) {
282
+ if bsh .released {
283
+ bsh .original .Set (key ,value )
284
+ } else {
285
+ old := []string {}
286
+ if v ,found := bsh .headerMap [key ];found {
287
+ old = v
288
+ }
289
+ bsh .headerMap [key ]= append (old ,value )
290
+ }
291
+
292
+ }
293
+ func (bsh * BufferedServerHeader ) Del (key string ){
294
+ if bsh .released {
295
+ bsh .original .Del (key )
296
+ } else {
297
+ delete (bsh .headerMap ,key )
298
+ }
299
+
300
+ }
301
+ func (bsh * BufferedServerHeader ) Get (key string ) (value string ){
302
+ if bsh .released {
303
+ value = bsh .original .Get (key )
304
+ } else {
305
+ if v ,found := bsh .headerMap [key ]; found && len (v )> 0 {
306
+ value = v [0 ]
307
+ } else {
308
+ value = bsh .original .Get (key )
309
+ }
310
+ }
311
+ return
312
+ }
313
+ func (bsh * BufferedServerHeader ) SetStatus (statusCode int ) {
314
+ if bsh .released {
315
+ bsh .original .SetStatus (statusCode )
316
+ } else {
317
+ bsh .status = statusCode
318
+ }
319
+ }
320
+ func (bsh * BufferedServerHeader ) Release () {
321
+ bsh .released = true
322
+ bsh .original .SetStatus (bsh .status )
323
+ for k ,v := range bsh .headerMap {
324
+ for _ ,r := range v {
325
+ bsh .original .Set (k , r )
326
+ }
327
+ }
328
+ for _ ,c := range bsh .cookieList {
329
+ bsh .original .SetCookie (c )
330
+ }
228
331
}
0 commit comments