diff --git a/pkg/server/sotw/v3/server.go b/pkg/server/sotw/v3/server.go index ef40e8ceaf..d3b167379a 100644 --- a/pkg/server/sotw/v3/server.go +++ b/pkg/server/sotw/v3/server.go @@ -18,10 +18,10 @@ package sotw import ( "context" "errors" + "reflect" "strconv" "sync/atomic" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -29,11 +29,11 @@ import ( discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" "github.com/envoyproxy/go-control-plane/pkg/cache/v3" "github.com/envoyproxy/go-control-plane/pkg/resource/v3" - streamv3 "github.com/envoyproxy/go-control-plane/pkg/server/stream/v3" + "github.com/envoyproxy/go-control-plane/pkg/server/stream/v3" ) type Server interface { - StreamHandler(stream Stream, typeURL string) error + StreamHandler(stream stream.Stream, typeURL string) error } type Callbacks interface { @@ -63,49 +63,6 @@ type server struct { streamCount int64 } -// Generic RPC stream. -type Stream interface { - grpc.ServerStream - - Send(*discovery.DiscoveryResponse) error - Recv() (*discovery.DiscoveryRequest, error) -} - -// watches for all xDS resource types -type watches struct { - endpoints chan cache.Response - clusters chan cache.Response - routes chan cache.Response - scopedRoutes chan cache.Response - listeners chan cache.Response - secrets chan cache.Response - runtimes chan cache.Response - extensionConfigs chan cache.Response - - endpointCancel func() - clusterCancel func() - routeCancel func() - scopedRouteCancel func() - listenerCancel func() - secretCancel func() - runtimeCancel func() - extensionConfigCancel func() - - endpointNonce string - clusterNonce string - routeNonce string - scopedRouteNonce string - listenerNonce string - secretNonce string - runtimeNonce string - extensionConfigNonce string - - // Opaque resources share a muxed channel. Nonces and watch cancellations are indexed by type URL. - responses chan cache.Response - cancellations map[string]func() - nonces map[string]string -} - // Discovery response that is sent over GRPC stream // We need to record what resource names are already sent to a client // So if the client requests a new name we can respond back @@ -115,52 +72,8 @@ type lastDiscoveryResponse struct { resources map[string]struct{} } -// Initialize all watches -func (values *watches) Init() { - // muxed channel needs a buffer to release go-routines populating it - values.responses = make(chan cache.Response, 5) - values.cancellations = make(map[string]func()) - values.nonces = make(map[string]string) -} - -// Token response value used to signal a watch failure in muxed watches. -var errorResponse = &cache.RawResponse{} - -// Cancel all watches -func (values *watches) Cancel() { - if values.endpointCancel != nil { - values.endpointCancel() - } - if values.clusterCancel != nil { - values.clusterCancel() - } - if values.routeCancel != nil { - values.routeCancel() - } - if values.scopedRouteCancel != nil { - values.scopedRouteCancel() - } - if values.listenerCancel != nil { - values.listenerCancel() - } - if values.secretCancel != nil { - values.secretCancel() - } - if values.runtimeCancel != nil { - values.runtimeCancel() - } - if values.extensionConfigCancel != nil { - values.extensionConfigCancel() - } - for _, cancel := range values.cancellations { - if cancel != nil { - cancel() - } - } -} - // process handles a bi-di stream request -func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest, defaultTypeURL string) error { +func (s *server) process(str stream.Stream, reqCh <-chan *discovery.DiscoveryRequest, defaultTypeURL string) error { // increment stream count streamID := atomic.AddInt64(&s.streamCount, 1) @@ -168,14 +81,14 @@ func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest // ignores stale nonces. nonce is only modified within send() function. var streamNonce int64 - streamState := streamv3.NewStreamState(false, map[string]string{}) + streamState := stream.NewStreamState(false, map[string]string{}) lastDiscoveryResponses := map[string]lastDiscoveryResponse{} // a collection of stack allocated watches per request type - var values watches - values.Init() + watches := newWatches() + defer func() { - values.Cancel() + watches.close() if s.callbacks != nil { s.callbacks.OnStreamClosed(streamID) } @@ -208,11 +121,11 @@ func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest if s.callbacks != nil { s.callbacks.OnStreamResponse(resp.GetContext(), streamID, resp.GetRequest(), out) } - return out.Nonce, stream.Send(out) + return out.Nonce, str.Send(out) } if s.callbacks != nil { - if err := s.callbacks.OnStreamOpen(stream.Context(), streamID, defaultTypeURL); err != nil { + if err := s.callbacks.OnStreamOpen(str.Context(), streamID, defaultTypeURL); err != nil { return err } } @@ -220,109 +133,27 @@ func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest // node may only be set on the first discovery request var node = &core.Node{} + // recompute dynamic channels for this stream + watches.recompute(s.ctx, reqCh) + for { - select { - case <-s.ctx.Done(): + // The list of select cases looks like this: + // 0: <- ctx.Done + // 1: <- reqCh + // 2...: per type watches + index, value, ok := reflect.Select(watches.cases) + switch index { + // ctx.Done() -> if we receive a value here we return as no further computation is needed + case 0: return nil - // config watcher can send the requested resources types in any order - case resp, more := <-values.endpoints: - if !more { - return status.Errorf(codes.Unavailable, "endpoints watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.endpointNonce = nonce - - case resp, more := <-values.clusters: - if !more { - return status.Errorf(codes.Unavailable, "clusters watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.clusterNonce = nonce - - case resp, more := <-values.routes: - if !more { - return status.Errorf(codes.Unavailable, "routes watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.routeNonce = nonce - - case resp, more := <-values.scopedRoutes: - if !more { - return status.Errorf(codes.Unavailable, "scopedRoutes watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.scopedRouteNonce = nonce - - case resp, more := <-values.listeners: - if !more { - return status.Errorf(codes.Unavailable, "listeners watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.listenerNonce = nonce - - case resp, more := <-values.secrets: - if !more { - return status.Errorf(codes.Unavailable, "secrets watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.secretNonce = nonce - - case resp, more := <-values.runtimes: - if !more { - return status.Errorf(codes.Unavailable, "runtimes watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.runtimeNonce = nonce - - case resp, more := <-values.extensionConfigs: - if !more { - return status.Errorf(codes.Unavailable, "extensionConfigs watch failed") - } - nonce, err := send(resp) - if err != nil { - return err - } - values.extensionConfigNonce = nonce - - case resp, more := <-values.responses: - if more { - if resp == errorResponse { - return status.Errorf(codes.Unavailable, "resource watch failed") - } - typeURL := resp.GetRequest().TypeUrl - nonce, err := send(resp) - if err != nil { - return err - } - values.nonces[typeURL] = nonce - } - - case req, more := <-reqCh: + // Case 1 handles any request inbound on the stream and handles all initialization as needed + case 1: // input stream ended or errored out - if !more { + if !ok { return nil } + + req := value.Interface().(*discovery.DiscoveryRequest) if req == nil { return status.Errorf(codes.Unavailable, "empty request") } @@ -359,88 +190,50 @@ func (s *server) process(stream Stream, reqCh <-chan *discovery.DiscoveryRequest } } - // cancel existing watches to (re-)request a newer version - switch { - case req.TypeUrl == resource.EndpointType: - if values.endpointNonce == "" || values.endpointNonce == nonce { - if values.endpointCancel != nil { - values.endpointCancel() - } - values.endpoints = make(chan cache.Response, 1) - values.endpointCancel = s.cache.CreateWatch(req, streamState, values.endpoints) - } - case req.TypeUrl == resource.ClusterType: - if values.clusterNonce == "" || values.clusterNonce == nonce { - if values.clusterCancel != nil { - values.clusterCancel() - } - values.clusters = make(chan cache.Response, 1) - values.clusterCancel = s.cache.CreateWatch(req, streamState, values.clusters) - } - case req.TypeUrl == resource.RouteType: - if values.routeNonce == "" || values.routeNonce == nonce { - if values.routeCancel != nil { - values.routeCancel() - } - values.routes = make(chan cache.Response, 1) - values.routeCancel = s.cache.CreateWatch(req, streamState, values.routes) - } - case req.TypeUrl == resource.ScopedRouteType: - if values.scopedRouteNonce == "" || values.scopedRouteNonce == nonce { - if values.scopedRouteCancel != nil { - values.scopedRouteCancel() - } - values.scopedRoutes = make(chan cache.Response, 1) - values.scopedRouteCancel = s.cache.CreateWatch(req, streamState, values.scopedRoutes) - } - case req.TypeUrl == resource.ListenerType: - if values.listenerNonce == "" || values.listenerNonce == nonce { - if values.listenerCancel != nil { - values.listenerCancel() - } - values.listeners = make(chan cache.Response, 1) - values.listenerCancel = s.cache.CreateWatch(req, streamState, values.listeners) - } - case req.TypeUrl == resource.SecretType: - if values.secretNonce == "" || values.secretNonce == nonce { - if values.secretCancel != nil { - values.secretCancel() - } - values.secrets = make(chan cache.Response, 1) - values.secretCancel = s.cache.CreateWatch(req, streamState, values.secrets) - } - case req.TypeUrl == resource.RuntimeType: - if values.runtimeNonce == "" || values.runtimeNonce == nonce { - if values.runtimeCancel != nil { - values.runtimeCancel() - } - values.runtimes = make(chan cache.Response, 1) - values.runtimeCancel = s.cache.CreateWatch(req, streamState, values.runtimes) - } - case req.TypeUrl == resource.ExtensionConfigType: - if values.extensionConfigNonce == "" || values.extensionConfigNonce == nonce { - if values.extensionConfigCancel != nil { - values.extensionConfigCancel() - } - values.extensionConfigs = make(chan cache.Response, 1) - values.extensionConfigCancel = s.cache.CreateWatch(req, streamState, values.extensionConfigs) - } - default: - typeURL := req.TypeUrl - responseNonce, seen := values.nonces[typeURL] - if !seen || responseNonce == nonce { - if cancel, seen := values.cancellations[typeURL]; seen && cancel != nil { - cancel() - } - values.cancellations[typeURL] = s.cache.CreateWatch(req, streamState, values.responses) + typeURL := req.GetTypeUrl() + responder := make(chan cache.Response, 1) + if w, ok := watches.responders[typeURL]; ok { + // We've found a pre-existing watch, lets check and update if needed. + // If these requirements aren't satisfied, leave an open watch. + if w.nonce == "" || w.nonce == nonce { + w.close() + + watches.addWatch(typeURL, &watch{ + cancel: s.cache.CreateWatch(req, streamState, responder), + response: responder, + }) } + } else { + // No pre-existing watch exists, let's create one. + // We need to precompute the watches first then open a watch in the cache. + watches.addWatch(typeURL, &watch{ + cancel: s.cache.CreateWatch(req, streamState, responder), + response: responder, + }) + } + + // Recompute the dynamic select cases for this stream. + watches.recompute(s.ctx, reqCh) + default: + // Channel n -> these are the dynamic list of responders that correspond to the stream request typeURL + if !ok { + // Receiver channel was closed. TODO(jpeach): probably cancel the watch or something? + return status.Errorf(codes.Unavailable, "resource watch %d -> failed", index) } + + res := value.Interface().(cache.Response) + nonce, err := send(res) + if err != nil { + return err + } + + watches.responders[res.GetRequest().TypeUrl].nonce = nonce } } } // StreamHandler converts a blocking read call to channels and initiates stream processing -func (s *server) StreamHandler(stream Stream, typeURL string) error { +func (s *server) StreamHandler(stream stream.Stream, typeURL string) error { // a channel for receiving incoming requests reqCh := make(chan *discovery.DiscoveryRequest) go func() { diff --git a/pkg/server/sotw/v3/watches.go b/pkg/server/sotw/v3/watches.go new file mode 100644 index 0000000000..45670d6a91 --- /dev/null +++ b/pkg/server/sotw/v3/watches.go @@ -0,0 +1,75 @@ +package sotw + +import ( + "context" + "reflect" + + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/envoyproxy/go-control-plane/pkg/cache/types" + "github.com/envoyproxy/go-control-plane/pkg/cache/v3" +) + +// watches for all xDS resource types +type watches struct { + responders map[string]*watch + + // cases is a dynamic select case for the watched channels. + cases []reflect.SelectCase +} + +// newWatches creates and initializes watches. +func newWatches() watches { + return watches{ + responders: make(map[string]*watch, int(types.UnknownType)), + cases: make([]reflect.SelectCase, 0), + } +} + +// addWatch creates a new watch entry in the watches map. +// Watches are sorted by typeURL. +func (w *watches) addWatch(typeURL string, watch *watch) { + w.responders[typeURL] = watch +} + +// close all open watches +func (w *watches) close() { + for _, watch := range w.responders { + watch.close() + } +} + +// recomputeWatches rebuilds the known list of dynamic channels if needed +func (w *watches) recompute(ctx context.Context, req <-chan *discovery.DiscoveryRequest) { + w.cases = w.cases[:0] // Clear the existing cases while retaining capacity. + + w.cases = append(w.cases, + reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + }, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(req), + }, + ) + + for _, watch := range w.responders { + w.cases = append(w.cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(watch.response), + }) + } +} + +// watch contains the necessary modifiables for receiving resource responses +type watch struct { + cancel func() + nonce string + response chan cache.Response +} + +// close cancels an open watch +func (w *watch) close() { + if w.cancel != nil { + w.cancel() + } +} diff --git a/pkg/server/v3/server.go b/pkg/server/v3/server.go index 80cb44e7f4..67c4fbc5ab 100644 --- a/pkg/server/v3/server.go +++ b/pkg/server/v3/server.go @@ -178,7 +178,7 @@ type server struct { delta delta.Server } -func (s *server) StreamHandler(stream sotw.Stream, typeURL string) error { +func (s *server) StreamHandler(stream stream.Stream, typeURL string) error { return s.sotw.StreamHandler(stream, typeURL) } diff --git a/pkg/server/v3/server_test.go b/pkg/server/v3/server_test.go index 5186cdac56..af349e0c8a 100644 --- a/pkg/server/v3/server_test.go +++ b/pkg/server/v3/server_test.go @@ -96,26 +96,18 @@ func (stream *mockStream) Context() context.Context { func (stream *mockStream) Send(resp *discovery.DiscoveryResponse) error { // check that nonce is monotonically incrementing stream.nonce = stream.nonce + 1 - if resp.Nonce != fmt.Sprintf("%d", stream.nonce) { - stream.t.Errorf("Nonce => got %q, want %d", resp.Nonce, stream.nonce) - } + assert.Equal(stream.t, resp.Nonce, fmt.Sprintf("%d", stream.nonce)) // check that version is set - if resp.VersionInfo == "" { - stream.t.Error("VersionInfo => got none, want non-empty") - } + assert.NotEmpty(stream.t, resp.VersionInfo) // check resources are non-empty - if len(resp.Resources) == 0 { - stream.t.Error("Resources => got none, want non-empty") - } + assert.NotEmpty(stream.t, resp.Resources) // check that type URL matches in resources - if resp.TypeUrl == "" { - stream.t.Error("TypeUrl => got none, want non-empty") - } + assert.NotEmpty(stream.t, resp.TypeUrl) + for _, res := range resp.Resources { - if res.TypeUrl != resp.TypeUrl { - stream.t.Errorf("TypeUrl => got %q, want %q", res.TypeUrl, resp.TypeUrl) - } + assert.Equal(stream.t, res.TypeUrl, resp.TypeUrl) } + stream.sent <- resp if stream.sendError { return errors.New("send error") @@ -387,99 +379,93 @@ func TestFetch(t *testing.T) { } s := server.NewServer(context.Background(), config, cb) - if out, err := s.FetchEndpoints(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for endpoints: %v", err) - } - if out, err := s.FetchClusters(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for clusters: %v", err) - } - if out, err := s.FetchRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for routes: %v", err) - } - if out, err := s.FetchScopedRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for scopedRoutes: %v", err) - } - if out, err := s.FetchListeners(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for listeners: %v", err) - } - if out, err := s.FetchSecrets(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for secrets: %v", err) - } - if out, err := s.FetchRuntime(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for runtime: %v", err) - } - if out, err := s.FetchExtensionConfigs(context.Background(), &discovery.DiscoveryRequest{Node: node}); out == nil || err != nil { - t.Errorf("unexpected empty or error for extensionConfigs: %v", err) - } + out, err := s.FetchEndpoints(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) + + out, err = s.FetchClusters(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) + + out, err = s.FetchRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) + + out, err = s.FetchListeners(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) + + out, err = s.FetchSecrets(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) + + out, err = s.FetchRuntime(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.NotNil(t, out) + assert.NoError(t, err) // try again and expect empty results - if out, err := s.FetchEndpoints(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil { - t.Errorf("expected empty or error for endpoints: %v", err) - } - if out, err := s.FetchClusters(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil { - t.Errorf("expected empty or error for clusters: %v", err) - } - if out, err := s.FetchRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil { - t.Errorf("expected empty or error for routes: %v", err) - } - if out, err := s.FetchScopedRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil { - t.Errorf("expected empty or error for routes: %v", err) - } - if out, err := s.FetchListeners(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil { - t.Errorf("expected empty or error for listeners: %v", err) - } + out, err = s.FetchEndpoints(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchClusters(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchListeners(context.Background(), &discovery.DiscoveryRequest{Node: node}) + assert.Nil(t, out) + assert.Error(t, err) // try empty requests: not valid in a real gRPC server - if out, err := s.FetchEndpoints(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchClusters(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchRoutes(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchScopedRoutes(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchListeners(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchSecrets(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchRuntime(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } - if out, err := s.FetchExtensionConfigs(context.Background(), nil); out != nil { - t.Errorf("expected empty on empty request: %v", err) - } + out, err = s.FetchEndpoints(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchClusters(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchRoutes(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchListeners(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchSecrets(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchRuntime(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) // send error from callback callbackError = true - if out, err := s.FetchEndpoints(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil || err == nil { - t.Errorf("expected empty or error due to callback error") - } - if out, err := s.FetchClusters(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil || err == nil { - t.Errorf("expected empty or error due to callback error") - } - if out, err := s.FetchRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil || err == nil { - t.Errorf("expected empty or error due to callback error") - } - if out, err := s.FetchScopedRoutes(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil || err == nil { - t.Errorf("expected empty or error due to callback error") - } - if out, err := s.FetchListeners(context.Background(), &discovery.DiscoveryRequest{Node: node}); out != nil || err == nil { - t.Errorf("expected empty or error due to callback error") - } + out, err = s.FetchEndpoints(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchClusters(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchRoutes(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) + + out, err = s.FetchListeners(context.Background(), nil) + assert.Nil(t, out) + assert.Error(t, err) // verify fetch callbacks - if want := 13; requestCount != want { - t.Errorf("unexpected number of fetch requests: got %d, want %d", requestCount, want) - } - if want := 8; responseCount != want { - t.Errorf("unexpected number of fetch responses: got %d, want %d", responseCount, want) - } + assert.Equal(t, requestCount, 10) + assert.Equal(t, responseCount, 6) } func TestSendError(t *testing.T) { @@ -498,9 +484,8 @@ func TestSendError(t *testing.T) { } // check that response fails since send returns error - if err := s.StreamAggregatedResources(resp); err == nil { - t.Error("Stream() => got no error, want send error") - } + err := s.StreamAggregatedResources(resp) + assert.Error(t, err) close(resp.recv) }) @@ -521,13 +506,10 @@ func TestStaleNonce(t *testing.T) { } stop := make(chan struct{}) go func() { - if err := s.StreamAggregatedResources(resp); err != nil { - t.Errorf("StreamAggregatedResources() => got %v, want no error", err) - } + err := s.StreamAggregatedResources(resp) + assert.NoError(t, err) // should be two watches called - if want := map[string]int{typ: 2}; !reflect.DeepEqual(want, config.counts) { - t.Errorf("watch counts => got %v, want %v", config.counts, want) - } + assert.False(t, !reflect.DeepEqual(map[string]int{typ: 2}, config.counts)) close(stop) }() select { @@ -575,6 +557,10 @@ func TestAggregatedHandlers(t *testing.T) { TypeUrl: rsrc.RouteType, ResourceNames: []string{routeName}, } + resp.recv <- &discovery.DiscoveryRequest{ + TypeUrl: rsrc.ExtensionConfigType, + ResourceNames: []string{extensionConfigName}, + } resp.recv <- &discovery.DiscoveryRequest{ TypeUrl: rsrc.ScopedRouteType, ResourceNames: []string{scopedRouteName}, @@ -582,9 +568,8 @@ func TestAggregatedHandlers(t *testing.T) { s := server.NewServer(context.Background(), config, server.CallbackFuncs{}) go func() { - if err := s.StreamAggregatedResources(resp); err != nil { - t.Errorf("StreamAggregatedResources() => got %v, want no error", err) - } + err := s.StreamAggregatedResources(resp) + assert.NoError(t, err) }() count := 0 @@ -592,17 +577,16 @@ func TestAggregatedHandlers(t *testing.T) { select { case <-resp.sent: count++ - if count >= 5 { + if count >= 6 { close(resp.recv) - if want := map[string]int{ - rsrc.EndpointType: 1, - rsrc.ClusterType: 1, - rsrc.RouteType: 1, - rsrc.ScopedRouteType: 1, - rsrc.ListenerType: 1, - }; !reflect.DeepEqual(want, config.counts) { - t.Errorf("watch counts => got %v, want %v", config.counts, want) - } + assert.False(t, !reflect.DeepEqual(map[string]int{ + rsrc.EndpointType: 1, + rsrc.ClusterType: 1, + rsrc.RouteType: 1, + rsrc.ListenerType: 1, + rsrc.ExtensionConfigType: 1, + rsrc.ScopedRouteType: 1, + }, config.counts)) // got all messages return @@ -618,9 +602,8 @@ func TestAggregateRequestType(t *testing.T) { s := server.NewServer(context.Background(), config, server.CallbackFuncs{}) resp := makeMockStream(t) resp.recv <- &discovery.DiscoveryRequest{Node: node} - if err := s.StreamAggregatedResources(resp); err == nil { - t.Error("StreamAggregatedResources() => got nil, want an error") - } + err := s.StreamAggregatedResources(resp) + assert.Error(t, err) } func TestCancellations(t *testing.T) { @@ -634,12 +617,9 @@ func TestCancellations(t *testing.T) { } close(resp.recv) s := server.NewServer(context.Background(), config, server.CallbackFuncs{}) - if err := s.StreamAggregatedResources(resp); err != nil { - t.Errorf("StreamAggregatedResources() => got %v, want no error", err) - } - if config.watches != 0 { - t.Errorf("Expect all watches canceled, got %q", config.watches) - } + err := s.StreamAggregatedResources(resp) + assert.NoError(t, err) + assert.Equal(t, config.watches, 0) } func TestOpaqueRequestsChannelMuxing(t *testing.T) { @@ -655,12 +635,9 @@ func TestOpaqueRequestsChannelMuxing(t *testing.T) { } close(resp.recv) s := server.NewServer(context.Background(), config, server.CallbackFuncs{}) - if err := s.StreamAggregatedResources(resp); err != nil { - t.Errorf("StreamAggregatedResources() => got %v, want no error", err) - } - if config.watches != 0 { - t.Errorf("Expect all watches canceled, got %q", config.watches) - } + err := s.StreamAggregatedResources(resp) + assert.NoError(t, err) + assert.Equal(t, config.watches, 0) } func TestCallbackError(t *testing.T) { @@ -683,9 +660,8 @@ func TestCallbackError(t *testing.T) { } // check that response fails since stream open returns error - if err := s.StreamAggregatedResources(resp); err == nil { - t.Error("Stream() => got no error, want error") - } + err := s.StreamAggregatedResources(resp) + assert.Error(t, err) close(resp.recv) })