Skip to content

Commit 5686044

Browse files
committed
add WriteString
1 parent b6a62f1 commit 5686044

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

adapter_test.go

+105
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,111 @@ type noopHandler struct{}
906906

907907
func (noopHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {}
908908

909+
func TestWriteStringNoCompressionStatic(t *testing.T) {
910+
t.Parallel()
911+
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
912+
if w, ok := w.(interface{ WriteString(string) (int, error) }); ok {
913+
w.WriteString("hello string world!")
914+
return
915+
}
916+
w.Write([]byte("hello bytes world!"))
917+
})
918+
a, _ := DefaultAdapter()
919+
h = a(h)
920+
// Do not send accept-encoding to disable compression
921+
r, _ := http.NewRequest("GET", "/", nil)
922+
t.Run("WriteString", func(t *testing.T) {
923+
w := &discardResponseWriterWithWriteString{}
924+
h.ServeHTTP(w, r)
925+
if w.s != 19 {
926+
t.Fatalf("WriteString not called: %+v", w)
927+
}
928+
})
929+
t.Run("Write", func(t *testing.T) {
930+
w := &discardResponseWriter{}
931+
h.ServeHTTP(w, r)
932+
if w.b != 18 {
933+
t.Fatalf("Write not called: %+v", w)
934+
}
935+
})
936+
}
937+
938+
func TestWriteStringNoCompressionDynamic(t *testing.T) {
939+
t.Parallel()
940+
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
941+
w.Header().Set("Content-Type", "text/uncompressible")
942+
if w, ok := w.(interface{ WriteString(string) (int, error) }); ok {
943+
w.WriteString(testBody) // first WriteString will fallback to Write
944+
w.WriteString(testBody)
945+
return
946+
}
947+
w.Write([]byte(testBody))
948+
w.Write([]byte(testBody))
949+
})
950+
a, _ := DefaultAdapter(ContentTypes([]string{"text/uncompressible"}, true))
951+
h = a(h)
952+
r, _ := http.NewRequest("GET", "/", nil)
953+
r.Header.Set("Accept-Encoding", "gzip")
954+
t.Run("WriteString", func(t *testing.T) {
955+
w := &discardResponseWriterWithWriteString{}
956+
h.ServeHTTP(w, r)
957+
if w.s != len(testBody) || w.b != len(testBody) { // first WriteString falls back to Write
958+
t.Fatalf("WriteString not called: %+v", w)
959+
}
960+
})
961+
t.Run("Write", func(t *testing.T) {
962+
w := &discardResponseWriter{}
963+
h.ServeHTTP(w, r)
964+
if w.b != len(testBody)*2 {
965+
t.Fatalf("Write not called: %+v", w)
966+
}
967+
})
968+
}
969+
970+
type discardResponseWriterWithWriteString struct {
971+
discardResponseWriter
972+
s int
973+
}
974+
975+
func (w *discardResponseWriterWithWriteString) WriteString(s string) (n int, err error) {
976+
w.s += len(s)
977+
return len(s), nil
978+
}
979+
980+
func TestWriteStringEquivalence(t *testing.T) {
981+
t.Parallel()
982+
983+
for _, ae := range []string{"gzip", "uncompressed"} {
984+
for _, ct := range []string{"text", "uncompressible"} {
985+
t.Run(fmt.Sprintf("%s/%s", ae, ct), func(t *testing.T) {
986+
r, _ := http.NewRequest("GET", "/", nil)
987+
r.Header.Set("Accept-Encoding", ae)
988+
a, _ := DefaultAdapter(ContentTypes([]string{"uncompressible"}, true))
989+
990+
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
991+
w.Header().Set("Content-Type", ct)
992+
w.(interface{ WriteString(string) (int, error) }).WriteString(testBody)
993+
w.(interface{ WriteString(string) (int, error) }).WriteString(testBody)
994+
})
995+
h = a(h)
996+
ws := httptest.NewRecorder()
997+
h.ServeHTTP(ws, r)
998+
999+
h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1000+
w.Header().Set("Content-Type", ct)
1001+
w.Write([]byte(testBody))
1002+
w.Write([]byte(testBody))
1003+
})
1004+
h = a(h)
1005+
w := httptest.NewRecorder()
1006+
h.ServeHTTP(w, r)
1007+
1008+
assert.Equal(t, ws.Body.Bytes(), w.Body.Bytes(), "response body mismatch")
1009+
})
1010+
}
1011+
}
1012+
}
1013+
9091014
// --------------------------------------------------------------------
9101015

9111016
func BenchmarkAdapter(b *testing.B) {

response_writer.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ var (
3232
_ io.WriteCloser = &compressWriter{}
3333
_ http.Flusher = &compressWriter{}
3434
_ http.Hijacker = &compressWriter{}
35+
_ writeStringer = &compressWriter{}
3536
)
3637

3738
type compressWriterWithCloseNotify struct {
@@ -46,11 +47,12 @@ var (
4647
_ io.WriteCloser = compressWriterWithCloseNotify{}
4748
_ http.Flusher = compressWriterWithCloseNotify{}
4849
_ http.Hijacker = compressWriterWithCloseNotify{}
50+
_ writeStringer = compressWriterWithCloseNotify{}
4951
)
5052

5153
const maxBuf = 1 << 16 // maximum size of recycled buffer
5254

53-
// Write appends data to the gzip writer.
55+
// WriteString compresses and appends the given byte slice to the underlying ResponseWriter.
5456
func (w *compressWriter) Write(b []byte) (int, error) {
5557
if w.w != nil {
5658
// The responseWriter is already initialized: use it.
@@ -108,6 +110,28 @@ func (w *compressWriter) Write(b []byte) (int, error) {
108110
return len(b), nil
109111
}
110112

113+
// WriteString compresses and appends the given string to the underlying ResponseWriter.
114+
//
115+
// This makes use of an optional method (WriteString) exposed by the compressors, or by
116+
// the underlying ResponseWriter.
117+
func (w *compressWriter) WriteString(s string) (int, error) {
118+
// Since WriteString is an optional interface of the compressor, and the actual compressor
119+
// is chosen only after the first call to Write, we can't statically know whether the interface
120+
// is supported. We therefore have to check dynamically.
121+
if ws, _ := w.w.(writeStringer); ws != nil {
122+
// The responseWriter is already initialized and it implements WriteString.
123+
return ws.WriteString(s)
124+
}
125+
// Fallback: the writer has not been initialized yet, or it has been initialized
126+
// and it does not implement WriteString. We could in theory do something unsafe
127+
// here but for now let's keep it simple and fallback to Write.
128+
return w.Write([]byte(s))
129+
}
130+
131+
type writeStringer interface {
132+
WriteString(string) (int, error)
133+
}
134+
111135
// startCompress initializes a compressing writer and writes the buffer.
112136
func (w *compressWriter) startCompress(enc string) error {
113137
comp, ok := w.config.compressor[enc]

0 commit comments

Comments
 (0)