diff --git a/pkg/cache/v3/linear.go b/pkg/cache/v3/linear.go index f7786ac4f9..b704d6bfc3 100644 --- a/pkg/cache/v3/linear.go +++ b/pkg/cache/v3/linear.go @@ -27,7 +27,7 @@ import ( "github.com/envoyproxy/go-control-plane/pkg/server/stream/v3" ) -type watches = map[chan Response]struct{} +type watches = map[ResponseWatch]struct{} // LinearCache supports collections of opaque resources. This cache has a // single collection indexed by resource names and manages resource versions @@ -113,7 +113,7 @@ func NewLinearCache(typeURL string, opts ...LinearCacheOption) *LinearCache { return out } -func (cache *LinearCache) respond(value chan Response, staleResources []string) { +func (cache *LinearCache) respond(watch ResponseWatch, staleResources []string) { var resources []types.ResourceWithTTL // TODO: optimize the resources slice creations across different clients if len(staleResources) == 0 { @@ -130,8 +130,8 @@ func (cache *LinearCache) respond(value chan Response, staleResources []string) } } } - value <- &RawResponse{ - Request: &Request{TypeUrl: cache.typeURL}, + watch.Response <- &RawResponse{ + Request: watch.Request, Resources: resources, Version: cache.getVersion(), Ctx: context.Background(), @@ -140,18 +140,18 @@ func (cache *LinearCache) respond(value chan Response, staleResources []string) func (cache *LinearCache) notifyAll(modified map[string]struct{}) { // de-duplicate watches that need to be responded - notifyList := make(map[chan Response][]string) + notifyList := make(map[ResponseWatch][]string) for name := range modified { for watch := range cache.watches[name] { notifyList[watch] = append(notifyList[watch], name) } - delete(cache.watches, name) } - for value, stale := range notifyList { - cache.respond(value, stale) + for watch, stale := range notifyList { + cache.removeWatch(watch) + cache.respond(watch, stale) } - for value := range cache.watchAll { - cache.respond(value, nil) + for watch := range cache.watchAll { + cache.respond(watch, nil) } cache.watchAll = make(watches) @@ -318,6 +318,8 @@ func (cache *LinearCache) CreateWatch(request *Request, _ stream.StreamState, va err = errors.New("mis-matched version prefix") } + watch := ResponseWatch{Request: request, Response: value} + cache.mu.Lock() defer cache.mu.Unlock() @@ -337,16 +339,16 @@ func (cache *LinearCache) CreateWatch(request *Request, _ stream.StreamState, va } } if stale { - cache.respond(value, staleResources) + cache.respond(watch, staleResources) return nil } // Create open watches since versions are up to date. if len(request.GetResourceNames()) == 0 { - cache.watchAll[value] = struct{}{} + cache.watchAll[watch] = struct{}{} return func() { cache.mu.Lock() defer cache.mu.Unlock() - delete(cache.watchAll, value) + delete(cache.watchAll, watch) } } for _, name := range request.GetResourceNames() { @@ -355,19 +357,24 @@ func (cache *LinearCache) CreateWatch(request *Request, _ stream.StreamState, va set = make(watches) cache.watches[name] = set } - set[value] = struct{}{} + set[watch] = struct{}{} } return func() { cache.mu.Lock() defer cache.mu.Unlock() - for _, name := range request.GetResourceNames() { - set, exists := cache.watches[name] - if exists { - delete(set, value) - } - if len(set) == 0 { - delete(cache.watches, name) - } + cache.removeWatch(watch) + } +} + +// Must be called under lock +func (cache *LinearCache) removeWatch(watch ResponseWatch) { + // Make sure we clean the watch for ALL resources it might be associated with, + // as the channel will no longer be listened to + for _, resource := range watch.Request.ResourceNames { + resourceWatches := cache.watches[resource] + delete(resourceWatches, watch) + if len(resourceWatches) == 0 { + delete(cache.watches, resource) } } } diff --git a/pkg/cache/v3/linear_test.go b/pkg/cache/v3/linear_test.go index 82478ca3aa..be6a413039 100644 --- a/pkg/cache/v3/linear_test.go +++ b/pkg/cache/v3/linear_test.go @@ -40,7 +40,14 @@ func testResource(s string) types.Resource { func verifyResponse(t *testing.T, ch <-chan Response, version string, num int) { t.Helper() - r := <-ch + var r Response + select { + case r = <-ch: + case <-time.After(1 * time.Second): + t.Error("failed to receive response after 1 second") + return + } + if r.GetRequest().GetTypeUrl() != testType { t.Errorf("unexpected empty request type URL: %q", r.GetRequest().GetTypeUrl()) } @@ -63,6 +70,9 @@ func verifyResponse(t *testing.T, ch <-chan Response, version string, num int) { if out.GetTypeUrl() != testType { t.Errorf("unexpected type URL: %q", out.GetTypeUrl()) } + if len(r.GetRequest().GetResourceNames()) != 0 && len(r.GetRequest().GetResourceNames()) < len(out.Resources) { + t.Errorf("received more resources (%d) than requested (%d)", len(r.GetRequest().GetResourceNames()), len(out.Resources)) + } } type resourceInfo struct { @@ -773,3 +783,81 @@ func TestLinearMixedWatches(t *testing.T) { verifyResponse(t, w, c.getVersion(), 0) verifyDeltaResponse(t, wd, nil, []string{"b"}) } + +func TestLinearSotwWatches(t *testing.T) { + t.Run("watches are properly removed from all objects", func(t *testing.T) { + cache := NewLinearCache(testType) + a := &endpoint.ClusterLoadAssignment{ClusterName: "a"} + err := cache.UpdateResource("a", a) + require.NoError(t, err) + b := &endpoint.ClusterLoadAssignment{ClusterName: "b"} + err = cache.UpdateResource("b", b) + require.NoError(t, err) + assert.Equal(t, 2, cache.NumResources()) + + // A watch tracks three different objects. + // An update is done for the three objects in a row + // If the watches are no properly purged, all three updates will send responses in the channel, but only the first one is tracked + // The buffer will therefore saturate and the third request will deadlock the entire cache as occurring under the mutex + sotwState := stream.NewStreamState(false, nil) + w := make(chan Response, 1) + _ = cache.CreateWatch(&Request{ResourceNames: []string{"a", "b", "c"}, TypeUrl: testType, VersionInfo: cache.getVersion()}, sotwState, w) + mustBlock(t, w) + checkVersionMapNotSet(t, cache) + + assert.Len(t, cache.watches["a"], 1) + assert.Len(t, cache.watches["b"], 1) + assert.Len(t, cache.watches["c"], 1) + + // Update a and c without touching b + a = &endpoint.ClusterLoadAssignment{ClusterName: "a", Endpoints: []*endpoint.LocalityLbEndpoints{ // resource update + {Priority: 25}, + }} + err = cache.UpdateResources(map[string]types.Resource{"a": a}, nil) + require.NoError(t, err) + verifyResponse(t, w, cache.getVersion(), 1) + checkVersionMapNotSet(t, cache) + + assert.Empty(t, cache.watches["a"]) + assert.Empty(t, cache.watches["b"]) + assert.Empty(t, cache.watches["c"]) + + // c no longer watched + w = make(chan Response, 1) + _ = cache.CreateWatch(&Request{ResourceNames: []string{"a", "b"}, TypeUrl: testType, VersionInfo: cache.getVersion()}, sotwState, w) + require.NoError(t, err) + mustBlock(t, w) + checkVersionMapNotSet(t, cache) + + b = &endpoint.ClusterLoadAssignment{ClusterName: "b", Endpoints: []*endpoint.LocalityLbEndpoints{ // resource update + {Priority: 15}, + }} + err = cache.UpdateResources(map[string]types.Resource{"b": b}, nil) + + assert.Empty(t, cache.watches["a"]) + assert.Empty(t, cache.watches["b"]) + assert.Empty(t, cache.watches["c"]) + + require.NoError(t, err) + verifyResponse(t, w, cache.getVersion(), 1) + checkVersionMapNotSet(t, cache) + + w = make(chan Response, 1) + _ = cache.CreateWatch(&Request{ResourceNames: []string{"c"}, TypeUrl: testType, VersionInfo: cache.getVersion()}, sotwState, w) + require.NoError(t, err) + mustBlock(t, w) + checkVersionMapNotSet(t, cache) + + c := &endpoint.ClusterLoadAssignment{ClusterName: "c", Endpoints: []*endpoint.LocalityLbEndpoints{ // resource update + {Priority: 15}, + }} + err = cache.UpdateResources(map[string]types.Resource{"c": c}, nil) + require.NoError(t, err) + verifyResponse(t, w, cache.getVersion(), 1) + checkVersionMapNotSet(t, cache) + + assert.Empty(t, cache.watches["a"]) + assert.Empty(t, cache.watches["b"]) + assert.Empty(t, cache.watches["c"]) + }) +}