Skip to content

Commit

Permalink
Rework cache key handling in caching client / generator (#5641)
Browse files Browse the repository at this point in the history
* Rework cache key handling in caching client / generator

- Expose the default cache key helper so that customization doesn't require re-implementing the whole thing.
- Make it easy to incorporate additional state into the cache key.
- Avoid serializing all of the values for the key into a new byte[], at least on .NET 8+. There, we can serialize directly into a stream that targets an IncrementalHash.
- Include Chat/EmbeddingGenerationOptions in the cache key by default.

* Update test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs

Co-authored-by: Shyam N <[email protected]>

---------

Co-authored-by: Shyam N <[email protected]>
  • Loading branch information
stephentoub and shyamnamboodiripad authored Nov 14, 2024
1 parent 73962c6 commit 430065c
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 55 deletions.
133 changes: 101 additions & 32 deletions src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,129 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.IO;
using System.Security.Cryptography;
using System.Text.Json;
using Microsoft.Shared.Diagnostics;
#if NET
using System.Threading;
using System.Threading.Tasks;
#endif

#pragma warning disable S109 // Magic numbers should not be used
#pragma warning disable SA1202 // Elements should be ordered by access
#pragma warning disable SA1502 // Element should not be on a single line

namespace Microsoft.Extensions.AI;

/// <summary>Provides internal helpers for implementing caching services.</summary>
internal static class CachingHelpers
{
/// <summary>Computes a default cache key for the specified parameters.</summary>
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
/// <param name="value">The data with which to compute the key.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
/// <returns>A string that will be used as a cache key.</returns>
public static string GetCacheKey<TValue>(TValue value, JsonSerializerOptions serializerOptions)
=> GetCacheKey(value, false, serializerOptions);

/// <summary>Computes a default cache key for the specified parameters.</summary>
/// <typeparam name="TValue">Specifies the type of the data being used to compute the key.</typeparam>
/// <param name="value">The data with which to compute the key.</param>
/// <param name="flag">Another data item that causes the key to vary.</param>
/// <param name="values">The data with which to compute the key.</param>
/// <param name="serializerOptions">The <see cref="JsonSerializerOptions"/>.</param>
/// <returns>A string that will be used as a cache key.</returns>
public static string GetCacheKey<TValue>(TValue value, bool flag, JsonSerializerOptions serializerOptions)
public static string GetCacheKey(ReadOnlySpan<object?> values, JsonSerializerOptions serializerOptions)
{
_ = Throw.IfNull(value);
_ = Throw.IfNull(serializerOptions);
serializerOptions.MakeReadOnly();

var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue)));

if (flag && jsonKeyBytes.Length > 0)
{
// Make an arbitrary change to the hash input based on the flag
// The alternative would be including the flag in "value" in the
// first place, but that's likely to require an extra allocation
// or the inclusion of another type in the JsonSerializerContext.
// This is a micro-optimization we can change at any time.
jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]);
}
Debug.Assert(serializerOptions is not null, "Expected serializer options to be non-null");
Debug.Assert(serializerOptions!.IsReadOnly, "Expected serializer options to already be read-only.");

// The complete JSON representation is excessively long for a cache key, duplicating much of the content
// from the value. So we use a hash of it as the default key, and we rely on collision resistance for security purposes.
// If a collision occurs, we'd serve the cached LLM response for a potentially unrelated prompt, leading to information
// disclosure. Use of SHA256 is an implementation detail and can be easily swapped in the future if needed, albeit
// invalidating any existing cache entries that may exist in whatever IDistributedCache was in use.
#if NET8_0_OR_GREATER

#if NET
IncrementalHashStream? stream = IncrementalHashStream.ThreadStaticInstance ?? new();
IncrementalHashStream.ThreadStaticInstance = null;

foreach (object? value in values)
{
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
}

Span<byte> hashData = stackalloc byte[SHA256.HashSizeInBytes];
SHA256.HashData(jsonKeyBytes, hashData);
stream.GetHashAndReset(hashData);
IncrementalHashStream.ThreadStaticInstance = stream;

return Convert.ToHexString(hashData);
#else
MemoryStream stream = new();
foreach (object? value in values)
{
JsonSerializer.Serialize(stream, value, serializerOptions.GetTypeInfo(typeof(object)));
}

using var sha256 = SHA256.Create();
var hashData = sha256.ComputeHash(jsonKeyBytes);
return BitConverter.ToString(hashData).Replace("-", string.Empty);
stream.Position = 0;
var hashData = sha256.ComputeHash(stream.GetBuffer(), 0, (int)stream.Length);

var chars = new char[hashData.Length * 2];
int destPos = 0;
foreach (byte b in hashData)
{
int div = Math.DivRem(b, 16, out int rem);
chars[destPos++] = ToHexChar(div);
chars[destPos++] = ToHexChar(rem);

static char ToHexChar(int i) => (char)(i < 10 ? i + '0' : i - 10 + 'A');
}

Debug.Assert(destPos == chars.Length, "Expected to have filled the entire array.");

return new string(chars);
#endif
}

#if NET
/// <summary>Provides a stream that writes to an <see cref="IncrementalHash"/>.</summary>
private sealed class IncrementalHashStream : Stream
{
/// <summary>A per-thread instance of <see cref="IncrementalHashStream"/>.</summary>
/// <remarks>An instance stored must be in a reset state ready to be used by another consumer.</remarks>
[ThreadStatic]
public static IncrementalHashStream? ThreadStaticInstance;

/// <summary>Gets the current hash and resets.</summary>
public void GetHashAndReset(Span<byte> bytes) => _hash.GetHashAndReset(bytes);

/// <summary>The <see cref="IncrementalHash"/> used by this instance.</summary>
private readonly IncrementalHash _hash = IncrementalHash.CreateHash(HashAlgorithmName.SHA256);

protected override void Dispose(bool disposing)
{
_hash.Dispose();
base.Dispose(disposing);
}

public override void WriteByte(byte value) => Write(new ReadOnlySpan<byte>(in value));
public override void Write(byte[] buffer, int offset, int count) => _hash.AppendData(buffer, offset, count);
public override void Write(ReadOnlySpan<byte> buffer) => _hash.AppendData(buffer);

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Write(buffer, offset, count);
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
Write(buffer.Span);
return ValueTask.CompletedTask;
}

public override void Flush() { }
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;

public override bool CanWrite => true;
public override bool CanRead => false;
public override bool CanSeek => false;
public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
}
#endif
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
Expand All @@ -19,8 +20,17 @@ namespace Microsoft.Extensions.AI;
/// </remarks>
public class DistributedCachingChatClient : CachingChatClient
{
/// <summary>A boxed <see langword="true"/> value.</summary>
private static readonly object _boxedTrue = true;

/// <summary>A boxed <see langword="false"/> value.</summary>
private static readonly object _boxedFalse = false;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;
private JsonSerializerOptions _jsonSerializerOptions;

/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;

/// <summary>Initializes a new instance of the <see cref="DistributedCachingChatClient"/> class.</summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>.</param>
Expand All @@ -29,7 +39,6 @@ public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache s
: base(innerClient)
{
_storage = Throw.IfNull(storage);
_jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
}

/// <summary>Gets or sets JSON serialization options to use when serializing cache data.</summary>
Expand Down Expand Up @@ -90,13 +99,16 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
}

/// <inheritdoc />
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options)
protected override string GetCacheKey(bool streaming, IList<ChatMessage> chatMessages, ChatOptions? options) =>
GetCacheKey([streaming ? _boxedTrue : _boxedFalse, chatMessages, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(bool, IList{ChatMessage}, ChatOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
{
// While it might be desirable to include ChatOptions in the cache key, it's not always possible,
// since ChatOptions can contain types that are not guaranteed to be serializable or have a stable
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
// the chat contents. Developers may subclass and override this to provide custom rules.
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions);
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
Expand Down Expand Up @@ -74,12 +75,16 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
}

/// <inheritdoc />
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options)
protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) =>
GetCacheKey([value, options]);

/// <summary>Gets a cache key based on the supplied values.</summary>
/// <param name="values">The values to inform the key.</param>
/// <returns>The computed key.</returns>
/// <remarks>This provides the default implementation for <see cref="GetCacheKey(TInput, EmbeddingGenerationOptions?)"/>.</remarks>
protected string GetCacheKey(ReadOnlySpan<object?> values)
{
// While it might be desirable to include options in the cache key, it's not always possible,
// since options can contain types that are not guaranteed to be serializable or have a stable
// hashcode across multiple calls. So the default cache key is simply the JSON representation of
// the value. Developers may subclass and override this to provide custom rules.
return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions);
_jsonSerializerOptions.MakeReadOnly();
return CachingHelpers.GetCacheKey(values, _jsonSerializerOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync()
}

[Fact]
public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
public async Task CacheKeyVariesByChatOptionsAsync()
{
// Arrange
var innerCallCount = 0;
Expand All @@ -546,20 +546,35 @@ public async Task CacheKeyDoesNotVaryByChatOptionsAsync()
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

// Act: Call with two different ChatOptions
// Act: Call with two different ChatOptions that have the same values
var result1 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 1" } }
});
var result2 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 2" } }
AdditionalProperties = new() { { "someKey", "value 1" } }
});

// Assert: Same result
Assert.Equal(1, innerCallCount);
Assert.Equal("value 1", result1.Message.Text);
Assert.Equal("value 1", result2.Message.Text);

// Act: Call with two different ChatOptions that have different values
var result3 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 1" } }
});
var result4 = await outer.CompleteAsync([], new ChatOptions
{
AdditionalProperties = new() { { "someKey", "value 2" } }
});

// Assert: Different results
Assert.Equal(2, innerCallCount);
Assert.Equal("value 1", result3.Message.Text);
Assert.Equal("value 2", result4.Message.Text);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public async Task DoesNotCacheCanceledResultsAsync()
}

[Fact]
public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
public async Task CacheKeyVariesByEmbeddingOptionsAsync()
{
// Arrange
var innerCallCount = 0;
Expand All @@ -232,28 +232,43 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
{
innerCallCount++;
await Task.Yield();
return [_expectedEmbedding];
return [new(((string)options!.AdditionalProperties!["someKey"]!).Select(c => (float)c).ToArray())];
}
};
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

// Act: Call with two different options
// Act: Call with two different EmbeddingGenerationOptions that have the same values
var result1 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result2 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
AdditionalProperties = new() { ["someKey"] = "value 1" }
});

// Assert: Same result
Assert.Equal(1, innerCallCount);
AssertEmbeddingsEqual(_expectedEmbedding, result1);
AssertEmbeddingsEqual(_expectedEmbedding, result2);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result1);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result2);

// Act: Call with two different EmbeddingGenerationOptions that have different values
var result3 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 1" }
});
var result4 = await outer.GenerateEmbeddingAsync("abc", new EmbeddingGenerationOptions
{
AdditionalProperties = new() { ["someKey"] = "value 2" }
});

// Assert: Different result
Assert.Equal(2, innerCallCount);
AssertEmbeddingsEqual(new("value 1".Select(c => (float)c).ToArray()), result3);
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(Dictionary<string, string>))]
[JsonSerializable(typeof(DayOfWeek[]))]
[JsonSerializable(typeof(Guid))]
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;

0 comments on commit 430065c

Please sign in to comment.