diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index 0215a5dde..b6b4f9c19 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -492,7 +492,7 @@ func (r *Resource) GoString() string { // A ResourceBody is a DNS resource record minus the header. type ResourceBody interface { // pack packs a Resource except for its header. - pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) + pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) // realType returns the actual type of the Resource. This is used to // fill in the header Type field. @@ -503,7 +503,7 @@ type ResourceBody interface { } // pack appends the wire format of the Resource to msg. -func (r *Resource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *Resource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { if r.Body == nil { return msg, errNilResouceBody } @@ -1129,7 +1129,7 @@ func (m *Message) AppendPack(b []byte) ([]byte, error) { // DNS messages can be a maximum of 512 bytes long. Without compression, // many DNS response messages are over this limit, so enabling // compression will help ensure compliance. - compression := map[string]int{} + compression := map[string]uint16{} for i := range m.Questions { var err error @@ -1220,7 +1220,7 @@ type Builder struct { // compression is a mapping from name suffixes to their starting index // in msg. - compression map[string]int + compression map[string]uint16 } // NewBuilder creates a new builder with compression disabled. @@ -1257,7 +1257,7 @@ func NewBuilder(buf []byte, h Header) Builder { // // Compression should be enabled before any sections are added for best results. func (b *Builder) EnableCompression() { - b.compression = map[string]int{} + b.compression = map[string]uint16{} } func (b *Builder) startCheck(s section) error { @@ -1673,7 +1673,7 @@ func (h *ResourceHeader) GoString() string { // pack appends the wire format of the ResourceHeader to oldMsg. // // lenOff is the offset in msg where the Length field was packed. -func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]int, compressionOff int) (msg []byte, lenOff int, err error) { +func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]uint16, compressionOff int) (msg []byte, lenOff int, err error) { msg = oldMsg if msg, err = h.Name.pack(msg, compression, compressionOff); err != nil { return oldMsg, 0, &nestedError{"Name", err} @@ -1901,7 +1901,7 @@ func unpackBytes(msg []byte, off int, field []byte) (int, error) { const nonEncodedNameMax = 254 -// A Name is a non-encoded domain name. It is used instead of strings to avoid +// A Name is a non-encoded and non-escaped domain name. It is used instead of strings to avoid // allocations. type Name struct { Data [255]byte @@ -1928,6 +1928,8 @@ func MustNewName(name string) Name { } // String implements fmt.Stringer.String. +// +// Note: characters inside the labels are not escaped in any way. func (n Name) String() string { return string(n.Data[:n.Length]) } @@ -1944,7 +1946,7 @@ func (n *Name) GoString() string { // // The compression map will be updated with new domain suffixes. If compression // is nil, compression will not be used. -func (n *Name) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (n *Name) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { oldMsg := msg if n.Length > nonEncodedNameMax { @@ -2008,7 +2010,7 @@ func (n *Name) pack(msg []byte, compression map[string]int, compressionOff int) // multiple times (for next labels). nameAsStr = string(n.Data[:n.Length]) } - compression[nameAsStr[i:]] = newPtr + compression[nameAsStr[i:]] = uint16(newPtr) } } } @@ -2148,7 +2150,7 @@ type Question struct { } // pack appends the wire format of the Question to msg. -func (q *Question) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (q *Question) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { msg, err := q.Name.pack(msg, compression, compressionOff) if err != nil { return msg, &nestedError{"Name", err} @@ -2244,7 +2246,7 @@ func (r *CNAMEResource) realType() Type { } // pack appends the wire format of the CNAMEResource to msg. -func (r *CNAMEResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *CNAMEResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return r.CNAME.pack(msg, compression, compressionOff) } @@ -2272,7 +2274,7 @@ func (r *MXResource) realType() Type { } // pack appends the wire format of the MXResource to msg. -func (r *MXResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *MXResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { oldMsg := msg msg = packUint16(msg, r.Pref) msg, err := r.MX.pack(msg, compression, compressionOff) @@ -2311,7 +2313,7 @@ func (r *NSResource) realType() Type { } // pack appends the wire format of the NSResource to msg. -func (r *NSResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *NSResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return r.NS.pack(msg, compression, compressionOff) } @@ -2338,7 +2340,7 @@ func (r *PTRResource) realType() Type { } // pack appends the wire format of the PTRResource to msg. -func (r *PTRResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *PTRResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return r.PTR.pack(msg, compression, compressionOff) } @@ -2375,7 +2377,7 @@ func (r *SOAResource) realType() Type { } // pack appends the wire format of the SOAResource to msg. -func (r *SOAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *SOAResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { oldMsg := msg msg, err := r.NS.pack(msg, compression, compressionOff) if err != nil { @@ -2447,7 +2449,7 @@ func (r *TXTResource) realType() Type { } // pack appends the wire format of the TXTResource to msg. -func (r *TXTResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *TXTResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { oldMsg := msg for _, s := range r.TXT { var err error @@ -2503,7 +2505,7 @@ func (r *SRVResource) realType() Type { } // pack appends the wire format of the SRVResource to msg. -func (r *SRVResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *SRVResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { oldMsg := msg msg = packUint16(msg, r.Priority) msg = packUint16(msg, r.Weight) @@ -2554,7 +2556,7 @@ func (r *AResource) realType() Type { } // pack appends the wire format of the AResource to msg. -func (r *AResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *AResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return packBytes(msg, r.A[:]), nil } @@ -2588,7 +2590,7 @@ func (r *AAAAResource) GoString() string { } // pack appends the wire format of the AAAAResource to msg. -func (r *AAAAResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *AAAAResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return packBytes(msg, r.AAAA[:]), nil } @@ -2628,7 +2630,7 @@ func (r *OPTResource) realType() Type { return TypeOPT } -func (r *OPTResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *OPTResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { for _, opt := range r.Options { msg = packUint16(msg, opt.Code) l := uint16(len(opt.Data)) @@ -2686,7 +2688,7 @@ func (r *UnknownResource) realType() Type { } // pack appends the wire format of the UnknownResource to msg. -func (r *UnknownResource) pack(msg []byte, compression map[string]int, compressionOff int) ([]byte, error) { +func (r *UnknownResource) pack(msg []byte, compression map[string]uint16, compressionOff int) ([]byte, error) { return packBytes(msg, r.Data[:]), nil } diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index 23fb3d574..c84d5a3aa 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -164,7 +164,7 @@ func TestQuestionPackUnpack(t *testing.T) { Type: TypeA, Class: ClassINET, } - buf, err := want.pack(make([]byte, 1, 50), map[string]int{}, 1) + buf, err := want.pack(make([]byte, 1, 50), map[string]uint16{}, 1) if err != nil { t.Fatal("Question.pack() =", err) } @@ -243,7 +243,7 @@ func TestNamePackUnpack(t *testing.T) { for _, test := range tests { in := MustNewName(test.in) - buf, err := in.pack(make([]byte, 0, 30), map[string]int{}, 0) + buf, err := in.pack(make([]byte, 0, 30), map[string]uint16{}, 0) if err != test.err { t.Errorf("got %q.pack() = %v, want = %v", test.in, err, test.err) continue @@ -305,7 +305,7 @@ func TestNameUnpackTooLongName(t *testing.T) { func TestIncompressibleName(t *testing.T) { name := MustNewName("example.com.") - compression := map[string]int{} + compression := map[string]uint16{} buf, err := name.pack(make([]byte, 0, 100), compression, 0) if err != nil { t.Fatal("first Name.pack() =", err) @@ -623,7 +623,7 @@ func TestVeryLongTxt(t *testing.T) { strings.Repeat(".", 255), }}, } - buf, err := want.pack(make([]byte, 0, 8000), map[string]int{}, 0) + buf, err := want.pack(make([]byte, 0, 8000), map[string]uint16{}, 0) if err != nil { t.Fatal("Resource.pack() =", err) } @@ -647,7 +647,7 @@ func TestVeryLongTxt(t *testing.T) { func TestTooLongTxt(t *testing.T) { rb := TXTResource{[]string{strings.Repeat(".", 256)}} - if _, err := rb.pack(make([]byte, 0, 8000), map[string]int{}, 0); err != errStringTooLong { + if _, err := rb.pack(make([]byte, 0, 8000), map[string]uint16{}, 0); err != errStringTooLong { t.Errorf("packing TXTResource with 256 character string: got err = %v, want = %v", err, errStringTooLong) } } diff --git a/go.mod b/go.mod index b16f4e5e6..38ac82b44 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module golang.org/x/net go 1.17 require ( - golang.org/x/crypto v0.13.0 - golang.org/x/sys v0.12.0 - golang.org/x/term v0.12.0 + golang.org/x/crypto v0.14.0 + golang.org/x/sys v0.13.0 + golang.org/x/term v0.13.0 golang.org/x/text v0.13.0 ) diff --git a/go.sum b/go.sum index 0fd3311f4..dc4dc125c 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -20,14 +20,14 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= diff --git a/http2/server.go b/http2/server.go index 6d5e00887..02c88b6b3 100644 --- a/http2/server.go +++ b/http2/server.go @@ -581,9 +581,11 @@ type serverConn struct { advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client curClientStreams uint32 // number of open streams initiated by the client curPushedStreams uint32 // number of open streams initiated by server push + curHandlers uint32 // number of running handler goroutines maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes streams map[uint32]*stream + unstartedHandlers []unstartedHandler initialStreamSendWindowSize int32 maxFrameSize int32 peerMaxHeaderListSize uint32 // zero means unknown (default) @@ -981,6 +983,8 @@ func (sc *serverConn) serve() { return case gracefulShutdownMsg: sc.startGracefulShutdownInternal() + case handlerDoneMsg: + sc.handlerDone() default: panic("unknown timer") } @@ -1020,6 +1024,7 @@ var ( idleTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage) + handlerDoneMsg = new(serverMessage) ) func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } @@ -1892,9 +1897,11 @@ func (st *stream) copyTrailersToHandlerRequest() { // onReadTimeout is run on its own goroutine (from time.AfterFunc) // when the stream's ReadTimeout has fired. func (st *stream) onReadTimeout() { - // Wrap the ErrDeadlineExceeded to avoid callers depending on us - // returning the bare error. - st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) + if st.body != nil { + // Wrap the ErrDeadlineExceeded to avoid callers depending on us + // returning the bare error. + st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) + } } // onWriteTimeout is run on its own goroutine (from time.AfterFunc) @@ -2012,13 +2019,10 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout != 0 { sc.conn.SetReadDeadline(time.Time{}) - if st.body != nil { - st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) - } + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } - go sc.runHandler(rw, req, handler) - return nil + return sc.scheduleHandler(id, rw, req, handler) } func (sc *serverConn) upgradeRequest(req *http.Request) { @@ -2038,6 +2042,10 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { sc.conn.SetReadDeadline(time.Time{}) } + // This is the first request on the connection, + // so start the handler directly rather than going + // through scheduleHandler. + sc.curHandlers++ go sc.runHandler(rw, req, sc.handler.ServeHTTP) } @@ -2278,8 +2286,62 @@ func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *response return &responseWriter{rws: rws} } +type unstartedHandler struct { + streamID uint32 + rw *responseWriter + req *http.Request + handler func(http.ResponseWriter, *http.Request) +} + +// scheduleHandler starts a handler goroutine, +// or schedules one to start as soon as an existing handler finishes. +func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error { + sc.serveG.check() + maxHandlers := sc.advMaxStreams + if sc.curHandlers < maxHandlers { + sc.curHandlers++ + go sc.runHandler(rw, req, handler) + return nil + } + if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) { + return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm)) + } + sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{ + streamID: streamID, + rw: rw, + req: req, + handler: handler, + }) + return nil +} + +func (sc *serverConn) handlerDone() { + sc.serveG.check() + sc.curHandlers-- + i := 0 + maxHandlers := sc.advMaxStreams + for ; i < len(sc.unstartedHandlers); i++ { + u := sc.unstartedHandlers[i] + if sc.streams[u.streamID] == nil { + // This stream was reset before its goroutine had a chance to start. + continue + } + if sc.curHandlers >= maxHandlers { + break + } + sc.curHandlers++ + go sc.runHandler(u.rw, u.req, u.handler) + sc.unstartedHandlers[i] = unstartedHandler{} // don't retain references + } + sc.unstartedHandlers = sc.unstartedHandlers[i:] + if len(sc.unstartedHandlers) == 0 { + sc.unstartedHandlers = nil + } +} + // Run on its own goroutine. func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { + defer sc.sendServeMsg(handlerDoneMsg) didPanic := true defer func() { rw.rws.stream.cancelCtx() diff --git a/http2/server_test.go b/http2/server_test.go index b99c5af54..22657cbfe 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4664,3 +4664,116 @@ func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { st.ts.Config.Close() <-donec } + +func TestServerMaxHandlerGoroutines(t *testing.T) { + const maxHandlers = 10 + handlerc := make(chan chan bool) + donec := make(chan struct{}) + defer close(donec) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + stopc := make(chan bool, 1) + select { + case handlerc <- stopc: + case <-donec: + } + select { + case shouldPanic := <-stopc: + if shouldPanic { + panic(http.ErrAbortHandler) + } + case <-donec: + } + }, func(s *Server) { + s.MaxConcurrentStreams = maxHandlers + }) + defer st.Close() + + st.writePreface() + st.writeInitialSettings() + st.writeSettingsAck() + + // Make maxHandlers concurrent requests. + // Reset them all, but only after the handler goroutines have started. + var stops []chan bool + streamID := uint32(1) + for i := 0; i < maxHandlers; i++ { + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + stops = append(stops, <-handlerc) + st.fr.WriteRSTStream(streamID, ErrCodeCancel) + streamID += 2 + } + + // Start another request, and immediately reset it. + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + st.fr.WriteRSTStream(streamID, ErrCodeCancel) + streamID += 2 + + // Start another two requests. Don't reset these. + for i := 0; i < 2; i++ { + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + streamID += 2 + } + + // The initial maxHandlers handlers are still executing, + // so the last two requests don't start any new handlers. + select { + case <-handlerc: + t.Errorf("handler unexpectedly started while maxHandlers are already running") + case <-time.After(1 * time.Millisecond): + } + + // Tell two handlers to exit. + // The pending requests which weren't reset start handlers. + stops[0] <- false // normal exit + stops[1] <- true // panic + stops = stops[2:] + stops = append(stops, <-handlerc) + stops = append(stops, <-handlerc) + + // Make a bunch more requests. + // Eventually, the server tells us to go away. + for i := 0; i < 5*maxHandlers; i++ { + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + st.fr.WriteRSTStream(streamID, ErrCodeCancel) + streamID += 2 + } +Frames: + for { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + switch f := f.(type) { + case *GoAwayFrame: + if f.ErrCode != ErrCodeEnhanceYourCalm { + t.Errorf("err code = %v; want %v", f.ErrCode, ErrCodeEnhanceYourCalm) + } + break Frames + default: + } + } + + for _, s := range stops { + close(s) + } +} diff --git a/internal/quic/config.go b/internal/quic/config.go index df493579f..b390d6911 100644 --- a/internal/quic/config.go +++ b/internal/quic/config.go @@ -18,28 +18,64 @@ type Config struct { // It must be non-nil and include at least one certificate or else set GetCertificate. TLSConfig *tls.Config - // StreamReadBufferSize is the maximum amount of data sent by the peer that a + // MaxBidiRemoteStreams limits the number of simultaneous bidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxBidiRemoteStreams int64 + + // MaxUniRemoteStreams limits the number of simultaneous unidirectional streams + // a peer may open. + // If zero, the default value of 100 is used. + // If negative, the limit is zero. + MaxUniRemoteStreams int64 + + // MaxStreamReadBufferSize is the maximum amount of data sent by the peer that a // stream will buffer for reading. // If zero, the default value of 1MiB is used. // If negative, the limit is zero. - StreamReadBufferSize int64 + MaxStreamReadBufferSize int64 - // StreamWriteBufferSize is the maximum amount of data a stream will buffer for + // MaxStreamWriteBufferSize is the maximum amount of data a stream will buffer for // sending to the peer. // If zero, the default value of 1MiB is used. // If negative, the limit is zero. - StreamWriteBufferSize int64 + MaxStreamWriteBufferSize int64 + + // MaxConnReadBufferSize is the maximum amount of data sent by the peer that a + // connection will buffer for reading, across all streams. + // If zero, the default value of 1MiB is used. + // If negative, the limit is zero. + MaxConnReadBufferSize int64 } -func configDefault(v, def int64) int64 { - switch v { - case -1: - return 0 - case 0: +func configDefault(v, def, limit int64) int64 { + switch { + case v == 0: return def + case v < 0: + return 0 + default: + return min(v, limit) } - return v } -func (c *Config) streamReadBufferSize() int64 { return configDefault(c.StreamReadBufferSize, 1<<20) } -func (c *Config) streamWriteBufferSize() int64 { return configDefault(c.StreamWriteBufferSize, 1<<20) } +func (c *Config) maxBidiRemoteStreams() int64 { + return configDefault(c.MaxBidiRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxUniRemoteStreams() int64 { + return configDefault(c.MaxUniRemoteStreams, 100, maxStreamsLimit) +} + +func (c *Config) maxStreamReadBufferSize() int64 { + return configDefault(c.MaxStreamReadBufferSize, 1<<20, maxVarint) +} + +func (c *Config) maxStreamWriteBufferSize() int64 { + return configDefault(c.MaxStreamWriteBufferSize, 1<<20, maxVarint) +} + +func (c *Config) maxConnReadBufferSize() int64 { + return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint) +} diff --git a/internal/quic/config_test.go b/internal/quic/config_test.go index cec57c5e3..d292854f5 100644 --- a/internal/quic/config_test.go +++ b/internal/quic/config_test.go @@ -10,16 +10,25 @@ import "testing" func TestConfigTransportParameters(t *testing.T) { const ( - wantInitialMaxStreamData = int64(2) + wantInitialMaxData = int64(1) + wantInitialMaxStreamData = int64(2) + wantInitialMaxStreamsBidi = int64(3) + wantInitialMaxStreamsUni = int64(4) ) tc := newTestConn(t, clientSide, func(c *Config) { - c.StreamReadBufferSize = wantInitialMaxStreamData + c.MaxBidiRemoteStreams = wantInitialMaxStreamsBidi + c.MaxUniRemoteStreams = wantInitialMaxStreamsUni + c.MaxStreamReadBufferSize = wantInitialMaxStreamData + c.MaxConnReadBufferSize = wantInitialMaxData }) tc.handshake() if tc.sentTransportParameters == nil { t.Fatalf("conn didn't send transport parameters during handshake") } p := tc.sentTransportParameters + if got, want := p.initialMaxData, wantInitialMaxData; got != want { + t.Errorf("initial_max_data = %v, want %v", got, want) + } if got, want := p.initialMaxStreamDataBidiLocal, wantInitialMaxStreamData; got != want { t.Errorf("initial_max_stream_data_bidi_local = %v, want %v", got, want) } @@ -29,4 +38,10 @@ func TestConfigTransportParameters(t *testing.T) { if got, want := p.initialMaxStreamDataUni, wantInitialMaxStreamData; got != want { t.Errorf("initial_max_stream_data_uni = %v, want %v", got, want) } + if got, want := p.initialMaxStreamsBidi, wantInitialMaxStreamsBidi; got != want { + t.Errorf("initial_max_stream_data_uni = %v, want %v", got, want) + } + if got, want := p.initialMaxStreamsUni, wantInitialMaxStreamsUni; got != want { + t.Errorf("initial_max_stream_data_uni = %v, want %v", got, want) + } } diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 04dcd7b6b..9db00fe09 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -20,7 +20,7 @@ import ( // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn struct { side connSide - listener connListener + listener *Listener config *Config testHooks connTestHooks peerAddr netip.AddrPort @@ -31,24 +31,22 @@ type Conn struct { w packetWriter acks [numberSpaceCount]ackState // indexed by number space + lifetime lifetimeState connIDState connIDState loss lossState streams streamsState - // errForPeer is set when the connection is being closed. - errForPeer error - connCloseSent [numberSpaceCount]bool - // idleTimeout is the time at which the connection will be closed due to inactivity. // https://www.rfc-editor.org/rfc/rfc9000#section-10.1 maxIdleTimeout time.Duration idleTimeout time.Time // Packet protection keys, CRYPTO streams, and TLS state. - rkeys [numberSpaceCount]keys - wkeys [numberSpaceCount]keys - crypto [numberSpaceCount]cryptoStream - tls *tls.QUICConn + keysInitial fixedKeyPair + keysHandshake fixedKeyPair + keysAppData updatingKeyPair + crypto [numberSpaceCount]cryptoStream + tls *tls.QUICConn // handshakeConfirmed is set when the handshake is confirmed. // For server connections, it tracks sending HANDSHAKE_DONE. @@ -61,22 +59,16 @@ type Conn struct { testSendPing sentVal } -// The connListener is the Conn's Listener. -// Defined as an interface so we can swap it out in tests. -type connListener interface { - sendDatagram(p []byte, addr netip.AddrPort) error -} - // connTestHooks override conn behavior in tests. type connTestHooks interface { nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any) handleTLSEvent(tls.QUICEvent) newConnID(seq int64) ([]byte, error) - waitAndLockGate(ctx context.Context, g *gate) error - waitOnDone(ctx context.Context, ch <-chan struct{}) error + waitUntil(ctx context.Context, until func() bool) error + timeNow() time.Time } -func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) { +func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l *Listener, hooks connTestHooks) (*Conn, error) { c := &Conn{ side: side, listener: l, @@ -94,35 +86,45 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip. // non-blocking operation. c.msgc = make(chan any, 1) + var originalDstConnID []byte if c.side == clientSide { - if err := c.connIDState.initClient(c.newConnIDFunc()); err != nil { + if err := c.connIDState.initClient(c); err != nil { return nil, err } initialConnID, _ = c.connIDState.dstConnID() } else { - if err := c.connIDState.initServer(c.newConnIDFunc(), initialConnID); err != nil { + if err := c.connIDState.initServer(c, initialConnID); err != nil { return nil, err } + originalDstConnID = initialConnID } // The smallest allowed maximum QUIC datagram size is 1200 bytes. // TODO: PMTU discovery. const maxDatagramSize = 1200 + c.keysAppData.init() c.loss.init(c.side, maxDatagramSize, now) c.streamsInit() + c.lifetimeInit() - // TODO: initial_source_connection_id, retry_source_connection_id - c.startTLS(now, initialConnID, transportParameters{ + // TODO: retry_source_connection_id + if err := c.startTLS(now, initialConnID, transportParameters{ initialSrcConnID: c.connIDState.srcConnID(), + originalDstConnID: originalDstConnID, ackDelayExponent: ackDelayExponent, maxUDPPayloadSize: maxUDPPayloadSize, maxAckDelay: maxAckDelay, disableActiveMigration: true, - initialMaxStreamDataBidiLocal: config.streamReadBufferSize(), - initialMaxStreamDataBidiRemote: config.streamReadBufferSize(), - initialMaxStreamDataUni: config.streamReadBufferSize(), + initialMaxData: config.maxConnReadBufferSize(), + initialMaxStreamDataBidiLocal: config.maxStreamReadBufferSize(), + initialMaxStreamDataBidiRemote: config.maxStreamReadBufferSize(), + initialMaxStreamDataUni: config.maxStreamReadBufferSize(), + initialMaxStreamsBidi: c.streams.remoteLimit[bidiStream].max, + initialMaxStreamsUni: c.streams.remoteLimit[uniStream].max, activeConnIDLimit: activeConnIDLimit, - }) + }); err != nil { + return nil, err + } go c.loop(now) return c, nil @@ -145,6 +147,7 @@ func (c *Conn) confirmHandshake(now time.Time) { if c.side == serverSide { // When the server confirms the handshake, it sends a HANDSHAKE_DONE. c.handshakeConfirmed.setUnsent() + c.listener.serverConnEstablished(c) } else { // The client never sends a HANDSHAKE_DONE, so we set handshakeConfirmed // to the received state, indicating that the handshake is confirmed and we @@ -160,19 +163,29 @@ func (c *Conn) confirmHandshake(now time.Time) { // discardKeys discards unused packet protection keys. // https://www.rfc-editor.org/rfc/rfc9001#section-4.9 func (c *Conn) discardKeys(now time.Time, space numberSpace) { - c.rkeys[space].discard() - c.wkeys[space].discard() + switch space { + case initialSpace: + c.keysInitial.discard() + case handshakeSpace: + c.keysHandshake.discard() + } c.loss.discardKeys(now, space) } // receiveTransportParameters applies transport parameters sent by the peer. func (c *Conn) receiveTransportParameters(p transportParameters) error { + if err := c.connIDState.validateTransportParameters(c.side, p); err != nil { + return err + } + c.streams.outflow.setMaxData(p.initialMaxData) + c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi) + c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni) c.streams.peerInitialMaxStreamDataBidiLocal = p.initialMaxStreamDataBidiLocal c.streams.peerInitialMaxStreamDataRemote[bidiStream] = p.initialMaxStreamDataBidiRemote c.streams.peerInitialMaxStreamDataRemote[uniStream] = p.initialMaxStreamDataUni c.peerAckDelayExponent = p.ackDelayExponent c.loss.setMaxAckDelay(p.maxAckDelay) - if err := c.connIDState.setPeerActiveConnIDLimit(p.activeConnIDLimit, c.newConnIDFunc()); err != nil { + if err := c.connIDState.setPeerActiveConnIDLimit(c, p.activeConnIDLimit); err != nil { return err } if p.preferredAddrConnID != nil { @@ -206,6 +219,7 @@ type ( func (c *Conn) loop(now time.Time) { defer close(c.donec) defer c.tls.Close() + defer c.listener.connDrained(c) // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to @@ -229,8 +243,12 @@ func (c *Conn) loop(now time.Time) { // since the Initial and Handshake spaces always ack immediately. nextTimeout := sendTimeout nextTimeout = firstTime(nextTimeout, c.idleTimeout) - nextTimeout = firstTime(nextTimeout, c.loss.timer) - nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) + if !c.isClosingOrDraining() { + nextTimeout = firstTime(nextTimeout, c.loss.timer) + nextTimeout = firstTime(nextTimeout, c.acks[appDataSpace].nextAck) + } else { + nextTimeout = firstTime(nextTimeout, c.lifetime.drainEndTime) + } var m any if hooks != nil { @@ -267,6 +285,11 @@ func (c *Conn) loop(now time.Time) { return } c.loss.advance(now, c.handleAckOrLoss) + if c.lifetimeAdvance(now) { + // The connection has completed the draining period, + // and may be shut down. + return + } case wakeEvent: // We're being woken up to try sending some frames. case func(time.Time, *Conn): @@ -311,16 +334,16 @@ func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error { return nil } -func (c *Conn) waitAndLockGate(ctx context.Context, g *gate) error { - if c.testHooks != nil { - return c.testHooks.waitAndLockGate(ctx, g) - } - return g.waitAndLockContext(ctx) -} - func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error { if c.testHooks != nil { - return c.testHooks.waitOnDone(ctx, ch) + return c.testHooks.waitUntil(ctx, func() bool { + select { + case <-ch: + return true + default: + } + return false + }) } // Check the channel before the context. // We always prefer to return results when available, @@ -338,21 +361,6 @@ func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error { return nil } -// abort terminates a connection with an error. -func (c *Conn) abort(now time.Time, err error) { - if c.errForPeer == nil { - c.errForPeer = err - } -} - -// exit fully terminates a connection immediately. -func (c *Conn) exit() { - c.runOnLoop(func(now time.Time, c *Conn) { - c.exited = true - }) - <-c.donec -} - // firstTime returns the earliest non-zero time, or zero if both times are zero. func firstTime(a, b time.Time) time.Time { switch { @@ -366,10 +374,3 @@ func firstTime(a, b time.Time) time.Time { return b } } - -func (c *Conn) newConnIDFunc() newConnIDFunc { - if c.testHooks != nil { - return c.testHooks.newConnID - } - return newRandomConnID -} diff --git a/internal/quic/conn_async_test.go b/internal/quic/conn_async_test.go index 5b419c4e5..dc2a57f9d 100644 --- a/internal/quic/conn_async_test.go +++ b/internal/quic/conn_async_test.go @@ -83,10 +83,7 @@ func (a *asyncOp[T]) result() (v T, err error) { // A blockedAsync is a blocked async operation. type blockedAsync struct { - // Exactly one of these will be set, depending on the type of blocked operation. - g *gate - ch <-chan struct{} - + until func() bool // when this returns true, the operation is unblocked donec chan struct{} // closed when the operation is unblocked } @@ -130,31 +127,12 @@ func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[ return a } -// waitAndLockGate replaces gate.waitAndLock in tests. -func (as *asyncTestState) waitAndLockGate(ctx context.Context, g *gate) error { - if g.lockIfSet() { - // Gate can be acquired without blocking. +// waitUntil waits for a blocked async operation to complete. +// The operation is complete when the until func returns true. +func (as *asyncTestState) waitUntil(ctx context.Context, until func() bool) error { + if until() { return nil } - return as.block(ctx, &blockedAsync{ - g: g, - }) -} - -// waitOnDone replaces receiving from a chan struct{} in tests. -func (as *asyncTestState) waitOnDone(ctx context.Context, ch <-chan struct{}) error { - select { - case <-ch: - return nil // read without blocking - default: - } - return as.block(ctx, &blockedAsync{ - ch: ch, - }) -} - -// block waits for a blocked async operation to complete. -func (as *asyncTestState) block(ctx context.Context, b *blockedAsync) error { if err := ctx.Err(); err != nil { // Context has already expired. return err @@ -166,7 +144,10 @@ func (as *asyncTestState) block(ctx context.Context, b *blockedAsync) error { // which may have unpredictable results. panic("blocking async point with unexpected Context") } - b.donec = make(chan struct{}) + b := &blockedAsync{ + until: until, + donec: make(chan struct{}), + } // Record this as a pending blocking operation. as.mu.Lock() as.blocked[b] = struct{}{} @@ -188,20 +169,9 @@ func (as *asyncTestState) wakeAsync() bool { as.mu.Lock() var woken *blockedAsync for w := range as.blocked { - switch { - case w.g != nil: - if w.g.lockIfSet() { - woken = w - } - case w.ch != nil: - select { - case <-w.ch: - woken = w - default: - } - } - if woken != nil { - delete(as.blocked, woken) + if w.until() { + woken = w + delete(as.blocked, w) break } } diff --git a/internal/quic/conn_close.go b/internal/quic/conn_close.go new file mode 100644 index 000000000..b8b86fd6f --- /dev/null +++ b/internal/quic/conn_close.go @@ -0,0 +1,252 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "errors" + "time" +) + +// lifetimeState tracks the state of a connection. +// +// This is fairly coupled to the rest of a Conn, but putting it in a struct of its own helps +// reason about operations that cause state transitions. +type lifetimeState struct { + readyc chan struct{} // closed when TLS handshake completes + drainingc chan struct{} // closed when entering the draining state + + // Possible states for the connection: + // + // Alive: localErr and finalErr are both nil. + // + // Closing: localErr is non-nil and finalErr is nil. + // We have sent a CONNECTION_CLOSE to the peer or are about to + // (if connCloseSentTime is zero) and are waiting for the peer to respond. + // drainEndTime is set to the time the closing state ends. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.1 + // + // Draining: finalErr is non-nil. + // If localErr is nil, we're waiting for the user to provide us with a final status + // to send to the peer. + // Otherwise, we've either sent a CONNECTION_CLOSE to the peer or are about to + // (if connCloseSentTime is zero). + // drainEndTime is set to the time the draining state ends. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 + localErr error // error sent to the peer + finalErr error // error sent by the peer, or transport error; always set before draining + + connCloseSentTime time.Time // send time of last CONNECTION_CLOSE frame + connCloseDelay time.Duration // delay until next CONNECTION_CLOSE frame sent + drainEndTime time.Time // time the connection exits the draining state +} + +func (c *Conn) lifetimeInit() { + c.lifetime.readyc = make(chan struct{}) + c.lifetime.drainingc = make(chan struct{}) +} + +var errNoPeerResponse = errors.New("peer did not respond to CONNECTION_CLOSE") + +// advance is called when time passes. +func (c *Conn) lifetimeAdvance(now time.Time) (done bool) { + if c.lifetime.drainEndTime.IsZero() || c.lifetime.drainEndTime.After(now) { + return false + } + // The connection drain period has ended, and we can shut down. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-7 + c.lifetime.drainEndTime = time.Time{} + if c.lifetime.finalErr == nil { + // The peer never responded to our CONNECTION_CLOSE. + c.enterDraining(errNoPeerResponse) + } + return true +} + +// confirmHandshake is called when the TLS handshake completes. +func (c *Conn) handshakeDone() { + close(c.lifetime.readyc) +} + +// isDraining reports whether the conn is in the draining state. +// +// The draining state is entered once an endpoint receives a CONNECTION_CLOSE frame. +// The endpoint will no longer send any packets, but we retain knowledge of the connection +// until the end of the drain period to ensure we discard packets for the connection +// rather than treating them as starting a new connection. +// +// https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 +func (c *Conn) isDraining() bool { + return c.lifetime.finalErr != nil +} + +// isClosingOrDraining reports whether the conn is in the closing or draining states. +func (c *Conn) isClosingOrDraining() bool { + return c.lifetime.localErr != nil || c.lifetime.finalErr != nil +} + +// sendOK reports whether the conn can send frames at this time. +func (c *Conn) sendOK(now time.Time) bool { + if !c.isClosingOrDraining() { + return true + } + // We are closing or draining. + if c.lifetime.localErr == nil { + // We're waiting for the user to close the connection, providing us with + // a final status to send to the peer. + return false + } + // Past this point, returning true will result in the conn sending a CONNECTION_CLOSE + // due to localErr being set. + if c.lifetime.drainEndTime.IsZero() { + // The closing and draining states should last for at least three times + // the current PTO interval. We currently use exactly that minimum. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2-5 + // + // The drain period begins when we send or receive a CONNECTION_CLOSE, + // whichever comes first. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2-3 + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + } + if c.lifetime.connCloseSentTime.IsZero() { + // We haven't sent a CONNECTION_CLOSE yet. Do so. + // Either we're initiating an immediate close + // (and will enter the closing state as soon as we send CONNECTION_CLOSE), + // or we've read a CONNECTION_CLOSE from our peer + // (and may send one CONNECTION_CLOSE before entering the draining state). + // + // Set the initial delay before we will send another CONNECTION_CLOSE. + // + // RFC 9000 states that we should rate limit CONNECTION_CLOSE frames, + // but leaves the implementation of the limit up to us. Here, we start + // with the same delay as the PTO timer (RFC 9002, Section 6.2.1), + // not including max_ack_delay, and double it on every CONNECTION_CLOSE sent. + c.lifetime.connCloseDelay = c.loss.rtt.smoothedRTT + max(4*c.loss.rtt.rttvar, timerGranularity) + c.lifetime.drainEndTime = now.Add(3 * c.loss.ptoBasePeriod()) + return true + } + if c.isDraining() { + // We are in the draining state, and will send no more packets. + return false + } + maxRecvTime := c.acks[initialSpace].maxRecvTime + if t := c.acks[handshakeSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if t := c.acks[appDataSpace].maxRecvTime; t.After(maxRecvTime) { + maxRecvTime = t + } + if maxRecvTime.Before(c.lifetime.connCloseSentTime.Add(c.lifetime.connCloseDelay)) { + // After sending CONNECTION_CLOSE, ignore packets from the peer for + // a delay. On the next packet received after the delay, send another + // CONNECTION_CLOSE. + return false + } + c.lifetime.connCloseSentTime = now + c.lifetime.connCloseDelay *= 2 + return true +} + +// enterDraining enters the draining state. +func (c *Conn) enterDraining(err error) { + if c.isDraining() { + return + } + if e, ok := c.lifetime.localErr.(localTransportError); ok && transportError(e) != errNo { + // If we've terminated the connection due to a peer protocol violation, + // record the final error on the connection as our reason for termination. + c.lifetime.finalErr = c.lifetime.localErr + } else { + c.lifetime.finalErr = err + } + close(c.lifetime.drainingc) + c.streams.queue.close(c.lifetime.finalErr) +} + +func (c *Conn) waitReady(ctx context.Context) error { + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.drainingc: + return c.lifetime.finalErr + default: + } + select { + case <-c.lifetime.readyc: + return nil + case <-c.lifetime.drainingc: + return c.lifetime.finalErr + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close closes the connection. +// +// Close is equivalent to: +// +// conn.Abort(nil) +// err := conn.Wait(context.Background()) +func (c *Conn) Close() error { + c.Abort(nil) + <-c.lifetime.drainingc + return c.lifetime.finalErr +} + +// Wait waits for the peer to close the connection. +// +// If the connection is closed locally and the peer does not close its end of the connection, +// Wait will return with a non-nil error after the drain period expires. +// +// If the peer closes the connection with a NO_ERROR transport error, Wait returns nil. +// If the peer closes the connection with an application error, Wait returns an ApplicationError +// containing the peer's error code and reason. +// If the peer closes the connection with any other status, Wait returns a non-nil error. +func (c *Conn) Wait(ctx context.Context) error { + if err := c.waitOnDone(ctx, c.lifetime.drainingc); err != nil { + return err + } + return c.lifetime.finalErr +} + +// Abort closes the connection and returns immediately. +// +// If err is nil, Abort sends a transport error of NO_ERROR to the peer. +// If err is an ApplicationError, Abort sends its error code and text. +// Otherwise, Abort sends a transport error of APPLICATION_ERROR with the error's text. +func (c *Conn) Abort(err error) { + if err == nil { + err = localTransportError(errNo) + } + c.sendMsg(func(now time.Time, c *Conn) { + c.abort(now, err) + }) +} + +// abort terminates a connection with an error. +func (c *Conn) abort(now time.Time, err error) { + if c.lifetime.localErr != nil { + return // already closing + } + c.lifetime.localErr = err +} + +// abortImmediately terminates a connection. +// The connection does not send a CONNECTION_CLOSE, and skips the draining period. +func (c *Conn) abortImmediately(now time.Time, err error) { + c.abort(now, err) + c.enterDraining(err) + c.exited = true +} + +// exit fully terminates a connection immediately. +func (c *Conn) exit() { + c.sendMsg(func(now time.Time, c *Conn) { + c.enterDraining(errors.New("connection closed")) + c.exited = true + }) +} diff --git a/internal/quic/conn_close_test.go b/internal/quic/conn_close_test.go new file mode 100644 index 000000000..20c00e754 --- /dev/null +++ b/internal/quic/conn_close_test.go @@ -0,0 +1,186 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "crypto/tls" + "errors" + "testing" + "time" +) + +func TestConnCloseResponseBackoff(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.handshake() + + tc.conn.Abort(nil) + tc.wantFrame("aborting connection generates CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + + waiting := runAsync(tc, func(ctx context.Context) (struct{}, error) { + return struct{}{}, tc.conn.Wait(ctx) + }) + if _, err := waiting.result(); err != errNotDone { + t.Errorf("conn.Wait() = %v, want still waiting", err) + } + + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantIdle("packets received immediately after CONN_CLOSE receive no response") + + tc.advance(1100 * time.Microsecond) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantFrame("receiving packet 1.1ms after CONN_CLOSE generates another CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + + tc.advance(1100 * time.Microsecond) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantIdle("no response to packet, because CONN_CLOSE backoff is now 2ms") + + tc.advance(1000 * time.Microsecond) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantFrame("2ms since last CONN_CLOSE, receiving a packet generates another CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + if _, err := waiting.result(); err != errNotDone { + t.Errorf("conn.Wait() = %v, want still waiting", err) + } + + tc.advance(100000 * time.Microsecond) + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.wantIdle("drain timer expired, no more responses") + + if _, err := waiting.result(); !errors.Is(err, errNoPeerResponse) { + t.Errorf("blocked conn.Wait() = %v, want errNoPeerResponse", err) + } + if err := tc.conn.Wait(canceledContext()); !errors.Is(err, errNoPeerResponse) { + t.Errorf("non-blocking conn.Wait() = %v, want errNoPeerResponse", err) + } +} + +func TestConnCloseWithPeerResponse(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.handshake() + + tc.conn.Abort(nil) + tc.wantFrame("aborting connection generates CONN_CLOSE", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errNo, + }) + + waiting := runAsync(tc, func(ctx context.Context) (struct{}, error) { + return struct{}{}, tc.conn.Wait(ctx) + }) + if _, err := waiting.result(); err != errNotDone { + t.Errorf("conn.Wait() = %v, want still waiting", err) + } + + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseApplication{ + code: 20, + }) + + wantErr := &ApplicationError{ + Code: 20, + } + if _, err := waiting.result(); !errors.Is(err, wantErr) { + t.Errorf("blocked conn.Wait() = %v, want %v", err, wantErr) + } + if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { + t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr) + } +} + +func TestConnClosePeerCloses(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.handshake() + + wantErr := &ApplicationError{ + Code: 42, + Reason: "why?", + } + tc.writeFrames(packetType1RTT, debugFrameConnectionCloseApplication{ + code: wantErr.Code, + reason: wantErr.Reason, + }) + tc.wantIdle("CONN_CLOSE response not sent until user closes this side") + + if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { + t.Errorf("conn.Wait() = %v, want %v", err, wantErr) + } + + tc.conn.Abort(&ApplicationError{ + Code: 9, + Reason: "because", + }) + tc.wantFrame("CONN_CLOSE sent after user closes connection", + packetType1RTT, debugFrameConnectionCloseApplication{ + code: 9, + reason: "because", + }) +} + +func TestConnCloseReceiveInInitial(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) + tc.wantIdle("CONN_CLOSE response not sent until user closes this side") + + wantErr := peerTransportError{code: errConnectionRefused} + if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { + t.Errorf("conn.Wait() = %v, want %v", err, wantErr) + } + + tc.conn.Abort(&ApplicationError{Code: 1}) + tc.wantFrame("CONN_CLOSE in Initial frame is APPLICATION_ERROR", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errApplicationError, + }) + tc.wantIdle("no more frames to send") +} + +func TestConnCloseReceiveInHandshake(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + tc.wantFrame("client sends Initial CRYPTO frame", + packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errConnectionRefused, + }) + tc.wantIdle("CONN_CLOSE response not sent until user closes this side") + + wantErr := peerTransportError{code: errConnectionRefused} + if err := tc.conn.Wait(canceledContext()); !errors.Is(err, wantErr) { + t.Errorf("conn.Wait() = %v, want %v", err, wantErr) + } + + // The conn has Initial and Handshake keys, so it will send CONN_CLOSE in both spaces. + tc.conn.Abort(&ApplicationError{Code: 1}) + tc.wantFrame("CONN_CLOSE in Initial frame is APPLICATION_ERROR", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errApplicationError, + }) + tc.wantFrame("CONN_CLOSE in Handshake frame is APPLICATION_ERROR", + packetTypeHandshake, debugFrameConnectionCloseTransport{ + code: errApplicationError, + }) + tc.wantIdle("no more frames to send") +} diff --git a/internal/quic/conn_flow.go b/internal/quic/conn_flow.go new file mode 100644 index 000000000..4f1ab6eaf --- /dev/null +++ b/internal/quic/conn_flow.go @@ -0,0 +1,141 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "sync/atomic" + "time" +) + +// connInflow tracks connection-level flow control for data sent by the peer to us. +// +// There are four byte offsets of significance in the stream of data received from the peer, +// each >= to the previous: +// +// - bytes read by the user +// - bytes received from the peer +// - limit sent to the peer in a MAX_DATA frame +// - potential new limit to sent to the peer +// +// We maintain a flow control window, so as bytes are read by the user +// the potential limit is extended correspondingly. +// +// We keep an atomic counter of bytes read by the user and not yet applied to the +// potential limit (credit). When this count grows large enough, we update the +// new limit to send and mark that we need to send a new MAX_DATA frame. +type connInflow struct { + sent sentVal // set when we need to send a MAX_DATA update to the peer + usedLimit int64 // total bytes sent by the peer, must be less than sentLimit + sentLimit int64 // last MAX_DATA sent to the peer + newLimit int64 // new MAX_DATA to send + + credit atomic.Int64 // bytes read but not yet applied to extending the flow-control window +} + +func (c *Conn) inflowInit() { + // The initial MAX_DATA limit is sent as a transport parameter. + c.streams.inflow.sentLimit = c.config.maxConnReadBufferSize() + c.streams.inflow.newLimit = c.streams.inflow.sentLimit +} + +// handleStreamBytesReadOffLoop records that the user has consumed bytes from a stream. +// We may extend the peer's flow control window. +// +// This is called indirectly by the user, via Read or CloseRead. +func (c *Conn) handleStreamBytesReadOffLoop(n int64) { + if n == 0 { + return + } + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + // We should send a MAX_DATA update to the peer. + // Record this on the Conn's main loop. + c.sendMsg(func(now time.Time, c *Conn) { + // A MAX_DATA update may have already happened, so check again. + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Load()) { + c.sendMaxDataUpdate() + } + }) + } +} + +// handleStreamBytesReadOnLoop extends the peer's flow control window after +// data has been discarded due to a RESET_STREAM frame. +// +// This is called on the conn's loop. +func (c *Conn) handleStreamBytesReadOnLoop(n int64) { + if c.shouldUpdateFlowControl(c.streams.inflow.credit.Add(n)) { + c.sendMaxDataUpdate() + } +} + +func (c *Conn) sendMaxDataUpdate() { + c.streams.inflow.sent.setUnsent() + // Apply current credit to the limit. + // We don't strictly need to do this here + // since appendMaxDataFrame will do so as well, + // but this avoids redundant trips down this path + // if the MAX_DATA frame doesn't go out right away. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) +} + +func (c *Conn) shouldUpdateFlowControl(credit int64) bool { + return shouldUpdateFlowControl(c.config.maxConnReadBufferSize(), credit) +} + +// handleStreamBytesReceived records that the peer has sent us stream data. +func (c *Conn) handleStreamBytesReceived(n int64) error { + c.streams.inflow.usedLimit += n + if c.streams.inflow.usedLimit > c.streams.inflow.sentLimit { + return localTransportError(errFlowControl) + } + return nil +} + +// appendMaxDataFrame appends a MAX_DATA frame to the current packet. +// +// It returns true if no more frames need appending, +// false if it could not fit a frame in the current packet. +func (c *Conn) appendMaxDataFrame(w *packetWriter, pnum packetNumber, pto bool) bool { + if c.streams.inflow.sent.shouldSendPTO(pto) { + // Add any unapplied credit to the new limit now. + c.streams.inflow.newLimit += c.streams.inflow.credit.Swap(0) + if !w.appendMaxDataFrame(c.streams.inflow.newLimit) { + return false + } + c.streams.inflow.sentLimit += c.streams.inflow.newLimit + c.streams.inflow.sent.setSent(pnum) + } + return true +} + +// ackOrLossMaxData records the fate of a MAX_DATA frame. +func (c *Conn) ackOrLossMaxData(pnum packetNumber, fate packetFate) { + c.streams.inflow.sent.ackLatestOrLoss(pnum, fate) +} + +// connOutflow tracks connection-level flow control for data sent by us to the peer. +type connOutflow struct { + max int64 // largest MAX_DATA received from peer + used int64 // total bytes of STREAM data sent to peer +} + +// setMaxData updates the connection-level flow control limit +// with the initial limit conveyed in transport parameters +// or an update from a MAX_DATA frame. +func (f *connOutflow) setMaxData(maxData int64) { + f.max = max(f.max, maxData) +} + +// avail returns the number of connection-level flow control bytes available. +func (f *connOutflow) avail() int64 { + return f.max - f.used +} + +// consume records consumption of n bytes of flow. +func (f *connOutflow) consume(n int64) { + f.used += n +} diff --git a/internal/quic/conn_flow_test.go b/internal/quic/conn_flow_test.go new file mode 100644 index 000000000..03e0757a6 --- /dev/null +++ b/internal/quic/conn_flow_test.go @@ -0,0 +1,430 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "testing" +) + +func TestConnInflowReturnOnRead(t *testing.T) { + ctx := canceledContext() + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { + c.MaxConnReadBufferSize = 64 + }) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + data: make([]byte, 64), + }) + const readSize = 8 + if n, err := s.ReadContext(ctx, make([]byte, readSize)); n != readSize || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, readSize) + } + tc.wantFrame("available window increases, send a MAX_DATA", + packetType1RTT, debugFrameMaxData{ + max: 64 + readSize, + }) + if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64-readSize || err != nil { + t.Fatalf("s.Read() = %v, %v; want %v, nil", n, err, 64-readSize) + } + tc.wantFrame("available window increases, send a MAX_DATA", + packetType1RTT, debugFrameMaxData{ + max: 128, + }) + // Peer can write up to the new limit. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 64, + data: make([]byte, 64), + }) + tc.wantIdle("connection is idle") + if n, err := s.ReadContext(ctx, make([]byte, 64)); n != 64 || err != nil { + t.Fatalf("offset 64: s.Read() = %v, %v; want %v, nil", n, err, 64) + } +} + +func TestConnInflowReturnOnRacingReads(t *testing.T) { + // Perform two reads at the same time, + // one for half of MaxConnReadBufferSize + // and one for one byte. + // + // We should observe a single MAX_DATA update. + // Depending on the ordering of events, + // this may include the credit from just the larger read + // or the credit from both. + ctx := canceledContext() + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxConnReadBufferSize = 64 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + data: make([]byte, 32), + }) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 1), + data: make([]byte, 32), + }) + s1, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("conn.AcceptStream() = %v", err) + } + s2, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("conn.AcceptStream() = %v", err) + } + read1 := runAsync(tc, func(ctx context.Context) (int, error) { + return s1.ReadContext(ctx, make([]byte, 16)) + }) + read2 := runAsync(tc, func(ctx context.Context) (int, error) { + return s2.ReadContext(ctx, make([]byte, 1)) + }) + // This MAX_DATA might extend the window by 16 or 17, depending on + // whether the second write occurs before the update happens. + tc.wantFrameType("MAX_DATA update is sent", + packetType1RTT, debugFrameMaxData{}) + tc.wantIdle("redundant MAX_DATA is not sent") + if _, err := read1.result(); err != nil { + t.Errorf("ReadContext #1 = %v", err) + } + if _, err := read2.result(); err != nil { + t.Errorf("ReadContext #2 = %v", err) + } +} + +func TestConnInflowReturnOnClose(t *testing.T) { + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { + c.MaxConnReadBufferSize = 64 + }) + tc.ignoreFrame(frameTypeStopSending) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + data: make([]byte, 64), + }) + s.CloseRead() + tc.wantFrame("closing stream updates connection-level flow control", + packetType1RTT, debugFrameMaxData{ + max: 128, + }) +} + +func TestConnInflowReturnOnReset(t *testing.T) { + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { + c.MaxConnReadBufferSize = 64 + }) + tc.ignoreFrame(frameTypeStopSending) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + data: make([]byte, 32), + }) + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: s.id, + finalSize: 64, + }) + s.CloseRead() + tc.wantFrame("receiving stream reseet updates connection-level flow control", + packetType1RTT, debugFrameMaxData{ + max: 128, + }) +} + +func TestConnInflowStreamViolation(t *testing.T) { + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxConnReadBufferSize = 100 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + // Total MAX_DATA consumed: 50 + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 0), + data: make([]byte, 50), + }) + // Total MAX_DATA consumed: 80 + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + off: 20, + data: make([]byte, 10), + }) + // Total MAX_DATA consumed: 100 + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 0), + off: 70, + fin: true, + }) + // This stream has already consumed quota for these bytes. + // Total MAX_DATA consumed: 100 + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + data: make([]byte, 20), + }) + tc.wantIdle("peer has consumed all MAX_DATA quota") + + // Total MAX_DATA consumed: 101 + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 2), + data: make([]byte, 1), + }) + tc.wantFrame("peer violates MAX_DATA limit", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errFlowControl, + }) +} + +func TestConnInflowResetViolation(t *testing.T) { + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxConnReadBufferSize = 100 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, bidiStream, 0), + data: make([]byte, 100), + }) + tc.wantIdle("peer has consumed all MAX_DATA quota") + + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: newStreamID(clientSide, uniStream, 0), + finalSize: 0, + }) + tc.wantIdle("stream reset does not consume MAX_DATA quota, no error") + + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: newStreamID(clientSide, uniStream, 1), + finalSize: 1, + }) + tc.wantFrame("RESET_STREAM final size violates MAX_DATA limit", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errFlowControl, + }) +} + +func TestConnInflowMultipleStreams(t *testing.T) { + ctx := canceledContext() + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxConnReadBufferSize = 128 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + var streams []*Stream + for _, id := range []streamID{ + newStreamID(clientSide, uniStream, 0), + newStreamID(clientSide, uniStream, 1), + newStreamID(clientSide, bidiStream, 0), + newStreamID(clientSide, bidiStream, 1), + } { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: id, + data: make([]byte, 32), + }) + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("AcceptStream() = %v", err) + } + streams = append(streams, s) + if n, err := s.ReadContext(ctx, make([]byte, 1)); err != nil || n != 1 { + t.Fatalf("s.Read() = %v, %v; want 1, nil", n, err) + } + } + tc.wantIdle("streams have read data, but not enough to update MAX_DATA") + + if n, err := streams[0].ReadContext(ctx, make([]byte, 32)); err != nil || n != 31 { + t.Fatalf("s.Read() = %v, %v; want 31, nil", n, err) + } + tc.wantFrame("read enough data to trigger a MAX_DATA update", + packetType1RTT, debugFrameMaxData{ + max: 128 + 32 + 1 + 1 + 1, + }) + + tc.ignoreFrame(frameTypeStopSending) + streams[2].CloseRead() + tc.wantFrame("closed stream triggers another MAX_DATA update", + packetType1RTT, debugFrameMaxData{ + max: 128 + 32 + 1 + 32 + 1, + }) +} + +func TestConnOutflowBlocked(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxData = 10 + }) + tc.ignoreFrame(frameTypeAck) + + data := makeTestData(32) + n, err := s.Write(data) + if n != len(data) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) + } + + tc.wantFrame("stream writes data up to MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s.id, + data: data[:10], + }) + tc.wantIdle("stream is blocked by MAX_DATA limit") + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 20, + }) + tc.wantFrame("stream writes data up to new MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 10, + data: data[10:20], + }) + tc.wantIdle("stream is blocked by new MAX_DATA limit") + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 100, + }) + tc.wantFrame("stream writes remaining data", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 20, + data: data[20:], + }) +} + +func TestConnOutflowMaxDataDecreases(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxData = 10 + }) + tc.ignoreFrame(frameTypeAck) + + // Decrease in MAX_DATA is ignored. + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 5, + }) + + data := makeTestData(32) + n, err := s.Write(data) + if n != len(data) || err != nil { + t.Fatalf("s.Write() = %v, %v; want %v, nil", n, err, len(data)) + } + + tc.wantFrame("stream writes data up to MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s.id, + data: data[:10], + }) +} + +func TestConnOutflowMaxDataRoundRobin(t *testing.T) { + ctx := canceledContext() + tc := newTestConn(t, clientSide, permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxData = 0 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + s1, err := tc.conn.newLocalStream(ctx, uniStream) + if err != nil { + t.Fatalf("conn.newLocalStream(%v) = %v", uniStream, err) + } + s2, err := tc.conn.newLocalStream(ctx, uniStream) + if err != nil { + t.Fatalf("conn.newLocalStream(%v) = %v", uniStream, err) + } + + s1.Write(make([]byte, 10)) + s2.Write(make([]byte, 10)) + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 1, + }) + tc.wantFrame("stream 1 writes data up to MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s1.id, + data: []byte{0}, + }) + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 2, + }) + tc.wantFrame("stream 2 writes data up to MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s2.id, + data: []byte{0}, + }) + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 3, + }) + tc.wantFrame("stream 1 writes data up to MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s1.id, + off: 1, + data: []byte{0}, + }) +} + +func TestConnOutflowMetaAndData(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxData = 0 + }) + tc.ignoreFrame(frameTypeAck) + + data := makeTestData(32) + s.Write(data) + + s.CloseRead() + tc.wantFrame("CloseRead sends a STOP_SENDING, not flow controlled", + packetType1RTT, debugFrameStopSending{ + id: s.id, + }) + + tc.writeFrames(packetType1RTT, debugFrameMaxData{ + max: 100, + }) + tc.wantFrame("unblocked MAX_DATA", + packetType1RTT, debugFrameStream{ + id: s.id, + data: data, + }) +} + +func TestConnOutflowResentData(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxData = 10 + }) + tc.ignoreFrame(frameTypeAck) + + data := makeTestData(15) + s.Write(data[:8]) + tc.wantFrame("data is under MAX_DATA limit, all sent", + packetType1RTT, debugFrameStream{ + id: s.id, + data: data[:8], + }) + + // Lose the last STREAM packet. + const pto = false + tc.triggerLossOrPTO(packetType1RTT, false) + tc.wantFrame("lost STREAM data is retransmitted", + packetType1RTT, debugFrameStream{ + id: s.id, + data: data[:8], + }) + + s.Write(data[8:]) + tc.wantFrame("new data is sent up to the MAX_DATA limit", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 8, + data: data[8:10], + }) +} diff --git a/internal/quic/conn_id.go b/internal/quic/conn_id.go index 561dea2c1..045e646ac 100644 --- a/internal/quic/conn_id.go +++ b/internal/quic/conn_id.go @@ -55,10 +55,10 @@ type connID struct { send sentVal } -func (s *connIDState) initClient(newID newConnIDFunc) error { +func (s *connIDState) initClient(c *Conn) error { // Client chooses its initial connection ID, and sends it // in the Source Connection ID field of the first Initial packet. - locid, err := newID(0) + locid, err := c.newConnID(0) if err != nil { return err } @@ -70,7 +70,7 @@ func (s *connIDState) initClient(newID newConnIDFunc) error { // Client chooses an initial, transient connection ID for the server, // and sends it in the Destination Connection ID field of the first Initial packet. - remid, err := newID(-1) + remid, err := c.newConnID(-1) if err != nil { return err } @@ -78,10 +78,12 @@ func (s *connIDState) initClient(newID newConnIDFunc) error { seq: -1, cid: remid, }) + const retired = false + c.listener.connIDsChanged(c, retired, s.local[:]) return nil } -func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { +func (s *connIDState) initServer(c *Conn, dstConnID []byte) error { // Client-chosen, transient connection ID received in the first Initial packet. // The server will not use this as the Source Connection ID of packets it sends, // but remembers it because it may receive packets sent to this destination. @@ -92,7 +94,7 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { // Server chooses a connection ID, and sends it in the Source Connection ID of // the response to the clent. - locid, err := newID(0) + locid, err := c.newConnID(0) if err != nil { return err } @@ -101,6 +103,8 @@ func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error { cid: locid, }) s.nextLocalSeq = 1 + const retired = false + c.listener.connIDsChanged(c, retired, s.local[:]) return nil } @@ -125,20 +129,21 @@ func (s *connIDState) dstConnID() (cid []byte, ok bool) { // setPeerActiveConnIDLimit sets the active_connection_id_limit // transport parameter received from the peer. -func (s *connIDState) setPeerActiveConnIDLimit(lim int64, newID newConnIDFunc) error { +func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error { s.peerActiveConnIDLimit = lim - return s.issueLocalIDs(newID) + return s.issueLocalIDs(c) } -func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error { +func (s *connIDState) issueLocalIDs(c *Conn) error { toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) for i := range s.local { if s.local[i].seq != -1 && !s.local[i].retired { toIssue-- } } + prev := len(s.local) for toIssue > 0 { - cid, err := newID(s.nextLocalSeq) + cid, err := c.newConnID(s.nextLocalSeq) if err != nil { return err } @@ -151,24 +156,62 @@ func (s *connIDState) issueLocalIDs(newID newConnIDFunc) error { s.needSend = true toIssue-- } + const retired = false + c.listener.connIDsChanged(c, retired, s.local[prev:]) + return nil +} + +// validateTransportParameters verifies the original_destination_connection_id and +// initial_source_connection_id transport parameters match the expected values. +func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error { + // TODO: Consider returning more detailed errors, for debugging. + switch side { + case clientSide: + // Verify original_destination_connection_id matches + // the transient remote connection ID we chose. + if len(s.remote) == 0 || s.remote[0].seq != -1 { + return localTransportError(errInternal) + } + if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) { + return localTransportError(errTransportParameter) + } + // Remove the transient remote connection ID. + // We have no further need for it. + s.remote = append(s.remote[:0], s.remote[1:]...) + case serverSide: + if p.originalDstConnID != nil { + // Clients do not send original_destination_connection_id. + return localTransportError(errTransportParameter) + } + } + // Verify initial_source_connection_id matches the first remote connection ID. + if len(s.remote) == 0 || s.remote[0].seq != 0 { + return localTransportError(errInternal) + } + if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) { + return localTransportError(errTransportParameter) + } return nil } // handlePacket updates the connection ID state during the handshake // (Initial and Handshake packets). -func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) { +func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) { switch { - case ptype == packetTypeInitial && side == clientSide: + case ptype == packetTypeInitial && c.side == clientSide: if len(s.remote) == 1 && s.remote[0].seq == -1 { // We're a client connection processing the first Initial packet // from the server. Replace the transient remote connection ID // with the Source Connection ID from the packet. - s.remote[0] = connID{ + // Leave the transient ID the list for now, since we'll need it when + // processing the transport parameters. + s.remote[0].retired = true + s.remote = append(s.remote, connID{ seq: 0, cid: cloneBytes(srcConnID), - } + }) } - case ptype == packetTypeInitial && side == serverSide: + case ptype == packetTypeInitial && c.side == serverSide: if len(s.remote) == 0 { // We're a server connection processing the first Initial packet // from the client. Set the client's connection ID. @@ -177,11 +220,13 @@ func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID [] cid: cloneBytes(srcConnID), }) } - case ptype == packetTypeHandshake && side == serverSide: - if len(s.local) > 0 && s.local[0].seq == -1 { + case ptype == packetTypeHandshake && c.side == serverSide: + if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired { // We're a server connection processing the first Handshake packet from // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. + const retired = true + c.listener.connIDsChanged(c, retired, s.local[0:1]) s.local = append(s.local[:0], s.local[1:]...) } } @@ -204,7 +249,7 @@ func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken active := 0 for i := range s.remote { rcid := &s.remote[i] - if !rcid.retired && rcid.seq < s.retireRemotePriorTo { + if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { s.retireRemote(rcid) } if !rcid.retired { @@ -263,17 +308,19 @@ func (s *connIDState) retireRemote(rcid *connID) { s.needSend = true } -func (s *connIDState) handleRetireConnID(seq int64, newID newConnIDFunc) error { +func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { if seq >= s.nextLocalSeq { return localTransportError(errProtocolViolation) } for i := range s.local { if s.local[i].seq == seq { + const retired = true + c.listener.connIDsChanged(c, retired, s.local[i:i+1]) s.local = append(s.local[:i], s.local[i+1:]...) break } } - s.issueLocalIDs(newID) + s.issueLocalIDs(c) return nil } @@ -355,7 +402,12 @@ func cloneBytes(b []byte) []byte { return n } -type newConnIDFunc func(seq int64) ([]byte, error) +func (c *Conn) newConnID(seq int64) ([]byte, error) { + if c.testHooks != nil { + return c.testHooks.newConnID(seq) + } + return newRandomConnID(seq) +} func newRandomConnID(_ int64) ([]byte, error) { // It is not necessary for connection IDs to be cryptographically secure, diff --git a/internal/quic/conn_id_test.go b/internal/quic/conn_id_test.go index d479cd4a8..44755ecf4 100644 --- a/internal/quic/conn_id_test.go +++ b/internal/quic/conn_id_test.go @@ -11,100 +11,138 @@ import ( "crypto/tls" "fmt" "net/netip" - "reflect" + "strings" "testing" ) func TestConnIDClientHandshake(t *testing.T) { + tc := newTestConn(t, clientSide) // On initialization, the client chooses local and remote IDs. // // The order in which we allocate the two isn't actually important, // but test is a lot simpler if we assume. - var s connIDState - s.initClient(newConnIDSequence()) - if got, want := string(s.srcConnID()), "local-1"; got != want { - t.Errorf("after initClient: srcConnID = %q, want %q", got, want) + if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) { + t.Errorf("after initialization: srcConnID = %x, want %x", got, want) } - dstConnID, _ := s.dstConnID() - if got, want := string(dstConnID), "local-2"; got != want { - t.Errorf("after initClient: dstConnID = %q, want %q", got, want) + dstConnID, _ := tc.conn.connIDState.dstConnID() + if got, want := dstConnID, testLocalConnID(-1); !bytes.Equal(got, want) { + t.Errorf("after initialization: dstConnID = %x, want %x", got, want) } // The server's first Initial packet provides the client with a // non-transient remote connection ID. - s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1")) - dstConnID, _ = s.dstConnID() - if got, want := string(dstConnID), "remote-1"; got != want { - t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + dstConnID, _ = tc.conn.connIDState.dstConnID() + if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) { + t.Errorf("after receiving Initial: dstConnID = %x, want %x", got, want) } wantLocal := []connID{{ - cid: []byte("local-1"), + cid: testLocalConnID(0), seq: 0, }} - if !reflect.DeepEqual(s.local, wantLocal) { - t.Errorf("local ids: %v, want %v", s.local, wantLocal) + if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) { + t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal)) } wantRemote := []connID{{ - cid: []byte("remote-1"), + cid: testLocalConnID(-1), + seq: -1, + }, { + cid: testPeerConnID(0), seq: 0, }} - if !reflect.DeepEqual(s.remote, wantRemote) { - t.Errorf("remote ids: %v, want %v", s.remote, wantRemote) + if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) { + t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote)) } } func TestConnIDServerHandshake(t *testing.T) { + tc := newTestConn(t, serverSide) // On initialization, the server is provided with the client-chosen // transient connection ID, and allocates an ID of its own. // The Initial packet sets the remote connection ID. - var s connIDState - s.initServer(newConnIDSequence(), []byte("transient")) - s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1")) - if got, want := string(s.srcConnID()), "local-1"; got != want { + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][:1], + }) + if got, want := tc.conn.connIDState.srcConnID(), testLocalConnID(0); !bytes.Equal(got, want) { t.Errorf("after initClient: srcConnID = %q, want %q", got, want) } - dstConnID, _ := s.dstConnID() - if got, want := string(dstConnID), "remote-1"; got != want { + dstConnID, _ := tc.conn.connIDState.dstConnID() + if got, want := dstConnID, testPeerConnID(0); !bytes.Equal(got, want) { t.Errorf("after initClient: dstConnID = %q, want %q", got, want) } + // The Initial flight of CRYPTO data includes transport parameters, + // which cause us to allocate another local connection ID. + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + off: 1, + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial][1:], + }) wantLocal := []connID{{ - cid: []byte("transient"), + cid: testPeerConnID(-1), seq: -1, }, { - cid: []byte("local-1"), + cid: testLocalConnID(0), seq: 0, + }, { + cid: testLocalConnID(1), + seq: 1, }} - if !reflect.DeepEqual(s.local, wantLocal) { - t.Errorf("local ids: %v, want %v", s.local, wantLocal) + if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) { + t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal)) } wantRemote := []connID{{ - cid: []byte("remote-1"), + cid: testPeerConnID(0), seq: 0, }} - if !reflect.DeepEqual(s.remote, wantRemote) { - t.Errorf("remote ids: %v, want %v", s.remote, wantRemote) + if got := tc.conn.connIDState.remote; !connIDListEqual(got, wantRemote) { + t.Errorf("remote ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantRemote)) } // The client's first Handshake packet permits the server to discard the // transient connection ID. - s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1")) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) wantLocal = []connID{{ - cid: []byte("local-1"), + cid: testLocalConnID(0), seq: 0, + }, { + cid: testLocalConnID(1), + seq: 1, }} - if !reflect.DeepEqual(s.local, wantLocal) { - t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal) + if got := tc.conn.connIDState.local; !connIDListEqual(got, wantLocal) { + t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal)) + } +} + +func connIDListEqual(a, b []connID) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].seq != b[i].seq { + return false + } + if !bytes.Equal(a[i].cid, b[i].cid) { + return false + } } + return true } -func newConnIDSequence() newConnIDFunc { - var n uint64 - return func(_ int64) ([]byte, error) { - n++ - return []byte(fmt.Sprintf("local-%v", n)), nil +func fmtConnIDList(s []connID) string { + var strs []string + for _, cid := range s { + strs = append(strs, fmt.Sprintf("[seq:%v cid:{%x}]", cid.seq, cid.cid)) } + return "{" + strings.Join(strs, " ") + "}" } func TestNewRandomConnID(t *testing.T) { @@ -226,10 +264,12 @@ func TestConnIDPeerRetiresConnID(t *testing.T) { } func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { - // An endpoint that selects a zero-length connection ID during the handshake + // "An endpoint that selects a zero-length connection ID during the handshake // cannot issue a new connection ID." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8 - tc := newTestConn(t, clientSide) + tc := newTestConn(t, clientSide, func(p *transportParameters) { + p.initialSrcConnID = []byte{} + }) tc.peerConnID = []byte{} tc.ignoreFrame(frameTypeAck) tc.uncheckedHandshake() @@ -501,6 +541,7 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { // Peer gives us more conn ids than our advertised limit, // including a conn id in the preferred address transport parameter. tc := newTestConn(t, serverSide, func(p *transportParameters) { + p.initialSrcConnID = []byte{} p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0") p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0") p.preferredAddrConnID = testPeerConnID(1) @@ -517,3 +558,31 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { code: errProtocolViolation, }) } + +func TestConnIDInitialSrcConnIDMismatch(t *testing.T) { + // "Endpoints MUST validate that received [initial_source_connection_id] + // parameters match received connection ID values." + // https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3 + testSides(t, "", func(t *testing.T, side connSide) { + tc := newTestConn(t, side, func(p *transportParameters) { + p.initialSrcConnID = []byte("invalid") + }) + tc.ignoreFrame(frameTypeAck) + tc.ignoreFrame(frameTypeCrypto) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + if side == clientSide { + // Server transport parameters are carried in the Handshake packet. + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + } + tc.wantFrame("initial_source_connection_id transport parameter mismatch", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errTransportParameter, + }) + }) +} diff --git a/internal/quic/conn_loss.go b/internal/quic/conn_loss.go index 103db9fa4..85bda314e 100644 --- a/internal/quic/conn_loss.go +++ b/internal/quic/conn_loss.go @@ -44,6 +44,8 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF case frameTypeCrypto: start, end := sent.nextRange() c.crypto[space].ackOrLoss(start, end, fate) + case frameTypeMaxData: + c.ackOrLossMaxData(sent.num, fate) case frameTypeResetStream, frameTypeStopSending, frameTypeMaxStreamData, @@ -64,6 +66,10 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF } fin := f&streamFinBit != 0 s.ackOrLossData(sent.num, start, end, fin, fate) + case frameTypeMaxStreamsBidi: + c.streams.remoteLimit[bidiStream].sendMax.ackLatestOrLoss(sent.num, fate) + case frameTypeMaxStreamsUni: + c.streams.remoteLimit[uniStream].sendMax.ackLatestOrLoss(sent.num, fate) case frameTypeNewConnectionID: seq := int64(sent.nextInt()) c.connIDState.ackOrLossNewConnectionID(sent.num, seq, fate) diff --git a/internal/quic/conn_loss_test.go b/internal/quic/conn_loss_test.go index bb4303033..9b8846251 100644 --- a/internal/quic/conn_loss_test.go +++ b/internal/quic/conn_loss_test.go @@ -174,9 +174,7 @@ func TestLostStreamFrameEmpty(t *testing.T) { // be retransmitted if lost. lostFrameTest(t, func(t *testing.T, pto bool) { ctx := canceledContext() - tc := newTestConn(t, clientSide, func(p *transportParameters) { - p.initialMaxStreamDataBidiRemote = 100 - }) + tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -291,18 +289,58 @@ func TestLostStreamPartialLoss(t *testing.T) { tc.wantIdle("no more frames sent after packet loss") } +func TestLostMaxDataFrame(t *testing.T) { + // "An updated value is sent in a MAX_DATA frame if the packet + // containing the most recently sent MAX_DATA frame is declared lost [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.7 + lostFrameTest(t, func(t *testing.T, pto bool) { + const maxWindowSize = 32 + buf := make([]byte, maxWindowSize) + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { + c.MaxConnReadBufferSize = 32 + }) + + // We send MAX_DATA = 63. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + off: 0, + data: make([]byte, maxWindowSize), + }) + if n, err := s.Read(buf[:maxWindowSize-1]); err != nil || n != maxWindowSize-1 { + t.Fatalf("Read() = %v, %v; want %v, nil", n, err, maxWindowSize-1) + } + tc.wantFrame("conn window is extended after reading data", + packetType1RTT, debugFrameMaxData{ + max: (maxWindowSize * 2) - 1, + }) + + // MAX_DATA = 64, which is only one more byte, so we don't send the frame. + if n, err := s.Read(buf); err != nil || n != 1 { + t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) + } + tc.wantIdle("read doesn't extend window enough to send another MAX_DATA") + + // The MAX_DATA = 63 packet was lost, so we send 64. + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantFrame("resent MAX_DATA includes most current value", + packetType1RTT, debugFrameMaxData{ + max: maxWindowSize * 2, + }) + }) +} + func TestLostMaxStreamDataFrame(t *testing.T) { // "[...] an updated value is sent when the packet containing // the most recent MAX_STREAM_DATA frame for a stream is lost" // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.8 lostFrameTest(t, func(t *testing.T, pto bool) { - const maxWindowSize = 10 + const maxWindowSize = 32 buf := make([]byte, maxWindowSize) tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { - c.StreamReadBufferSize = maxWindowSize + c.MaxStreamReadBufferSize = maxWindowSize }) - // We send MAX_STREAM_DATA = 19. + // We send MAX_STREAM_DATA = 63. tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, off: 0, @@ -317,13 +355,13 @@ func TestLostMaxStreamDataFrame(t *testing.T) { max: (maxWindowSize * 2) - 1, }) - // MAX_STREAM_DATA = 20, which is only one more byte, so we don't send the frame. + // MAX_STREAM_DATA = 64, which is only one more byte, so we don't send the frame. if n, err := s.Read(buf); err != nil || n != 1 { t.Fatalf("Read() = %v, %v; want %v, nil", n, err, 1) } tc.wantIdle("read doesn't extend window enough to send another MAX_STREAM_DATA") - // The MAX_STREAM_DATA = 19 packet was lost, so we send 20. + // The MAX_STREAM_DATA = 63 packet was lost, so we send 64. tc.triggerLossOrPTO(packetType1RTT, pto) tc.wantFrame("resent MAX_STREAM_DATA includes most current value", packetType1RTT, debugFrameMaxStreamData{ @@ -341,7 +379,7 @@ func TestLostMaxStreamDataFrameAfterStreamFinReceived(t *testing.T) { const maxWindowSize = 10 buf := make([]byte, maxWindowSize) tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { - c.StreamReadBufferSize = maxWindowSize + c.MaxStreamReadBufferSize = maxWindowSize }) tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -370,6 +408,95 @@ func TestLostMaxStreamDataFrameAfterStreamFinReceived(t *testing.T) { }) } +func TestLostMaxStreamsFrameMostRecent(t *testing.T) { + // "[...] an updated value is sent when a packet containing the + // most recent MAX_STREAMS for a stream type frame is declared lost [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.9 + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + lostFrameTest(t, func(t *testing.T, pto bool) { + ctx := canceledContext() + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxUniRemoteStreams = 1 + c.MaxBidiRemoteStreams = 1 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, 0), + fin: true, + }) + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("AcceptStream() = %v", err) + } + s.CloseContext(ctx) + if styp == bidiStream { + tc.wantFrame("stream is closed", + packetType1RTT, debugFrameStream{ + id: s.id, + data: []byte{}, + fin: true, + }) + tc.writeAckForAll() + } + tc.wantFrame("closing stream updates peer's MAX_STREAMS", + packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 2, + }) + + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantFrame("lost MAX_STREAMS is resent", + packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 2, + }) + }) + }) +} + +func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) { + // Send two MAX_STREAMS frames, lose the first one. + // + // No PTO mode for this test: The ack that causes the first frame + // to be lost arms the loss timer for the second, so the PTO timer is not armed. + const pto = false + ctx := canceledContext() + tc := newTestConn(t, serverSide, func(c *Config) { + c.MaxUniRemoteStreams = 2 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + for i := int64(0); i < 2; i++ { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, i), + fin: true, + }) + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("AcceptStream() = %v", err) + } + if err := s.CloseContext(ctx); err != nil { + t.Fatalf("stream.Close() = %v", err) + } + tc.wantFrame("closing stream updates peer's MAX_STREAMS", + packetType1RTT, debugFrameMaxStreams{ + streamType: uniStream, + max: 3 + i, + }) + } + + // The second MAX_STREAMS frame is acked. + tc.writeAckForLatest() + + // The first MAX_STREAMS frame is lost. + tc.conn.ping(appDataSpace) + tc.wantFrame("connection should send a PING frame", + packetType1RTT, debugFramePing{}) + tc.triggerLossOrPTO(packetType1RTT, pto) + tc.wantIdle("superseded MAX_DATA is not resent on loss") +} + func TestLostStreamDataBlockedFrame(t *testing.T) { // "A new [STREAM_DATA_BLOCKED] frame is sent if a packet containing // the most recent frame for a scope is lost [...]" diff --git a/internal/quic/conn_recv.go b/internal/quic/conn_recv.go index e0a91ab00..9b1ba1ae1 100644 --- a/internal/quic/conn_recv.go +++ b/internal/quic/conn_recv.go @@ -7,12 +7,18 @@ package quic import ( + "bytes" + "encoding/binary" + "errors" "time" ) func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { buf := dgram.b c.loss.datagramReceived(now, len(buf)) + if c.isDraining() { + return + } for len(buf) > 0 { var n int ptype := getPacketType(buf) @@ -23,11 +29,14 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { // https://www.rfc-editor.org/rfc/rfc9000#section-14.1-4 return } - n = c.handleLongHeader(now, ptype, initialSpace, buf) + n = c.handleLongHeader(now, ptype, initialSpace, c.keysInitial.r, buf) case packetTypeHandshake: - n = c.handleLongHeader(now, ptype, handshakeSpace, buf) + n = c.handleLongHeader(now, ptype, handshakeSpace, c.keysHandshake.r, buf) case packetType1RTT: n = c.handle1RTT(now, buf) + case packetTypeVersionNegotiation: + c.handleVersionNegotiation(now, buf) + return default: return } @@ -40,21 +49,27 @@ func (c *Conn) handleDatagram(now time.Time, dgram *datagram) { } } -func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, buf []byte) int { - if !c.rkeys[space].isSet() { +func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpace, k fixedKeys, buf []byte) int { + if !k.isSet() { return skipLongHeaderPacket(buf) } pnumMax := c.acks[space].largestSeen() - p, n := parseLongHeaderPacket(buf, c.rkeys[space], pnumMax) + p, n := parseLongHeaderPacket(buf, k, pnumMax) if n < 0 { return -1 } - if p.reservedBits != 0 { + if buf[0]&reservedLongBits != 0 { + // Reserved header bits must be 0. // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 c.abort(now, localTransportError(errProtocolViolation)) return -1 } + if p.version != quicVersion1 { + // The peer has changed versions on us mid-handshake? + c.abort(now, localTransportError(errProtocolViolation)) + return -1 + } if !c.acks[space].shouldProcess(p.num) { return n @@ -63,7 +78,7 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa if logPackets { logInboundLongPacket(c, p) } - c.connIDState.handlePacket(c.side, p.ptype, p.srcConnID) + c.connIDState.handlePacket(c, p.ptype, p.srcConnID) ackEliciting := c.handleFrames(now, ptype, space, p.payload) c.acks[space].receive(now, space, p.num, ackEliciting) if p.ptype == packetTypeHandshake && c.side == serverSide { @@ -78,18 +93,24 @@ func (c *Conn) handleLongHeader(now time.Time, ptype packetType, space numberSpa } func (c *Conn) handle1RTT(now time.Time, buf []byte) int { - if !c.rkeys[appDataSpace].isSet() { + if !c.keysAppData.canRead() { // 1-RTT packets extend to the end of the datagram, // so skip the remainder of the datagram if we can't parse this. return len(buf) } pnumMax := c.acks[appDataSpace].largestSeen() - p, n := parse1RTTPacket(buf, c.rkeys[appDataSpace], connIDLen, pnumMax) - if n < 0 { + p, err := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax) + if err != nil { + // A localTransportError terminates the connection. + // Other errors indicate an unparseable packet, but otherwise may be ignored. + if _, ok := err.(localTransportError); ok { + c.abort(now, err) + } return -1 } - if p.reservedBits != 0 { + if buf[0]&reserved1RTTBits != 0 { + // Reserved header bits must be 0. // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 c.abort(now, localTransportError(errProtocolViolation)) return -1 @@ -107,6 +128,42 @@ func (c *Conn) handle1RTT(now time.Time, buf []byte) int { return len(buf) } +var errVersionNegotiation = errors.New("server does not support QUIC version 1") + +func (c *Conn) handleVersionNegotiation(now time.Time, pkt []byte) { + if c.side != clientSide { + return // servers don't handle Version Negotiation packets + } + // "A client MUST discard any Version Negotiation packet if it has + // received and successfully processed any other packet [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + if !c.keysInitial.canRead() { + return // discarded Initial keys, connection is already established + } + if c.acks[initialSpace].seen.numRanges() != 0 { + return // processed at least one packet + } + _, srcConnID, versions := parseVersionNegotiation(pkt) + if len(c.connIDState.remote) < 1 || !bytes.Equal(c.connIDState.remote[0].cid, srcConnID) { + return // Source Connection ID doesn't match what we sent + } + for len(versions) >= 4 { + ver := binary.BigEndian.Uint32(versions) + if ver == 1 { + // "A client MUST discard a Version Negotiation packet that lists + // the QUIC version selected by the client." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + return + } + versions = versions[4:] + } + // "A client that supports only this version of QUIC MUST + // abandon the current connection attempt if it receives + // a Version Negotiation packet, [with the two exceptions handled above]." + // https://www.rfc-editor.org/rfc/rfc9000#section-6.2-2 + c.abortImmediately(now, errVersionNegotiation) +} + func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, payload []byte) (ackEliciting bool) { if len(payload) == 0 { // "An endpoint MUST treat receipt of a packet containing no frames @@ -186,7 +243,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, if !frameOK(c, ptype, __01) { return } - _, n = consumeMaxDataFrame(payload) + n = c.handleMaxDataFrame(now, payload) case frameTypeMaxStreamData: if !frameOK(c, ptype, __01) { return @@ -196,7 +253,12 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, if !frameOK(c, ptype, __01) { return } - _, _, n = consumeMaxStreamsFrame(payload) + n = c.handleMaxStreamsFrame(now, payload) + case frameTypeDataBlocked: + if !frameOK(c, ptype, __01) { + return + } + _, n = consumeDataBlockedFrame(payload) case frameTypeStreamsBlockedBidi, frameTypeStreamsBlockedUni: if !frameOK(c, ptype, __01) { return @@ -218,15 +280,13 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, } n = c.handleRetireConnectionIDFrame(now, space, payload) case frameTypeConnectionCloseTransport: - // CONNECTION_CLOSE is OK in all spaces. - _, _, _, n = consumeConnectionCloseTransportFrame(payload) - // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 - c.abort(now, localTransportError(errNo)) + // Transport CONNECTION_CLOSE is OK in all spaces. + n = c.handleConnectionCloseTransportFrame(now, payload) case frameTypeConnectionCloseApplication: - // CONNECTION_CLOSE is OK in all spaces. - _, _, n = consumeConnectionCloseApplicationFrame(payload) - // TODO: https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2.2 - c.abort(now, localTransportError(errNo)) + if !frameOK(c, ptype, __01) { + return + } + n = c.handleConnectionCloseApplicationFrame(now, payload) case frameTypeHandshakeDone: if !frameOK(c, ptype, ___1) { return @@ -244,7 +304,7 @@ func (c *Conn) handleFrames(now time.Time, ptype packetType, space numberSpace, func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int { c.loss.receiveAckStart() - _, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { + largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { if end > c.loss.nextNumber(space) { // Acknowledgement of a packet we never sent. c.abort(now, localTransportError(errProtocolViolation)) @@ -277,11 +337,26 @@ func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) delay = ackDelay.Duration(uint8(c.peerAckDelayExponent)) } c.loss.receiveAckEnd(now, space, delay, c.handleAckOrLoss) + if space == appDataSpace { + c.keysAppData.handleAckFor(largest) + } + return n +} + +func (c *Conn) handleMaxDataFrame(now time.Time, payload []byte) int { + maxData, n := consumeMaxDataFrame(payload) + if n < 0 { + return -1 + } + c.streams.outflow.setMaxData(maxData) return n } func (c *Conn) handleMaxStreamDataFrame(now time.Time, payload []byte) int { id, maxStreamData, n := consumeMaxStreamDataFrame(payload) + if n < 0 { + return -1 + } if s := c.streamForFrame(now, id, sendStream); s != nil { if err := s.handleMaxStreamData(maxStreamData); err != nil { c.abort(now, err) @@ -291,6 +366,15 @@ func (c *Conn) handleMaxStreamDataFrame(now time.Time, payload []byte) int { return n } +func (c *Conn) handleMaxStreamsFrame(now time.Time, payload []byte) int { + styp, max, n := consumeMaxStreamsFrame(payload) + if n < 0 { + return -1 + } + c.streams.localLimit[styp].setMax(max) + return n +} + func (c *Conn) handleResetStreamFrame(now time.Time, space numberSpace, payload []byte) int { id, code, finalSize, n := consumeResetStreamFrame(payload) if n < 0 { @@ -356,12 +440,30 @@ func (c *Conn) handleRetireConnectionIDFrame(now time.Time, space numberSpace, p if n < 0 { return -1 } - if err := c.connIDState.handleRetireConnID(seq, c.newConnIDFunc()); err != nil { + if err := c.connIDState.handleRetireConnID(c, seq); err != nil { c.abort(now, err) } return n } +func (c *Conn) handleConnectionCloseTransportFrame(now time.Time, payload []byte) int { + code, _, reason, n := consumeConnectionCloseTransportFrame(payload) + if n < 0 { + return -1 + } + c.enterDraining(peerTransportError{code: code, reason: reason}) + return n +} + +func (c *Conn) handleConnectionCloseApplicationFrame(now time.Time, payload []byte) int { + code, reason, n := consumeConnectionCloseApplicationFrame(payload) + if n < 0 { + return -1 + } + c.enterDraining(&ApplicationError{Code: code, Reason: reason}) + return n +} + func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payload []byte) int { if c.side == serverSide { // Clients should never send HANDSHAKE_DONE. @@ -369,6 +471,8 @@ func (c *Conn) handleHandshakeDoneFrame(now time.Time, space numberSpace, payloa c.abort(now, localTransportError(errProtocolViolation)) return -1 } - c.confirmHandshake(now) + if !c.isClosingOrDraining() { + c.confirmHandshake(now) + } return 1 } diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index 9d315fb39..00b02c2a3 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -16,6 +16,8 @@ import ( // // If sending is blocked by pacing, it returns the next time // a datagram may be sent. +// +// If sending is blocked indefinitely, it returns the zero Time. func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Assumption: The congestion window is not underutilized. // If congestion control, pacing, and anti-amplification all permit sending, @@ -39,6 +41,9 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // If anti-amplification blocks sending, then no packet can be sent. return next } + if !c.sendOK(now) { + return time.Time{} + } // We may still send ACKs, even if congestion control or pacing limit sending. // Prepare to write a datagram of at most maxSendSize bytes. @@ -54,12 +59,12 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Initial packet. pad := false var sentInitial *sentPacket - if k := c.wkeys[initialSpace]; k.isSet() { + if c.keysInitial.canWrite() { pnumMaxAcked := c.acks[initialSpace].largestSeen() pnum := c.loss.nextNumber(initialSpace) p := longPacket{ ptype: packetTypeInitial, - version: 1, + version: quicVersion1, num: pnum, dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), @@ -69,7 +74,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetTypeInitial, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } - sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p) + sentInitial = c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysInitial.w, p) if sentInitial != nil { // Client initial packets need to be sent in a datagram padded to // at least 1200 bytes. We can't add the padding yet, however, @@ -81,12 +86,12 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } // Handshake packet. - if k := c.wkeys[handshakeSpace]; k.isSet() { + if c.keysHandshake.canWrite() { pnumMaxAcked := c.acks[handshakeSpace].largestSeen() pnum := c.loss.nextNumber(handshakeSpace) p := longPacket{ ptype: packetTypeHandshake, - version: 1, + version: quicVersion1, num: pnum, dstConnID: dstConnID, srcConnID: c.connIDState.srcConnID(), @@ -96,7 +101,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetTypeHandshake, pnum, p.srcConnID, p.dstConnID, c.w.payload()) } - if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, p); sent != nil { + if sent := c.w.finishProtectedLongHeaderPacket(pnumMaxAcked, c.keysHandshake.w, p); sent != nil { c.loss.packetSent(now, handshakeSpace, sent) if c.side == clientSide { // "[...] a client MUST discard Initial keys when it first @@ -108,7 +113,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } // 1-RTT packet. - if k := c.wkeys[appDataSpace]; k.isSet() { + if c.keysAppData.canWrite() { pnumMaxAcked := c.acks[appDataSpace].largestSeen() pnum := c.loss.nextNumber(appDataSpace) c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) @@ -123,7 +128,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { if logPackets { logSentPacket(c, packetType1RTT, pnum, nil, dstConnID, c.w.payload()) } - if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, k); sent != nil { + if sent := c.w.finish1RTTPacket(pnum, pnumMaxAcked, dstConnID, &c.keysAppData); sent != nil { c.loss.packetSent(now, appDataSpace, sent) } } @@ -152,7 +157,10 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { sentInitial.inFlight = true } } - if k := c.wkeys[initialSpace]; k.isSet() { + // If we're a client and this Initial packet is coalesced + // with a Handshake packet, then we've discarded Initial keys + // since constructing the packet and shouldn't record it as in-flight. + if c.keysInitial.canWrite() { c.loss.packetSent(now, initialSpace, sentInitial) } } @@ -162,23 +170,8 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { } func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, limit ccLimit) { - if c.errForPeer != nil { - // This is the bare minimum required to send a CONNECTION_CLOSE frame - // when closing a connection immediately, for example in response to a - // protocol error. - // - // This does not handle the closing and draining states - // (https://www.rfc-editor.org/rfc/rfc9000.html#section-10.2), - // but it's enough to let us write tests that result in a CONNECTION_CLOSE, - // and have those tests still pass when we finish implementing - // connection shutdown. - // - // TODO: Finish implementing connection shutdown. - if !c.connCloseSent[space] { - c.exited = true - c.appendConnectionCloseFrame(c.errForPeer) - c.connCloseSent[space] = true - } + if c.lifetime.localErr != nil { + c.appendConnectionCloseFrame(now, space, c.lifetime.localErr) return } @@ -204,16 +197,23 @@ func (c *Conn) appendFrames(now time.Time, space numberSpace, pnum packetNumber, // All frames other than ACK and PADDING are ack-eliciting, // so if the packet is ack-eliciting we've added additional // frames to it. - if shouldSendAck || c.w.sent.ackEliciting { - // Either we are willing to send an ACK-only packet, - // or we've added additional frames. - c.acks[space].sentAck() - } else { + if !shouldSendAck && !c.w.sent.ackEliciting { // There's nothing in this packet but ACK frames, and // we don't want to send an ACK-only packet at this time. // Abandoning the packet means we wrote an ACK frame for // nothing, but constructing the frame is cheap. c.w.abandonPacket() + return + } + // Either we are willing to send an ACK-only packet, + // or we've added additional frames. + c.acks[space].sentAck() + if !c.w.sent.ackEliciting && c.keysAppData.needAckEliciting() { + // The peer has initiated a key update. + // We haven't sent them any packets yet in the new phase. + // Make this an ack-eliciting packet. + // Their ack of this packet will complete the key update. + c.w.appendPingFrame() } }() } @@ -322,11 +322,20 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { return c.w.appendAckFrame(seen, d) } -func (c *Conn) appendConnectionCloseFrame(err error) { - // TODO: Send application errors. +func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { + c.lifetime.connCloseSentTime = now switch e := err.(type) { case localTransportError: c.w.appendConnectionCloseTransportFrame(transportError(e), 0, "") + case *ApplicationError: + if space != appDataSpace { + // "CONNECTION_CLOSE frames signaling application errors (type 0x1d) + // MUST only appear in the application data packet number space." + // https://www.rfc-editor.org/rfc/rfc9000#section-12.5-2.2 + c.w.appendConnectionCloseTransportFrame(errApplicationError, 0, "") + } else { + c.w.appendConnectionCloseApplicationFrame(e.Code, e.Reason) + } default: // TLS alerts are sent using error codes [0x0100,0x01ff). // https://www.rfc-editor.org/rfc/rfc9000#section-20.1-2.36.1 @@ -335,8 +344,8 @@ func (c *Conn) appendConnectionCloseFrame(err error) { // tls.AlertError is a uint8, so this can't exceed 0x01ff. code := errTLSBase + transportError(alert) c.w.appendConnectionCloseTransportFrame(code, 0, "") - return + } else { + c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") } - c.w.appendConnectionCloseTransportFrame(errInternal, 0, "") } } diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index 0ede284e2..a0793297e 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -18,29 +18,40 @@ type streamsState struct { streamsMu sync.Mutex streams map[streamID]*Stream - opened [streamTypeCount]int64 // number of streams opened by us + + // Limits on the number of streams, indexed by streamType. + localLimit [streamTypeCount]localStreamLimits + remoteLimit [streamTypeCount]remoteStreamLimits // Peer configuration provided in transport parameters. peerInitialMaxStreamDataRemote [streamTypeCount]int64 // streams opened by us peerInitialMaxStreamDataBidiLocal int64 // streams opened by them - // Streams with frames to send are stored in a circular linked list. - // sendHead is the next stream to write, or nil if there are no streams - // with data to send. sendTail is the last stream to write. - needSend atomic.Bool - sendMu sync.Mutex - sendHead *Stream - sendTail *Stream + // Connection-level flow control. + inflow connInflow + outflow connOutflow + + // Streams with frames to send are stored in one of two circular linked lists, + // depending on whether they require connection-level flow control. + needSend atomic.Bool + sendMu sync.Mutex + queueMeta streamRing // streams with any non-flow-controlled frames + queueData streamRing // streams with only flow-controlled frames } func (c *Conn) streamsInit() { c.streams.streams = make(map[streamID]*Stream) c.streams.queue = newQueue[*Stream]() + c.streams.localLimit[bidiStream].init() + c.streams.localLimit[uniStream].init() + c.streams.remoteLimit[bidiStream].init(c.config.maxBidiRemoteStreams()) + c.streams.remoteLimit[uniStream].init(c.config.maxUniRemoteStreams()) + c.inflowInit() } // AcceptStream waits for and returns the next stream created by the peer. func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { - return c.streams.queue.getWithHooks(ctx, c.testHooks) + return c.streams.queue.get(ctx, c.testHooks) } // NewStream creates a stream. @@ -60,19 +71,20 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) { } func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) { - // TODO: Stream limits. c.streams.streamsMu.Lock() defer c.streams.streamsMu.Unlock() - num := c.streams.opened[styp] - c.streams.opened[styp]++ + num, err := c.streams.localLimit[styp].open(ctx, c) + if err != nil { + return nil, err + } s := newStream(c, newStreamID(c.side, styp, num)) - s.outmaxbuf = c.config.streamWriteBufferSize() + s.outmaxbuf = c.config.maxStreamWriteBufferSize() s.outwin = c.streams.peerInitialMaxStreamDataRemote[styp] if styp == bidiStream { - s.inmaxbuf = c.config.streamReadBufferSize() - s.inwin = c.config.streamReadBufferSize() + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() } s.inUnlock() s.outUnlock() @@ -122,20 +134,50 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) c.streams.streamsMu.Lock() defer c.streams.streamsMu.Unlock() - if s := c.streams.streams[id]; s != nil { + s, isOpen := c.streams.streams[id] + if s != nil { return s } - // TODO: Check for closed streams, once we support closing streams. + + num := id.num() + styp := id.streamType() if id.initiator() == c.side { + if num < c.streams.localLimit[styp].opened { + // This stream was created by us, and has been closed. + return nil + } + // Received a frame for a stream that should be originated by us, + // but which we never created. c.abort(now, localTransportError(errStreamState)) return nil + } else { + // if isOpen, this is a stream that was implicitly opened by a + // previous frame for a larger-numbered stream, but we haven't + // actually created it yet. + if !isOpen && num < c.streams.remoteLimit[styp].opened { + // This stream was created by the peer, and has been closed. + return nil + } } - s := newStream(c, id) - s.inmaxbuf = c.config.streamReadBufferSize() - s.inwin = c.config.streamReadBufferSize() + prevOpened := c.streams.remoteLimit[styp].opened + if err := c.streams.remoteLimit[styp].open(id); err != nil { + c.abort(now, err) + return nil + } + + // Receiving a frame for a stream implicitly creates all streams + // with the same initiator and type and a lower number. + // Add a nil entry to the streams map for each implicitly created stream. + for n := newStreamID(id.initiator(), id.streamType(), prevOpened); n < id; n += 4 { + c.streams.streams[n] = nil + } + + s = newStream(c, id) + s.inmaxbuf = c.config.maxStreamReadBufferSize() + s.inwin = c.config.maxStreamReadBufferSize() if id.streamType() == bidiStream { - s.outmaxbuf = c.config.streamWriteBufferSize() + s.outmaxbuf = c.config.maxStreamWriteBufferSize() s.outwin = c.streams.peerInitialMaxStreamDataBidiLocal } s.inUnlock() @@ -146,34 +188,85 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType) return s } -// queueStreamForSend marks a stream as containing frames that need sending. -func (c *Conn) queueStreamForSend(s *Stream) { +// maybeQueueStreamForSend marks a stream as containing frames that need sending. +func (c *Conn) maybeQueueStreamForSend(s *Stream, state streamState) { + if state.wantQueue() == state.inQueue() { + return // already on the right queue + } c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() - if s.next != nil { - // Already in the queue. - return - } - if c.streams.sendHead == nil { - // The queue was empty. - c.streams.sendHead = s - c.streams.sendTail = s - s.next = s - } else { - // Insert this stream at the end of the queue. - c.streams.sendTail.next = s - c.streams.sendTail = s - s.next = c.streams.sendHead - } + state = s.state.load() // may have changed while waiting + c.queueStreamForSendLocked(s, state) + c.streams.needSend.Store(true) c.wake() } +// queueStreamForSendLocked moves a stream to the correct send queue, +// or removes it from all queues. +// +// state is the last known stream state. +func (c *Conn) queueStreamForSendLocked(s *Stream, state streamState) { + for { + wantQueue := state.wantQueue() + inQueue := state.inQueue() + if inQueue == wantQueue { + return // already on the right queue + } + + switch inQueue { + case metaQueue: + c.streams.queueMeta.remove(s) + case dataQueue: + c.streams.queueData.remove(s) + } + + switch wantQueue { + case metaQueue: + c.streams.queueMeta.append(s) + state = s.state.set(streamQueueMeta, streamQueueMeta|streamQueueData) + case dataQueue: + c.streams.queueData.append(s) + state = s.state.set(streamQueueData, streamQueueMeta|streamQueueData) + case noQueue: + state = s.state.set(0, streamQueueMeta|streamQueueData) + } + + // If the stream state changed while we were moving the stream, + // we might now be on the wrong queue. + // + // For example: + // - stream has data to send: streamOutSendData|streamQueueData + // - appendStreamFrames sends all the data: streamQueueData + // - concurrently, more data is written: streamOutSendData|streamQueueData + // - appendStreamFrames calls us with the last state it observed + // (streamQueueData). + // - We remove the stream from the queue and observe the updated state: + // streamOutSendData + // - We realize that the stream needs to go back on the data queue. + // + // Go back around the loop to confirm we're on the correct queue. + } +} + // appendStreamFrames writes stream-related frames to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) bool { + // MAX_DATA + if !c.appendMaxDataFrame(w, pnum, pto) { + return false + } + + // MAX_STREAM_DATA + if !c.streams.remoteLimit[uniStream].appendFrame(w, uniStream, pnum, pto) { + return false + } + if !c.streams.remoteLimit[bidiStream].appendFrame(w, bidiStream, pnum, pto) { + return false + } + if pto { return c.appendStreamFramesPTO(w, pnum) } @@ -182,65 +275,107 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) } c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() - for { - s := c.streams.sendHead - const pto = false - + // queueMeta contains streams with non-flow-controlled frames to send. + for c.streams.queueMeta.head != nil { + s := c.streams.queueMeta.head state := s.state.load() - if state&streamInSend != 0 { + if state&(streamQueueMeta|streamConnRemoved) != streamQueueMeta { + panic("BUG: queueMeta stream is not streamQueueMeta") + } + if state&streamInSendMeta != 0 { s.ingate.lock() ok := s.appendInFramesLocked(w, pnum, pto) state = s.inUnlockNoQueue() if !ok { return false } + if state&streamInSendMeta != 0 { + panic("BUG: streamInSendMeta set after successfully appending frames") + } } - - if state&streamOutSend != 0 { - avail := w.avail() + if state&streamOutSendMeta != 0 { s.outgate.lock() + // This might also append flow-controlled frames if we have any + // and available conn-level quota. That's fine. ok := s.appendOutFramesLocked(w, pnum, pto) state = s.outUnlockNoQueue() - if !ok { - // We've sent some data for this stream, but it still has more to send. - // If the stream got a reasonable chance to put data in a packet, - // advance sendHead to the next stream in line, to avoid starvation. - // We'll come back to this stream after going through the others. - // - // If the packet was already mostly out of space, leave sendHead alone - // and come back to this stream again on the next packet. - if avail > 512 { - c.streams.sendHead = s.next - c.streams.sendTail = s - } + // We're checking both ok and state, because appendOutFramesLocked + // might have filled up the packet with flow-controlled data. + // If so, we want to move the stream to queueData for any remaining frames. + if !ok && state&streamOutSendMeta != 0 { return false } + if state&streamOutSendMeta != 0 { + panic("BUG: streamOutSendMeta set after successfully appending frames") + } } - - if state == streamInDone|streamOutDone { + // We've sent all frames for this stream, so remove it from the send queue. + c.streams.queueMeta.remove(s) + if state&(streamInDone|streamOutDone) == streamInDone|streamOutDone { // Stream is finished, remove it from the conn. - s.state.set(streamConnRemoved, streamConnRemoved) + state = s.state.set(streamConnRemoved, streamQueueMeta|streamConnRemoved) delete(c.streams.streams, s.id) - // TODO: Provide the peer with additional stream quota (MAX_STREAMS). + // Record finalization of remote streams, to know when + // to extend the peer's stream limit. + if s.id.initiator() != c.side { + c.streams.remoteLimit[s.id.streamType()].close() + } + } else { + state = s.state.set(0, streamQueueMeta|streamConnRemoved) } - - next := s.next - s.next = nil - if (next == s) != (s == c.streams.sendTail) { - panic("BUG: sendable stream list state is inconsistent") + // The stream may have flow-controlled data to send, + // or something might have added non-flow-controlled frames after we + // unlocked the stream. + // If so, put the stream back on a queue. + c.queueStreamForSendLocked(s, state) + } + // queueData contains streams with flow-controlled frames. + for c.streams.queueData.head != nil { + avail := c.streams.outflow.avail() + if avail == 0 { + break // no flow control quota available + } + s := c.streams.queueData.head + s.outgate.lock() + ok := s.appendOutFramesLocked(w, pnum, pto) + state := s.outUnlockNoQueue() + if !ok { + // We've sent some data for this stream, but it still has more to send. + // If the stream got a reasonable chance to put data in a packet, + // advance sendHead to the next stream in line, to avoid starvation. + // We'll come back to this stream after going through the others. + // + // If the packet was already mostly out of space, leave sendHead alone + // and come back to this stream again on the next packet. + if avail > 512 { + c.streams.queueData.head = s.next + } + return false } - if s == c.streams.sendTail { - // This was the last stream. - c.streams.sendHead = nil - c.streams.sendTail = nil - c.streams.needSend.Store(false) + if state&streamQueueData == 0 { + panic("BUG: queueData stream is not streamQueueData") + } + if state&streamOutSendData != 0 { + // We must have run out of connection-level flow control: + // appendOutFramesLocked says it wrote all it can, but there's + // still data to send. + // + // Advance sendHead to the next stream in line to avoid starvation. + if c.streams.outflow.avail() != 0 { + panic("BUG: streamOutSendData set and flow control available after send") + } + c.streams.queueData.head = s.next return true } - // We've sent all data for this stream, so remove it from the list. - c.streams.sendTail.next = next - c.streams.sendHead = next + c.streams.queueData.remove(s) + state = s.state.set(0, streamQueueData) + c.queueStreamForSendLocked(s, state) + } + if c.streams.queueMeta.head == nil && c.streams.queueData.head == nil { + c.streams.needSend.Store(false) } + return true } // appendStreamFramesPTO writes stream-related frames to the current packet @@ -251,6 +386,7 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() + const pto = true for _, s := range c.streams.streams { const pto = true s.ingate.lock() @@ -259,6 +395,7 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { if !inOK { return false } + s.outgate.lock() outOK := s.appendOutFramesLocked(w, pnum, pto) s.outUnlockNoQueue() @@ -268,3 +405,37 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { } return true } + +// A streamRing is a circular linked list of streams. +type streamRing struct { + head *Stream +} + +// remove removes s from the ring. +// s must be on the ring. +func (r *streamRing) remove(s *Stream) { + if s.next == s { + r.head = nil // s was the last stream in the ring + } else { + s.prev.next = s.next + s.next.prev = s.prev + if r.head == s { + r.head = s.next + } + } +} + +// append places s at the last position in the ring. +// s must not be attached to any ring. +func (r *streamRing) append(s *Stream) { + if r.head == nil { + r.head = s + s.next = s + s.prev = s + } else { + s.prev = r.head.prev + s.next = r.head + s.prev.next = s + s.next.prev = s + } +} diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 9bbc994b1..69f982c3a 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -10,15 +10,13 @@ import ( "context" "fmt" "io" + "math" "testing" ) func TestStreamsCreate(t *testing.T) { ctx := canceledContext() - tc := newTestConn(t, clientSide, func(p *transportParameters) { - p.initialMaxStreamDataBidiLocal = 100 - p.initialMaxStreamDataBidiRemote = 100 - }) + tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() c, err := tc.conn.NewStream(ctx) @@ -126,7 +124,7 @@ func TestStreamsBlockingAccept(t *testing.T) { } } -func TestStreamsStreamNotCreated(t *testing.T) { +func TestStreamsLocalStreamNotCreated(t *testing.T) { // "An endpoint MUST terminate the connection with error STREAM_STATE_ERROR // if it receives a STREAM frame for a locally initiated stream that has // not yet been created [...]" @@ -144,13 +142,39 @@ func TestStreamsStreamNotCreated(t *testing.T) { }) } +func TestStreamsLocalStreamClosed(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters) + s.CloseWrite() + tc.wantFrame("FIN for closed stream", + packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, uniStream, 0), + fin: true, + data: []byte{}, + }) + tc.writeAckForAll() + + tc.writeFrames(packetType1RTT, debugFrameStopSending{ + id: newStreamID(clientSide, uniStream, 0), + }) + tc.wantIdle("frame for finalized stream is ignored") + + // ACKing the last stream packet should have cleaned up the stream. + // Check that we don't have any state left. + if got := len(tc.conn.streams.streams); got != 0 { + t.Fatalf("after close, len(tc.conn.streams.streams) = %v, want 0", got) + } + if tc.conn.streams.queueMeta.head != nil { + t.Fatalf("after close, stream send queue is not empty; should be") + } +} + func TestStreamsStreamSendOnly(t *testing.T) { // "An endpoint MUST terminate the connection with error STREAM_STATE_ERROR // if it receives a STREAM frame for a locally initiated stream that has // not yet been created [...]" // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.8-3 ctx := canceledContext() - tc := newTestConn(t, serverSide) + tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() c, err := tc.conn.NewSendOnlyStream(ctx) @@ -183,7 +207,7 @@ func TestStreamsWriteQueueFairness(t *testing.T) { p.initialMaxData = 1<<62 - 1 p.initialMaxStreamDataBidiRemote = dataLen }, func(c *Config) { - c.StreamWriteBufferSize = dataLen + c.MaxStreamWriteBufferSize = dataLen }) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -342,3 +366,115 @@ func TestStreamsShutdown(t *testing.T) { }) } } + +func TestStreamsCreateAndCloseRemote(t *testing.T) { + // This test exercises creating new streams in response to frames + // from the peer, and cleaning up after streams are fully closed. + // + // It's overfitted to the current implementation, but works through + // a number of corner cases in that implementation. + // + // Disable verbose logging in this test: It sends a lot of packets, + // and they're not especially interesting on their own. + defer func(vv bool) { + *testVV = vv + }(*testVV) + *testVV = false + ctx := canceledContext() + tc := newTestConn(t, serverSide, permissiveTransportParameters) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + type op struct { + id streamID + } + type streamOp op + type resetOp op + type acceptOp op + const noStream = math.MaxInt64 + stringID := func(id streamID) string { + return fmt.Sprintf("%v/%v", id.streamType(), id.num()) + } + for _, op := range []any{ + "opening bidi/5 implicitly opens bidi/0-4", + streamOp{newStreamID(clientSide, bidiStream, 5)}, + acceptOp{newStreamID(clientSide, bidiStream, 5)}, + "bidi/3 was implicitly opened", + streamOp{newStreamID(clientSide, bidiStream, 3)}, + acceptOp{newStreamID(clientSide, bidiStream, 3)}, + resetOp{newStreamID(clientSide, bidiStream, 3)}, + "bidi/3 is done, frames for it are discarded", + streamOp{newStreamID(clientSide, bidiStream, 3)}, + "open and close some uni streams as well", + streamOp{newStreamID(clientSide, uniStream, 0)}, + acceptOp{newStreamID(clientSide, uniStream, 0)}, + streamOp{newStreamID(clientSide, uniStream, 1)}, + acceptOp{newStreamID(clientSide, uniStream, 1)}, + streamOp{newStreamID(clientSide, uniStream, 2)}, + acceptOp{newStreamID(clientSide, uniStream, 2)}, + resetOp{newStreamID(clientSide, uniStream, 1)}, + resetOp{newStreamID(clientSide, uniStream, 0)}, + resetOp{newStreamID(clientSide, uniStream, 2)}, + "closing an implicitly opened stream causes us to accept it", + resetOp{newStreamID(clientSide, bidiStream, 0)}, + acceptOp{newStreamID(clientSide, bidiStream, 0)}, + resetOp{newStreamID(clientSide, bidiStream, 1)}, + acceptOp{newStreamID(clientSide, bidiStream, 1)}, + resetOp{newStreamID(clientSide, bidiStream, 2)}, + acceptOp{newStreamID(clientSide, bidiStream, 2)}, + "stream bidi/3 was reset previously", + resetOp{newStreamID(clientSide, bidiStream, 3)}, + resetOp{newStreamID(clientSide, bidiStream, 4)}, + acceptOp{newStreamID(clientSide, bidiStream, 4)}, + "stream bidi/5 was reset previously", + resetOp{newStreamID(clientSide, bidiStream, 5)}, + "stream bidi/6 was not implicitly opened", + resetOp{newStreamID(clientSide, bidiStream, 6)}, + acceptOp{newStreamID(clientSide, bidiStream, 6)}, + } { + if _, ok := op.(acceptOp); !ok { + if s, err := tc.conn.AcceptStream(ctx); err == nil { + t.Fatalf("accepted stream %v, want none", stringID(s.id)) + } + } + switch op := op.(type) { + case string: + t.Log("# " + op) + case streamOp: + t.Logf("open stream %v", stringID(op.id)) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: streamID(op.id), + }) + case resetOp: + t.Logf("reset stream %v", stringID(op.id)) + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: op.id, + }) + case acceptOp: + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("AcceptStream() = %q; want stream %v", err, stringID(op.id)) + } + if s.id != op.id { + t.Fatalf("accepted stram %v; want stream %v", err, stringID(op.id)) + } + t.Logf("accepted stream %v", stringID(op.id)) + // Immediately close the stream, so the stream becomes done when the + // peer closes its end. + s.CloseContext(ctx) + } + p := tc.readPacket() + if p != nil { + tc.writeFrames(p.ptype, debugFrameAck{ + ranges: []i64range[packetNumber]{{0, p.num + 1}}, + }) + } + } + // Every stream should be fully closed now. + // Check that we don't have any state left. + if got := len(tc.conn.streams.streams); got != 0 { + t.Fatalf("after test, len(tc.conn.streams.streams) = %v, want 0", got) + } + if tc.conn.streams.queueMeta.head != nil { + t.Fatalf("after test, stream send queue is not empty; should be") + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index ea720d575..6a359e89a 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -74,12 +74,14 @@ func (d testDatagram) String() string { } type testPacket struct { - ptype packetType - version uint32 - num packetNumber - dstConnID []byte - srcConnID []byte - frames []debugFrame + ptype packetType + version uint32 + num packetNumber + keyPhaseBit bool + keyNumber int + dstConnID []byte + srcConnID []byte + frames []debugFrame } func (p testPacket) String() string { @@ -100,25 +102,33 @@ func (p testPacket) String() string { return b.String() } +// maxTestKeyPhases is the maximum number of 1-RTT keys we'll generate in a test. +const maxTestKeyPhases = 3 + // A testConn is a Conn whose external interactions (sending and receiving packets, // setting timers) can be manipulated in tests. type testConn struct { t *testing.T conn *Conn + listener *testListener now time.Time timer time.Time timerLastFired time.Time idlec chan struct{} // only accessed on the conn's loop - // Read and write keys are distinct from the conn's keys, + // Keys are distinct from the conn's keys, // because the test may know about keys before the conn does. // For example, when sending a datagram with coalesced // Initial and Handshake packets to a client conn, // we use Handshake keys to encrypt the packet. // The client only acquires those keys when it processes // the Initial packet. - rkeys [numberSpaceCount]keyData // for packets sent to the conn - wkeys [numberSpaceCount]keyData // for packets sent by the conn + keysInitial fixedKeyPair + keysHandshake fixedKeyPair + rkeyAppData test1RTTKeys + wkeyAppData test1RTTKeys + rsecrets [numberSpaceCount]keySecret + wsecrets [numberSpaceCount]keySecret // testConn uses a test hook to snoop on the conn's TLS events. // CRYPTO data produced by the conn's QUICConn is placed in @@ -142,19 +152,29 @@ type testConn struct { sentFrames []debugFrame lastPacket *testPacket + recvDatagram chan *datagram + // Transport parameters sent by the conn. sentTransportParameters *transportParameters // Frame types to ignore in tests. ignoreFrames map[byte]bool + // Values to set in packets sent to the conn. + sendKeyNumber int + sendKeyPhaseBit bool + asyncTestState } -type keyData struct { +type test1RTTKeys struct { + hdr headerKey + pkt [maxTestKeyPhases]packetKey +} + +type keySecret struct { suite uint16 secret []byte - k keys } // newTestConn creates a Conn for testing. @@ -173,6 +193,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { }, cryptoDataOut: make(map[tls.QUICEncryptionLevel][]byte), cryptoDataIn: make(map[tls.QUICEncryptionLevel][]byte), + recvDatagram: make(chan *datagram), } t.Cleanup(tc.cleanup) @@ -180,6 +201,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { TLSConfig: newTestTLSConfig(side), } peerProvidedParams := defaultTransportParameters() + peerProvidedParams.initialSrcConnID = testPeerConnID(0) + if side == clientSide { + peerProvidedParams.originalDstConnID = testLocalConnID(-1) + } for _, o := range opts { switch o := o.(type) { case func(*Config): @@ -196,12 +221,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { var initialConnID []byte if side == serverSide { // The initial connection ID for the server is chosen by the client. - // When creating a server-side connection, pick a random connection ID here. - var err error - initialConnID, err = newRandomConnID(0) - if err != nil { - tc.t.Fatal(err) - } + initialConnID = testPeerConnID(-1) } peerQUICConfig := &tls.QUICConfig{TLSConfig: newTestTLSConfig(side.peer())} @@ -213,21 +233,20 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { tc.peerTLSConn.SetTransportParameters(marshalTransportParameters(peerProvidedParams)) tc.peerTLSConn.Start(context.Background()) - conn, err := newConn( + tc.listener = newTestListener(t, config, (*testConnHooks)(tc)) + conn, err := tc.listener.l.newConn( tc.now, side, initialConnID, - netip.MustParseAddrPort("127.0.0.1:443"), - config, - (*testConnListener)(tc), - (*testConnHooks)(tc)) + netip.MustParseAddrPort("127.0.0.1:443")) if err != nil { tc.t.Fatal(err) } tc.conn = conn - tc.wkeys[initialSpace].k = conn.wkeys[initialSpace] - tc.rkeys[initialSpace].k = conn.rkeys[initialSpace] + conn.keysAppData.updateAfter = maxPacketNumber // disable key updates + tc.keysInitial.r = conn.keysInitial.w + tc.keysInitial.w = conn.keysInitial.r tc.wait() return tc @@ -305,6 +324,8 @@ func (tc *testConn) wait() { select { case <-idlec: case <-tc.conn.donec: + // We may have async ops that can proceed now that the conn is done. + tc.wakeAsync() } if fail { panic(fail) @@ -316,6 +337,7 @@ func (tc *testConn) cleanup() { return } tc.conn.exit() + <-tc.conn.donec } func (tc *testConn) logDatagram(text string, d *testDatagram) { @@ -329,12 +351,20 @@ func (tc *testConn) logDatagram(text string, d *testDatagram) { } tc.t.Logf("%v datagram%v", text, pad) for _, p := range d.packets { + var s string switch p.ptype { case packetType1RTT: - tc.t.Logf(" %v pnum=%v", p.ptype, p.num) + s = fmt.Sprintf(" %v pnum=%v", p.ptype, p.num) default: - tc.t.Logf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID) + s = fmt.Sprintf(" %v pnum=%v ver=%v dst={%x} src={%x}", p.ptype, p.num, p.version, p.dstConnID, p.srcConnID) + } + if p.keyPhaseBit { + s += fmt.Sprintf(" KeyPhase") } + if p.keyNumber != 0 { + s += fmt.Sprintf(" keynum=%v", p.keyNumber) + } + tc.t.Log(s) for _, f := range p.frames { tc.t.Logf(" %v", f) } @@ -360,6 +390,7 @@ func (tc *testConn) write(d *testDatagram) { for len(buf) < d.paddedSize { buf = append(buf, 0) } + // TODO: This should use tc.listener.write. tc.conn.sendMsg(&datagram{ b: buf, }) @@ -377,12 +408,14 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { } d := &testDatagram{ packets: []*testPacket{{ - ptype: ptype, - num: tc.peerNextPacketNum[space], - frames: frames, - version: 1, - dstConnID: dstConnID, - srcConnID: tc.peerConnID, + ptype: ptype, + num: tc.peerNextPacketNum[space], + keyNumber: tc.sendKeyNumber, + keyPhaseBit: tc.sendKeyPhaseBit, + frames: frames, + version: quicVersion1, + dstConnID: dstConnID, + srcConnID: tc.peerConnID, }}, } if ptype == packetTypeInitial && tc.conn.side == serverSide { @@ -427,11 +460,10 @@ func (tc *testConn) readDatagram() *testDatagram { tc.wait() tc.sentPackets = nil tc.sentFrames = nil - if len(tc.sentDatagrams) == 0 { + buf := tc.listener.read() + if buf == nil { return nil } - buf := tc.sentDatagrams[0] - tc.sentDatagrams = tc.sentDatagrams[1:] d := tc.parseTestDatagram(buf) // Log the datagram before removing ignored frames. // When things go wrong, it's useful to see all the frames. @@ -576,6 +608,22 @@ func (tc *testConn) wantFrame(expectation string, wantType packetType, want debu } } +// wantFrameType indicates that we expect the Conn to send a frame, +// although we don't care about the contents. +func (tc *testConn) wantFrameType(expectation string, wantType packetType, want debugFrame) { + tc.t.Helper() + got, gotType := tc.readFrame() + if got == nil { + tc.t.Fatalf("%v:\nconnection is idle\nwant %v frame: %v", expectation, wantType, want) + } + if gotType != wantType { + tc.t.Fatalf("%v:\ngot %v packet, want %v\ngot frame: %v", expectation, gotType, wantType, got) + } + if reflect.TypeOf(got) != reflect.TypeOf(want) { + tc.t.Fatalf("%v:\ngot frame: %v\nwant frame of type: %v", expectation, got, want) + } +} + // wantIdle indicates that we expect the Conn to not send any more frames. func (tc *testConn) wantIdle(expectation string) { tc.t.Helper() @@ -609,14 +657,19 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte { for _, f := range p.frames { f.write(&w) } - space := spaceForPacketType(p.ptype) - if !tc.rkeys[space].k.isSet() { - tc.t.Fatalf("sending packet with no %v keys available", space) - return nil - } w.appendPaddingTo(pad) if p.ptype != packetType1RTT { - w.finishProtectedLongHeaderPacket(pnumMaxAcked, tc.rkeys[space].k, longPacket{ + var k fixedKeys + switch p.ptype { + case packetTypeInitial: + k = tc.keysInitial.w + case packetTypeHandshake: + k = tc.keysHandshake.w + } + if !k.isSet() { + tc.t.Fatalf("sending %v packet with no write key", p.ptype) + } + w.finishProtectedLongHeaderPacket(pnumMaxAcked, k, longPacket{ ptype: p.ptype, version: p.version, num: p.num, @@ -624,7 +677,25 @@ func (tc *testConn) encodeTestPacket(p *testPacket, pad int) []byte { srcConnID: p.srcConnID, }) } else { - w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, tc.rkeys[space].k) + if !tc.wkeyAppData.hdr.isSet() { + tc.t.Fatalf("sending 1-RTT packet with no write key") + } + // Somewhat hackish: Generate a temporary updatingKeyPair that will + // always use our desired key phase. + k := &updatingKeyPair{ + w: updatingKeys{ + hdr: tc.wkeyAppData.hdr, + pkt: [2]packetKey{ + tc.wkeyAppData.pkt[p.keyNumber], + tc.wkeyAppData.pkt[p.keyNumber], + }, + }, + updateAfter: maxPacketNumber, + } + if p.keyPhaseBit { + k.phase |= keyPhaseBit + } + w.finish1RTTPacket(p.num, pnumMaxAcked, p.dstConnID, k) } return w.datagram() } @@ -640,13 +711,19 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { break } ptype := getPacketType(buf) - space := spaceForPacketType(ptype) - if !tc.wkeys[space].k.isSet() { - tc.t.Fatalf("no keys for space %v, packet type %v", space, ptype) - } if isLongHeader(buf[0]) { + var k fixedKeyPair + switch ptype { + case packetTypeInitial: + k = tc.keysInitial + case packetTypeHandshake: + k = tc.keysHandshake + } + if !k.canRead() { + tc.t.Fatalf("reading %v packet with no read key", ptype) + } var pnumMax packetNumber // TODO: Track packet numbers. - p, n := parseLongHeaderPacket(buf, tc.wkeys[space].k, pnumMax) + p, n := parseLongHeaderPacket(buf, k.r, pnumMax) if n < 0 { tc.t.Fatalf("packet parse error") } @@ -664,22 +741,45 @@ func (tc *testConn) parseTestDatagram(buf []byte) *testDatagram { }) buf = buf[n:] } else { + if !tc.rkeyAppData.hdr.isSet() { + tc.t.Fatalf("reading 1-RTT packet with no read key") + } var pnumMax packetNumber // TODO: Track packet numbers. - p, n := parse1RTTPacket(buf, tc.wkeys[space].k, len(tc.peerConnID), pnumMax) - if n < 0 { - tc.t.Fatalf("packet parse error") + pnumOff := 1 + len(tc.peerConnID) + // Try unprotecting the packet with the first maxTestKeyPhases keys. + var phase int + var pnum packetNumber + var hdr []byte + var pay []byte + var err error + for phase = 0; phase < maxTestKeyPhases; phase++ { + b := append([]byte{}, buf...) + hdr, pay, pnum, err = tc.rkeyAppData.hdr.unprotect(b, pnumOff, pnumMax) + if err != nil { + tc.t.Fatalf("1-RTT packet header parse error") + } + k := tc.rkeyAppData.pkt[phase] + pay, err = k.unprotect(hdr, pay, pnum) + if err == nil { + break + } } - frames, err := tc.parseTestFrames(p.payload) + if err != nil { + tc.t.Fatalf("1-RTT packet payload parse error") + } + frames, err := tc.parseTestFrames(pay) if err != nil { tc.t.Fatal(err) } d.packets = append(d.packets, &testPacket{ - ptype: packetType1RTT, - num: p.num, - dstConnID: buf[1:][:len(tc.peerConnID)], - frames: frames, + ptype: packetType1RTT, + num: pnum, + dstConnID: hdr[1:][:len(tc.peerConnID)], + keyPhaseBit: hdr[0]&keyPhaseBit != 0, + keyNumber: phase, + frames: frames, }) - buf = buf[n:] + buf = buf[len(buf):] } } // This is rather hackish: If the last frame in the last packet @@ -745,12 +845,7 @@ type testConnHooks testConn // and verify that both sides of the connection are getting // matching keys. func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { - setKey := func(keys *[numberSpaceCount]keyData, e tls.QUICEvent) { - k, err := newKeys(e.Suite, e.Data) - if err != nil { - tc.t.Errorf("newKeys: %v", err) - return - } + checkKey := func(typ string, secrets *[numberSpaceCount]keySecret, e tls.QUICEvent) { var space numberSpace switch { case e.Level == tls.QUICEncryptionLevelHandshake: @@ -761,25 +856,37 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { tc.t.Errorf("unexpected encryption level %v", e.Level) return } - s := "read" - if keys == &tc.wkeys { - s = "write" + if secrets[space].secret == nil { + secrets[space].suite = e.Suite + secrets[space].secret = append([]byte{}, e.Data...) + } else if secrets[space].suite != e.Suite || !bytes.Equal(secrets[space].secret, e.Data) { + tc.t.Errorf("%v key mismatch for level for level %v", typ, e.Level) } - if keys[space].k.isSet() { - if keys[space].suite != e.Suite || !bytes.Equal(keys[space].secret, e.Data) { - tc.t.Errorf("%v key mismatch for level for level %v", s, e.Level) - } - return + } + setAppDataKey := func(suite uint16, secret []byte, k *test1RTTKeys) { + k.hdr.init(suite, secret) + for i := 0; i < len(k.pkt); i++ { + k.pkt[i].init(suite, secret) + secret = updateSecret(suite, secret) } - keys[space].suite = e.Suite - keys[space].secret = append([]byte{}, e.Data...) - keys[space].k = k } switch e.Kind { case tls.QUICSetReadSecret: - setKey(&tc.rkeys, e) + checkKey("write", &tc.wsecrets, e) + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + tc.keysHandshake.w.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) + } case tls.QUICSetWriteSecret: - setKey(&tc.wkeys, e) + checkKey("read", &tc.rsecrets, e) + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + tc.keysHandshake.r.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) + } case tls.QUICWriteData: tc.cryptoDataOut[e.Level] = append(tc.cryptoDataOut[e.Level], e.Data...) tc.peerTLSConn.HandleData(e.Level, e.Data) @@ -790,9 +897,21 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { case tls.QUICNoEvent: return case tls.QUICSetReadSecret: - setKey(&tc.wkeys, e) + checkKey("write", &tc.rsecrets, e) + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + tc.keysHandshake.r.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + setAppDataKey(e.Suite, e.Data, &tc.rkeyAppData) + } case tls.QUICSetWriteSecret: - setKey(&tc.rkeys, e) + checkKey("read", &tc.wsecrets, e) + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + tc.keysHandshake.w.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + setAppDataKey(e.Suite, e.Data, &tc.wkeyAppData) + } case tls.QUICWriteData: tc.cryptoDataIn[e.Level] = append(tc.cryptoDataIn[e.Level], e.Data...) case tls.QUICTransportParameters: @@ -844,6 +963,10 @@ func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { return testLocalConnID(seq), nil } +func (tc *testConnHooks) timeNow() time.Time { + return tc.now +} + // testLocalConnID returns the connection ID with a given sequence number // used by a Conn under test. func testLocalConnID(seq int64) []byte { @@ -861,14 +984,6 @@ func testPeerConnID(seq int64) []byte { return []byte{0xbe, 0xee, 0xff, byte(seq)} } -// testConnListener implements connListener. -type testConnListener testConn - -func (tc *testConnListener) sendDatagram(p []byte, addr netip.AddrPort) error { - tc.sentDatagrams = append(tc.sentDatagrams, append([]byte(nil), p...)) - return nil -} - // canceledContext returns a canceled Context. // // Functions which take a context preference progress over cancelation. diff --git a/internal/quic/crypto_stream.go b/internal/quic/crypto_stream.go index 75dea87d0..8aa8f7b82 100644 --- a/internal/quic/crypto_stream.go +++ b/internal/quic/crypto_stream.go @@ -118,7 +118,7 @@ func (s *cryptoStream) ackOrLoss(start, end int64, fate packetFate) { // copy the data it wants into position. func (s *cryptoStream) dataToSend(pto bool, f func(off, size int64) (sent int64)) { for { - off, size := dataToSend(s.out, s.outunsent, s.outacked, pto) + off, size := dataToSend(s.out.start, s.out.end, s.outunsent, s.outacked, pto) if size == 0 { return } diff --git a/internal/quic/errors.go b/internal/quic/errors.go index f15685932..8e01bb7cb 100644 --- a/internal/quic/errors.go +++ b/internal/quic/errors.go @@ -114,7 +114,13 @@ type ApplicationError struct { Reason string } -func (e ApplicationError) Error() string { +func (e *ApplicationError) Error() string { // TODO: Include the Reason string here, but sanitize it first. return fmt.Sprintf("AppError %v", e.Code) } + +// Is reports a match if err is an *ApplicationError with a matching Code. +func (e *ApplicationError) Is(err error) bool { + e2, ok := err.(*ApplicationError) + return ok && e2.Code == e.Code +} diff --git a/internal/quic/gate.go b/internal/quic/gate.go index 27ab07a6f..a2fb53711 100644 --- a/internal/quic/gate.go +++ b/internal/quic/gate.go @@ -47,13 +47,11 @@ func (g *gate) lock() (set bool) { } // waitAndLock waits until the condition is set before acquiring the gate. -func (g *gate) waitAndLock() { - <-g.set -} - -// waitAndLockContext waits until the condition is set before acquiring the gate. -// If the context expires, waitAndLockContext returns an error and does not acquire the gate. -func (g *gate) waitAndLockContext(ctx context.Context) error { +// If the context expires, waitAndLock returns an error and does not acquire the gate. +func (g *gate) waitAndLock(ctx context.Context, testHooks connTestHooks) error { + if testHooks != nil { + return testHooks.waitUntil(ctx, g.lockIfSet) + } select { case <-g.set: return nil @@ -67,23 +65,6 @@ func (g *gate) waitAndLockContext(ctx context.Context) error { } } -// waitWithLock releases an acquired gate until the condition is set. -// The caller must have previously acquired the gate. -// Upon return from waitWithLock, the gate will still be held. -// If waitWithLock returns nil, the condition is set. -func (g *gate) waitWithLock(ctx context.Context) error { - g.unlock(false) - err := g.waitAndLockContext(ctx) - if err != nil { - if g.lock() { - // The condition was set in between the context expiring - // and us reacquiring the gate. - err = nil - } - } - return err -} - // lockIfSet acquires the gate if and only if the condition is set. func (g *gate) lockIfSet() (acquired bool) { select { diff --git a/internal/quic/gate_test.go b/internal/quic/gate_test.go index 0122e3986..9e84a84bd 100644 --- a/internal/quic/gate_test.go +++ b/internal/quic/gate_test.go @@ -41,37 +41,18 @@ func TestGateLockAndUnlock(t *testing.T) { } } -func TestGateWaitAndLock(t *testing.T) { - g := newGate() - set := false - go func() { - for i := 0; i < 3; i++ { - g.lock() - g.unlock(false) - time.Sleep(1 * time.Millisecond) - } - g.lock() - set = true - g.unlock(true) - }() - g.waitAndLock() - if !set { - t.Errorf("g.waitAndLock() returned before gate was set") - } -} - func TestGateWaitAndLockContext(t *testing.T) { g := newGate() - // waitAndLockContext is canceled + // waitAndLock is canceled ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(1 * time.Millisecond) cancel() }() - if err := g.waitAndLockContext(ctx); err != context.Canceled { - t.Errorf("g.waitAndLockContext() = %v, want context.Canceled", err) + if err := g.waitAndLock(ctx, nil); err != context.Canceled { + t.Errorf("g.waitAndLock() = %v, want context.Canceled", err) } - // waitAndLockContext succeeds + // waitAndLock succeeds set := false go func() { time.Sleep(1 * time.Millisecond) @@ -79,44 +60,16 @@ func TestGateWaitAndLockContext(t *testing.T) { set = true g.unlock(true) }() - if err := g.waitAndLockContext(context.Background()); err != nil { - t.Errorf("g.waitAndLockContext() = %v, want nil", err) + if err := g.waitAndLock(context.Background(), nil); err != nil { + t.Errorf("g.waitAndLock() = %v, want nil", err) } if !set { - t.Errorf("g.waitAndLockContext() returned before gate was set") + t.Errorf("g.waitAndLock() returned before gate was set") } g.unlock(true) - // waitAndLockContext succeeds when the gate is set and the context is canceled - if err := g.waitAndLockContext(ctx); err != nil { - t.Errorf("g.waitAndLockContext() = %v, want nil", err) - } -} - -func TestGateWaitWithLock(t *testing.T) { - g := newGate() - // waitWithLock is canceled - ctx, cancel := context.WithCancel(context.Background()) - go func() { - time.Sleep(1 * time.Millisecond) - cancel() - }() - g.lock() - if err := g.waitWithLock(ctx); err != context.Canceled { - t.Errorf("g.waitWithLock() = %v, want context.Canceled", err) - } - // waitWithLock succeeds - set := false - go func() { - g.lock() - set = true - g.unlock(true) - }() - time.Sleep(1 * time.Millisecond) - if err := g.waitWithLock(context.Background()); err != nil { - t.Errorf("g.waitWithLock() = %v, want nil", err) - } - if !set { - t.Errorf("g.waitWithLock() returned before gate was set") + // waitAndLock succeeds when the gate is set and the context is canceled + if err := g.waitAndLock(ctx, nil); err != nil { + t.Errorf("g.waitAndLock() = %v, want nil", err) } } @@ -138,5 +91,5 @@ func TestGateUnlockFunc(t *testing.T) { g.lock() defer g.unlockFunc(func() bool { return true }) }() - g.waitAndLock() + g.waitAndLock(context.Background(), nil) } diff --git a/internal/quic/key_update_test.go b/internal/quic/key_update_test.go new file mode 100644 index 000000000..4a4d67771 --- /dev/null +++ b/internal/quic/key_update_test.go @@ -0,0 +1,234 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "testing" +) + +func TestKeyUpdatePeerUpdates(t *testing.T) { + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrames = nil // ignore nothing + + // Peer initiates a key update. + tc.sendKeyNumber = 1 + tc.sendKeyPhaseBit = true + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We update to the new key. + tc.advanceToTimer() + tc.wantFrameType("conn ACKs last packet", + packetType1RTT, debugFrameAck{}) + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want) + } + if !tc.lastPacket.keyPhaseBit { + t.Errorf("after key rotation, conn failed to change Key Phase bit") + } + tc.wantIdle("conn has nothing to send") + + // Peer's ACK of a packet we sent in the new phase completes the update. + tc.writeAckForAll() + + // Peer initiates a second key update. + tc.sendKeyNumber = 2 + tc.sendKeyPhaseBit = false + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We update to the new key. + tc.advanceToTimer() + tc.wantFrameType("conn ACKs last packet", + packetType1RTT, debugFrameAck{}) + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + if got, want := tc.lastPacket.keyNumber, 2; got != want { + t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want) + } + if tc.lastPacket.keyPhaseBit { + t.Errorf("after second key rotation, conn failed to change Key Phase bit") + } + tc.wantIdle("conn has nothing to send") +} + +func TestKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) { + // "An endpoint SHOULD retain old keys for some time after + // unprotecting a packet sent using the new keys." + // https://www.rfc-editor.org/rfc/rfc9001#section-6.1-8 + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrames = nil // ignore nothing + + // Peer initiates a key update, skipping one packet number. + pnum0 := tc.peerNextPacketNum[appDataSpace] + tc.peerNextPacketNum[appDataSpace]++ + tc.sendKeyNumber = 1 + tc.sendKeyPhaseBit = true + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We update to the new key. + // This ACK is not delayed, because we've skipped a packet number. + tc.wantFrame("conn ACKs last packet", + packetType1RTT, debugFrameAck{ + ranges: []i64range[packetNumber]{ + {0, pnum0}, + {pnum0 + 1, pnum0 + 2}, + }, + }) + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want) + } + if !tc.lastPacket.keyPhaseBit { + t.Errorf("after key rotation, conn failed to change Key Phase bit") + } + tc.wantIdle("conn has nothing to send") + + // We receive the previously-skipped packet in the earlier key phase. + tc.peerNextPacketNum[appDataSpace] = pnum0 + tc.sendKeyNumber = 0 + tc.sendKeyPhaseBit = false + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We ack the reordered packet immediately, still in the new key phase. + tc.wantFrame("conn ACKs reordered packet", + packetType1RTT, debugFrameAck{ + ranges: []i64range[packetNumber]{ + {0, pnum0 + 2}, + }, + }) + tc.wantIdle("packet is not ack-eliciting") + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want) + } + if !tc.lastPacket.keyPhaseBit { + t.Errorf("after key rotation, conn failed to change Key Phase bit") + } +} + +func TestKeyUpdateRejectPacketFromPriorPhase(t *testing.T) { + // "Packets with higher packet numbers MUST be protected with either + // the same or newer packet protection keys than packets with lower packet numbers." + // https://www.rfc-editor.org/rfc/rfc9001#section-6.4-2 + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrames = nil // ignore nothing + + // Peer initiates a key update. + tc.sendKeyNumber = 1 + tc.sendKeyPhaseBit = true + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We update to the new key. + tc.advanceToTimer() + tc.wantFrameType("conn ACKs last packet", + packetType1RTT, debugFrameAck{}) + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key rotation, conn sent packet with key %v, want %v", got, want) + } + if !tc.lastPacket.keyPhaseBit { + t.Errorf("after key rotation, conn failed to change Key Phase bit") + } + tc.wantIdle("conn has nothing to send") + + // Peer sends an ack-eliciting packet using the prior phase keys. + // We fail to unprotect the packet and ignore it. + skipped := tc.peerNextPacketNum[appDataSpace] + tc.sendKeyNumber = 0 + tc.sendKeyPhaseBit = false + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // Peer sends an ack-eliciting packet using the current phase keys. + tc.sendKeyNumber = 1 + tc.sendKeyPhaseBit = true + tc.writeFrames(packetType1RTT, debugFramePing{}) + + // We ack the peer's packets, not including the one sent with the wrong keys. + tc.wantFrame("conn ACKs packets, not including packet sent with wrong keys", + packetType1RTT, debugFrameAck{ + ranges: []i64range[packetNumber]{ + {0, skipped}, + {skipped + 1, skipped + 2}, + }, + }) +} + +func TestKeyUpdateLocallyInitiated(t *testing.T) { + const updateAfter = 4 // initiate key update after 1-RTT packet 4 + tc := newTestConn(t, serverSide) + tc.conn.keysAppData.updateAfter = updateAfter + tc.handshake() + + for { + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("conn ACKs last packet", + packetType1RTT, debugFrameAck{}) + if tc.lastPacket.num > updateAfter { + break + } + if got, want := tc.lastPacket.keyNumber, 0; got != want { + t.Errorf("before key update, conn sent packet with key %v, want %v", got, want) + } + if tc.lastPacket.keyPhaseBit { + t.Errorf("before key update, keyPhaseBit is set, want unset") + } + } + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key update, conn sent packet with key %v, want %v", got, want) + } + if !tc.lastPacket.keyPhaseBit { + t.Errorf("after key update, keyPhaseBit is unset, want set") + } + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + tc.wantIdle("no more frames") + + // Peer sends another packet using the prior phase keys. + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("conn ACKs packet in prior phase", + packetType1RTT, debugFrameAck{}) + tc.wantIdle("packet is not ack-eliciting") + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key update, conn sent packet with key %v, want %v", got, want) + } + + // Peer updates to the next phase. + tc.sendKeyNumber = 1 + tc.sendKeyPhaseBit = true + tc.writeAckForAll() + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("conn ACKs packet in current phase", + packetType1RTT, debugFrameAck{}) + tc.wantIdle("packet is not ack-eliciting") + if got, want := tc.lastPacket.keyNumber, 1; got != want { + t.Errorf("after key update, conn sent packet with key %v, want %v", got, want) + } + + // Peer initiates its own update. + tc.sendKeyNumber = 2 + tc.sendKeyPhaseBit = false + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("conn ACKs packet in current phase", + packetType1RTT, debugFrameAck{}) + tc.wantFrame("first packet after a key update is always ack-eliciting", + packetType1RTT, debugFramePing{}) + if got, want := tc.lastPacket.keyNumber, 2; got != want { + t.Errorf("after peer key update, conn sent packet with key %v, want %v", got, want) + } + if tc.lastPacket.keyPhaseBit { + t.Errorf("after peer key update, keyPhaseBit is unset, want set") + } +} diff --git a/internal/quic/listener.go b/internal/quic/listener.go new file mode 100644 index 000000000..96b1e4593 --- /dev/null +++ b/internal/quic/listener.go @@ -0,0 +1,322 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" +) + +// A Listener listens for QUIC traffic on a network address. +// It can accept inbound connections or create outbound ones. +// +// Multiple goroutines may invoke methods on a Listener simultaneously. +type Listener struct { + config *Config + udpConn udpConn + testHooks connTestHooks + + acceptQueue queue[*Conn] // new inbound connections + + connsMu sync.Mutex + conns map[*Conn]struct{} + closing bool // set when Close is called + closec chan struct{} // closed when the listen loop exits + + // The datagram receive loop keeps a mapping of connection IDs to conns. + // When a conn's connection IDs change, we add it to connIDUpdates and set + // connIDUpdateNeeded, and the receive loop updates its map. + connIDUpdateMu sync.Mutex + connIDUpdateNeeded atomic.Bool + connIDUpdates []connIDUpdate +} + +// A udpConn is a UDP connection. +// It is implemented by net.UDPConn. +type udpConn interface { + Close() error + LocalAddr() net.Addr + ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) + WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) +} + +type connIDUpdate struct { + conn *Conn + retired bool + cid []byte +} + +// Listen listens on a local network address. +// The configuration config must be non-nil. +func Listen(network, address string, config *Config) (*Listener, error) { + if config.TLSConfig == nil { + return nil, errors.New("TLSConfig is not set") + } + a, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP(network, a) + if err != nil { + return nil, err + } + return newListener(udpConn, config, nil), nil +} + +func newListener(udpConn udpConn, config *Config, hooks connTestHooks) *Listener { + l := &Listener{ + config: config, + udpConn: udpConn, + testHooks: hooks, + conns: make(map[*Conn]struct{}), + acceptQueue: newQueue[*Conn](), + closec: make(chan struct{}), + } + go l.listen() + return l +} + +// LocalAddr returns the local network address. +func (l *Listener) LocalAddr() netip.AddrPort { + a, _ := l.udpConn.LocalAddr().(*net.UDPAddr) + return a.AddrPort() +} + +// Close closes the listener. +// Any blocked operations on the Listener or associated Conns and Stream will be unblocked +// and return errors. +// +// Close aborts every open connection. +// Data in stream read and write buffers is discarded. +// It waits for the peers of any open connection to acknowledge the connection has been closed. +func (l *Listener) Close(ctx context.Context) error { + l.acceptQueue.close(errors.New("listener closed")) + l.connsMu.Lock() + if !l.closing { + l.closing = true + for c := range l.conns { + c.Abort(errors.New("listener closed")) + } + if len(l.conns) == 0 { + l.udpConn.Close() + } + } + l.connsMu.Unlock() + select { + case <-l.closec: + case <-ctx.Done(): + l.connsMu.Lock() + for c := range l.conns { + c.exit() + } + l.connsMu.Unlock() + return ctx.Err() + } + return nil +} + +// Accept waits for and returns the next connection to the listener. +func (l *Listener) Accept(ctx context.Context) (*Conn, error) { + return l.acceptQueue.get(ctx, nil) +} + +// Dial creates and returns a connection to a network address. +func (l *Listener) Dial(ctx context.Context, network, address string) (*Conn, error) { + u, err := net.ResolveUDPAddr(network, address) + if err != nil { + return nil, err + } + addr := u.AddrPort() + addr = netip.AddrPortFrom(addr.Addr().Unmap(), addr.Port()) + c, err := l.newConn(time.Now(), clientSide, nil, addr) + if err != nil { + return nil, err + } + if err := c.waitReady(ctx); err != nil { + c.Abort(nil) + return nil, err + } + return c, nil +} + +func (l *Listener) newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort) (*Conn, error) { + l.connsMu.Lock() + defer l.connsMu.Unlock() + if l.closing { + return nil, errors.New("listener closed") + } + c, err := newConn(now, side, initialConnID, peerAddr, l.config, l, l.testHooks) + if err != nil { + return nil, err + } + l.conns[c] = struct{}{} + return c, nil +} + +// serverConnEstablished is called by a conn when the handshake completes +// for an inbound (serverSide) connection. +func (l *Listener) serverConnEstablished(c *Conn) { + l.acceptQueue.put(c) +} + +// connDrained is called by a conn when it leaves the draining state, +// either when the peer acknowledges connection closure or the drain timeout expires. +func (l *Listener) connDrained(c *Conn) { + l.connsMu.Lock() + defer l.connsMu.Unlock() + delete(l.conns, c) + if l.closing && len(l.conns) == 0 { + l.udpConn.Close() + } +} + +// connIDsChanged is called by a conn when its connection IDs change. +func (l *Listener) connIDsChanged(c *Conn, retired bool, cids []connID) { + l.connIDUpdateMu.Lock() + defer l.connIDUpdateMu.Unlock() + for _, cid := range cids { + l.connIDUpdates = append(l.connIDUpdates, connIDUpdate{ + conn: c, + retired: retired, + cid: cid.cid, + }) + } + l.connIDUpdateNeeded.Store(true) +} + +// updateConnIDs is called by the datagram receive loop to update its connection ID map. +func (l *Listener) updateConnIDs(conns map[string]*Conn) { + l.connIDUpdateMu.Lock() + defer l.connIDUpdateMu.Unlock() + for i, u := range l.connIDUpdates { + if u.retired { + delete(conns, string(u.cid)) + } else { + conns[string(u.cid)] = u.conn + } + l.connIDUpdates[i] = connIDUpdate{} // drop refs + } + l.connIDUpdates = l.connIDUpdates[:0] + l.connIDUpdateNeeded.Store(false) +} + +func (l *Listener) listen() { + defer close(l.closec) + conns := map[string]*Conn{} + for { + m := newDatagram() + // TODO: Read and process the ECN (explicit congestion notification) field. + // https://tools.ietf.org/html/draft-ietf-quic-transport-32#section-13.4 + n, _, _, addr, err := l.udpConn.ReadMsgUDPAddrPort(m.b, nil) + if err != nil { + // The user has probably closed the listener. + // We currently don't surface errors from other causes; + // we could check to see if the listener has been closed and + // record the unexpected error if it has not. + return + } + if n == 0 { + continue + } + if l.connIDUpdateNeeded.Load() { + l.updateConnIDs(conns) + } + m.addr = addr + m.b = m.b[:n] + l.handleDatagram(m, conns) + } +} + +func (l *Listener) handleDatagram(m *datagram, conns map[string]*Conn) { + dstConnID, ok := dstConnIDForDatagram(m.b) + if !ok { + m.recycle() + return + } + c := conns[string(dstConnID)] + if c == nil { + // TODO: Move this branch into a separate goroutine to avoid blocking + // the listener while processing packets. + l.handleUnknownDestinationDatagram(m) + return + } + + // TODO: This can block the listener while waiting for the conn to accept the dgram. + // Think about buffering between the receive loop and the conn. + c.sendMsg(m) +} + +func (l *Listener) handleUnknownDestinationDatagram(m *datagram) { + defer func() { + if m != nil { + m.recycle() + } + }() + if len(m.b) < minimumClientInitialDatagramSize { + return + } + p, ok := parseGenericLongHeaderPacket(m.b) + if !ok { + // Not a long header packet, or not parseable. + // Short header (1-RTT) packets don't contain enough information + // to do anything useful with if we don't recognize the + // connection ID. + return + } + + switch p.version { + case quicVersion1: + case 0: + // Version Negotiation for an unknown connection. + return + default: + // Unknown version. + l.sendVersionNegotiation(p, m.addr) + return + } + if getPacketType(m.b) != packetTypeInitial { + // This packet isn't trying to create a new connection. + // It might be associated with some connection we've lost state for. + // TODO: Send a stateless reset when appropriate. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-10.3 + return + } + var now time.Time + if l.testHooks != nil { + now = l.testHooks.timeNow() + } else { + now = time.Now() + } + var err error + c, err := l.newConn(now, serverSide, p.dstConnID, m.addr) + if err != nil { + // The accept queue is probably full. + // We could send a CONNECTION_CLOSE to the peer to reject the connection. + // Currently, we just drop the datagram. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.2.2-5 + return + } + c.sendMsg(m) + m = nil // don't recycle, sendMsg takes ownership +} + +func (l *Listener) sendVersionNegotiation(p genericLongPacket, addr netip.AddrPort) { + m := newDatagram() + m.b = appendVersionNegotiation(m.b[:0], p.srcConnID, p.dstConnID, quicVersion1) + l.sendDatagram(m.b, addr) + m.recycle() +} + +func (l *Listener) sendDatagram(p []byte, addr netip.AddrPort) error { + _, err := l.udpConn.WriteToUDPAddrPort(p, addr) + return err +} diff --git a/internal/quic/listener_test.go b/internal/quic/listener_test.go new file mode 100644 index 000000000..9d0f314ec --- /dev/null +++ b/internal/quic/listener_test.go @@ -0,0 +1,163 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "context" + "io" + "net" + "net/netip" + "testing" +) + +func TestConnect(t *testing.T) { + newLocalConnPair(t, &Config{}, &Config{}) +} + +func TestStreamTransfer(t *testing.T) { + ctx := context.Background() + cli, srv := newLocalConnPair(t, &Config{}, &Config{}) + data := makeTestData(1 << 20) + + srvdone := make(chan struct{}) + go func() { + defer close(srvdone) + s, err := srv.AcceptStream(ctx) + if err != nil { + t.Errorf("AcceptStream: %v", err) + return + } + b, err := io.ReadAll(s) + if err != nil { + t.Errorf("io.ReadAll(s): %v", err) + return + } + if !bytes.Equal(b, data) { + t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data)) + } + if err := s.Close(); err != nil { + t.Errorf("s.Close() = %v", err) + } + }() + + s, err := cli.NewStream(ctx) + if err != nil { + t.Fatalf("NewStream: %v", err) + } + n, err := io.Copy(s, bytes.NewBuffer(data)) + if n != int64(len(data)) || err != nil { + t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data)) + } + if err := s.Close(); err != nil { + t.Fatalf("s.Close() = %v", err) + } +} + +func newLocalConnPair(t *testing.T, conf1, conf2 *Config) (clientConn, serverConn *Conn) { + t.Helper() + ctx := context.Background() + l1 := newLocalListener(t, serverSide, conf1) + l2 := newLocalListener(t, clientSide, conf2) + c2, err := l2.Dial(ctx, "udp", l1.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + c1, err := l1.Accept(ctx) + if err != nil { + t.Fatal(err) + } + return c2, c1 +} + +func newLocalListener(t *testing.T, side connSide, conf *Config) *Listener { + t.Helper() + if conf.TLSConfig == nil { + conf.TLSConfig = newTestTLSConfig(side) + } + l, err := Listen("udp", "127.0.0.1:0", conf) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + l.Close(context.Background()) + }) + return l +} + +type testListener struct { + t *testing.T + l *Listener + recvc chan *datagram + idlec chan struct{} + sentDatagrams [][]byte +} + +func newTestListener(t *testing.T, config *Config, testHooks connTestHooks) *testListener { + tl := &testListener{ + t: t, + recvc: make(chan *datagram), + idlec: make(chan struct{}), + } + tl.l = newListener((*testListenerUDPConn)(tl), config, testHooks) + t.Cleanup(tl.cleanup) + return tl +} + +func (tl *testListener) cleanup() { + tl.l.Close(canceledContext()) +} + +func (tl *testListener) wait() { + tl.idlec <- struct{}{} +} + +func (tl *testListener) write(d *datagram) { + tl.recvc <- d + tl.wait() +} + +func (tl *testListener) read() []byte { + tl.wait() + if len(tl.sentDatagrams) == 0 { + return nil + } + d := tl.sentDatagrams[0] + tl.sentDatagrams = tl.sentDatagrams[1:] + return d +} + +// testListenerUDPConn implements UDPConn. +type testListenerUDPConn testListener + +func (tl *testListenerUDPConn) Close() error { + close(tl.recvc) + return nil +} + +func (tl *testListenerUDPConn) LocalAddr() net.Addr { + return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:443")) +} + +func (tl *testListenerUDPConn) ReadMsgUDPAddrPort(b, control []byte) (n, controln, flags int, _ netip.AddrPort, _ error) { + for { + select { + case d, ok := <-tl.recvc: + if !ok { + return 0, 0, 0, netip.AddrPort{}, io.EOF + } + n = copy(b, d.b) + return n, 0, 0, d.addr, nil + case <-tl.idlec: + } + } +} + +func (tl *testListenerUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + tl.sentDatagrams = append(tl.sentDatagrams, append([]byte(nil), b...)) + return len(b), nil +} diff --git a/internal/quic/packet.go b/internal/quic/packet.go index 93a9102e8..7d69f96d2 100644 --- a/internal/quic/packet.go +++ b/internal/quic/packet.go @@ -6,6 +6,11 @@ package quic +import ( + "encoding/binary" + "fmt" +) + // packetType is a QUIC packet type. // https://www.rfc-editor.org/rfc/rfc9000.html#section-17 type packetType byte @@ -20,12 +25,30 @@ const ( packetTypeVersionNegotiation ) +func (p packetType) String() string { + switch p { + case packetTypeInitial: + return "Initial" + case packetType0RTT: + return "0-RTT" + case packetTypeHandshake: + return "Handshake" + case packetTypeRetry: + return "Retry" + case packetType1RTT: + return "1-RTT" + } + return fmt.Sprintf("unknown packet type %v", byte(p)) +} + // Bits set in the first byte of a packet. const ( - headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 - headerFormShort = 0x00 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.2.1 - fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1 - reservedBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + headerFormLong = 0x80 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.2.1 + headerFormShort = 0x00 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.3.1-4.2.1 + fixedBit = 0x40 // https://www.rfc-editor.org/rfc/rfc9000.html#section-17.2-3.4.1 + reservedLongBits = 0x0c // https://www.rfc-editor.org/rfc/rfc9000#section-17.2-8.2.1 + reserved1RTTBits = 0x18 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.8.1 + keyPhaseBit = 0x04 // https://www.rfc-editor.org/rfc/rfc9000#section-17.3.1-4.10.1 ) // Long Packet Type bits. @@ -137,15 +160,41 @@ func dstConnIDForDatagram(pkt []byte) (id []byte, ok bool) { return b[:n], true } +// parseVersionNegotiation parses a Version Negotiation packet. +// The returned versions is a slice of big-endian uint32s. +// It returns (nil, nil, nil) for an invalid packet. +func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte) { + p, ok := parseGenericLongHeaderPacket(pkt) + if !ok { + return nil, nil, nil + } + if len(p.data)%4 != 0 { + return nil, nil, nil + } + return p.dstConnID, p.srcConnID, p.data +} + +// appendVersionNegotiation appends a Version Negotiation packet to pkt, +// returning the result. +func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte { + pkt = append(pkt, headerFormLong|fixedBit) // header byte + pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation) + pkt = appendUint8Bytes(pkt, dstConnID) // Destination Connection ID + pkt = appendUint8Bytes(pkt, srcConnID) // Source Connection ID + for _, v := range versions { + pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version + } + return pkt +} + // A longPacket is a long header packet. type longPacket struct { - ptype packetType - reservedBits uint8 - version uint32 - num packetNumber - dstConnID []byte - srcConnID []byte - payload []byte + ptype packetType + version uint32 + num packetNumber + dstConnID []byte + srcConnID []byte + payload []byte // The extra data depends on the packet type: // Initial: Token. @@ -155,7 +204,45 @@ type longPacket struct { // A shortPacket is a short header (1-RTT) packet. type shortPacket struct { - reservedBits uint8 - num packetNumber - payload []byte + num packetNumber + payload []byte +} + +// A genericLongPacket is a long header packet of an arbitrary QUIC version. +// https://www.rfc-editor.org/rfc/rfc8999#section-5.1 +type genericLongPacket struct { + version uint32 + dstConnID []byte + srcConnID []byte + data []byte +} + +func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) { + if len(b) < 5 || !isLongHeader(b[0]) { + return genericLongPacket{}, false + } + b = b[1:] + // Version (32), + var n int + p.version, n = consumeUint32(b) + if n < 0 { + return genericLongPacket{}, false + } + b = b[n:] + // Destination Connection ID Length (8), + // Destination Connection ID (0..2048), + p.dstConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + // Source Connection ID Length (8), + // Source Connection ID (0..2048), + p.srcConnID, n = consumeUint8Bytes(b) + if n < 0 || len(p.dstConnID) > 2048/8 { + return genericLongPacket{}, false + } + b = b[n:] + p.data = b + return p, true } diff --git a/internal/quic/packet_codec_test.go b/internal/quic/packet_codec_test.go index 3503d2431..7b01bb00d 100644 --- a/internal/quic/packet_codec_test.go +++ b/internal/quic/packet_codec_test.go @@ -17,7 +17,7 @@ func TestParseLongHeaderPacket(t *testing.T) { // Example Initial packet from: // https://www.rfc-editor.org/rfc/rfc9001.html#section-a.3 cid := unhex(`8394c8f03e515708`) - _, initialServerKeys := initialKeys(cid) + initialServerKeys := initialKeys(cid, clientSide).r pkt := unhex(` cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 @@ -65,20 +65,21 @@ func TestParseLongHeaderPacket(t *testing.T) { } // Parse with the wrong keys. - _, invalidKeys := initialKeys([]byte{}) + invalidKeys := initialKeys([]byte{}, clientSide).w if _, n := parseLongHeaderPacket(pkt, invalidKeys, 0); n != -1 { t.Fatalf("parse long header packet with wrong keys: n=%v, want -1", n) } } func TestRoundtripEncodeLongPacket(t *testing.T) { - aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) - aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) - chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) + var aes128Keys, aes256Keys, chachaKeys fixedKeys + aes128Keys.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) + aes256Keys.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) + chachaKeys.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) for _, test := range []struct { desc string p longPacket - k keys + k fixedKeys }{{ desc: "Initial, 1-byte number, AES128", p: longPacket{ @@ -145,9 +146,16 @@ func TestRoundtripEncodeLongPacket(t *testing.T) { } func TestRoundtripEncodeShortPacket(t *testing.T) { - aes128Keys, _ := newKeys(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) - aes256Keys, _ := newKeys(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) - chachaKeys, _ := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) + var aes128Keys, aes256Keys, chachaKeys updatingKeyPair + aes128Keys.r.init(tls.TLS_AES_128_GCM_SHA256, []byte("secret")) + aes256Keys.r.init(tls.TLS_AES_256_GCM_SHA384, []byte("secret")) + chachaKeys.r.init(tls.TLS_CHACHA20_POLY1305_SHA256, []byte("secret")) + aes128Keys.w = aes128Keys.r + aes256Keys.w = aes256Keys.r + chachaKeys.w = chachaKeys.r + aes128Keys.updateAfter = maxPacketNumber + aes256Keys.updateAfter = maxPacketNumber + chachaKeys.updateAfter = maxPacketNumber connID := make([]byte, connIDLen) for i := range connID { connID[i] = byte(i) @@ -156,7 +164,7 @@ func TestRoundtripEncodeShortPacket(t *testing.T) { desc string num packetNumber payload []byte - k keys + k updatingKeyPair }{{ desc: "1-byte number, AES128", num: 0, // 1-byte encoding, @@ -183,11 +191,11 @@ func TestRoundtripEncodeShortPacket(t *testing.T) { w.reset(1200) w.start1RTTPacket(test.num, 0, connID) w.b = append(w.b, test.payload...) - w.finish1RTTPacket(test.num, 0, connID, test.k) + w.finish1RTTPacket(test.num, 0, connID, &test.k) pkt := w.datagram() - p, n := parse1RTTPacket(pkt, test.k, connIDLen, 0) - if n != len(pkt) { - t.Errorf("parse1RTTPacket: n=%v, want %v", n, len(pkt)) + p, err := parse1RTTPacket(pkt, &test.k, connIDLen, 0) + if err != nil { + t.Errorf("parse1RTTPacket: err=%v, want nil", err) } if p.num != test.num || !bytes.Equal(p.payload, test.payload) { t.Errorf("Round-trip encode/decode did not preserve packet.\nsent: num=%v, payload={%x}\ngot: num=%v, payload={%x}", test.num, test.payload, p.num, p.payload) @@ -700,7 +708,7 @@ func TestFrameDecodeErrors(t *testing.T) { func FuzzParseLongHeaderPacket(f *testing.F) { cid := unhex(`0000000000000000`) - _, initialServerKeys := initialKeys(cid) + initialServerKeys := initialKeys(cid, clientSide).r f.Fuzz(func(t *testing.T, in []byte) { parseLongHeaderPacket(in, initialServerKeys, 0) }) diff --git a/internal/quic/packet_parser.go b/internal/quic/packet_parser.go index ca5b37b2b..ce0433902 100644 --- a/internal/quic/packet_parser.go +++ b/internal/quic/packet_parser.go @@ -18,7 +18,7 @@ package quic // and its length in bytes. // // It returns an empty packet and -1 if the packet could not be parsed. -func parseLongHeaderPacket(pkt []byte, k keys, pnumMax packetNumber) (p longPacket, n int) { +func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p longPacket, n int) { if len(pkt) < 5 || !isLongHeader(pkt[0]) { return longPacket{}, -1 } @@ -97,9 +97,6 @@ func parseLongHeaderPacket(pkt []byte, k keys, pnumMax packetNumber) (p longPack if err != nil { return longPacket{}, -1 } - // Reserved bits should always be zero, but this is handled - // as a protocol-level violation by the caller rather than a parse error. - p.reservedBits = pkt[0] & reservedBits } return p, len(pkt) } @@ -146,16 +143,14 @@ func skipLongHeaderPacket(pkt []byte) int { // // On input, pkt contains a short header packet, k the decryption keys for the packet, // and pnumMax the largest packet number seen in the number space of this packet. -func parse1RTTPacket(pkt []byte, k keys, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, n int) { - var err error - p.payload, p.num, err = k.unprotect(pkt, 1+dstConnIDLen, pnumMax) +func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax packetNumber) (p shortPacket, err error) { + pay, pnum, err := k.unprotect(pkt, 1+dstConnIDLen, pnumMax) if err != nil { - return shortPacket{}, -1 + return shortPacket{}, err } - // Reserved bits should always be zero, but this is handled - // as a protocol-level violation by the caller rather than a parse error. - p.reservedBits = pkt[0] & reservedBits - return p, len(pkt) + p.num = pnum + p.payload = pay + return p, nil } // Consume functions return n=-1 on conditions which result in FRAME_ENCODING_ERROR, diff --git a/internal/quic/packet_protection.go b/internal/quic/packet_protection.go index 18470536f..7b141ac49 100644 --- a/internal/quic/packet_protection.go +++ b/internal/quic/packet_protection.go @@ -13,7 +13,6 @@ import ( "crypto/sha256" "crypto/tls" "errors" - "fmt" "hash" "golang.org/x/crypto/chacha20" @@ -24,135 +23,183 @@ import ( var errInvalidPacket = errors.New("quic: invalid packet") -// keys holds the cryptographic material used to protect packets -// at an encryption level and direction. (e.g., Initial client keys.) -// -// keys are not safe for concurrent use. -type keys struct { - // AEAD function used for packet protection. - aead cipher.AEAD +// headerProtectionSampleSize is the size of the ciphertext sample used for header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.2 +const headerProtectionSampleSize = 16 - // The header_protection function as defined in: - // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1 - // - // This function takes a sample of the packet ciphertext - // and returns a 5-byte mask which will be applied to the - // protected portions of the packet header. - headerProtection func(sample []byte) (mask [5]byte) +// aeadOverhead is the difference in size between the AEAD output and input. +// All cipher suites defined for use with QUIC have 16 bytes of overhead. +const aeadOverhead = 16 - // IV used to construct the AEAD nonce. - iv []byte +// A headerKey applies or removes header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4 +type headerKey struct { + hp headerProtection } -// newKeys creates keys for a given cipher suite and secret. -// -// It returns an error if the suite is unknown. -func newKeys(suite uint16, secret []byte) (keys, error) { +func (k headerKey) isSet() bool { + return k.hp != nil +} + +func (k *headerKey) init(suite uint16, secret []byte) { + h, keySize := hashForSuite(suite) + hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keySize) switch suite { - case tls.TLS_AES_128_GCM_SHA256: - return newAESKeys(secret, crypto.SHA256, 128/8), nil - case tls.TLS_AES_256_GCM_SHA384: - return newAESKeys(secret, crypto.SHA384, 256/8), nil + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + c, err := aes.NewCipher(hpKey) + if err != nil { + panic(err) + } + k.hp = &aesHeaderProtection{cipher: c} case tls.TLS_CHACHA20_POLY1305_SHA256: - return newChaCha20Keys(secret), nil + k.hp = chaCha20HeaderProtection{hpKey} + default: + panic("BUG: unknown cipher suite") } - return keys{}, fmt.Errorf("unknown cipher suite %x", suite) } -func newAESKeys(secret []byte, h crypto.Hash, keyBytes int) keys { - // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 - key := hkdfExpandLabel(h.New, secret, "quic key", nil, keyBytes) - c, err := aes.NewCipher(key) - if err != nil { - panic(err) +// protect applies header protection. +// pnumOff is the offset of the packet number in the packet. +func (k headerKey) protect(hdr []byte, pnumOff int) { + // Apply header protection. + pnumSize := int(hdr[0]&0x03) + 1 + sample := hdr[pnumOff+4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(hdr[0]) { + hdr[0] ^= mask[0] & 0x0f + } else { + hdr[0] ^= mask[0] & 0x1f } - aead, err := cipher.NewGCM(c) - if err != nil { - panic(err) + for i := 0; i < pnumSize; i++ { + hdr[pnumOff+i] ^= mask[1+i] } - iv := hkdfExpandLabel(h.New, secret, "quic iv", nil, aead.NonceSize()) - // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3 - hpKey := hkdfExpandLabel(h.New, secret, "quic hp", nil, keyBytes) - hp, err := aes.NewCipher(hpKey) - if err != nil { - panic(err) +} + +// unprotect removes header protection. +// pnumOff is the offset of the packet number in the packet. +// pnumMax is the largest packet number seen in the number space of this packet. +func (k headerKey) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (hdr, pay []byte, pnum packetNumber, _ error) { + if len(pkt) < pnumOff+4+headerProtectionSampleSize { + return nil, nil, 0, errInvalidPacket } - var scratch [aes.BlockSize]byte - headerProtection := func(sample []byte) (mask [5]byte) { - hp.Encrypt(scratch[:], sample) - copy(mask[:], scratch[:]) - return mask + numpay := pkt[pnumOff:] + sample := numpay[4:][:headerProtectionSampleSize] + mask := k.hp.headerProtection(sample) + if isLongHeader(pkt[0]) { + pkt[0] ^= mask[0] & 0x0f + } else { + pkt[0] ^= mask[0] & 0x1f } - return keys{ - aead: aead, - iv: iv, - headerProtection: headerProtection, + pnumLen := int(pkt[0]&0x03) + 1 + pnum = packetNumber(0) + for i := 0; i < pnumLen; i++ { + numpay[i] ^= mask[1+i] + pnum = (pnum << 8) | packetNumber(numpay[i]) } + pnum = decodePacketNumber(pnumMax, pnum, pnumLen) + hdr = pkt[:pnumOff+pnumLen] + pay = numpay[pnumLen:] + return hdr, pay, pnum, nil } -func newChaCha20Keys(secret []byte) keys { - // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 - key := hkdfExpandLabel(sha256.New, secret, "quic key", nil, chacha20poly1305.KeySize) - aead, err := chacha20poly1305.New(key) +// headerProtection is the header_protection function as defined in: +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.1 +// +// This function takes a sample of the packet ciphertext +// and returns a 5-byte mask which will be applied to the +// protected portions of the packet header. +type headerProtection interface { + headerProtection(sample []byte) (mask [5]byte) +} + +// AES-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.3 +type aesHeaderProtection struct { + cipher cipher.Block + scratch [aes.BlockSize]byte +} + +func (hp *aesHeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + hp.cipher.Encrypt(hp.scratch[:], sample) + copy(mask[:], hp.scratch[:]) + return mask +} + +// ChaCha20-based header protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4 +type chaCha20HeaderProtection struct { + key []byte +} + +func (hp chaCha20HeaderProtection) headerProtection(sample []byte) (mask [5]byte) { + counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0]) + nonce := sample[4:16] + c, err := chacha20.NewUnauthenticatedCipher(hp.key, nonce) if err != nil { panic(err) } - iv := hkdfExpandLabel(sha256.New, secret, "quic iv", nil, aead.NonceSize()) - // https://www.rfc-editor.org/rfc/rfc9001#section-5.4.4 - hpKey := hkdfExpandLabel(sha256.New, secret, "quic hp", nil, chacha20.KeySize) - headerProtection := func(sample []byte) [5]byte { - counter := uint32(sample[3])<<24 | uint32(sample[2])<<16 | uint32(sample[1])<<8 | uint32(sample[0]) - nonce := sample[4:16] - c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce) - if err != nil { - panic(err) - } - c.SetCounter(counter) - var mask [5]byte - c.XORKeyStream(mask[:], mask[:]) - return mask - } - return keys{ - aead: aead, - iv: iv, - headerProtection: headerProtection, - } + c.SetCounter(counter) + c.XORKeyStream(mask[:], mask[:]) + return mask } -// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2 -var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} +// A packetKey applies or removes packet protection. +// https://www.rfc-editor.org/rfc/rfc9001#section-5.1 +type packetKey struct { + aead cipher.AEAD // AEAD function used for packet protection. + iv []byte // IV used to construct the AEAD nonce. +} -// initialKeys returns the keys used to protect Initial packets. -// -// The Initial packet keys are derived from the Destination Connection ID -// field in the client's first Initial packet. -// -// https://www.rfc-editor.org/rfc/rfc9001#section-5.2 -func initialKeys(cid []byte) (clientKeys, serverKeys keys) { - initialSecret := hkdf.Extract(sha256.New, cid, initialSalt) - clientInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size) - clientKeys, err := newKeys(tls.TLS_AES_128_GCM_SHA256, clientInitialSecret) +func (k *packetKey) init(suite uint16, secret []byte) { + // https://www.rfc-editor.org/rfc/rfc9001#section-5.1 + h, keySize := hashForSuite(suite) + key := hkdfExpandLabel(h.New, secret, "quic key", nil, keySize) + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + k.aead = newAESAEAD(key) + case tls.TLS_CHACHA20_POLY1305_SHA256: + k.aead = newChaCha20AEAD(key) + default: + panic("BUG: unknown cipher suite") + } + k.iv = hkdfExpandLabel(h.New, secret, "quic iv", nil, k.aead.NonceSize()) +} + +func newAESAEAD(key []byte) cipher.AEAD { + c, err := aes.NewCipher(key) if err != nil { panic(err) } - - serverInitialSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size) - serverKeys, err = newKeys(tls.TLS_AES_128_GCM_SHA256, serverInitialSecret) + aead, err := cipher.NewGCM(c) if err != nil { panic(err) } + return aead +} - return clientKeys, serverKeys +func newChaCha20AEAD(key []byte) cipher.AEAD { + var err error + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + return aead } -const headerProtectionSampleSize = 16 +func (k packetKey) protect(hdr, pay []byte, pnum packetNumber) []byte { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Seal(hdr, k.iv, pay, hdr) +} -// aeadOverhead is the difference in size between the AEAD output and input. -// All cipher suites defined for use with QUIC have 16 bytes of overhead. -const aeadOverhead = 16 +func (k packetKey) unprotect(hdr, pay []byte, pnum packetNumber) (dec []byte, err error) { + k.xorIV(pnum) + defer k.xorIV(pnum) + return k.aead.Open(pay[:0], k.iv, pay, hdr) +} // xorIV xors the packet protection IV with the packet number. -func (k keys) xorIV(pnum packetNumber) { +func (k packetKey) xorIV(pnum packetNumber) { k.iv[len(k.iv)-8] ^= uint8(pnum >> 56) k.iv[len(k.iv)-7] ^= uint8(pnum >> 48) k.iv[len(k.iv)-6] ^= uint8(pnum >> 40) @@ -163,17 +210,22 @@ func (k keys) xorIV(pnum packetNumber) { k.iv[len(k.iv)-1] ^= uint8(pnum) } -// isSet returns true if valid keys are available. -func (k keys) isSet() bool { - return k.aead != nil +// A fixedKeys is a header protection key and fixed packet protection key. +// The packet protection key is fixed (it does not update). +// +// Fixed keys are used for Initial and Handshake keys, which do not update. +type fixedKeys struct { + hdr headerKey + pkt packetKey } -// discard discards the keys (in the sense that we won't use them any more, -// not that the keys are securely erased). -// -// https://www.rfc-editor.org/rfc/rfc9001.html#section-4.9 -func (k *keys) discard() { - *k = keys{} +func (k *fixedKeys) init(suite uint16, secret []byte) { + k.hdr.init(suite, secret) + k.pkt.init(suite, secret) +} + +func (k fixedKeys) isSet() bool { + return k.hdr.hp != nil } // protect applies packet protection to a packet. @@ -184,25 +236,10 @@ func (k *keys) discard() { // // protect returns the result of appending the encrypted payload to hdr and // applying header protection. -func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { - k.xorIV(pnum) - hdr = k.aead.Seal(hdr, k.iv, pay, hdr) - k.xorIV(pnum) - - // Apply header protection. - pnumSize := int(hdr[0]&0x03) + 1 - sample := hdr[pnumOff+4:][:headerProtectionSampleSize] - mask := k.headerProtection(sample) - if isLongHeader(hdr[0]) { - hdr[0] ^= mask[0] & 0x0f - } else { - hdr[0] ^= mask[0] & 0x1f - } - for i := 0; i < pnumSize; i++ { - hdr[pnumOff+i] ^= mask[1+i] - } - - return hdr +func (k fixedKeys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + pkt := k.pkt.protect(hdr, pay, pnum) + k.hdr.protect(pkt, pnumOff) + return pkt } // unprotect removes packet protection from a packet. @@ -213,38 +250,269 @@ func (k keys) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { // // unprotect removes header protection from the header in pkt, and returns // the unprotected payload and packet number. -func (k keys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) { - if len(pkt) < pnumOff+4+headerProtectionSampleSize { - return nil, 0, errInvalidPacket +func (k fixedKeys) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, num packetNumber, err error) { + hdr, pay, pnum, err := k.hdr.unprotect(pkt, pnumOff, pnumMax) + if err != nil { + return nil, 0, err } - numpay := pkt[pnumOff:] - sample := numpay[4:][:headerProtectionSampleSize] - mask := k.headerProtection(sample) - if isLongHeader(pkt[0]) { - pkt[0] ^= mask[0] & 0x0f - } else { - pkt[0] ^= mask[0] & 0x1f + pay, err = k.pkt.unprotect(hdr, pay, pnum) + if err != nil { + return nil, 0, err } - pnumLen := int(pkt[0]&0x03) + 1 - pnum := packetNumber(0) - for i := 0; i < pnumLen; i++ { - numpay[i] ^= mask[1+i] - pnum = (pnum << 8) | packetNumber(numpay[i]) + return pay, pnum, nil +} + +// A fixedKeyPair is a read/write pair of fixed keys. +type fixedKeyPair struct { + r, w fixedKeys +} + +func (k *fixedKeyPair) discard() { + *k = fixedKeyPair{} +} + +func (k *fixedKeyPair) canRead() bool { + return k.r.isSet() +} + +func (k *fixedKeyPair) canWrite() bool { + return k.w.isSet() +} + +// An updatingKeys is a header protection key and updatable packet protection key. +// updatingKeys are used for 1-RTT keys, where the packet protection key changes +// over the lifetime of a connection. +// https://www.rfc-editor.org/rfc/rfc9001#section-6 +type updatingKeys struct { + suite uint16 + hdr headerKey + pkt [2]packetKey // current, next + nextSecret []byte // secret used to generate pkt[1] +} + +func (k *updatingKeys) init(suite uint16, secret []byte) { + k.suite = suite + k.hdr.init(suite, secret) + // Initialize pkt[1] with secret_0, and then call update to generate secret_1. + k.pkt[1].init(suite, secret) + k.nextSecret = secret + k.update() +} + +// update performs a key update. +// The current key in pkt[0] is discarded. +// The next key in pkt[1] becomes the current key. +// A new next key is generated in pkt[1]. +func (k *updatingKeys) update() { + k.nextSecret = updateSecret(k.suite, k.nextSecret) + k.pkt[0] = k.pkt[1] + k.pkt[1].init(k.suite, k.nextSecret) +} + +func updateSecret(suite uint16, secret []byte) (nextSecret []byte) { + h, _ := hashForSuite(suite) + return hkdfExpandLabel(h.New, secret, "quic ku", nil, len(secret)) +} + +// An updatingKeyPair is a read/write pair of updating keys. +// +// We keep two keys (current and next) in both read and write directions. +// When an incoming packet's phase matches the current phase bit, +// we unprotect it using the current keys; otherwise we use the next keys. +// +// When updating=false, outgoing packets are protected using the current phase. +// +// An update is initiated and updating is set to true when: +// - we decide to initiate a key update; or +// - we successfully unprotect a packet using the next keys, +// indicating the peer has initiated a key update. +// +// When updating=true, outgoing packets are protected using the next phase. +// We do not change the current phase bit or generate new keys yet. +// +// The update concludes when we receive an ACK frame for a packet sent +// with the next keys. At this time, we set updating to false, flip the +// phase bit, and update the keys. This permits us to handle up to 1-RTT +// of reordered packets before discarding the previous phase's keys after +// an update. +type updatingKeyPair struct { + phase uint8 // current key phase (r.pkt[0], w.pkt[0]) + updating bool + authFailures int64 // total packet unprotect failures + minSent packetNumber // min packet number sent since entering the updating state + minReceived packetNumber // min packet number received in the next phase + updateAfter packetNumber // packet number after which to initiate key update + r, w updatingKeys +} + +func (k *updatingKeyPair) init() { + // 1-RTT packets until the first key update. + // + // We perform the first key update early in the connection so a peer + // which does not support key updates will fail rapidly, + // rather than after the connection has been long established. + k.updateAfter = 1000 +} + +func (k *updatingKeyPair) canRead() bool { + return k.r.hdr.hp != nil +} + +func (k *updatingKeyPair) canWrite() bool { + return k.w.hdr.hp != nil +} + +// handleAckFor finishes a key update after receiving an ACK for a packet in the next phase. +func (k *updatingKeyPair) handleAckFor(pnum packetNumber) { + if k.updating && pnum >= k.minSent { + k.updating = false + k.phase ^= keyPhaseBit + k.r.update() + k.w.update() } - pnum = decodePacketNumber(pnumMax, pnum, pnumLen) +} - hdr := pkt[:pnumOff+pnumLen] - pay = numpay[pnumLen:] - k.xorIV(pnum) - pay, err = k.aead.Open(pay[:0], k.iv, pay, hdr) - k.xorIV(pnum) +// needAckEliciting reports whether we should send an ack-eliciting packet in the next phase. +// The first packet sent in a phase is ack-eliciting, since the peer must acknowledge a +// packet in the new phase for us to finish the update. +func (k *updatingKeyPair) needAckEliciting() bool { + return k.updating && k.minSent == maxPacketNumber +} + +// protect applies packet protection to a packet. +// Parameters and returns are as for fixedKeyPair.protect. +func (k *updatingKeyPair) protect(hdr, pay []byte, pnumOff int, pnum packetNumber) []byte { + var pkt []byte + if k.updating { + hdr[0] |= k.phase ^ keyPhaseBit + pkt = k.w.pkt[1].protect(hdr, pay, pnum) + k.minSent = min(pnum, k.minSent) + } else { + hdr[0] |= k.phase + pkt = k.w.pkt[0].protect(hdr, pay, pnum) + if pnum >= k.updateAfter { + // Initiate a key update, starting with the next packet we send. + // + // We do this after protecting the current packet + // to allow Conn.appendFrames to ensure that the first packet sent + // in the new phase is ack-eliciting. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = maxPacketNumber + // The lowest confidentiality limit for a supported AEAD is 2^23 packets. + // https://www.rfc-editor.org/rfc/rfc9001#section-6.6-5 + // + // Schedule our next update for half that. + k.updateAfter += (1 << 22) + } + } + k.w.hdr.protect(pkt, pnumOff) + return pkt +} + +// unprotect removes packet protection from a packet. +// Parameters and returns are as for fixedKeyPair.unprotect. +func (k *updatingKeyPair) unprotect(pkt []byte, pnumOff int, pnumMax packetNumber) (pay []byte, pnum packetNumber, err error) { + hdr, pay, pnum, err := k.r.hdr.unprotect(pkt, pnumOff, pnumMax) if err != nil { return nil, 0, err } - + // To avoid timing signals that might indicate the key phase bit is invalid, + // we always attempt to unprotect the packet with one key. + // + // If the key phase bit matches and the packet number doesn't come after + // the start of an in-progress update, use the current phase. + // Otherwise, use the next phase. + if hdr[0]&keyPhaseBit == k.phase && (!k.updating || pnum < k.minReceived) { + pay, err = k.r.pkt[0].unprotect(hdr, pay, pnum) + } else { + pay, err = k.r.pkt[1].unprotect(hdr, pay, pnum) + if err == nil { + if !k.updating { + // The peer has initiated a key update. + k.updating = true + k.minSent = maxPacketNumber + k.minReceived = pnum + } else { + k.minReceived = min(pnum, k.minReceived) + } + } + } + if err != nil { + k.authFailures++ + if k.authFailures >= aeadIntegrityLimit(k.r.suite) { + return nil, 0, localTransportError(errAEADLimitReached) + } + return nil, 0, err + } return pay, pnum, nil } +// aeadIntegrityLimit returns the integrity limit for an AEAD: +// The maximum number of received packets that may fail authentication +// before closing the connection. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-6.6-4 +func aeadIntegrityLimit(suite uint16) int64 { + switch suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + return 1 << 52 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return 1 << 36 + default: + panic("BUG: unknown cipher suite") + } +} + +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2-2 +var initialSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + +// initialKeys returns the keys used to protect Initial packets. +// +// The Initial packet keys are derived from the Destination Connection ID +// field in the client's first Initial packet. +// +// https://www.rfc-editor.org/rfc/rfc9001#section-5.2 +func initialKeys(cid []byte, side connSide) fixedKeyPair { + initialSecret := hkdf.Extract(sha256.New, cid, initialSalt) + var clientKeys fixedKeys + clientSecret := hkdfExpandLabel(sha256.New, initialSecret, "client in", nil, sha256.Size) + clientKeys.init(tls.TLS_AES_128_GCM_SHA256, clientSecret) + var serverKeys fixedKeys + serverSecret := hkdfExpandLabel(sha256.New, initialSecret, "server in", nil, sha256.Size) + serverKeys.init(tls.TLS_AES_128_GCM_SHA256, serverSecret) + if side == clientSide { + return fixedKeyPair{r: serverKeys, w: clientKeys} + } else { + return fixedKeyPair{w: serverKeys, r: clientKeys} + } +} + +// checkCipherSuite returns an error if suite is not a supported cipher suite. +func checkCipherSuite(suite uint16) error { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + case tls.TLS_AES_256_GCM_SHA384: + case tls.TLS_CHACHA20_POLY1305_SHA256: + default: + return errors.New("invalid cipher suite") + } + return nil +} + +func hashForSuite(suite uint16) (h crypto.Hash, keySize int) { + switch suite { + case tls.TLS_AES_128_GCM_SHA256: + return crypto.SHA256, 128 / 8 + case tls.TLS_AES_256_GCM_SHA384: + return crypto.SHA384, 256 / 8 + case tls.TLS_CHACHA20_POLY1305_SHA256: + return crypto.SHA256, chacha20.KeySize + default: + panic("BUG: unknown cipher suite") + } +} + // hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. // // Copied from crypto/tls/key_schedule.go. diff --git a/internal/quic/packet_protection_test.go b/internal/quic/packet_protection_test.go index 6495360a3..1fe130731 100644 --- a/internal/quic/packet_protection_test.go +++ b/internal/quic/packet_protection_test.go @@ -16,10 +16,11 @@ func TestPacketProtection(t *testing.T) { // Test cases from: // https://www.rfc-editor.org/rfc/rfc9001#section-appendix.a cid := unhex(`8394c8f03e515708`) - initialClientKeys, initialServerKeys := initialKeys(cid) + k := initialKeys(cid, clientSide) + initialClientKeys, initialServerKeys := k.w, k.r for _, test := range []struct { name string - k keys + k fixedKeys pnum packetNumber hdr []byte pay []byte @@ -103,15 +104,13 @@ func TestPacketProtection(t *testing.T) { `), }, { name: "ChaCha20_Poly1305 Short Header", - k: func() keys { + k: func() fixedKeys { secret := unhex(` 9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b `) - k, err := newKeys(tls.TLS_CHACHA20_POLY1305_SHA256, secret) - if err != nil { - t.Fatal(err) - } + var k fixedKeys + k.init(tls.TLS_CHACHA20_POLY1305_SHA256, secret) return k }(), pnum: 654360564, diff --git a/internal/quic/packet_test.go b/internal/quic/packet_test.go index f3a8b7d57..58c584e16 100644 --- a/internal/quic/packet_test.go +++ b/internal/quic/packet_test.go @@ -8,28 +8,13 @@ package quic import ( "bytes" + "encoding/binary" "encoding/hex" - "fmt" + "reflect" "strings" "testing" ) -func (p packetType) String() string { - switch p { - case packetTypeInitial: - return "Initial" - case packetType0RTT: - return "0-RTT" - case packetTypeHandshake: - return "Handshake" - case packetTypeRetry: - return "Retry" - case packetType1RTT: - return "1-RTT" - } - return fmt.Sprintf("unknown packet type %v", byte(p)) -} - func TestPacketHeader(t *testing.T) { for _, test := range []struct { name string @@ -129,6 +114,124 @@ func TestPacketHeader(t *testing.T) { } } +func TestEncodeDecodeVersionNegotiation(t *testing.T) { + dstConnID := []byte("this is a very long destination connection id") + srcConnID := []byte("this is a very long source connection id") + versions := []uint32{1, 0xffffffff} + got := appendVersionNegotiation([]byte{}, dstConnID, srcConnID, versions...) + want := bytes.Join([][]byte{{ + 0b1100_0000, // header byte + 0, 0, 0, 0, // Version + byte(len(dstConnID)), + }, dstConnID, { + byte(len(srcConnID)), + }, srcConnID, { + 0x00, 0x00, 0x00, 0x01, + 0xff, 0xff, 0xff, 0xff, + }}, nil) + if !bytes.Equal(got, want) { + t.Fatalf("appendVersionNegotiation(nil, %x, %x, %v):\ngot %x\nwant %x", + dstConnID, srcConnID, versions, got, want) + } + gotDst, gotSrc, gotVersionBytes := parseVersionNegotiation(got) + if got, want := gotDst, dstConnID; !bytes.Equal(got, want) { + t.Errorf("parseVersionNegotiation: got dstConnID = %x, want %x", got, want) + } + if got, want := gotSrc, srcConnID; !bytes.Equal(got, want) { + t.Errorf("parseVersionNegotiation: got srcConnID = %x, want %x", got, want) + } + var gotVersions []uint32 + for len(gotVersionBytes) >= 4 { + gotVersions = append(gotVersions, binary.BigEndian.Uint32(gotVersionBytes)) + gotVersionBytes = gotVersionBytes[4:] + } + if got, want := gotVersions, versions; !reflect.DeepEqual(got, want) { + t.Errorf("parseVersionNegotiation: got versions = %v, want %v", got, want) + } +} + +func TestParseGenericLongHeaderPacket(t *testing.T) { + for _, test := range []struct { + name string + packet []byte + version uint32 + dstConnID []byte + srcConnID []byte + data []byte + }{{ + name: "long header packet", + packet: unhex(` + 80 01020304 04a1a2a3a4 05b1b2b3b4b5 c1 + `), + version: 0x01020304, + dstConnID: unhex(`a1a2a3a4`), + srcConnID: unhex(`b1b2b3b4b5`), + data: unhex(`c1`), + }, { + name: "zero everything", + packet: unhex(` + 80 00000000 00 00 + `), + version: 0, + dstConnID: []byte{}, + srcConnID: []byte{}, + data: []byte{}, + }} { + t.Run(test.name, func(t *testing.T) { + p, ok := parseGenericLongHeaderPacket(test.packet) + if !ok { + t.Fatalf("parseGenericLongHeaderPacket() = _, false; want true") + } + if got, want := p.version, test.version; got != want { + t.Errorf("version = %v, want %v", got, want) + } + if got, want := p.dstConnID, test.dstConnID; !bytes.Equal(got, want) { + t.Errorf("Destination Connection ID = {%x}, want {%x}", got, want) + } + if got, want := p.srcConnID, test.srcConnID; !bytes.Equal(got, want) { + t.Errorf("Source Connection ID = {%x}, want {%x}", got, want) + } + if got, want := p.data, test.data; !bytes.Equal(got, want) { + t.Errorf("Data = {%x}, want {%x}", got, want) + } + }) + } +} + +func TestParseGenericLongHeaderPacketErrors(t *testing.T) { + for _, test := range []struct { + name string + packet []byte + }{{ + name: "short header packet", + packet: unhex(` + 00 01020304 04a1a2a3a4 05b1b2b3b4b5 c1 + `), + }, { + name: "packet too short", + packet: unhex(` + 80 000000 + `), + }, { + name: "destination id too long", + packet: unhex(` + 80 00000000 02 00 + `), + }, { + name: "source id too long", + packet: unhex(` + 80 00000000 00 01 + `), + }} { + t.Run(test.name, func(t *testing.T) { + _, ok := parseGenericLongHeaderPacket(test.packet) + if ok { + t.Fatalf("parseGenericLongHeaderPacket() = _, true; want false") + } + }) + } +} + func unhex(s string) []byte { b, err := hex.DecodeString(strings.Map(func(c rune) rune { switch c { diff --git a/internal/quic/packet_writer.go b/internal/quic/packet_writer.go index a80b4711e..0c2b2ee41 100644 --- a/internal/quic/packet_writer.go +++ b/internal/quic/packet_writer.go @@ -100,7 +100,7 @@ func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber, // finishProtectedLongHeaderPacket finishes writing an Initial, 0-RTT, or Handshake packet, // canceling the packet if it contains no payload. // It returns a sentPacket describing the packet, or nil if no packet was written. -func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k keys, p longPacket) *sentPacket { +func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber, k fixedKeys, p longPacket) *sentPacket { if len(w.b) == w.payOff { // The payload is empty, so just abandon the packet. w.b = w.b[:w.pktOff] @@ -135,7 +135,8 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber pnumOff := len(hdr) hdr = appendPacketNumber(hdr, p.num, pnumMaxAcked) - return w.protect(hdr[w.pktOff:], p.num, pnumOff, k) + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, p.num) + return w.finish(p.num) } // start1RTTPacket starts writing a 1-RTT (short header) packet. @@ -162,14 +163,13 @@ func (w *packetWriter) start1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnI // finish1RTTPacket finishes writing a 1-RTT packet, // canceling the packet if it contains no payload. // It returns a sentPacket describing the packet, or nil if no packet was written. -func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k keys) *sentPacket { +func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConnID []byte, k *updatingKeyPair) *sentPacket { if len(w.b) == w.payOff { // The payload is empty, so just abandon the packet. w.b = w.b[:w.pktOff] return nil } // TODO: Spin - // TODO: Key phase pnumLen := packetNumberLength(pnum, pnumMaxAcked) hdr := w.b[:w.pktOff] hdr = append(hdr, 0x40|byte(pnumLen-1)) @@ -177,7 +177,8 @@ func (w *packetWriter) finish1RTTPacket(pnum, pnumMaxAcked packetNumber, dstConn pnumOff := len(hdr) hdr = appendPacketNumber(hdr, pnum, pnumMaxAcked) w.padPacketLength(pnumLen) - return w.protect(hdr[w.pktOff:], pnum, pnumOff, k) + k.protect(hdr[w.pktOff:], w.b[len(hdr):], pnumOff-w.pktOff, pnum) + return w.finish(pnum) } // padPacketLength pads out the payload of the current packet to the minimum size, @@ -197,9 +198,8 @@ func (w *packetWriter) padPacketLength(pnumLen int) int { return plen } -// protect applies packet protection and finishes the current packet. -func (w *packetWriter) protect(hdr []byte, pnum packetNumber, pnumOff int, k keys) *sentPacket { - k.protect(hdr, w.b[w.pktOff+len(hdr):], pnumOff-w.pktOff, pnum) +// finish finishes the current packet after protection is applied. +func (w *packetWriter) finish(pnum packetNumber) *sentPacket { w.b = w.b[:len(w.b)+aeadOverhead] w.sent.size = len(w.b) - w.pktOff w.sent.num = pnum diff --git a/internal/quic/queue.go b/internal/quic/queue.go index 489721a8a..7085e578b 100644 --- a/internal/quic/queue.go +++ b/internal/quic/queue.go @@ -44,21 +44,9 @@ func (q *queue[T]) put(v T) bool { // get removes the first item from the queue, blocking until ctx is done, an item is available, // or the queue is closed. -func (q *queue[T]) get(ctx context.Context) (T, error) { - return q.getWithHooks(ctx, nil) -} - -// getWithHooks is get, but uses testHooks for locking when non-nil. -// This is a bit of an layer violation, but a simplification overall. -func (q *queue[T]) getWithHooks(ctx context.Context, testHooks connTestHooks) (T, error) { +func (q *queue[T]) get(ctx context.Context, testHooks connTestHooks) (T, error) { var zero T - var err error - if testHooks != nil { - err = testHooks.waitAndLockGate(ctx, &q.gate) - } else { - err = q.gate.waitAndLockContext(ctx) - } - if err != nil { + if err := q.gate.waitAndLock(ctx, testHooks); err != nil { return zero, err } defer q.unlock() diff --git a/internal/quic/queue_test.go b/internal/quic/queue_test.go index 8debeff11..d78216b0e 100644 --- a/internal/quic/queue_test.go +++ b/internal/quic/queue_test.go @@ -18,7 +18,7 @@ func TestQueue(t *testing.T) { cancel() q := newQueue[int]() - if got, err := q.get(nonblocking); err != context.Canceled { + if got, err := q.get(nonblocking, nil); err != context.Canceled { t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err) } @@ -28,13 +28,13 @@ func TestQueue(t *testing.T) { if !q.put(2) { t.Fatalf("q.put(2) = false, want true") } - if got, err := q.get(nonblocking); got != 1 || err != nil { + if got, err := q.get(nonblocking, nil); got != 1 || err != nil { t.Fatalf("q.get() = %v, %v, want 1, nil", got, err) } - if got, err := q.get(nonblocking); got != 2 || err != nil { + if got, err := q.get(nonblocking, nil); got != 2 || err != nil { t.Fatalf("q.get() = %v, %v, want 2, nil", got, err) } - if got, err := q.get(nonblocking); err != context.Canceled { + if got, err := q.get(nonblocking, nil); err != context.Canceled { t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err) } @@ -42,7 +42,7 @@ func TestQueue(t *testing.T) { time.Sleep(1 * time.Millisecond) q.put(3) }() - if got, err := q.get(context.Background()); got != 3 || err != nil { + if got, err := q.get(context.Background(), nil); got != 3 || err != nil { t.Fatalf("q.get() = %v, %v, want 3, nil", got, err) } @@ -50,7 +50,7 @@ func TestQueue(t *testing.T) { t.Fatalf("q.put(2) = false, want true") } q.close(io.EOF) - if got, err := q.get(context.Background()); got != 0 || err != io.EOF { + if got, err := q.get(context.Background(), nil); got != 0 || err != io.EOF { t.Fatalf("q.get() = %v, %v, want 0, io.EOF", got, err) } if q.put(5) { diff --git a/internal/quic/quic.go b/internal/quic/quic.go index 71738e129..9de97b6d8 100644 --- a/internal/quic/quic.go +++ b/internal/quic/quic.go @@ -10,6 +10,13 @@ import ( "time" ) +// QUIC versions. +// We only support v1 at this time. +const ( + quicVersion1 = 1 + quicVersion2 = 0x6b3343cf // https://www.rfc-editor.org/rfc/rfc9369 +) + // connIDLen is the length in bytes of connection IDs chosen by this package. // Since 1-RTT packets don't include a connection ID length field, // we use a consistent length for all our IDs. @@ -59,6 +66,13 @@ const minimumClientInitialDatagramSize = 1200 // https://www.rfc-editor.org/rfc/rfc9000.html#section-4.6-2 const maxStreamsLimit = 1 << 60 +// Maximum number of streams we will allow the peer to create implicitly. +// A stream ID that is used out of order results in all streams of that type +// with lower-numbered IDs also being opened. To limit the amount of work we +// will do in response to a single frame, we cap the peer's stream limit to +// this value. +const implicitStreamLimit = 100 + // A connSide distinguishes between the client and server sides of a connection. type connSide int8 diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 2dbf4461b..89036b19b 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -39,6 +39,7 @@ type Stream struct { outgate gate out pipe // buffered data to send outwin int64 // maximum MAX_STREAM_DATA received from the peer + outmaxsent int64 // maximum data offset we've sent to the peer outmaxbuf int64 // maximum amount of data we will buffer outunsent rangeset[int64] // ranges buffered but not yet sent outacked rangeset[int64] // ranges sent and acknowledged @@ -57,6 +58,7 @@ type Stream struct { // streamIn* bits must be set with ingate held. // streamOut* bits must be set with outgate held. // streamConn* bits are set by the conn's loop. + // streamQueue* bits must be set with streamsState.sendMu held. state atomicBits[streamState] prev, next *Stream // guarded by streamsState.sendMu @@ -65,11 +67,19 @@ type Stream struct { type streamState uint32 const ( - // streamInSend and streamOutSend are set when there are - // frames to send for the inbound or outbound sides of the stream. - // For example, MAX_STREAM_DATA or STREAM_DATA_BLOCKED. - streamInSend = streamState(1 << iota) - streamOutSend + // streamInSendMeta is set when there are frames to send for the + // inbound side of the stream. For example, MAX_STREAM_DATA. + // Inbound frames are never flow-controlled. + streamInSendMeta = streamState(1 << iota) + + // streamOutSendMeta is set when there are non-flow-controlled frames + // to send for the outbound side of the stream. For example, STREAM_DATA_BLOCKED. + // streamOutSendData is set when there are no non-flow-controlled outbound frames + // and the stream has data to send. + // + // At most one of streamOutSendMeta and streamOutSendData is set at any time. + streamOutSendMeta + streamOutSendData // streamInDone and streamOutDone are set when the inbound or outbound // sides of the stream are finished. When both are set, the stream @@ -79,8 +89,48 @@ const ( // streamConnRemoved is set when the stream has been removed from the conn. streamConnRemoved + + // streamQueueMeta and streamQueueData indicate which of the streamsState + // send queues the conn is currently on. + streamQueueMeta + streamQueueData +) + +type streamQueue int + +const ( + noQueue = streamQueue(iota) + metaQueue // streamsState.queueMeta + dataQueue // streamsState.queueData ) +// wantQueue returns the send queue the stream should be on. +func (s streamState) wantQueue() streamQueue { + switch { + case s&(streamInSendMeta|streamOutSendMeta) != 0: + return metaQueue + case s&(streamInDone|streamOutDone|streamConnRemoved) == streamInDone|streamOutDone: + return metaQueue + case s&streamOutSendData != 0: + // The stream has no non-flow-controlled frames to send, + // but does have data. Put it on the data queue, which is only + // processed when flow control is available. + return dataQueue + } + return noQueue +} + +// inQueue returns the send queue the stream is currently on. +func (s streamState) inQueue() streamQueue { + switch { + case s&streamQueueMeta != 0: + return metaQueue + case s&streamQueueData != 0: + return dataQueue + } + return noQueue +} + // newStream returns a new stream. // // The stream's ingate and outgate are locked. @@ -132,11 +182,13 @@ func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { if s.IsWriteOnly() { return 0, errors.New("read from write-only stream") } - // Wait until data is available. - if err := s.conn.waitAndLockGate(ctx, &s.ingate); err != nil { + if err := s.ingate.waitAndLock(ctx, s.conn.testHooks); err != nil { return 0, err } - defer s.inUnlock() + defer func() { + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(int64(n)) // must be done with ingate unlocked + }() if s.inresetcode != -1 { return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode)) } @@ -158,7 +210,7 @@ func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { s.in.copy(start, b) s.in.discardBefore(end) if s.insize == -1 || s.insize > s.inwin { - if shouldUpdateFlowControl(s.inwin-s.in.start, s.inmaxbuf) { + if shouldUpdateFlowControl(s.inmaxbuf, s.in.start+s.inmaxbuf-s.inwin) { // Update stream flow control with a STREAM_MAX_DATA frame. s.insendmax.setUnsent() } @@ -173,10 +225,8 @@ func (s *Stream) ReadContext(ctx context.Context, b []byte) (n int, err error) { // // We want to balance keeping the peer well-supplied with flow control with not sending // many small updates. -func shouldUpdateFlowControl(curwin, maxwin int64) bool { - // Update flow control if doing so gives the peer at least 64k tokens, - // or if it will double the current window. - return maxwin-curwin >= 64<<10 || curwin*2 < maxwin +func shouldUpdateFlowControl(maxWindow, addedWindow int64) bool { + return addedWindow >= maxWindow/8 } // Write writes data to the stream. @@ -202,16 +252,9 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) // We exit the loop after writing all data, so on subsequent passes through // the loop we are always write blocked. if len(b) > 0 && !canWrite { - // We're blocked, either by flow control or by our own buffer limit. - // We either need the peer to extend our flow control window, - // or ack some of our outstanding packets. - if s.out.end == s.outwin { - // We're blocked by flow control. - // Send a STREAM_DATA_BLOCKED frame to let the peer know. - s.outblocked.setUnsent() - } + // Our send buffer is full. Wait for the peer to ack some data. s.outUnlock() - if err := s.conn.waitAndLockGate(ctx, &s.outgate); err != nil { + if err := s.outgate.waitAndLock(ctx, s.conn.testHooks); err != nil { return n, err } // Successfully returning from waitAndLockGate means we are no longer @@ -233,18 +276,24 @@ func (s *Stream) WriteContext(ctx context.Context, b []byte) (n int, err error) if len(b) == 0 { break } - s.outblocked.clear() - // Write limit is min(our own buffer limit, the peer-provided flow control window). + // Write limit is our send buffer limit. // This is a stream offset. - lim := min(s.out.start+s.outmaxbuf, s.outwin) + lim := s.out.start + s.outmaxbuf // Amount to write is min(the full buffer, data up to the write limit). // This is a number of bytes. nn := min(int64(len(b)), lim-s.out.end) // Copy the data into the output buffer and mark it as unsent. - s.outunsent.add(s.out.end, s.out.end+nn) + if s.out.end <= s.outwin { + s.outunsent.add(s.out.end, min(s.out.end+nn, s.outwin)) + } s.out.writeAt(b[:nn], s.out.end) b = b[nn:] n += int(nn) + if s.out.end > s.outwin { + // We're blocked by flow control. + // Send a STREAM_DATA_BLOCKED frame to let the peer know. + s.outblocked.set() + } // If we have bytes left to send, we're blocked. canWrite = false } @@ -287,7 +336,6 @@ func (s *Stream) CloseRead() { return } s.ingate.lock() - defer s.inUnlock() if s.inset.isrange(0, s.insize) || s.inresetcode != -1 { // We've already received all data from the peer, // so there's no need to send STOP_SENDING. @@ -296,7 +344,10 @@ func (s *Stream) CloseRead() { } else { s.inclosed.set() } + discarded := s.in.end - s.in.start s.in.discardBefore(s.in.end) + s.inUnlock() + s.conn.handleStreamBytesReadOffLoop(discarded) // must be done with ingate unlocked } // CloseWrite aborts writes on the stream. @@ -330,6 +381,10 @@ func (s *Stream) Reset(code uint64) { s.resetInternal(code, userClosed) } +// resetInternal resets the send side of the stream. +// +// If userClosed is true, this is s.Reset. +// If userClosed is false, this is a reaction to a STOP_SENDING frame. func (s *Stream) resetInternal(code uint64, userClosed bool) { s.outgate.lock() defer s.outUnlock() @@ -362,9 +417,7 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) { // are done and the stream should be removed, it notifies the Conn. func (s *Stream) inUnlock() { state := s.inUnlockNoQueue() - if state&streamInSend != 0 || state == streamInDone|streamOutDone { - s.conn.queueStreamForSend(s) - } + s.conn.maybeQueueStreamForSend(s, state) } // inUnlockNoQueue is inUnlock, @@ -388,11 +441,11 @@ func (s *Stream) inUnlockNoQueue() streamState { state = streamInDone } case s.insendmax.shouldSend(): // STREAM_MAX_DATA - state = streamInSend + state = streamInSendMeta case s.inclosed.shouldSend(): // STOP_SENDING - state = streamInSend + state = streamInSendMeta } - const mask = streamInDone | streamInSend + const mask = streamInDone | streamInSendMeta return s.state.set(state, mask) } @@ -402,9 +455,7 @@ func (s *Stream) inUnlockNoQueue() streamState { // are done and the stream should be removed, it notifies the Conn. func (s *Stream) outUnlock() { state := s.outUnlockNoQueue() - if state&streamOutSend != 0 || state == streamInDone|streamOutDone { - s.conn.queueStreamForSend(s) - } + s.conn.maybeQueueStreamForSend(s, state) } // outUnlockNoQueue is outUnlock, @@ -421,8 +472,8 @@ func (s *Stream) outUnlockNoQueue() streamState { } } } - lim := min(s.out.start+s.outmaxbuf, s.outwin) - canWrite := lim > s.out.end || // available flow control + lim := s.out.start + s.outmaxbuf + canWrite := lim > s.out.end || // available send buffer s.outclosed.isSet() || // closed locally s.outreset.isSet() // reset locally defer s.outgate.unlock(canWrite) @@ -439,18 +490,22 @@ func (s *Stream) outUnlockNoQueue() streamState { state = streamOutDone } case s.outreset.shouldSend(): // RESET_STREAM - state = streamOutSend + state = streamOutSendMeta case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged + case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED + state = streamOutSendMeta case len(s.outunsent) > 0: // STREAM frame with data - state = streamOutSend - case s.outclosed.shouldSend(): // STREAM frame with FIN bit - state = streamOutSend + if s.outunsent.min() < s.outmaxsent { + state = streamOutSendMeta // resent data, will not consume flow control + } else { + state = streamOutSendData // new data, requires flow control + } + case s.outclosed.shouldSend() && s.out.end == s.outmaxsent: // empty STREAM frame with FIN bit + state = streamOutSendMeta case s.outopened.shouldSend(): // STREAM frame with no data - state = streamOutSend - case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED - state = streamOutSend + state = streamOutSendMeta } - const mask = streamOutDone | streamOutSend + const mask = streamOutDone | streamOutSendMeta | streamOutSendData return s.state.set(state, mask) } @@ -467,6 +522,12 @@ func (s *Stream) handleData(off int64, b []byte, fin bool) error { // Either way, we can discard this frame. return nil } + if s.insize == -1 && end > s.in.end { + added := end - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } s.in.writeAt(b, off) s.inset.add(off, end) if fin { @@ -489,6 +550,13 @@ func (s *Stream) handleReset(code uint64, finalSize int64) error { // The stream was already reset. return nil } + if s.insize == -1 { + added := finalSize - s.in.end + if err := s.conn.handleStreamBytesReceived(added); err != nil { + return err + } + } + s.conn.handleStreamBytesReadOnLoop(finalSize - s.in.start) s.in.discardBefore(s.in.end) s.inresetcode = int64(code) s.insize = finalSize @@ -529,7 +597,19 @@ func (s *Stream) handleStopSending(code uint64) error { func (s *Stream) handleMaxStreamData(maxStreamData int64) error { s.outgate.lock() defer s.outUnlock() - s.outwin = max(maxStreamData, s.outwin) + if maxStreamData <= s.outwin { + return nil + } + if s.out.end > s.outwin { + s.outunsent.add(s.outwin, min(maxStreamData, s.out.end)) + } + s.outwin = maxStreamData + if s.out.end > s.outwin { + // We've still got more data than flow control window. + s.outblocked.setUnsent() + } else { + s.outblocked.clear() + } return nil } @@ -631,7 +711,7 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b if s.outreset.isSet() { // RESET_STREAM if s.outreset.shouldSendPTO(pto) { - if !w.appendResetStreamFrame(s.id, s.outresetcode, s.out.end) { + if !w.appendResetStreamFrame(s.id, s.outresetcode, min(s.outwin, s.out.end)) { return false } s.outreset.setSent(pnum) @@ -641,15 +721,20 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b } if s.outblocked.shouldSendPTO(pto) { // STREAM_DATA_BLOCKED - if !w.appendStreamDataBlockedFrame(s.id, s.out.end) { + if !w.appendStreamDataBlockedFrame(s.id, s.outwin) { return false } s.outblocked.setSent(pnum) s.frameOpensStream(pnum) } - // STREAM for { - off, size := dataToSend(s.out, s.outunsent, s.outacked, pto) + // STREAM + off, size := dataToSend(min(s.out.start, s.outwin), min(s.out.end, s.outwin), s.outunsent, s.outacked, pto) + if end := off + size; end > s.outmaxsent { + // This will require connection-level flow control to send. + end = min(end, s.outmaxsent+s.conn.streams.outflow.avail()) + size = end - off + } fin := s.outclosed.isSet() && off+size == s.out.end shouldSend := size > 0 || // have data to send s.outopened.shouldSendPTO(pto) || // should open the stream @@ -662,7 +747,12 @@ func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto b return false } s.out.copy(off, b) - s.outunsent.sub(off, off+int64(len(b))) + end := off + int64(len(b)) + if end > s.outmaxsent { + s.conn.streams.outflow.consume(end - s.outmaxsent) + s.outmaxsent = end + } + s.outunsent.sub(off, end) s.frameOpensStream(pnum) if fin { s.outclosed.setSent(pnum) @@ -687,7 +777,7 @@ func (s *Stream) frameOpensStream(pnum packetNumber) { } // dataToSend returns the next range of data to send in a STREAM or CRYPTO_STREAM. -func dataToSend(out pipe, outunsent, outacked rangeset[int64], pto bool) (start, size int64) { +func dataToSend(start, end int64, outunsent, outacked rangeset[int64], pto bool) (sendStart, size int64) { switch { case pto: // On PTO, resend unacked data that fits in the probe packet. @@ -698,14 +788,14 @@ func dataToSend(out pipe, outunsent, outacked rangeset[int64], pto bool) (start, // This may miss unacked data starting after that acked byte, // but avoids resending data the peer has acked. for _, r := range outacked { - if r.start > out.start { - return out.start, r.start - out.start + if r.start > start { + return start, r.start - start } } - return out.start, out.end - out.start + return start, end - start case outunsent.numRanges() > 0: return outunsent.min(), outunsent[0].size() default: - return out.end, 0 + return end, 0 } } diff --git a/internal/quic/stream_limits.go b/internal/quic/stream_limits.go new file mode 100644 index 000000000..6eda7883b --- /dev/null +++ b/internal/quic/stream_limits.go @@ -0,0 +1,113 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" +) + +// Limits on the number of open streams. +// Every connection has separate limits for bidirectional and unidirectional streams. +// +// Note that the MAX_STREAMS limit includes closed as well as open streams. +// Closing a stream doesn't enable an endpoint to open a new one; +// only an increase in the MAX_STREAMS limit does. + +// localStreamLimits are limits on the number of open streams created by us. +type localStreamLimits struct { + gate gate + max int64 // peer-provided MAX_STREAMS + opened int64 // number of streams opened by us +} + +func (lim *localStreamLimits) init() { + lim.gate = newGate() +} + +// open creates a new local stream, blocking until MAX_STREAMS quota is available. +func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err error) { + // TODO: Send a STREAMS_BLOCKED when blocked. + if err := lim.gate.waitAndLock(ctx, c.testHooks); err != nil { + return 0, err + } + n := lim.opened + lim.opened++ + lim.gate.unlock(lim.opened < lim.max) + return n, nil +} + +// setMax sets the MAX_STREAMS provided by the peer. +func (lim *localStreamLimits) setMax(maxStreams int64) { + lim.gate.lock() + lim.max = max(lim.max, maxStreams) + lim.gate.unlock(lim.opened < lim.max) +} + +// remoteStreamLimits are limits on the number of open streams created by the peer. +type remoteStreamLimits struct { + max int64 // last MAX_STREAMS sent to the peer + opened int64 // number of streams opened by the peer (including subsequently closed ones) + closed int64 // number of peer streams in the "closed" state + maxOpen int64 // how many streams we want to let the peer simultaneously open + sendMax sentVal // set when we should send MAX_STREAMS +} + +func (lim *remoteStreamLimits) init(maxOpen int64) { + lim.maxOpen = maxOpen + lim.max = min(maxOpen, implicitStreamLimit) // initial limit sent in transport parameters + lim.opened = 0 +} + +// open handles the peer opening a new stream. +func (lim *remoteStreamLimits) open(id streamID) error { + num := id.num() + if num >= lim.max { + return localTransportError(errStreamLimit) + } + if num >= lim.opened { + lim.opened = num + 1 + lim.maybeUpdateMax() + } + return nil +} + +// close handles the peer closing an open stream. +func (lim *remoteStreamLimits) close() { + lim.closed++ + lim.maybeUpdateMax() +} + +// maybeUpdateMax updates the MAX_STREAMS value we will send to the peer. +func (lim *remoteStreamLimits) maybeUpdateMax() { + newMax := min( + // Max streams the peer can have open at once. + lim.closed+lim.maxOpen, + // Max streams the peer can open with a single frame. + lim.opened+implicitStreamLimit, + ) + avail := lim.max - lim.opened + if newMax > lim.max && (avail < 8 || newMax-lim.max >= 2*avail) { + // If the peer has less than 8 streams, or if increasing the peer's + // stream limit would double it, then send a MAX_STREAMS. + lim.max = newMax + lim.sendMax.setUnsent() + } +} + +// appendFrame appends a MAX_STREAMS frame to the current packet, if necessary. +// +// It returns true if no more frames need appending, +// false if not everything fit in the current packet. +func (lim *remoteStreamLimits) appendFrame(w *packetWriter, typ streamType, pnum packetNumber, pto bool) bool { + if lim.sendMax.shouldSendPTO(pto) { + if !w.appendMaxStreamsFrame(typ, lim.max) { + return false + } + lim.sendMax.setSent(pnum) + } + return true +} diff --git a/internal/quic/stream_limits_test.go b/internal/quic/stream_limits_test.go new file mode 100644 index 000000000..3f291e9f4 --- /dev/null +++ b/internal/quic/stream_limits_test.go @@ -0,0 +1,269 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "crypto/tls" + "testing" +) + +func TestStreamLimitNewStreamBlocked(t *testing.T) { + // "An endpoint that receives a frame with a stream ID exceeding the limit + // it has sent MUST treat this as a connection error of type STREAM_LIMIT_ERROR [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-3 + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + ctx := canceledContext() + tc := newTestConn(t, clientSide, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxStreamsBidi = 0 + p.initialMaxStreamsUni = 0 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + opening := runAsync(tc, func(ctx context.Context) (*Stream, error) { + return tc.conn.newLocalStream(ctx, styp) + }) + if _, err := opening.result(); err != errNotDone { + t.Fatalf("new stream blocked by limit: %v, want errNotDone", err) + } + tc.writeFrames(packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 1, + }) + if _, err := opening.result(); err != nil { + t.Fatalf("new stream not created after limit raised: %v", err) + } + if _, err := tc.conn.newLocalStream(ctx, styp); err == nil { + t.Fatalf("new stream blocked by raised limit: %v, want error", err) + } + }) +} + +func TestStreamLimitMaxStreamsDecreases(t *testing.T) { + // "MAX_STREAMS frames that do not increase the stream limit MUST be ignored." + // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-4 + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + ctx := canceledContext() + tc := newTestConn(t, clientSide, + permissiveTransportParameters, + func(p *transportParameters) { + p.initialMaxStreamsBidi = 0 + p.initialMaxStreamsUni = 0 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 2, + }) + tc.writeFrames(packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 1, + }) + if _, err := tc.conn.newLocalStream(ctx, styp); err != nil { + t.Fatalf("open stream 1, limit 2, got error: %v", err) + } + if _, err := tc.conn.newLocalStream(ctx, styp); err != nil { + t.Fatalf("open stream 2, limit 2, got error: %v", err) + } + if _, err := tc.conn.newLocalStream(ctx, styp); err == nil { + t.Fatalf("open stream 3, limit 2, got error: %v", err) + } + }) +} + +func TestStreamLimitViolated(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc := newTestConn(t, serverSide, + func(c *Config) { + if styp == bidiStream { + c.MaxBidiRemoteStreams = 10 + } else { + c.MaxUniRemoteStreams = 10 + } + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, 9), + }) + tc.wantIdle("stream number 9 is within the limit") + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, 10), + }) + tc.wantFrame("stream number 10 is beyond the limit", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errStreamLimit, + }, + ) + }) +} + +func TestStreamLimitImplicitStreams(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc := newTestConn(t, serverSide, + func(c *Config) { + c.MaxBidiRemoteStreams = 1 << 60 + c.MaxUniRemoteStreams = 1 << 60 + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + if got, want := tc.sentTransportParameters.initialMaxStreamsBidi, int64(implicitStreamLimit); got != want { + t.Errorf("sent initial_max_streams_bidi = %v, want %v", got, want) + } + if got, want := tc.sentTransportParameters.initialMaxStreamsUni, int64(implicitStreamLimit); got != want { + t.Errorf("sent initial_max_streams_uni = %v, want %v", got, want) + } + + // Create stream 0. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, 0), + }) + tc.wantIdle("max streams not increased enough to send a new frame") + + // Create streams [0, implicitStreamLimit). + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, implicitStreamLimit-1), + }) + tc.wantFrame("max streams increases to implicit stream limit", + packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 2 * implicitStreamLimit, + }) + + // Create a stream past the limit. + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, 2*implicitStreamLimit), + }) + tc.wantFrame("stream is past the limit", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errStreamLimit, + }, + ) + }) +} + +func TestStreamLimitMaxStreamsTransportParameterTooLarge(t *testing.T) { + // "If a max_streams transport parameter [...] is received with + // a value greater than 2^60 [...] the connection MUST be closed + // immediately with a connection error of type TRANSPORT_PARAMETER_ERROR [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-2 + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc := newTestConn(t, serverSide, + func(p *transportParameters) { + if styp == bidiStream { + p.initialMaxStreamsBidi = 1<<60 + 1 + } else { + p.initialMaxStreamsUni = 1<<60 + 1 + } + }) + tc.writeFrames(packetTypeInitial, debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.wantFrame("max streams transport parameter is too large", + packetTypeInitial, debugFrameConnectionCloseTransport{ + code: errTransportParameter, + }, + ) + }) +} + +func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) { + // "If [...] a MAX_STREAMS frame is received with a value + // greater than 2^60 [...] the connection MUST be closed immediately + // with a connection error [...] of type FRAME_ENCODING_ERROR [...]" + // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-2 + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + tc := newTestConn(t, serverSide) + tc.handshake() + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 1<<60 + 1, + }) + tc.wantFrame("MAX_STREAMS value is too large", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errFrameEncoding, + }, + ) + }) +} + +func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + ctx := canceledContext() + tc := newTestConn(t, serverSide, func(c *Config) { + if styp == uniStream { + c.MaxUniRemoteStreams = 4 + c.MaxBidiRemoteStreams = 0 + } else { + c.MaxUniRemoteStreams = 0 + c.MaxBidiRemoteStreams = 4 + } + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + var streams []*Stream + for i := 0; i < 4; i++ { + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: newStreamID(clientSide, styp, int64(i)), + fin: true, + }) + s, err := tc.conn.AcceptStream(ctx) + if err != nil { + t.Fatalf("AcceptStream = %v", err) + } + streams = append(streams, s) + } + streams[3].CloseContext(ctx) + if styp == bidiStream { + tc.wantFrame("stream is closed", + packetType1RTT, debugFrameStream{ + id: streams[3].id, + fin: true, + data: []byte{}, + }) + tc.writeAckForAll() + } + tc.wantFrame("closing a stream when peer is at limit immediately extends the limit", + packetType1RTT, debugFrameMaxStreams{ + streamType: styp, + max: 5, + }) + }) +} + +func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { + tc, s := newTestConnAndRemoteStream(t, serverSide, bidiStream, func(c *Config) { + c.MaxBidiRemoteStreams = 1 + }) + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + fin: true, + }) + s.CloseRead() + tc.writeFrames(packetType1RTT, debugFrameStopSending{ + id: s.id, + }) + tc.wantFrame("recieved STOP_SENDING, send RESET_STREAM", + packetType1RTT, debugFrameResetStream{ + id: s.id, + }) + tc.writeAckForAll() + tc.wantIdle("MAX_STREAMS is not extended until the user fully closes the stream") + s.CloseWrite() + tc.wantFrame("user closing the stream triggers MAX_STREAMS update", + packetType1RTT, debugFrameMaxStreams{ + streamType: bidiStream, + max: 2, + }) +} diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index e22e0432e..7c1377fae 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -18,6 +18,67 @@ import ( "testing" ) +func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { + testStreamTypes(t, "", func(t *testing.T, styp streamType) { + ctx := canceledContext() + want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + const writeBufferSize = 4 + tc := newTestConn(t, clientSide, permissiveTransportParameters, func(c *Config) { + c.MaxStreamWriteBufferSize = writeBufferSize + }) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + s, err := tc.conn.newLocalStream(ctx, styp) + if err != nil { + t.Fatal(err) + } + + // Non-blocking write. + n, err := s.WriteContext(ctx, want) + if n != writeBufferSize || err != context.Canceled { + t.Fatalf("s.WriteContext() = %v, %v; want %v, context.Canceled", n, err, writeBufferSize) + } + tc.wantFrame("first write buffer of data sent", + packetType1RTT, debugFrameStream{ + id: s.id, + data: want[:writeBufferSize], + }) + off := int64(writeBufferSize) + + // Blocking write, which must wait for buffer space. + w := runAsync(tc, func(ctx context.Context) (int, error) { + return s.WriteContext(ctx, want[writeBufferSize:]) + }) + tc.wantIdle("write buffer is full, no more data can be sent") + + // The peer's ack of the STREAM frame allows progress. + tc.writeAckForAll() + tc.wantFrame("second write buffer of data sent", + packetType1RTT, debugFrameStream{ + id: s.id, + off: off, + data: want[off:][:writeBufferSize], + }) + off += writeBufferSize + tc.wantIdle("write buffer is full, no more data can be sent") + + // The peer's ack of the second STREAM frame allows sending the remaining data. + tc.writeAckForAll() + tc.wantFrame("remaining data sent", + packetType1RTT, debugFrameStream{ + id: s.id, + off: off, + data: want[off:], + }) + + if n, err := w.result(); n != len(want)-writeBufferSize || err != nil { + t.Fatalf("s.WriteContext() = %v, %v; want %v, nil", + len(want)-writeBufferSize, err, writeBufferSize) + } + }) +} + func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() @@ -30,14 +91,15 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { tc.handshake() tc.ignoreFrame(frameTypeAck) - // Non-blocking write with no flow control. s, err := tc.conn.newLocalStream(ctx, styp) if err != nil { t.Fatal(err) } - _, err = s.WriteContext(ctx, want) - if err != context.Canceled { - t.Fatalf("write to stream with no flow control: err = %v, want context.Canceled", err) + + // Data is written to the stream output buffer, but we have no flow control. + _, err = s.WriteContext(ctx, want[:1]) + if err != nil { + t.Fatalf("write with available output buffer: unexpected error: %v", err) } tc.wantFrame("write blocked by flow control triggers a STREAM_DATA_BLOCKED frame", packetType1RTT, debugFrameStreamDataBlocked{ @@ -45,15 +107,14 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { max: 0, }) - // Blocking write waiting for flow control. - w := runAsync(tc, func(ctx context.Context) (int, error) { - return s.WriteContext(ctx, want) - }) - tc.wantFrame("second blocked write triggers another STREAM_DATA_BLOCKED", - packetType1RTT, debugFrameStreamDataBlocked{ - id: s.id, - max: 0, - }) + // Write more data. + _, err = s.WriteContext(ctx, want[1:]) + if err != nil { + t.Fatalf("write with available output buffer: unexpected error: %v", err) + } + tc.wantIdle("adding more blocked data does not trigger another STREAM_DATA_BLOCKED") + + // Provide some flow control window. tc.writeFrames(packetType1RTT, debugFrameMaxStreamData{ id: s.id, max: 4, @@ -69,6 +130,7 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { data: want[:4], }) + // Provide more flow control window. tc.writeFrames(packetType1RTT, debugFrameMaxStreamData{ id: s.id, max: int64(len(want)), @@ -79,10 +141,6 @@ func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { off: 4, data: want[4:], }) - n, err := w.result() - if n != len(want) || err != nil { - t.Errorf("Write() = %v, %v; want %v, nil", n, err, len(want)) - } }) } @@ -169,7 +227,7 @@ func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { p.initialMaxStreamDataBidiRemote = 1 << 20 p.initialMaxStreamDataUni = 1 << 20 }, func(c *Config) { - c.StreamWriteBufferSize = maxWriteBuffer + c.MaxStreamWriteBufferSize = maxWriteBuffer }) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -391,7 +449,7 @@ func TestStreamReceiveExtendsStreamWindow(t *testing.T) { const maxWindowSize = 20 ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { - c.StreamReadBufferSize = maxWindowSize + c.MaxStreamReadBufferSize = maxWindowSize }) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -448,7 +506,7 @@ func TestStreamReceiveViolatesStreamDataLimit(t *testing.T) { size: 2, }} { tc := newTestConn(t, serverSide, func(c *Config) { - c.StreamReadBufferSize = maxStreamData + c.MaxStreamReadBufferSize = maxStreamData }) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -473,7 +531,7 @@ func TestStreamReceiveDuplicateDataDoesNotViolateLimits(t *testing.T) { const maxData = 10 tc := newTestConn(t, serverSide, func(c *Config) { // TODO: Add connection-level maximum data here as well. - c.StreamReadBufferSize = maxData + c.MaxStreamReadBufferSize = maxData }) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -557,7 +615,7 @@ func TestStreamFinalSizePastMaxStreamData(t *testing.T) { finalSizeTest(t, errFlowControl, func(tc *testConn, sid streamID) (finalSize int64) { return 11 }, func(c *Config) { - c.StreamReadBufferSize = 10 + c.MaxStreamReadBufferSize = 10 }) } @@ -649,7 +707,7 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { // to the conn and expects a STREAM_STATE_ERROR. func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFrame) { testSides(t, "stream_not_created", func(t *testing.T, side connSide) { - tc := newTestConn(t, side) + tc := newTestConn(t, side, permissiveTransportParameters) tc.handshake() tc.writeFrames(packetType1RTT, f(newStreamID(side, bidiStream, 0))) tc.wantFrame("frame for local stream which has not been created", @@ -659,7 +717,7 @@ func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFra }) testSides(t, "uni_stream", func(t *testing.T, side connSide) { ctx := canceledContext() - tc := newTestConn(t, side) + tc := newTestConn(t, side, permissiveTransportParameters) tc.handshake() sid := newStreamID(side, uniStream, 0) s, err := tc.conn.NewSendOnlyStream(ctx) @@ -796,7 +854,7 @@ func TestStreamOffsetTooLarge(t *testing.T) { } func TestStreamReadFromWriteOnlyStream(t *testing.T) { - _, s := newTestConnAndLocalStream(t, serverSide, uniStream) + _, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) buf := make([]byte, 10) wantErr := "read from write-only stream" if n, err := s.Read(buf); err == nil || !strings.Contains(err.Error(), wantErr) { @@ -868,16 +926,15 @@ func TestStreamWriteToClosedStream(t *testing.T) { } func TestStreamResetBlockedStream(t *testing.T) { - tc, s := newTestConnAndLocalStream(t, serverSide, bidiStream, func(p *transportParameters) { - p.initialMaxStreamsBidi = 1 - p.initialMaxData = 1 << 20 - p.initialMaxStreamDataBidiRemote = 4 - }) + tc, s := newTestConnAndLocalStream(t, serverSide, bidiStream, permissiveTransportParameters, + func(c *Config) { + c.MaxStreamWriteBufferSize = 4 + }) tc.ignoreFrame(frameTypeStreamDataBlocked) writing := runAsync(tc, func(ctx context.Context) (int, error) { return s.WriteContext(ctx, []byte{0, 1, 2, 3, 4, 5, 6, 7}) }) - tc.wantFrame("stream writes data until blocked by flow control", + tc.wantFrame("stream writes data until write buffer fills", packetType1RTT, debugFrameStream{ id: s.id, off: 0, @@ -894,11 +951,8 @@ func TestStreamResetBlockedStream(t *testing.T) { if n, err := writing.result(); n != 4 || !strings.Contains(err.Error(), wantErr) { t.Errorf("s.Write() interrupted by Reset: %v, %q; want 4, %q", n, err, wantErr) } - tc.writeFrames(packetType1RTT, debugFrameMaxStreamData{ - id: s.id, - max: 1 << 20, - }) - tc.wantIdle("flow control is available, but stream has been reset") + tc.writeAckForAll() + tc.wantIdle("buffer space is available, but stream has been reset") s.Reset(100) tc.wantIdle("resetting stream a second time has no effect") if n, err := s.Write([]byte{}); err == nil || !strings.Contains(err.Error(), wantErr) { @@ -1040,6 +1094,44 @@ func TestStreamCloseUnblocked(t *testing.T) { } } +func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { + ctx := canceledContext() + tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters, + func(p *transportParameters) { + //p.initialMaxData = 0 + p.initialMaxStreamDataUni = 0 + }) + tc.ignoreFrame(frameTypeStreamDataBlocked) + if _, err := s.WriteContext(ctx, []byte{0, 1}); err != nil { + t.Fatalf("s.Write = %v", err) + } + s.CloseWrite() + tc.wantIdle("stream write is blocked by flow control") + + tc.writeFrames(packetType1RTT, debugFrameMaxStreamData{ + id: s.id, + max: 1, + }) + tc.wantFrame("send data up to flow control limit", + packetType1RTT, debugFrameStream{ + id: s.id, + data: []byte{0}, + }) + tc.wantIdle("stream write is again blocked by flow control") + + tc.writeFrames(packetType1RTT, debugFrameMaxStreamData{ + id: s.id, + max: 2, + }) + tc.wantFrame("send remaining data and FIN", + packetType1RTT, debugFrameStream{ + id: s.id, + off: 1, + data: []byte{1}, + fin: true, + }) +} + func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { testStreamTypes(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() @@ -1112,7 +1204,7 @@ func TestStreamPeerResetFollowedByData(t *testing.T) { } func TestStreamResetInvalidCode(t *testing.T) { - tc, s := newTestConnAndLocalStream(t, serverSide, uniStream) + tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) s.Reset(1 << 62) tc.wantFrame("reset with invalid code sends a RESET_STREAM anyway", packetType1RTT, debugFrameResetStream{ @@ -1163,6 +1255,23 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { }) } +func TestStreamReceiveDataBlocked(t *testing.T) { + tc := newTestConn(t, serverSide, permissiveTransportParameters) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + + // We don't do anything with these frames, + // but should accept them if the peer sends one. + tc.writeFrames(packetType1RTT, debugFrameStreamDataBlocked{ + id: newStreamID(clientSide, bidiStream, 0), + max: 100, + }) + tc.writeFrames(packetType1RTT, debugFrameDataBlocked{ + max: 100, + }) + tc.wantIdle("no response to STREAM_DATA_BLOCKED and DATA_BLOCKED") +} + type streamSide string const ( @@ -1216,3 +1325,11 @@ func permissiveTransportParameters(p *transportParameters) { p.initialMaxStreamDataBidiLocal = maxVarint p.initialMaxStreamDataUni = maxVarint } + +func makeTestData(n int) []byte { + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = byte(i) + } + return b +} diff --git a/internal/quic/tls.go b/internal/quic/tls.go index 584316f0e..a37e26fb8 100644 --- a/internal/quic/tls.go +++ b/internal/quic/tls.go @@ -16,12 +16,7 @@ import ( // startTLS starts the TLS handshake. func (c *Conn) startTLS(now time.Time, initialConnID []byte, params transportParameters) error { - clientKeys, serverKeys := initialKeys(initialConnID) - if c.side == clientSide { - c.wkeys[initialSpace], c.rkeys[initialSpace] = clientKeys, serverKeys - } else { - c.wkeys[initialSpace], c.rkeys[initialSpace] = serverKeys, clientKeys - } + c.keysInitial = initialKeys(initialConnID, c.side) qconfig := &tls.QUICConfig{TLSConfig: c.config.TLSConfig} if c.side == clientSide { @@ -49,21 +44,36 @@ func (c *Conn) handleTLSEvents(now time.Time) error { case tls.QUICNoEvent: return nil case tls.QUICSetReadSecret: - space, k, err := tlsKey(e) - if err != nil { + if err := checkCipherSuite(e.Suite); err != nil { return err } - c.rkeys[space] = k + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.r.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.r.init(e.Suite, e.Data) + } case tls.QUICSetWriteSecret: - space, k, err := tlsKey(e) - if err != nil { + if err := checkCipherSuite(e.Suite); err != nil { return err } - c.wkeys[space] = k + switch e.Level { + case tls.QUICEncryptionLevelHandshake: + c.keysHandshake.w.init(e.Suite, e.Data) + case tls.QUICEncryptionLevelApplication: + c.keysAppData.w.init(e.Suite, e.Data) + } case tls.QUICWriteData: - space, err := spaceForLevel(e.Level) - if err != nil { - return err + var space numberSpace + switch e.Level { + case tls.QUICEncryptionLevelInitial: + space = initialSpace + case tls.QUICEncryptionLevelHandshake: + space = handshakeSpace + case tls.QUICEncryptionLevelApplication: + space = appDataSpace + default: + return fmt.Errorf("quic: internal error: write handshake data at level %v", e.Level) } c.crypto[space].write(e.Data) case tls.QUICHandshakeDone: @@ -73,6 +83,7 @@ func (c *Conn) handleTLSEvents(now time.Time) error { // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2-1 c.confirmHandshake(now) } + c.handshakeDone() case tls.QUICTransportParameters: params, err := unmarshalTransportParams(e.Data) if err != nil { @@ -85,32 +96,6 @@ func (c *Conn) handleTLSEvents(now time.Time) error { } } -// tlsKey returns the keys in a QUICSetReadSecret or QUICSetWriteSecret event. -func tlsKey(e tls.QUICEvent) (numberSpace, keys, error) { - space, err := spaceForLevel(e.Level) - if err != nil { - return 0, keys{}, err - } - k, err := newKeys(e.Suite, e.Data) - if err != nil { - return 0, keys{}, err - } - return space, k, nil -} - -func spaceForLevel(level tls.QUICEncryptionLevel) (numberSpace, error) { - switch level { - case tls.QUICEncryptionLevelInitial: - return initialSpace, nil - case tls.QUICEncryptionLevelHandshake: - return handshakeSpace, nil - case tls.QUICEncryptionLevelApplication: - return appDataSpace, nil - default: - return 0, fmt.Errorf("quic: internal error: write handshake data at level %v", level) - } -} - // handleCrypto processes data received in a CRYPTO frame. func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []byte) error { var level tls.QUICEncryptionLevel diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index 180ea8bee..81d17b858 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -21,6 +21,7 @@ func (tc *testConn) handshake() { if *testVV { *testVV = false defer func() { + tc.t.Helper() *testVV = true tc.t.Logf("performed connection handshake") }() @@ -96,7 +97,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: transientConnID, frames: []debugFrame{ @@ -109,7 +110,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 0, - version: 1, + version: quicVersion1, srcConnID: serverConnIDs[0], dstConnID: clientConnIDs[0], frames: []debugFrame{ @@ -121,7 +122,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { }, { ptype: packetTypeHandshake, num: 0, - version: 1, + version: quicVersion1, srcConnID: serverConnIDs[0], dstConnID: clientConnIDs[0], frames: []debugFrame{ @@ -143,7 +144,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { packets: []*testPacket{{ ptype: packetTypeInitial, num: 1, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: serverConnIDs[0], frames: []debugFrame{ @@ -154,7 +155,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { }, { ptype: packetTypeHandshake, num: 0, - version: 1, + version: quicVersion1, srcConnID: clientConnIDs[0], dstConnID: serverConnIDs[0], frames: []debugFrame{ @@ -352,6 +353,7 @@ func TestConnKeysDiscardedClient(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{code: errInternal}) + tc.conn.Abort(nil) tc.wantFrame("client closes connection after 1-RTT CONNECTION_CLOSE", packetType1RTT, debugFrameConnectionCloseTransport{ code: errNo, @@ -405,6 +407,7 @@ func TestConnKeysDiscardedServer(t *testing.T) { tc.writeFrames(packetType1RTT, debugFrameConnectionCloseTransport{code: errInternal}) + tc.conn.Abort(nil) tc.wantFrame("server closes connection after 1-RTT CONNECTION_CLOSE", packetType1RTT, debugFrameConnectionCloseTransport{ code: errNo, @@ -535,3 +538,66 @@ func TestConnCryptoBufferSizeExceeded(t *testing.T) { code: errCryptoBufferExceeded, }) } + +func TestConnAEADLimitReached(t *testing.T) { + // "[...] endpoints MUST count the number of received packets that + // fail authentication during the lifetime of a connection. + // If the total number of received packets that fail authentication [...] + // exceeds the integrity limit for the selected AEAD, + // the endpoint MUST immediately close the connection [...]" + // https://www.rfc-editor.org/rfc/rfc9001#section-6.6-6 + tc := newTestConn(t, clientSide) + tc.handshake() + + var limit int64 + switch suite := tc.conn.keysAppData.r.suite; suite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: + limit = 1 << 52 + case tls.TLS_CHACHA20_POLY1305_SHA256: + limit = 1 << 36 + default: + t.Fatalf("conn.keysAppData.r.suite = %v, unknown suite", suite) + } + + dstConnID := tc.conn.connIDState.local[0].cid + if tc.conn.connIDState.local[0].seq == -1 { + // Only use the transient connection ID in Initial packets. + dstConnID = tc.conn.connIDState.local[1].cid + } + invalid := tc.encodeTestPacket(&testPacket{ + ptype: packetType1RTT, + num: 1000, + frames: []debugFrame{debugFramePing{}}, + version: quicVersion1, + dstConnID: dstConnID, + srcConnID: tc.peerConnID, + }, 0) + invalid[len(invalid)-1] ^= 1 + sendInvalid := func() { + t.Logf("<- conn under test receives invalid datagram") + tc.conn.sendMsg(&datagram{ + b: invalid, + }) + tc.wait() + } + + // Set the conn's auth failure count to just before the AEAD integrity limit. + tc.conn.keysAppData.authFailures = limit - 1 + + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("auth failures less than limit: conn ACKs packet", + packetType1RTT, debugFrameAck{}) + + sendInvalid() + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advanceToTimer() + tc.wantFrameType("auth failures at limit: conn closes", + packetType1RTT, debugFrameConnectionCloseTransport{ + code: errAEADLimitReached, + }) + + tc.writeFrames(packetType1RTT, debugFramePing{}) + tc.advance(1 * time.Second) + tc.wantIdle("auth failures at limit: conn does not process additional packets") +} diff --git a/internal/quic/version_test.go b/internal/quic/version_test.go new file mode 100644 index 000000000..cfb7ce4be --- /dev/null +++ b/internal/quic/version_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "bytes" + "context" + "crypto/tls" + "testing" +) + +func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { + config := &Config{ + TLSConfig: newTestTLSConfig(serverSide), + } + tl := newTestListener(t, config, nil) + + // Packet of unknown contents for some unrecognized QUIC version. + dstConnID := []byte{1, 2, 3, 4} + srcConnID := []byte{5, 6, 7, 8} + pkt := []byte{ + 0b1000_0000, + 0x00, 0x00, 0x00, 0x0f, + } + pkt = append(pkt, byte(len(dstConnID))) + pkt = append(pkt, dstConnID...) + pkt = append(pkt, byte(len(srcConnID))) + pkt = append(pkt, srcConnID...) + for len(pkt) < minimumClientInitialDatagramSize { + pkt = append(pkt, 0) + } + + tl.write(&datagram{ + b: pkt, + }) + gotPkt := tl.read() + if gotPkt == nil { + t.Fatalf("got no response; want Version Negotiaion") + } + if got := getPacketType(gotPkt); got != packetTypeVersionNegotiation { + t.Fatalf("got packet type %v; want Version Negotiaion", got) + } + gotDst, gotSrc, versions := parseVersionNegotiation(gotPkt) + if got, want := gotDst, srcConnID; !bytes.Equal(got, want) { + t.Errorf("got Destination Connection ID %x, want %x", got, want) + } + if got, want := gotSrc, dstConnID; !bytes.Equal(got, want) { + t.Errorf("got Source Connection ID %x, want %x", got, want) + } + if got, want := versions, []byte{0, 0, 0, 1}; !bytes.Equal(got, want) { + t.Errorf("got Supported Version %x, want %x", got, want) + } +} + +func TestVersionNegotiationClientAborts(t *testing.T) { + tc := newTestConn(t, clientSide) + p := tc.readPacket() // client Initial packet + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), + }) + tc.wantIdle("connection does not send a CONNECTION_CLOSE") + if err := tc.conn.waitReady(canceledContext()); err != errVersionNegotiation { + t.Errorf("conn.waitReady() = %v, want errVersionNegotiation", err) + } +} + +func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + p := tc.readPacket() // client Initial packet + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, p.dstConnID, 10), + }) + if err := tc.conn.waitReady(canceledContext()); err != context.Canceled { + t.Errorf("conn.waitReady() = %v, want context.Canceled", err) + } + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("conn ignores Version Negotiation and continues with handshake", + packetTypeHandshake, debugFrameCrypto{}) +} + +func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { + tc := newTestConn(t, clientSide) + tc.ignoreFrame(frameTypeAck) + p := tc.readPacket() // client Initial packet + tc.listener.write(&datagram{ + b: appendVersionNegotiation(nil, p.srcConnID, []byte("mismatch"), 10), + }) + tc.writeFrames(packetTypeInitial, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial], + }) + tc.writeFrames(packetTypeHandshake, + debugFrameCrypto{ + data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake], + }) + tc.wantFrameType("conn ignores Version Negotiation and continues with handshake", + packetTypeHandshake, debugFrameCrypto{}) +}