From 8778d5f3759986b74d0a3cc2070db3c8efeb5116 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 18 Oct 2024 09:37:47 -0400 Subject: [PATCH 01/62] Add PostgresVectorStore Memory connector. Work in progress, some methods are not implemented yet. --- .../IPostgresDbClient.cs | 5 + ...PostgresVectorStoreCollectionSqlBuilder.cs | 68 ++++ .../IPostgresVectorStoreDbClient.cs | 117 ++++++ ...tgresVectorStoreRecordCollectionFactory.cs | 23 ++ .../PostgresConstants.cs | 80 ++++ .../PostgresDbClient.cs | 5 + .../PostgresGenericDataModelMapper.cs | 141 +++++++ .../PostgresSqlCommandInfo.cs | 51 +++ .../PostgresVectorStore.cs | 91 +++++ ...tgresVectorStoreCollectionCreateMapping.cs | 119 ++++++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 257 +++++++++++++ .../PostgresVectorStoreDbClient.cs | 149 ++++++++ .../PostgresVectorStoreOptions.cs | 24 ++ .../PostgresVectorStoreRecordCollection.cs | 190 +++++++++ ...tgresVectorStoreRecordCollectionOptions.cs | 35 ++ .../PostgresVectorStoreRecordMapper.cs | 102 +++++ ...ostgresVectorStoreRecordPropertyMapping.cs | 78 ++++ .../Memory/Postgres/PostgresHotel.cs | 47 +++ ...resVectorStoreCollectionSqlBuilderTests.cs | 148 +++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 109 ++++++ .../Postgres/PostgresVectorStoreTests.cs | 107 ++++++ .../Memory/Postgres/PostgresHotel.cs | 51 +++ .../Postgres/PostgresMemoryStoreTests.cs | 6 +- .../PostgresVectorStoreCollectionFixture.cs | 10 + .../Postgres/PostgresVectorStoreFixture.cs | 361 ++++++++++++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 95 +++++ .../Postgres/PostgresVectorStoreTests.cs | 29 ++ 27 files changed, 2495 insertions(+), 3 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs index 70747990e2fd..2056260eb292 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresDbClient.cs @@ -11,6 +11,11 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Interface for client managing postgres database operations. /// +/// +/// This interface is used with the PostgresMemoryStore, which is being deprecated. +/// Use the interface with the PostgresVectorStore +/// and related classes instead. +/// public interface IPostgresDbClient { /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..ed0a763b7b25 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing SQL commands for Postgres vector store collections. +/// +public interface IPostgresVectorStoreCollectionSqlBuilder +{ + /// + /// Builds a SQL command to check if a table exists in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command. + /// + /// The command must return a single row with a single column named "table_name" if the table exists. + /// + PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName); + + /// + /// Builds a SQL command to fetch all tables in the Postgres vector store. + /// + /// The schema of the tables. + PostgresSqlCommandInfo BuildGetTablesCommand(string schema); + + /// + /// Builds a SQL command to create a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The record definition of the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true); + + /// + /// Builds a SQL command to drop a table in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName); + + /// + /// Builds a SQL command to upsert a record in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The row to upsert. + /// The key column of the table. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, Dictionary row, string keyColumn); + + /// + /// Builds a SQL command to get a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The record definition of the table. + /// The key of the record to get. + /// Specifies whether to include vectors in the record. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, TKey key, bool includeVectors = false) where TKey : notnull; +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..4a1f7ff4e13b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Internal interface for client managing postgres database operations. +/// +public interface IPostgresVectorStoreDbClient +{ + /// + /// Check if a table exists. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + /// + Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Get all tables. + /// + /// The to monitor for cancellation requests. The default is . + /// A group of tables. + IAsyncEnumerable GetTablesAsync(CancellationToken cancellationToken = default); + /// + /// Create a table. + /// + /// The name assigned to a table of entries. + /// The record definition of the table. + /// Specifies whether to include IF NOT EXISTS in the command. + /// The to monitor for cancellation requests. The default is . + /// + Task CreateTableAsync(string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true, CancellationToken cancellationToken = default); + + /// + /// Drop a table. + /// + /// The name assigned to a table of entries. + /// The to monitor for cancellation requests. The default is . + Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default); + + /// + /// Upsert entry into a table. + /// + /// The name assigned to a table of entries. + /// The row to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default); + + /// + /// Get a entry by its key. + /// + /// The name assigned to a table of entries. + /// The key of the entry to get. + /// The record definition of the table. + /// If true, the vectors will be included in the entry. + /// The to monitor for cancellation requests. The default is . + /// The row if the key is found, otherwise null. + Task?> GetAsync(string tableName, TKey key, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + // /// + // /// Gets the nearest matches to the . + // /// + // /// The name assigned to a table of entries. + // /// The to compare the table's embeddings with. + // /// The maximum number of similarity results to return. + // /// The minimum relevance threshold for returned results. + // /// If true, the embeddings will be returned in the entries. + // /// The to monitor for cancellation requests. The default is . + // /// An asynchronous stream of objects that the nearest matches to the . + // IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync(string tableName, Vector embedding, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + // /// + // /// Read a entry by its key. + // /// + // /// The name assigned to a table of entries. + // /// The key of the entry to read. + // /// If true, the embeddings will be returned in the entry. + // /// The to monitor for cancellation requests. The default is . + // /// + // Task ReadAsync(string tableName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + // /// + // /// Read multiple entries by their keys. + // /// + // /// The name assigned to a table of entries. + // /// The keys of the entries to read. + // /// If true, the embeddings will be returned in the entries. + // /// The to monitor for cancellation requests. The default is . + // /// An asynchronous stream of objects that match the given keys. + // IAsyncEnumerable ReadBatchAsync(string tableName, IEnumerable keys, bool withEmbeddings = false, CancellationToken cancellationToken = default); + + // /// + // /// Delete a entry by its key. + // /// + // /// The name assigned to a table of entries. + // /// The key of the entry to delete. + // /// The to monitor for cancellation requests. The default is . + // /// + // Task DeleteAsync(string tableName, string key, CancellationToken cancellationToken = default); + + // /// + // /// Delete multiple entries by their key. + // /// + // /// The name assigned to a table of entries. + // /// The keys of the entries to delete. + // /// The to monitor for cancellation requests. The default is . + // /// + // Task DeleteBatchAsync(string tableName, IEnumerable keys, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..98b1a344c194 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Interface for constructing Postgres instances when using to retrieve these. +/// +public interface IPostgresVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// The Postgres client. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs new file mode 100644 index 000000000000..b5f1939291ac --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresConstants +{ + /// A of types that a key on the provided model may have. + public static readonly HashSet SupportedKeyTypes = + [ + typeof(string), + typeof(int), + typeof(long), + typeof(ulong), + typeof(short), + typeof(ushort), + ]; + + /// A of types that data properties on the provided model may have. + public static readonly HashSet SupportedDataTypes = + [ + typeof(bool), + typeof(bool?), + typeof(short), + typeof(short?), + typeof(ushort), + typeof(ushort?), + typeof(int), + typeof(int?), + typeof(uint), + typeof(uint?), + typeof(long), + typeof(long?), + typeof(ulong), + typeof(ulong?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + typeof(string), + typeof(DateTimeOffset), + typeof(DateTimeOffset?), + typeof(byte[]), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(List), + typeof(bool[]), + typeof(short[]), + typeof(ushort[]), + typeof(int[]), + typeof(uint[]), + typeof(long[]), + typeof(ulong[]), + typeof(float[]), + typeof(double[]), + typeof(decimal[]), + typeof(string[]), + typeof(DateTimeOffset[]), + ]; + + /// A of types that vector properties on the provided model may have. + public static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs index 1dc1ffef3c1d..88741c236531 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresDbClient.cs @@ -15,6 +15,11 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// An implementation of a client for Postgres. This class is used to managing postgres database operations. /// +/// +/// This class is used with the PostgresMemoryStore, which is being deprecated. +/// Use the class with the PostgresVectorStore +/// and related classes instead. +/// [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] public class PostgresDbClient : IPostgresDbClient { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs new file mode 100644 index 000000000000..c13b9a85783b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary>, + IVectorStoreRecordMapper, Dictionary> +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// + /// A that defines the schema of the data in the database. + public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: false); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + return this.InternalMapFromDataToStorageModel(dataModel); + } + + VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + return this.InternalMapFromStorageToDataModel(storageModel, options); + } + + public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + return this.InternalMapFromDataToStorageModel(dataModel); + } + + VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + return this.InternalMapFromStorageToDataModel(storageModel, options); + } + + private Dictionary InternalMapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + where TKey : notnull + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, dataModel.Key } + }; + + // Add data properties + if (dataModel.Data is not null) + { + foreach (var property in this._propertyReader.DataProperties) + { + if (dataModel.Data.TryGetValue(property.DataModelPropertyName, out var dataValue)) + { + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), dataValue); + } + } + } + + // Add vector properties + if (dataModel.Vectors is not null) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) + { + object? result = null; + + if (vectorValue is not null) + { + var vector = (ReadOnlyMemory)vectorValue; + result = new Vector(PostgresVectorStoreRecordPropertyMapping.GetOrCreateArray(vector)); + } + + properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); + } + } + } + + return properties; + } + + private VectorStoreGenericDataModel InternalMapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + where TKey : notnull + { + TKey key; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + // Process key property. + if (storageModel.TryGetValue(this._propertyReader.KeyPropertyStoragePropertyName, out var keyObject) && keyObject is not null) + { + key = (TKey)keyObject; + } + else + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + // Process data properties. + foreach (var property in this._propertyReader.DataProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var dataValue)) + { + dataProperties.Add(property.DataModelPropertyName, dataValue); + } + } + + // Process vector properties + if (options.IncludeVectors) + { + foreach (var property in this._propertyReader.VectorProperties) + { + if (storageModel.TryGetValue(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), out var vectorValue)) + { + if (vectorValue is null) + { + vectorProperties.Add(property.DataModelPropertyName, ReadOnlyMemory.Empty); + } + else if (vectorValue is Vector pgVector) + { + vectorProperties.Add(property.DataModelPropertyName, pgVector.ToArray()); + } + } + } + } + + return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs new file mode 100644 index 000000000000..99dadf105fe0 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a SQL command for Postgres. +/// +public class PostgresSqlCommandInfo +{ + /// + /// Gets or sets the SQL command text. + /// + public string CommandText { get; set; } + /// + /// Gets or sets the parameters for the SQL command. + /// + public List? Parameters { get; set; } = null; + + /// + /// Initializes a new instance of the class. + /// + /// The SQL command text. + /// The parameters for the SQL command. + public PostgresSqlCommandInfo(string commandText, List? parameters = null) + { + this.CommandText = commandText; + this.Parameters = parameters; + } + + /// + /// Converts this instance to an . + /// + [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] + public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection) + { + NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = this.CommandText; + if (this.Parameters != null) + { + foreach (var parameter in this.Parameters) + { + cmd.Parameters.Add(parameter); + } + } + return cmd; + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs new file mode 100644 index 000000000000..a17c8e982811 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.VectorData; +using System.Threading.Tasks; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a vector store implementation using PostgreSQL. +/// +public class PostgresVectorStore : IVectorStore +{ + private readonly IPostgresVectorStoreDbClient _postgresClient; + private readonly NpgsqlDataSource? _dataSource; + private readonly PostgresVectorStoreOptions? _options; + + /// + /// Initializes a new instance of the class. + /// + /// Postgres database connection string. + /// Optional configuration options for this class + public PostgresVectorStore(string connectionString, PostgresVectorStoreOptions? options = default) + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + this._dataSource = dataSourceBuilder.Build(); + this._options = options ?? new PostgresVectorStoreOptions(); + this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); + } + + /// + /// Initializes a new instance of the class. + /// + /// Postgres data source. + /// Optional configuration options for this class + public PostgresVectorStore(NpgsqlDataSource dataSource, PostgresVectorStoreOptions? options = default) + { + this._dataSource = dataSource; + this._options = options ?? new PostgresVectorStoreOptions(); + this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); + } + + /// + /// Initializes a new instance of the class. + /// + /// An instance of . + /// Optional configuration options for this class + public PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, PostgresVectorStoreOptions? options = default) + { + this._postgresClient = postgresDbClient; + this._options = options ?? new PostgresVectorStoreOptions(); + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (string collection in this._postgresClient.GetTablesAsync(cancellationToken).ConfigureAwait(false)) + { + yield return collection; + } + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + { + // Support int, long, Guid, and string keys + if (typeof(TKey) != typeof(int) && typeof(TKey) != typeof(long) && typeof(TKey) != typeof(Guid) && typeof(TKey) != typeof(string)) + { + throw new NotSupportedException($"Only int, long, {nameof(Guid)}, and {nameof(String)} keys are supported."); + } + + if (this._options?.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient, name, vectorStoreRecordDefinition); + } + + var recordCollection = new PostgresVectorStoreRecordCollection( + this._postgresClient, + name, + new PostgresVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition } + ); + + return recordCollection as IVectorStoreRecordCollection ?? throw new InvalidOperationException("Failed to cast record collection."); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..3bfd7910e956 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Generates the PostgreSQL vector type name based on the dimensions of the vector property. +/// /// Provides methods to generate SQL statements for creating tables in /// a PostgreSQL database +/// for storing vector data. +/// +public static class PostgresVectorStoreCollectionCreateMapping +{ + /// + /// Generates a SQL CREATE TABLE statement. + /// + /// The schema name. + /// The table name. + /// The key property. + /// The list of data properties. + /// The list of vector properties. + /// The generated SQL CREATE TABLE statement. + /// Thrown when the table name is null or whitespace. + public static string GenerateCreateTableStatement(string schema, string tableName, VectorStoreRecordKeyProperty KeyProperty, IEnumerable DataProperties, IEnumerable VectorProperties) + { + if (string.IsNullOrWhiteSpace(tableName)) + { + throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); + } + + var keyName = KeyProperty.StoragePropertyName ?? KeyProperty.DataModelPropertyName; + + StringBuilder createTableCommand = new(); + createTableCommand.AppendLine($"CREATE TABLE {schema}.{tableName} ("); + + // Add the key column + createTableCommand.AppendLine($" {keyName} {GetPostgresTypeName(KeyProperty.PropertyType)},"); + + // Add the data columns + foreach (var dataProperty in DataProperties) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + createTableCommand.AppendLine($" {columnName} {GetPostgresTypeName(dataProperty.PropertyType)},"); + } + + // Add the vector columns + foreach (var vectorProperty in VectorProperties) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + createTableCommand.AppendLine($" {columnName} {GetPgVectorTypeName(vectorProperty)},"); + } + + createTableCommand.AppendLine($" PRIMARY KEY ({keyName})"); + + createTableCommand.AppendLine(");"); + + return createTableCommand.ToString(); + } + + /// + /// Maps a .NET type to a PostgreSQL type name. + /// + /// The .NET type. + /// The PostgreSQL type name. + private static string GetPostgresTypeName(Type propertyType) + { + var pgType = propertyType switch + { + Type t when t == typeof(int) => "INTEGER", + Type t when t == typeof(string) => "TEXT", + Type t when t == typeof(bool) => "BOOLEAN", + Type t when t == typeof(DateTime) => "TIMESTAMP", + Type t when t == typeof(double) => "DOUBLE PRECISION", + Type t when t == typeof(decimal) => "NUMERIC", + Type t when t == typeof(float) => "REAL", + Type t when t == typeof(byte[]) => "BYTEA", + Type t when t == typeof(Guid) => "UUID", + Type t when t == typeof(short) => "SMALLINT", + Type t when t == typeof(long) => "BIGINT", + _ => null + }; + + if (pgType != null) { return pgType; } + + // Handle arrays (PostgreSQL supports array types for most types) + if (propertyType.IsArray) + { + Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); + return GetPostgresTypeName(elementType) + "[]"; + } + + // Handle nullable types (e.g. Nullable) + if (Nullable.GetUnderlyingType(propertyType) != null) + { + Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); + return GetPostgresTypeName(underlyingType); + } + + throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); + } + + /// + /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. + /// + /// The vector property. + /// The PostgreSQL vector type name. + private static string GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions <= 0) + { + throw new ArgumentException("Vector property must have a positive number of dimensions."); + } + + return $"VECTOR({vectorProperty.Dimensions})"; + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs new file mode 100644 index 000000000000..5fe62a0f57e1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Provides methods to build SQL commands for managing vector store collections in PostgreSQL. +/// +public class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder +{ + /// + public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: $@" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = '{tableName}'", + parameters: [new NpgsqlParameter() { Value = schema }] + ); + } + + /// + public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) + { + return new PostgresSqlCommandInfo( + commandText: @" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = $1 + AND table_type = 'BASE TABLE'", + parameters: [new NpgsqlParameter() { Value = schema }] + ); + } + + /// + public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true) + { + if (string.IsNullOrWhiteSpace(tableName)) + { + throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); + } + + VectorStoreRecordKeyProperty? keyProperty = default; + List dataProperties = new(); + List vectorProperties = new(); + + foreach (var property in recordDefinition.Properties) + { + if (property is VectorStoreRecordKeyProperty keyProp) + { + if (keyProperty != null) + { + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyProperty = keyProp; + } + else if (property is VectorStoreRecordDataProperty dataProp) + { + dataProperties.Add(dataProp); + } + else if (property is VectorStoreRecordVectorProperty vectorProp) + { + vectorProperties.Add(vectorProp); + } + else + { + throw new NotSupportedException($"Property type {property.GetType().Name} is not supported by this store."); + } + } + + if (keyProperty == null) + { + throw new ArgumentException("Record definition must have a key property."); + } + + var keyName = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + StringBuilder createTableCommand = new(); + createTableCommand.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : "")}{schema}.\"{tableName}\" ("); + + // Add the key column + var keyPgTypeInfo = GetPostgresTypeName(keyProperty.PropertyType); + createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + + // Add the data columns + foreach (var dataProperty in dataProperties) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + var dataPgTypeInfo = GetPostgresTypeName(dataProperty.PropertyType); + createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + // Add the vector columns + foreach (var vectorProperty in vectorProperties) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var vectorPgTypeInfo = GetPgVectorTypeName(vectorProperty); + createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); + } + + createTableCommand.AppendLine($" PRIMARY KEY (\"{keyName}\")"); + + createTableCommand.AppendLine(");"); + + return new PostgresSqlCommandInfo(commandText: createTableCommand.ToString()); + } + + /// + public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) + { + return new PostgresSqlCommandInfo( + commandText: $@"DROP TABLE IF EXISTS {schema}.""{tableName}""" + ); + } + + /// + public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, Dictionary row, string keyColumn) + { + var columns = row.Keys.ToList(); + var columnNames = string.Join(", ", columns.Select(k => $"\"{k}\"")); + var valuesParams = string.Join(", ", columns.Select((k, i) => $"${i + 1}")); + var columnsWithIndex = columns.Select((k, i) => (col: k, idx: i)); + var updateColumnsWithParams = string.Join(", ", columnsWithIndex.Where(c => c.col != keyColumn).Select(c => $"\"{c.col}\"=${c.idx + 1}")); + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES({valuesParams}) + ON CONFLICT (""{keyColumn}"") + DO UPDATE SET {updateColumnsWithParams};"; + + var parameters = row.ToDictionary(kvp => $"@{kvp.Key}", kvp => kvp.Value); + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = columns.Select(c => new NpgsqlParameter() { Value = row[c] ?? DBNull.Value }).ToList() + }; + } + + /// + /// Maps a .NET type to a PostgreSQL type name. + /// + /// The .NET type. + /// Tuple of the the PostgreSQL type name and whether it can be NULL + private static (string PgType, bool IsNullable) GetPostgresTypeName(Type propertyType) + { + var (pgType, isNullable) = propertyType switch + { + Type t when t == typeof(int) => ("INTEGER", false), + Type t when t == typeof(string) => ("TEXT", true), + Type t when t == typeof(bool) => ("BOOLEAN", false), + Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(double) => ("DOUBLE PRECISION", false), + Type t when t == typeof(decimal) => ("NUMERIC", false), + Type t when t == typeof(float) => ("REAL", false), + Type t when t == typeof(byte[]) => ("BYTEA", true), + Type t when t == typeof(Guid) => ("UUID", false), + Type t when t == typeof(short) => ("SMALLINT", false), + Type t when t == typeof(long) => ("BIGINT", false), + _ => (null, false) + }; + + if (pgType != null) + { + return (pgType, isNullable); + } + + // Handle lists and arrays (PostgreSQL supports array types for most types) + if (propertyType.IsArray) + { + Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + else if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + Type elementType = propertyType.GetGenericArguments()[0]; + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + + // Handle nullable types (e.g. Nullable) + if (Nullable.GetUnderlyingType(propertyType) != null) + { + Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); + var underlyingPgType = GetPostgresTypeName(underlyingType); + return (underlyingPgType.PgType, true); + } + + throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); + } + + /// + /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. + /// + /// The vector property. + /// The PostgreSQL vector type name. + private static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions <= 0) + { + throw new ArgumentException("Vector property must have a positive number of dimensions."); + } + + return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + } + + /// + public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, TKey key, bool includeVectors = false) + where TKey : notnull + { + List queryColumns = new(); + string? keyColumn = null; + + foreach (var property in recordDefinition.Properties) + { + if (property is VectorStoreRecordKeyProperty keyProperty) + { + if (keyColumn != null) + { + throw new ArgumentException("Record definition cannot have more than one key property."); + } + keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + queryColumns.Add($"\"{keyColumn}\""); + } + else if (property is VectorStoreRecordDataProperty dataProperty) + { + string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + else if (property is VectorStoreRecordVectorProperty vectorProperty && includeVectors) + { + string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + queryColumns.Add($"\"{columnName}\""); + } + } + + Verify.NotNull(keyColumn, "Record definition must have a key property."); + + var queryColumnList = string.Join(", ", queryColumns); + + return new PostgresSqlCommandInfo( + commandText: $@" + SELECT {queryColumnList} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs new file mode 100644 index 000000000000..172f1acd19d2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// An implementation of a client for Postgres. This class is used to managing postgres database operations. +/// +/// +/// Initializes a new instance of the class. +/// +/// Postgres data source. +/// Schema of collection tables. +/// Sql builder for collection tables. +[System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] +public class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema, IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) : IPostgresVectorStoreDbClient +{ + private readonly NpgsqlDataSource _dataSource = dataSource; + private readonly IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = sqlBuilder; + private readonly string _schema = schema; + + /// + /// Initializes a new instance of the class. + /// + /// Postgres data source. + /// Schema of collection tables. + public PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = "public") : this(dataSource, schema, new PostgresVectorStoreCollectionSqlBuilder()) { } + + /// + public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDoesTableExistCommand(this._schema, tableName); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return dataReader.GetString(dataReader.GetOrdinal("table_name")) == tableName; + } + + return false; + } + } + + /// + public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetTablesCommand(this._schema); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("table_name")); + } + } + } + + /// + public async Task CreateTableAsync(string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, recordDefinition, ifNotExists); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, row, keyColumn); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task?> GetAsync(string tableName, TKey key, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, recordDefinition, key, includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return this.GetRecord(dataReader, recordDefinition.Properties, includeVectors); + } + + return null; + } + } + + private Dictionary GetRecord( + NpgsqlDataReader reader, + IEnumerable properties, + bool includeVectors = false + ) + { + var storageModel = new Dictionary(); + + foreach (var property in properties) + { + var isEmbedding = property is VectorStoreRecordVectorProperty; + var propertyName = property.StoragePropertyName ?? property.DataModelPropertyName; + var propertyType = property.PropertyType; + var propertyValue = !isEmbedding || includeVectors ? PostgresVectorStoreRecordPropertyMapping.GetPropertyValue(reader, propertyName, propertyType) : null; + + storageModel.Add(propertyName, propertyValue); + } + + return storageModel; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs new file mode 100644 index 000000000000..131036d7b0c7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreOptions +{ + /// + /// Gets or sets the default vector size to use when creating a new vector. + /// + public int DefaultVectorSize { get; init; } = 100; + + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. + /// + public IPostgresVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..969ddb2078d1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Represents a collection of vector store records in a Postgres database. +/// +/// The type of the key. +/// The type of the record. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection, IVectorizableTextSearch +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix + where TKey : notnull +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "Postgres"; + + /// + public string CollectionName { get; } + + /// Postgres client that is used to interact with the database. + private readonly IPostgresVectorStoreDbClient _client; + + // Optional configuration options for this class. + private readonly PostgresVectorStoreRecordCollectionOptions _options; + + /// A helper to access property information for the current data model and record definition. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// A mapper to use for converting between the data model and the Azure AI Search record. + private readonly IVectorStoreRecordMapper> _mapper; + + /// + /// Initializes a new instance of the class. + /// + /// The Postgres client used to interact with the database. + /// The name of the collection. + /// Optional configuration options for this class. + public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNull(client); + Verify.NotNullOrWhiteSpace(collectionName); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, PostgresConstants.SupportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + + // Assign. + this._client = client; + this.CollectionName = collectionName; + this._options = options ?? new PostgresVectorStoreRecordCollectionOptions(); + this._propertyReader = new VectorStoreRecordPropertyReader( + typeof(TRecord), + this._options.VectorStoreRecordDefinition, + new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true, + }); + + // Validate property types. + this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: true); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + + // Resolve mapper. + // First, if someone has provided a custom mapper, use that. + // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. + // Otherwise, don't set the mapper, and we'll default to just using Azure AI Search's built in json serialization and deserialization. + if (this._options.DictionaryCustomMapper is not null) + { + this._mapper = this._options.DictionaryCustomMapper; + } + else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel) || typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + { + this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; + } + else + { + this._mapper = new PostgresVectorStoreRecordMapper(this._propertyReader); + } + } + + /// + public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + return await this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, false, cancellationToken); + } + + /// + public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, true, cancellationToken); + } + + /// + public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + { + return this._client.DeleteTableAsync(this.CollectionName, cancellationToken); + } + + /// + public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + var operationName = "Get"; + + Verify.NotNull(key); + + bool includeVectors = options?.IncludeVectors is true; + + var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition, includeVectors, cancellationToken).ConfigureAwait(false); + + if (row is null) { return default; } + + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + operationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + } + + /// + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + const string OperationName = "Upsert"; + + var storageModel = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + Verify.NotNull(storageModel); + + var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; + Verify.NotNull(keyObj); + TKey key = (TKey)keyObj!; + + await this._client.UpsertAsync(this.CollectionName, this._mapper?.MapFromDataToStorageModel(record) ?? throw new InvalidOperationException("Failed to map record to storage model."), this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + return key; + } + + /// + public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..373f01a25a99 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// Options when creating a . +/// +public sealed class PostgresVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets the database schema. + /// + public string Schema { get; init; } = "public"; + + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the Postgres record. + /// + /// + /// If not set, the default mapper will be used. + /// + public IVectorStoreRecordMapper>? DictionaryCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs new file mode 100644 index 000000000000..8559d1218054 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +/// +/// A mapper class that handles the conversion between data models and storage models for Postgres vector store. +/// +/// The type of the data model record. +internal class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> +{ + /// with helpers for reading vector store model properties and their attributes. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + /// Initializes a new instance of the class. + /// + /// A that defines the schema of the data in the database. + public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + { + Verify.NotNull(propertyReader); + + this._propertyReader = propertyReader; + + this._propertyReader.VerifyHasParameterlessConstructor(); + + // Validate property types. + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: false); + this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); + } + + public Dictionary MapFromDataToStorageModel(TRecord dataModel) + { + var properties = new Dictionary + { + // Add key property + { this._propertyReader.KeyPropertyStoragePropertyName, this._propertyReader.KeyPropertyInfo.GetValue(dataModel) } + }; + + // Add data properties + foreach (var property in this._propertyReader.DataPropertiesInfo) + { + properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), property.GetValue(dataModel)); + } + + // Add vector properties + foreach (var property in this._propertyReader.VectorPropertiesInfo) + { + object? result = null; + var propertyValue = property.GetValue(dataModel); + + if (propertyValue is not null) + { + var vector = (ReadOnlyMemory)propertyValue; + result = new Vector(PostgresVectorStoreRecordPropertyMapping.GetOrCreateArray(vector)); + } + + properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); + } + + return properties; + } + + public TRecord MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + { + var record = (TRecord)this._propertyReader.ParameterLessConstructorInfo.Invoke(null); + + // Set key. + var keyPropertyValue = Convert.ChangeType( + storageModel[this._propertyReader.KeyPropertyStoragePropertyName], + this._propertyReader.KeyProperty.PropertyType); + + this._propertyReader.KeyPropertyInfo.SetValue(record, keyPropertyValue); + + // Process data properties. + var dataPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.DataPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, dataPropertiesInfoWithValues); + + if (options.IncludeVectors) + { + // Process vector properties. + var vectorPropertiesInfoWithValues = VectorStoreRecordMapping.BuildPropertiesInfoWithValues( + this._propertyReader.VectorPropertiesInfo, + this._propertyReader.StoragePropertyNamesMap, + storageModel, + (object? vector, Type type) => vector is Vector pgVector ? + pgVector.ToArray() : null); + + VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); + } + + return record; + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs new file mode 100644 index 000000000000..e19a2413864b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreRecordPropertyMapping +{ + internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => + MemoryMarshal.TryGetArray(memory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + memory.ToArray(); + + public static TPropertyType? GetPropertyValue(NpgsqlDataReader reader, string propertyName) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return default; + } + + return reader.GetFieldValue(propertyIndex); + } + + public static object? GetPropertyValue(NpgsqlDataReader reader, string propertyName, Type propertyType) + { + int propertyIndex = reader.GetOrdinal(propertyName); + + if (reader.IsDBNull(propertyIndex)) + { + return null; + } + + // Check if the type is a List + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + var elementType = propertyType.GetGenericArguments()[0]; + var list = (IEnumerable)reader.GetValue(propertyIndex); + // Convert list to the correct element type + return ConvertList(list, elementType); + } + + return propertyType switch + { + Type t when t == typeof(int) || t == typeof(int?) => reader.GetInt32(propertyIndex), + Type t when t == typeof(long) || t == typeof(long?) => reader.GetInt64(propertyIndex), + Type t when t == typeof(ulong) || t == typeof(ulong?) => (ulong)reader.GetInt64(propertyIndex), + Type t when t == typeof(short) || t == typeof(short?) => reader.GetInt16(propertyIndex), + Type t when t == typeof(ushort) || t == typeof(ushort?) => (ushort)reader.GetInt16(propertyIndex), + Type t when t == typeof(bool) || t == typeof(bool?) => reader.GetBoolean(propertyIndex), + Type t when t == typeof(float) || t == typeof(float?) => reader.GetFloat(propertyIndex), + Type t when t == typeof(double) || t == typeof(double?) => reader.GetDouble(propertyIndex), + Type t when t == typeof(decimal) || t == typeof(decimal?) => reader.GetDecimal(propertyIndex), + Type t when t == typeof(string) => reader.GetString(propertyIndex), + _ => reader.GetValue(propertyIndex) + }; + } + + // Helper method to convert lists + private static object ConvertList(IEnumerable list, Type elementType) + { + var listType = typeof(List<>).MakeGenericType(elementType); + var convertedList = (IList)Activator.CreateInstance(listType)!; + + foreach (var item in list) + { + convertedList.Add(Convert.ChangeType(item, elementType)); + } + + return convertedList; + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs new file mode 100644 index 000000000000..b2357e302fda --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.Connectors.UnitTests.Postgres; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public int HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs new file mode 100644 index 000000000000..c159f92f8fcb --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.UnitTests.Postgres; + +public class PostgresVectorStoreCollectionSqlBuilderTests +{ + private readonly ITestOutputHelper _output; + + public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) + { + this._output = output; + } + + [Fact] + public void TestBuildCreateTableCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition, ifNotExists: true); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("IF NOT EXISTS", cmdInfo.CommandText); + Assert.Contains("\"name\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"parking_is_included\" BOOLEAN NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"embedding2\" VECTOR(10)", cmdInfo.CommandText); + Assert.Contains("PRIMARY KEY (\"id\")", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var row = new Dictionary() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(new float[] { 1.0f, 2.0f, 3.0f }), + }; + + var keyColumn = "id"; + + var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", row, keyColumn); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (key, index) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[key], cmdInfo.Parameters[index].Value); + // If the key is not the key column, it should be included in the update clause. + if (key != keyColumn) + { + Assert.Contains($"\"{key}\"=${index + 1}", cmdInfo.CommandText); + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var key = 123; + + // Act + var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition, key, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..69373123431b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.Postgres; + +public class PostgresVectorStoreRecordCollectionTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreRecordCollectionTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public async Task CreatesCollectionForGenericModelAsync() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 100, DistanceFunction = DistanceFunction.ManhattanDistance } + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TestCollectionName, this._testCancellationToken)).ReturnsAsync(false); + + // Act + var exists = await sut.CollectionExistsAsync(); + + // Assert. + Assert.False(exists); + } + + [Fact] + public async Task UpsertRecordAsyncProducesExpectedSqlAsync() + { + // Arrange + Dictionary? capturedArguments = null; + + var sut = new PostgresVectorStoreRecordCollection(this._postgresClientMock.Object, TestCollectionName); + var record = new PostgresHotel + { + HotelId = 1, + HotelName = "Hotel 1", + HotelCode = 1, + HotelRating = 4.5f, + ParkingIncluded = true, + Tags = ["tag1", "tag2"], + Description = "A hotel", + DescriptionEmbedding = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }) + }; + + this._postgresClientMock.Setup(x => x.UpsertAsync( + TestCollectionName, + It.IsAny>(), + "HotelId", + this._testCancellationToken)) + .Callback, string, CancellationToken>((collectionName, args, key, ct) => capturedArguments = args) + .Returns(Task.CompletedTask); + + // Act + await sut.UpsertAsync(record, cancellationToken: this._testCancellationToken); + + // Assert + Assert.NotNull(capturedArguments); + Assert.Equal(1, (int)(capturedArguments["HotelId"] ?? 0)); + Assert.Equal("Hotel 1", (string)(capturedArguments["HotelName"] ?? "")); + Assert.Equal(1, (int)(capturedArguments["HotelCode"] ?? 0)); + Assert.Equal(4.5f, (float)(capturedArguments["HotelRating"] ?? 0.0f)); + Assert.True((bool)(capturedArguments["parking_is_included"] ?? false)); + Assert.True(capturedArguments["Tags"] is List); + var tags = capturedArguments["Tags"] as List; + Assert.Equal(2, tags!.Count); + Assert.Equal("tag1", tags[0]); + Assert.Equal("tag2", tags[1]); + Assert.Equal("A hotel", (string)(capturedArguments["Description"] ?? "")); + Assert.NotNull(capturedArguments["DescriptionEmbedding"]); + Assert.IsType(capturedArguments["DescriptionEmbedding"]); + var embedding = ((Vector)capturedArguments["DescriptionEmbedding"]!).ToArray(); + Assert.Equal(1.0f, embedding[0]); + Assert.Equal(2.0f, embedding[1]); + Assert.Equal(3.0f, embedding[2]); + Assert.Equal(4.0f, embedding[3]); + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..8ed97fce5f7c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Microsoft.Extensions.VectorData; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.UnitTests.Postgres; + +/// +/// Contains tests for the class. +/// +public class PostgresVectorStoreTests +{ + private const string TestCollectionName = "testcollection"; + + private readonly Mock _postgresClientMock; + private readonly CancellationToken _testCancellationToken = new(false); + + public PostgresVectorStoreTests() + { + this._postgresClientMock = new Mock(MockBehavior.Strict); + } + + [Fact] + public void GetCollectionReturnsCollection() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.NotNull(actual); + Assert.IsType>>(actual); + } + + [Fact] + public void GetCollectionThrowsForInvalidKeyType() + { + // Arrange. + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act and Assert. + Assert.Throws(() => sut.GetCollection>(TestCollectionName)); + } + + [Fact] + public void GetCollectionCallsFactoryIfProvided() + { + // Arrange. + var factoryMock = new Mock(MockBehavior.Strict); + var collectionMock = new Mock>>(MockBehavior.Strict); + var clientMock = new Mock(MockBehavior.Strict); + factoryMock + .Setup(x => x.CreateVectorStoreRecordCollection>(clientMock.Object, TestCollectionName, null)) + .Returns(collectionMock.Object); + var sut = new PostgresVectorStore(clientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); + + // Act. + var actual = sut.GetCollection>(TestCollectionName); + + // Assert. + Assert.Equal(collectionMock.Object, actual); + } + + [Fact] + public async Task ListCollectionNamesCallsSDKAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(expectedCollections.ToAsyncEnumerable()); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + var actualList = await actual.ToListAsync(); + Assert.Equal(expectedCollections, actualList); + } + + public sealed class SinglePropsModel + { + [VectorStoreRecordKey] + public required TKey Key { get; set; } + + [VectorStoreRecordData] + public string Data { get; set; } = string.Empty; + + [VectorStoreRecordVector(4)] + public ReadOnlyMemory? Vector { get; set; } + + public string? NotAnnotated { get; set; } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs new file mode 100644 index 000000000000..3510280397e0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + +/// +/// A test model for the postgres vector store. +/// +public record PostgresHotel() +{ + /// The key of the record. + [VectorStoreRecordKey] + public int HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData()] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData()] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData()] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + [VectorStoreRecordData] + public List? ListInts { get; set; } = null; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} + +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs index 19126a090874..71474ff0ebc6 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -11,7 +11,7 @@ using Npgsql; using Xunit; -namespace SemanticKernel.IntegrationTests.Connectors.Postgres; +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// Integration tests of . @@ -41,6 +41,8 @@ public async Task InitializeAsync() this._connectionString = connectionString; this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + await this.CreateDatabaseAsync(); + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) { Database = this._databaseName @@ -50,8 +52,6 @@ public async Task InitializeAsync() dataSourceBuilder.UseVector(); this._dataSource = dataSourceBuilder.Build(); - - await this.CreateDatabaseAsync(); } public async Task DisposeAsync() diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..5d202af5b9f5 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreCollectionFixture.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[CollectionDefinition("PostgresVectorStoreCollection")] +public class PostgresVectorStoreCollectionFixture : ICollectionFixture +{ +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs new file mode 100644 index 000000000000..ab7afa0489c6 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -0,0 +1,361 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Microsoft.Extensions.VectorData; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +public class PostgresVectorStoreFixture : IAsyncLifetime +{ + /// The docker client we are using to create a postgres container with. + private readonly DockerClient _client; + + /// The id of the postgres container that we are testing with. + private string? _containerId = null; + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + + /// + /// Initializes a new instance of the class. + /// + public PostgresVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + this.HotelVectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } + } + }; + this.HotelWithGuidIdVectorStoreRecordDefinition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } + } + }; + } + + /// + /// Holds the Npgsql data source to use for tests. + /// + private NpgsqlDataSource? _dataSource; + + private string _connectionString = null!; + private string _databaseName = null!; + + /// Gets the postgres client connection to use for tests. + public PostgresVectorStoreDbClient PostgresClient { get; private set; } + + /// Gets the manually created vector store record definition for our test model. + public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } + + /// Gets the manually created vector store record definition for our test model. + public VectorStoreRecordDefinition HotelWithGuidIdVectorStoreRecordDefinition { get; private set; } + + public PostgresVectorStoreRecordCollection GetCollection( + string collectionName, + PostgresVectorStoreRecordCollectionOptions? options = default) + where TKey : notnull + where TRecord : class + { + return new PostgresVectorStoreRecordCollection( + this.PostgresClient, + collectionName, + options); + } + + /// + /// Create / Recreate postgres docker container and run it. + /// + /// An async task. + public async Task InitializeAsync() + { + this._containerId = await SetupPostgresContainerAsync(this._client); + this._connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; + this._databaseName = $"sk_it_{Guid.NewGuid():N}"; + + // Connect to postgres. + NpgsqlConnectionStringBuilder connectionStringBuilder = new(this._connectionString) + { + Database = this._databaseName + }; + + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString()); + dataSourceBuilder.UseVector(); + + this._dataSource = dataSourceBuilder.Build(); + + this.PostgresClient = new PostgresVectorStoreDbClient(this._dataSource); + + // Wait for the postgres container to be ready and create the test database using the initial data source. + var initialDataSource = NpgsqlDataSource.Create(this._connectionString); + using (initialDataSource) + { + var retryCount = 0; + var exceptionCount = 0; + while (retryCount++ < 5) + { + try + { + NpgsqlConnection connection = await initialDataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT count(*) FROM information_schema.tables WHERE table_schema = 'public';"; + await cmd.ExecuteScalarAsync().ConfigureAwait(false); + } + } + catch (NpgsqlException) + { + exceptionCount++; + await Task.Delay(1000); + } + } + + if (exceptionCount >= 5) + { + // Throw an exception for test setup + throw new InvalidOperationException("Postgres container did not start in time."); + } + + await this.CreateDatabaseAsync(initialDataSource); + } + + // Create the table. + await this.CreateTableAsync(); + + // await this.PostgresClient.CreateCollectionAsync( + // "singleVectorHotels", + // new VectorParams { Size = 4, Distance = Distance.Cosine }); + + // await this.PostgresClient.CreateCollectionAsync( + // "singleVectorGuidIdHotels", + // new VectorParams { Size = 4, Distance = Distance.Cosine }); + + // // Create test data common to both named and unnamed vectors. + // var tags = new ListValue(); + // tags.Values.Add("t1"); + // tags.Values.Add("t2"); + // var tagsValue = new Value(); + // tagsValue.ListValue = tags; + + // // Create some test data using named vectors. + // var embedding = new[] { 30f, 31f, 32f, 33f }; + + // var namedVectors1 = new NamedVectors(); + // var namedVectors2 = new NamedVectors(); + // var namedVectors3 = new NamedVectors(); + + // namedVectors1.Vectors.Add("DescriptionEmbedding", embedding); + // namedVectors2.Vectors.Add("DescriptionEmbedding", embedding); + // namedVectors3.Vectors.Add("DescriptionEmbedding", embedding); + + // List namedVectorPoints = + // [ + // new PointStruct + // { + // Id = 11, + // Vectors = new Vectors { Vectors_ = namedVectors1 }, + // Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = 12, + // Vectors = new Vectors { Vectors_ = namedVectors2 }, + // Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = 13, + // Vectors = new Vectors { Vectors_ = namedVectors3 }, + // Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + // }, + // ]; + + // await this.PostgresClient.UpsertAsync("namedVectorsHotels", namedVectorPoints); + + // // Create some test data using a single unnamed vector. + // List unnamedVectorPoints = + // [ + // new PointStruct + // { + // Id = 11, + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = 12, + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = 13, + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } + // }, + // ]; + + // await this.PostgresClient.UpsertAsync("singleVectorHotels", unnamedVectorPoints); + + // // Create some test data using a single unnamed vector and a guid id. + // List unnamedVectorGuidIdPoints = + // [ + // new PointStruct + // { + // Id = Guid.Parse("11111111-1111-1111-1111-111111111111"), + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 11", ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = Guid.Parse("22222222-2222-2222-2222-222222222222"), + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 12", ["Description"] = "This is a great hotel." } + // }, + // new PointStruct + // { + // Id = Guid.Parse("33333333-3333-3333-3333-333333333333"), + // Vectors = embedding, + // Payload = { ["HotelName"] = "My Hotel 13", ["Description"] = "This is a great hotel." } + // }, + // ]; + + // await this.PostgresClient.UpsertAsync("singleVectorGuidIdHotels", unnamedVectorGuidIdPoints); + } + + private async Task CreateTableAsync() + { + NpgsqlConnection connection = await this._dataSource!.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @" + CREATE TABLE hotel_info ( + HotelId INTEGER NOT NULL, + HotelName TEXT, + HotelCode INTEGER NOT NULL, + HotelRating REAL, + parking_is_included BOOLEAN, + Tags TEXT[] NOT NULL, + Description TEXT NOT NULL, + DescriptionEmbedding VECTOR(4) NOT NULL, + PRIMARY KEY (HotelId));"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); + } + } + + /// + /// Delete the docker container after the test run. + /// + /// An async task. + public async Task DisposeAsync() + { + if (this._dataSource != null) + { + this._dataSource.Dispose(); + } + + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + + /// + /// Setup the postgres container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + private static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task CreateDatabaseAsync(NpgsqlDataSource initialDataSource) + { + await using (NpgsqlConnection conn = await initialDataSource.OpenConnectionAsync()) + { + await using NpgsqlCommand command = new($"CREATE DATABASE \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } + + await using (NpgsqlConnection conn = await this._dataSource!.OpenConnectionAsync()) + { + await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn)) + { + await command.ExecuteNonQueryAsync(); + } + await conn.ReloadTypesAsync(); + } + } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "The database name is generated randomly, it does not support parameterized passing.")] + private async Task DropDatabaseAsync() + { + using NpgsqlDataSource dataSource = NpgsqlDataSource.Create(this._connectionString); + await using NpgsqlConnection conn = await dataSource.OpenConnectionAsync(); + await using NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); + await command.ExecuteNonQueryAsync(); + } +} \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..04c6f815d1bd --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollection) + { + // Arrange + var sut = fixture.GetCollection("CollectionExists"); + + if (createCollection) + { + await sut.CreateCollectionAsync(); + } + + try + { + // Act + var collectionExists = await sut.CollectionExistsAsync(); + + // Assert + Assert.Equal(createCollection, collectionExists); + } + finally + { + // Cleanup + if (createCollection) + { + await sut.DeleteCollectionAsync(); + } + } + } + + [Fact] + public async Task CollectionCanUpsertAndGetAsync() + { + // Arrange + var sut = fixture.GetCollection("CollectionCanUpsertAndGet"); + if (await sut.CollectionExistsAsync()) + { + await sut.DeleteCollectionAsync(); + } + + await sut.CreateCollectionAsync(); + + try + { + // Act + await sut.UpsertAsync(new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }); + await sut.UpsertAsync(new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }); + + var hotel1 = await sut.GetAsync(1); + var hotel2 = await sut.GetAsync(2); + + // Assert + Assert.NotNull(hotel1); + Assert.Equal(1, hotel1!.HotelId); + Assert.Equal("Hotel 1", hotel1!.HotelName); + Assert.Equal(1, hotel1!.HotelCode); + Assert.True(hotel1!.ParkingIncluded); + Assert.Equal(4.5f, hotel1!.HotelRating); + Assert.NotNull(hotel1!.Tags); + Assert.Equal(2, hotel1!.Tags!.Count); + Assert.Equal("tag1", hotel1!.Tags![0]); + Assert.Equal("tag2", hotel1!.Tags![1]); + Assert.Null(hotel1!.ListInts); + + Assert.NotNull(hotel2); + Assert.Equal(2, hotel2!.HotelId); + Assert.Equal("Hotel 2", hotel2!.HotelName); + Assert.Equal(2, hotel2!.HotelCode); + Assert.False(hotel2!.ParkingIncluded); + Assert.Equal(2.5f, hotel2!.HotelRating); + Assert.NotNull(hotel2!.Tags); + Assert.Empty(hotel2!.Tags); + Assert.NotNull(hotel2!.ListInts); + Assert.Equal(2, hotel2!.ListInts!.Count); + Assert.Equal(1, hotel2!.ListInts![0]); + Assert.Equal(2, hotel2!.ListInts![1]); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } +} \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs new file mode 100644 index 000000000000..0ae8fc8d5228 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +[Collection("PostgresVectorStoreCollection")] +public class PostgresVectorStoreTests(PostgresVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = new PostgresVectorStore(fixture.PostgresClient); + + // Setup + var collection = sut.GetCollection("VS_TEST_HOTELS"); + await collection.CreateCollectionIfNotExistsAsync(); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Contains("VS_TEST_HOTELS", collectionNames); + } +} From ddad99a903494b651310d18f48cec1020a6c893d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 18 Oct 2024 18:00:46 -0400 Subject: [PATCH 02/62] Add UpsertBatch, GetBatch, and DeleteBatch --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 43 ++++- .../IPostgresVectorStoreDbClient.cs | 42 +++++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 169 +++++++++++------- .../PostgresVectorStoreDbClient.cs | 60 ++++++- .../PostgresVectorStoreRecordCollection.cs | 81 ++++++--- ...ostgresVectorStoreRecordPropertyMapping.cs | 89 +++++++++ ...resVectorStoreCollectionSqlBuilderTests.cs | 41 ++++- ...ostgresVectorStoreRecordCollectionTests.cs | 69 +++++++ 8 files changed, 504 insertions(+), 90 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index ed0a763b7b25..3c23741c5599 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -50,10 +50,20 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. + /// The key column of the table. /// The row to upsert. + /// The built SQL command info. + PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row); + + /// + /// Builds a SQL command to upsert a batch of records in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. /// The key column of the table. + /// The rows to upsert. /// The built SQL command info. - PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, Dictionary row, string keyColumn); + PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows); /// /// Builds a SQL command to get a record from the Postgres vector store. @@ -65,4 +75,35 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// Specifies whether to include vectors in the record. /// The built SQL command info. PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, TKey key, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to get a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The record definition of the table. + /// The keys of the records to get. + /// Specifies whether to include vectors in the records. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, List keys, bool includeVectors = false) where TKey : notnull; + + /// + /// Builds a SQL command to delete a record from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The key of the record to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key); + + /// + /// Builds a SQL command to delete a batch of records from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The key column of the table. + /// The keys of the records to delete. + /// The built SQL command info. + PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys); } \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 4a1f7ff4e13b..a2962215c632 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -53,6 +53,16 @@ public interface IPostgresVectorStoreDbClient /// Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default); + /// + /// Upsert multiple entries into a table. + /// + /// The name assigned to a table of entries. + /// The rows to upsert into the table. + /// The key column of the table. + /// The to monitor for cancellation requests. The default is . + /// + Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default); + /// /// Get a entry by its key. /// @@ -65,6 +75,38 @@ public interface IPostgresVectorStoreDbClient Task?> GetAsync(string tableName, TKey key, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull; + /// + /// Get multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The keys of the entries to get. + /// The record definition of the table. + /// If true, the vectors will be included in the entries. + /// The to monitor for cancellation requests. The default is . + /// The rows that match the given keys. + IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) + where TKey : notnull; + + /// + /// Delete a entry by its key. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The key of the entry to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default); + + /// + /// Delete multiple entries by their keys. + /// + /// The name assigned to a table of entries. + /// The name of the key column. + /// The keys of the entries to delete. + /// The to monitor for cancellation requests. The default is . + /// + Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default); + // /// // /// Gets the nearest matches to the . // /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 5fe62a0f57e1..7a6d2e6905fe 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -6,6 +6,7 @@ using System.Text; using Microsoft.Extensions.VectorData; using Npgsql; +using NpgsqlTypes; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -88,14 +89,14 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl createTableCommand.AppendLine($"CREATE TABLE {(ifNotExists ? "IF NOT EXISTS " : "")}{schema}.\"{tableName}\" ("); // Add the key column - var keyPgTypeInfo = GetPostgresTypeName(keyProperty.PropertyType); + var keyPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(keyProperty.PropertyType); createTableCommand.AppendLine($" \"{keyName}\" {keyPgTypeInfo.PgType} {(keyPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); // Add the data columns foreach (var dataProperty in dataProperties) { string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - var dataPgTypeInfo = GetPostgresTypeName(dataProperty.PropertyType); + var dataPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(dataProperty.PropertyType); createTableCommand.AppendLine($" \"{columnName}\" {dataPgTypeInfo.PgType} {(dataPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); } @@ -103,7 +104,7 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl foreach (var vectorProperty in vectorProperties) { string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - var vectorPgTypeInfo = GetPgVectorTypeName(vectorProperty); + var vectorPgTypeInfo = PostgresVectorStoreRecordPropertyMapping.GetPgVectorTypeName(vectorProperty); createTableCommand.AppendLine($" \"{columnName}\" {vectorPgTypeInfo.PgType} {(vectorPgTypeInfo.IsNullable ? "" : "NOT NULL")},"); } @@ -123,7 +124,7 @@ public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableN } /// - public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, Dictionary row, string keyColumn) + public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName, string keyColumn, Dictionary row) { var columns = row.Keys.ToList(); var columnNames = string.Join(", ", columns.Select(k => $"\"{k}\"")); @@ -144,72 +145,47 @@ ON CONFLICT (""{keyColumn}"") }; } - /// - /// Maps a .NET type to a PostgreSQL type name. - /// - /// The .NET type. - /// Tuple of the the PostgreSQL type name and whether it can be NULL - private static (string PgType, bool IsNullable) GetPostgresTypeName(Type propertyType) + /// + public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tableName, string keyColumn, List> rows) { - var (pgType, isNullable) = propertyType switch - { - Type t when t == typeof(int) => ("INTEGER", false), - Type t when t == typeof(string) => ("TEXT", true), - Type t when t == typeof(bool) => ("BOOLEAN", false), - Type t when t == typeof(DateTime) => ("TIMESTAMP", false), - Type t when t == typeof(double) => ("DOUBLE PRECISION", false), - Type t when t == typeof(decimal) => ("NUMERIC", false), - Type t when t == typeof(float) => ("REAL", false), - Type t when t == typeof(byte[]) => ("BYTEA", true), - Type t when t == typeof(Guid) => ("UUID", false), - Type t when t == typeof(short) => ("SMALLINT", false), - Type t when t == typeof(long) => ("BIGINT", false), - _ => (null, false) - }; - - if (pgType != null) + if (rows == null || rows.Count == 0) { - return (pgType, isNullable); + throw new ArgumentException("Rows cannot be null or empty", nameof(rows)); } - // Handle lists and arrays (PostgreSQL supports array types for most types) - if (propertyType.IsArray) - { - Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); - var underlyingPgType = GetPostgresTypeName(elementType); - return (underlyingPgType.PgType + "[]", true); - } - else if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) - { - Type elementType = propertyType.GetGenericArguments()[0]; - var underlyingPgType = GetPostgresTypeName(elementType); - return (underlyingPgType.PgType + "[]", true); - } + var firstRow = rows[0]; + var columns = firstRow.Keys.ToList(); - // Handle nullable types (e.g. Nullable) - if (Nullable.GetUnderlyingType(propertyType) != null) - { - Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); - var underlyingPgType = GetPostgresTypeName(underlyingType); - return (underlyingPgType.PgType, true); - } + // Generate column names and parameter placeholders + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var valuePlaceholders = string.Join(", ", columns.Select((c, i) => $"${i + 1}")); + var valuesRows = string.Join(", ", rows.Select((row, rowIndex) => $"({string.Join(", ", columns.Select((c, colIndex) => $"${rowIndex * columns.Count + colIndex + 1}"))})")); - throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); - } + // Generate the update set clause + var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); - /// - /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. - /// - /// The vector property. - /// The PostgreSQL vector type name. - private static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.Dimensions <= 0) + // Generate the SQL command + var commandText = $@" + INSERT INTO {schema}.""{tableName}"" ({columnNames}) + VALUES {valuesRows} + ON CONFLICT(""{keyColumn}"") + DO UPDATE SET {updateSetClause}; "; + + // Generate the parameters + var parameters = new List(); + for (int rowIndex = 0; rowIndex < rows.Count; rowIndex++) { - throw new ArgumentException("Vector property must have a positive number of dimensions."); + var row = rows[rowIndex]; + foreach (var column in columns) + { + parameters.Add(new NpgsqlParameter() + { + Value = row[column] ?? DBNull.Value + }); + } } - return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + return new PostgresSqlCommandInfo(commandText, parameters); } /// @@ -254,4 +230,77 @@ public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableN parameters: [new NpgsqlParameter() { Value = key }] ); } + + /// + public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, List keys, bool includeVectors = false) + where TKey : notnull + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + var keyProperty = recordDefinition.Properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Record definition must contain a key property", nameof(recordDefinition)); + var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; + + // Generate the column names + var columns = recordDefinition.Properties + .Where(p => includeVectors || p is not VectorStoreRecordVectorProperty) + .Select(p => p.StoragePropertyName ?? p.DataModelPropertyName) + .ToList(); + + var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); + var keyParams = string.Join(", ", keys.Select((k, i) => $"${i + 1}")); + + // Generate the SQL command + var commandText = $@" + SELECT {columnNames} + FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys.ToArray(), NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } + + /// + public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) + { + return new PostgresSqlCommandInfo( + commandText: $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ${1};", + parameters: [new NpgsqlParameter() { Value = key }] + ); + } + + /// + public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys) + { + NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); + if (keys == null || keys.Count == 0) + { + throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); + } + + for (int i = 0; i < keys.Count; i++) + { + if (keys[i] == null) + { + throw new ArgumentException("Keys cannot contain null values", nameof(keys)); + } + } + + var commandText = $@" + DELETE FROM {schema}.""{tableName}"" + WHERE ""{keyColumn}"" = ANY($1);"; + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = keys, NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] + }; + } } \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 172f1acd19d2..adfffa86959a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -101,7 +102,20 @@ public async Task UpsertAsync(string tableName, Dictionary row, await using (connection) { - var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, row, keyColumn); + var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } @@ -126,6 +140,50 @@ public async Task UpsertAsync(string tableName, Dictionary row, } } + /// + public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TKey : notnull + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, recordDefinition, keys.ToList(), includeVectors); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return this.GetRecord(dataReader, recordDefinition.Properties, includeVectors); + } + } + } + + /// + public async Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + private Dictionary GetRecord( NpgsqlDataReader reader, IEnumerable properties, diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 969ddb2078d1..30201d0718ad 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -2,6 +2,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -106,21 +108,48 @@ public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken } /// - public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + return this._client.DeleteTableAsync(this.CollectionName, cancellationToken); } /// - public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + const string OperationName = "Upsert"; + + var storageModel = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + Verify.NotNull(storageModel); + + var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; + Verify.NotNull(keyObj); + TKey key = (TKey)keyObj!; + + await this._client.UpsertAsync(this.CollectionName, this._mapper?.MapFromDataToStorageModel(record) ?? throw new InvalidOperationException("Failed to map record to storage model."), this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + return key; } /// - public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + public async IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this._client.DeleteTableAsync(this.CollectionName, cancellationToken); + const string OperationName = "UpsertBatch"; + + var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record))).ToList(); + + var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); + + await this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + + foreach (var key in keys) { yield return (TKey)key!; } } /// @@ -144,36 +173,34 @@ public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) } /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) - { - throw new NotImplementedException(); - } - - /// - public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - const string OperationName = "Upsert"; + var operationName = "GetBatch"; - var storageModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromDataToStorageModel(record)); + Verify.NotNull(keys); - Verify.NotNull(storageModel); + bool includeVectors = options?.IncludeVectors is true; - var keyObj = storageModel[this._propertyReader.KeyPropertyStoragePropertyName]; - Verify.NotNull(keyObj); - TKey key = (TKey)keyObj!; + await foreach (var row in this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition, includeVectors, cancellationToken).ConfigureAwait(false)) + { + yield return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + operationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + } + } - await this._client.UpsertAsync(this.CollectionName, this._mapper?.MapFromDataToStorageModel(record) ?? throw new InvalidOperationException("Failed to map record to storage model."), this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); - return key; + /// + public async Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + await this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken).ConfigureAwait(false); } /// - public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + return this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken); } /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index e19a2413864b..bfedf8dd9424 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -4,7 +4,9 @@ using System.Collections; using System.Collections.Generic; using System.Runtime.InteropServices; +using Microsoft.Extensions.VectorData; using Npgsql; +using NpgsqlTypes; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -62,6 +64,93 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => }; } + public static NpgsqlDbType? GetNpgsqlDbType(Type propertyType) => + propertyType switch + { + Type t when t == typeof(int) || t == typeof(int?) => NpgsqlDbType.Integer, + Type t when t == typeof(long) || t == typeof(long?) => NpgsqlDbType.Bigint, + Type t when t == typeof(ulong) || t == typeof(ulong?) => NpgsqlDbType.Bigint, + Type t when t == typeof(short) || t == typeof(short?) => NpgsqlDbType.Smallint, + Type t when t == typeof(ushort) || t == typeof(ushort?) => NpgsqlDbType.Smallint, + Type t when t == typeof(bool) || t == typeof(bool?) => NpgsqlDbType.Boolean, + Type t when t == typeof(float) || t == typeof(float?) => NpgsqlDbType.Real, + Type t when t == typeof(double) || t == typeof(double?) => NpgsqlDbType.Double, + Type t when t == typeof(decimal) || t == typeof(decimal?) => NpgsqlDbType.Numeric, + Type t when t == typeof(string) => NpgsqlDbType.Text, + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, + Type t when t == typeof(byte[]) => NpgsqlDbType.Bytea, + Type t when t == typeof(Guid) => NpgsqlDbType.Uuid, + _ => null + }; + + /// + /// Maps a .NET type to a PostgreSQL type name. + /// + /// The .NET type. + /// Tuple of the the PostgreSQL type name and whether it can be NULL + public static (string PgType, bool IsNullable) GetPostgresTypeName(Type propertyType) + { + var (pgType, isNullable) = propertyType switch + { + Type t when t == typeof(int) => ("INTEGER", false), + Type t when t == typeof(string) => ("TEXT", true), + Type t when t == typeof(bool) => ("BOOLEAN", false), + Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(double) => ("DOUBLE PRECISION", false), + Type t when t == typeof(decimal) => ("NUMERIC", false), + Type t when t == typeof(float) => ("REAL", false), + Type t when t == typeof(byte[]) => ("BYTEA", true), + Type t when t == typeof(Guid) => ("UUID", false), + Type t when t == typeof(short) => ("SMALLINT", false), + Type t when t == typeof(long) => ("BIGINT", false), + _ => (null, false) + }; + + if (pgType != null) + { + return (pgType, isNullable); + } + + // Handle lists and arrays (PostgreSQL supports array types for most types) + if (propertyType.IsArray) + { + Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + else if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + Type elementType = propertyType.GetGenericArguments()[0]; + var underlyingPgType = GetPostgresTypeName(elementType); + return (underlyingPgType.PgType + "[]", true); + } + + // Handle nullable types (e.g. Nullable) + if (Nullable.GetUnderlyingType(propertyType) != null) + { + Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); + var underlyingPgType = GetPostgresTypeName(underlyingType); + return (underlyingPgType.PgType, true); + } + + throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); + } + + /// + /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. + /// + /// The vector property. + /// The PostgreSQL vector type name. + public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) + { + if (vectorProperty.Dimensions <= 0) + { + throw new ArgumentException("Vector property must have a positive number of dimensions."); + } + + return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); + } + // Helper method to convert lists private static object ConvertList(IEnumerable list, Type elementType) { diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs index c159f92f8fcb..45d9391f62c3 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -85,7 +85,7 @@ public void TestBuildUpsertCommand() var keyColumn = "id"; - var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", row, keyColumn); + var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); @@ -145,4 +145,43 @@ public void TestBuildGetCommand() // Output this._output.WriteLine(cmdInfo.CommandText); } + + [Fact] + public void TestBuildGetBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var keys = new List { 123, 456, 789 }; + + // Act + var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition, keys, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } } \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 04c6f815d1bd..7e1398b441ed 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Linq; using System.Threading.Tasks; using Xunit; @@ -92,4 +93,72 @@ public async Task CollectionCanUpsertAndGetAsync() await sut.DeleteCollectionAsync(); } } + + [Fact] + public async Task ItCanGetAndDeleteRecordAsync() + { + // Arrange + const int HotelId = 5; + var sut = fixture.GetCollection("DeleteRecord"); + + await sut.CreateCollectionAsync(); + + try + { + var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + + // Act + await sut.DeleteAsync(HotelId); + + getResult = await sut.GetAsync(HotelId); + + // Assert + Assert.Null(getResult); + } + finally + { + // Cleanup + await sut.DeleteCollectionAsync(); + } + } + + [Fact] + public async Task ItCanGetUpsertDeleteBatchAsync() + { + // Arrange + const int HotelId1 = 1; + const int HotelId2 = 2; + const int HotelId3 = 3; + + var sut = fixture.GetCollection("GetUpsertDeleteBatch"); + + await sut.CreateCollectionAsync(); + + var record1 = new PostgresHotel { HotelId = HotelId1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; + var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; + + var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); + var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); + + Assert.NotNull(getResults.First(l => l.HotelId == HotelId1)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId2)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); + + // Act + await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + + getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + // Assert + Assert.Empty(getResults); + } } \ No newline at end of file From 54478154d22a339f8b36f32ee95609eaccadf21a Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 18 Oct 2024 18:01:05 -0400 Subject: [PATCH 03/62] Remove unused CreateMapping --- ...tgresVectorStoreCollectionCreateMapping.cs | 119 ------------------ 1 file changed, 119 deletions(-) delete mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs deleted file mode 100644 index 3bfd7910e956..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionCreateMapping.cs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Text; -using Microsoft.Extensions.VectorData; - -namespace Microsoft.SemanticKernel.Connectors.Postgres; - -/// -/// Generates the PostgreSQL vector type name based on the dimensions of the vector property. -/// /// Provides methods to generate SQL statements for creating tables in /// a PostgreSQL database -/// for storing vector data. -/// -public static class PostgresVectorStoreCollectionCreateMapping -{ - /// - /// Generates a SQL CREATE TABLE statement. - /// - /// The schema name. - /// The table name. - /// The key property. - /// The list of data properties. - /// The list of vector properties. - /// The generated SQL CREATE TABLE statement. - /// Thrown when the table name is null or whitespace. - public static string GenerateCreateTableStatement(string schema, string tableName, VectorStoreRecordKeyProperty KeyProperty, IEnumerable DataProperties, IEnumerable VectorProperties) - { - if (string.IsNullOrWhiteSpace(tableName)) - { - throw new ArgumentException("Table name cannot be null or whitespace", nameof(tableName)); - } - - var keyName = KeyProperty.StoragePropertyName ?? KeyProperty.DataModelPropertyName; - - StringBuilder createTableCommand = new(); - createTableCommand.AppendLine($"CREATE TABLE {schema}.{tableName} ("); - - // Add the key column - createTableCommand.AppendLine($" {keyName} {GetPostgresTypeName(KeyProperty.PropertyType)},"); - - // Add the data columns - foreach (var dataProperty in DataProperties) - { - string columnName = dataProperty.StoragePropertyName ?? dataProperty.DataModelPropertyName; - createTableCommand.AppendLine($" {columnName} {GetPostgresTypeName(dataProperty.PropertyType)},"); - } - - // Add the vector columns - foreach (var vectorProperty in VectorProperties) - { - string columnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - createTableCommand.AppendLine($" {columnName} {GetPgVectorTypeName(vectorProperty)},"); - } - - createTableCommand.AppendLine($" PRIMARY KEY ({keyName})"); - - createTableCommand.AppendLine(");"); - - return createTableCommand.ToString(); - } - - /// - /// Maps a .NET type to a PostgreSQL type name. - /// - /// The .NET type. - /// The PostgreSQL type name. - private static string GetPostgresTypeName(Type propertyType) - { - var pgType = propertyType switch - { - Type t when t == typeof(int) => "INTEGER", - Type t when t == typeof(string) => "TEXT", - Type t when t == typeof(bool) => "BOOLEAN", - Type t when t == typeof(DateTime) => "TIMESTAMP", - Type t when t == typeof(double) => "DOUBLE PRECISION", - Type t when t == typeof(decimal) => "NUMERIC", - Type t when t == typeof(float) => "REAL", - Type t when t == typeof(byte[]) => "BYTEA", - Type t when t == typeof(Guid) => "UUID", - Type t when t == typeof(short) => "SMALLINT", - Type t when t == typeof(long) => "BIGINT", - _ => null - }; - - if (pgType != null) { return pgType; } - - // Handle arrays (PostgreSQL supports array types for most types) - if (propertyType.IsArray) - { - Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); - return GetPostgresTypeName(elementType) + "[]"; - } - - // Handle nullable types (e.g. Nullable) - if (Nullable.GetUnderlyingType(propertyType) != null) - { - Type underlyingType = Nullable.GetUnderlyingType(propertyType) ?? throw new ArgumentException("Nullable type must have an underlying type."); - return GetPostgresTypeName(underlyingType); - } - - throw new NotSupportedException($"Type {propertyType.Name} is not supported by this store."); - } - - /// - /// Gets the PostgreSQL vector type name based on the dimensions of the vector property. - /// - /// The vector property. - /// The PostgreSQL vector type name. - private static string GetPgVectorTypeName(VectorStoreRecordVectorProperty vectorProperty) - { - if (vectorProperty.Dimensions <= 0) - { - throw new ArgumentException("Vector property must have a positive number of dimensions."); - } - - return $"VECTOR({vectorProperty.Dimensions})"; - } -} \ No newline at end of file From 68a000e0efa1f17fad03dbad0416a7e4b6986260 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Tue, 22 Oct 2024 16:29:38 -0400 Subject: [PATCH 04/62] Add vector search to PostgresVectorStore --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 32 ++- .../IPostgresVectorStoreDbClient.cs | 35 +-- .../PostgresConstants.cs | 29 +-- .../PostgresGenericDataModelMapper.cs | 58 +---- .../PostgresSqlCommandInfo.cs | 4 +- .../PostgresVectorStore.cs | 6 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 106 +++++++- .../PostgresVectorStoreDbClient.cs | 35 ++- .../PostgresVectorStoreOptions.cs | 2 +- .../PostgresVectorStoreRecordCollection.cs | 91 ++++++- ...tgresVectorStoreRecordCollectionOptions.cs | 2 +- .../PostgresVectorStoreRecordMapper.cs | 21 +- ...ostgresVectorStoreRecordPropertyMapping.cs | 71 ++++-- ...ostgresVectorStoreRecordCollectionTests.cs | 26 +- .../Memory/Postgres/PostgresHotel.cs | 2 +- .../Postgres/PostgresVectorStoreFixture.cs | 104 +------- ...ostgresVectorStoreRecordCollectionTests.cs | 233 ++++++++++++++++++ .../src/Linq/AsyncEnumerable.cs | 35 +++ 18 files changed, 619 insertions(+), 273 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 3c23741c5599..ec61286c595c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -1,7 +1,8 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; using Microsoft.Extensions.VectorData; +using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -32,10 +33,10 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The record definition of the table. + /// The properties of the table. /// Specifies whether to include IF NOT EXISTS in the command. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true); + PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true); /// /// Builds a SQL command to drop a table in the Postgres vector store. @@ -70,22 +71,22 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The record definition of the table. + /// The properties of the table. /// The key of the record to get. /// Specifies whether to include vectors in the record. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, TKey key, bool includeVectors = false) where TKey : notnull; + PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) where TKey : notnull; /// /// Builds a SQL command to get a batch of records from the Postgres vector store. /// /// The schema of the table. /// The name of the table. - /// The record definition of the table. + /// The properties of the table. /// The keys of the records to get. /// Specifies whether to include vectors in the records. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, List keys, bool includeVectors = false) where TKey : notnull; + PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) where TKey : notnull; /// /// Builds a SQL command to delete a record from the Postgres vector store. @@ -106,4 +107,19 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// The keys of the records to delete. /// The built SQL command info. PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, string tableName, string keyColumn, List keys); -} \ No newline at end of file + + /// + /// Builds a SQL command to get the nearest match from the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The properties of the table. + /// The property which the vectors to compare are stored in. + /// The vector to match. + /// The filter conditions for the query. + /// The number of records to skip. + /// Specifies whether to include embeddings in the result. + /// The maximum number of records to return. + /// The built SQL command info. + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool withEmbeddings, int limit); +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index a2962215c632..2d3e92c6ab84 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -4,6 +4,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -68,11 +69,11 @@ public interface IPostgresVectorStoreDbClient /// /// The name assigned to a table of entries. /// The key of the entry to get. - /// The record definition of the table. + /// The properties to include in the entry. /// If true, the vectors will be included in the entry. /// The to monitor for cancellation requests. The default is . /// The row if the key is found, otherwise null. - Task?> GetAsync(string tableName, TKey key, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) + Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull; /// @@ -80,11 +81,11 @@ public interface IPostgresVectorStoreDbClient /// /// The name assigned to a table of entries. /// The keys of the entries to get. - /// The record definition of the table. + /// The properties of the table. /// If true, the vectors will be included in the entries. /// The to monitor for cancellation requests. The default is . /// The rows that match the given keys. - IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) + IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull; /// @@ -107,17 +108,21 @@ public interface IPostgresVectorStoreDbClient /// Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default); - // /// - // /// Gets the nearest matches to the . - // /// - // /// The name assigned to a table of entries. - // /// The to compare the table's embeddings with. - // /// The maximum number of similarity results to return. - // /// The minimum relevance threshold for returned results. - // /// If true, the embeddings will be returned in the entries. - // /// The to monitor for cancellation requests. The default is . - // /// An asynchronous stream of objects that the nearest matches to the . - // IAsyncEnumerable<(PostgresMemoryEntry, double)> GetNearestMatchesAsync(string tableName, Vector embedding, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, CancellationToken cancellationToken = default); + /// + /// Gets the nearest matches to the . + /// + /// The name assigned to a table of entries. + /// The properties to retrieve. + /// The property which the vectors to compare are stored in. + /// The to compare the table's vector with. + /// The maximum number of similarity results to return. + /// Optional conditions to filter the results. + /// The number of entries to skip. + /// If true, the vectors will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous stream of objects that the nearest matches to the . + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); // /// // /// Read a entry by its key. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index b5f1939291ac..27575247a73d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -13,9 +13,7 @@ internal static class PostgresConstants typeof(string), typeof(int), typeof(long), - typeof(ulong), typeof(short), - typeof(ushort), ]; /// A of types that data properties on the provided model may have. @@ -25,16 +23,10 @@ internal static class PostgresConstants typeof(bool?), typeof(short), typeof(short?), - typeof(ushort), - typeof(ushort?), typeof(int), typeof(int?), - typeof(uint), - typeof(uint?), typeof(long), typeof(long?), - typeof(ulong), - typeof(ulong?), typeof(float), typeof(float?), typeof(double), @@ -47,28 +39,13 @@ internal static class PostgresConstants typeof(byte[]), typeof(List), typeof(List), - typeof(List), typeof(List), - typeof(List), typeof(List), - typeof(List), typeof(List), typeof(List), typeof(List), typeof(List), typeof(List), - typeof(bool[]), - typeof(short[]), - typeof(ushort[]), - typeof(int[]), - typeof(uint[]), - typeof(long[]), - typeof(ulong[]), - typeof(float[]), - typeof(double[]), - typeof(decimal[]), - typeof(string[]), - typeof(DateTimeOffset[]), ]; /// A of types that vector properties on the provided model may have. @@ -77,4 +54,8 @@ internal static class PostgresConstants typeof(ReadOnlyMemory), typeof(ReadOnlyMemory?) ]; + + /// The name of the column that returns distance value in the database. + /// It is used in the similarity search query. Must not conflict with model property. + public const string DistanceColumnName = "sk_pg_distance"; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs index c13b9a85783b..6a60f3c056ce 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -1,21 +1,19 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; -using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; -internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary>, - IVectorStoreRecordMapper, Dictionary> +internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordMapper, Dictionary> + where TKey : notnull { /// with helpers for reading vector store model properties and their attributes. private readonly VectorStoreRecordPropertyReader _propertyReader; /// - /// Initializes a new instance of the class. - /// + /// Initializes a new instance of the class. + /// /// /// A that defines the schema of the data in the database. public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) { @@ -27,28 +25,7 @@ public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyRe this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: false); this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); } - public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - return this.InternalMapFromDataToStorageModel(dataModel); - } - - VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - { - return this.InternalMapFromStorageToDataModel(storageModel, options); - } - - public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - { - return this.InternalMapFromDataToStorageModel(dataModel); - } - - VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - { - return this.InternalMapFromStorageToDataModel(storageModel, options); - } - - private Dictionary InternalMapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) - where TKey : notnull + public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) { var properties = new Dictionary { @@ -75,14 +52,7 @@ public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyRe { if (dataModel.Vectors.TryGetValue(property.DataModelPropertyName, out var vectorValue)) { - object? result = null; - - if (vectorValue is not null) - { - var vector = (ReadOnlyMemory)vectorValue; - result = new Vector(PostgresVectorStoreRecordPropertyMapping.GetOrCreateArray(vector)); - } - + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vectorValue); properties.Add(this._propertyReader.GetStoragePropertyName(property.DataModelPropertyName), result); } } @@ -91,8 +61,7 @@ public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyRe return properties; } - private VectorStoreGenericDataModel InternalMapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) - where TKey : notnull + VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) { TKey key; var dataProperties = new Dictionary(); @@ -124,18 +93,11 @@ private VectorStoreGenericDataModel InternalMapFromStorageToDataModel.Empty); - } - else if (vectorValue is Vector pgVector) - { - vectorProperties.Add(property.DataModelPropertyName, pgVector.ToArray()); - } + vectorProperties.Add(property.DataModelPropertyName, PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vectorValue)); } } } return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs index 99dadf105fe0..fb8c892d6bf1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -48,4 +48,4 @@ public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection) } return cmd; } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index a17c8e982811..86f263a740ae 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -1,11 +1,11 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; -using System.Threading; using System.Runtime.CompilerServices; -using Microsoft.Extensions.VectorData; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; using Npgsql; namespace Microsoft.SemanticKernel.Connectors.Postgres; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 7a6d2e6905fe..1d681ce11614 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -7,6 +7,7 @@ using Microsoft.Extensions.VectorData; using Npgsql; using NpgsqlTypes; +using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -43,7 +44,7 @@ FROM information_schema.tables } /// - public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true) + public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true) { if (string.IsNullOrWhiteSpace(tableName)) { @@ -54,7 +55,7 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl List dataProperties = new(); List vectorProperties = new(); - foreach (var property in recordDefinition.Properties) + foreach (var property in properties) { if (property is VectorStoreRecordKeyProperty keyProp) { @@ -168,7 +169,7 @@ public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tabl var commandText = $@" INSERT INTO {schema}.""{tableName}"" ({columnNames}) VALUES {valuesRows} - ON CONFLICT(""{keyColumn}"") + ON CONFLICT(""{keyColumn}"") DO UPDATE SET {updateSetClause}; "; // Generate the parameters @@ -189,13 +190,13 @@ ON CONFLICT(""{keyColumn}"") } /// - public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, TKey key, bool includeVectors = false) + public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableName, IReadOnlyList properties, TKey key, bool includeVectors = false) where TKey : notnull { List queryColumns = new(); string? keyColumn = null; - foreach (var property in recordDefinition.Properties) + foreach (var property in properties) { if (property is VectorStoreRecordKeyProperty keyProperty) { @@ -232,7 +233,7 @@ public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableN } /// - public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, VectorStoreRecordDefinition recordDefinition, List keys, bool includeVectors = false) + public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string tableName, IReadOnlyList properties, List keys, bool includeVectors = false) where TKey : notnull { NpgsqlDbType? keyType = PostgresVectorStoreRecordPropertyMapping.GetNpgsqlDbType(typeof(TKey)) ?? throw new ArgumentException($"Unsupported key type {typeof(TKey).Name}"); @@ -242,11 +243,11 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t throw new ArgumentException("Keys cannot be null or empty", nameof(keys)); } - var keyProperty = recordDefinition.Properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Record definition must contain a key property", nameof(recordDefinition)); + var keyProperty = properties.OfType().FirstOrDefault() ?? throw new ArgumentException("Properties must contain a key property", nameof(properties)); var keyColumn = keyProperty.StoragePropertyName ?? keyProperty.DataModelPropertyName; // Generate the column names - var columns = recordDefinition.Properties + var columns = properties .Where(p => includeVectors || p is not VectorStoreRecordVectorProperty) .Select(p => p.StoragePropertyName ?? p.DataModelPropertyName) .ToList(); @@ -303,4 +304,89 @@ DELETE FROM {schema}.""{tableName}"" Parameters = [new NpgsqlParameter() { Value = keys, NpgsqlDbType = NpgsqlDbType.Array | keyType.Value }] }; } -} \ No newline at end of file + + /// + public PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + VectorSearchFilter? filter, int? skip, bool withEmbeddings, int limit) + { + var columns = string.Join(" ,", + properties + .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) + .Select(column => $"\"{column}\"") + ); + + var distanceOp = vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineSimilarity => "<=>", + DistanceFunction.EuclideanDistance => "<->", + DistanceFunction.ManhattanDistance => "<+>", + DistanceFunction.DotProductSimilarity => "<#>", + _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + }; + + var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Start where clause params at 2, vector takes param 1. + var where = GenerateWhereClause(schema, tableName, properties, filter, startParamIndex: 2); + + var commandText = $@" + SELECT {columns}, ""{vectorColumn}"" {distanceOp} $1 AS ""{PostgresConstants.DistanceColumnName}"" + FROM {schema}.""{tableName}"" {where.Clause} + ORDER BY {PostgresConstants.DistanceColumnName} + LIMIT {limit}"; + + if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } + + return new PostgresSqlCommandInfo(commandText) + { + Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] + }; + } + + internal static (string Clause, List Parameters) GenerateWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter? filter, int startParamIndex) + { + if (filter == null) { return (string.Empty, new List()); } + + var whereClause = new StringBuilder("WHERE "); + var filterClauses = new List(); + var parameters = new List(); + + var paramIndex = startParamIndex; + + foreach (var filterClause in filter.FilterClauses) + { + if (filterClause is EqualToFilterClause equalTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == equalTo.FieldName || p.StoragePropertyName == equalTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {equalTo.FieldName} not found in record definition."); } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" = ${paramIndex}"); + parameters.Add(equalTo.Value); + paramIndex++; + } + else if (filterClause is AnyTagEqualToFilterClause anyTagEqualTo) + { + var property = properties.FirstOrDefault(p => p.DataModelPropertyName == anyTagEqualTo.FieldName || p.StoragePropertyName == anyTagEqualTo.FieldName); + if (property == null) { throw new ArgumentException($"Property {anyTagEqualTo.FieldName} not found in record definition."); } + + if (property.PropertyType != typeof(List)) + { + throw new ArgumentException($"Property {anyTagEqualTo.FieldName} must be of type List to use AnyTagEqualTo filter."); + } + + var columnName = property.StoragePropertyName ?? property.DataModelPropertyName; + filterClauses.Add($"\"{columnName}\" @> ARRAY[${paramIndex}::TEXT]"); + parameters.Add(anyTagEqualTo.Value); + paramIndex++; + } + else + { + throw new NotSupportedException($"Filter clause type {filterClause.GetType().Name} is not supported."); + } + } + + whereClause.Append(string.Join(" AND ", filterClauses)); + return (whereClause.ToString(), parameters); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index adfffa86959a..161fbe1c74ff 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Npgsql; +using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -76,7 +77,7 @@ public async Task CreateTableAsync(string tableName, VectorStoreRecordDefinition await using (connection) { - var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, recordDefinition, ifNotExists); + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, recordDefinition.Properties, ifNotExists); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } @@ -122,18 +123,18 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable - public async Task?> GetAsync(string tableName, TKey key, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull + public async Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull { NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, recordDefinition, key, includeVectors); + var commandInfo = this._sqlBuilder.BuildGetCommand(this._schema, tableName, properties, key, includeVectors); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - return this.GetRecord(dataReader, recordDefinition.Properties, includeVectors); + return this.GetRecord(dataReader, properties, includeVectors); } return null; @@ -141,19 +142,19 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable - public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, VectorStoreRecordDefinition recordDefinition, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TKey : notnull { NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, recordDefinition, keys.ToList(), includeVectors); + var commandInfo = this._sqlBuilder.BuildGetBatchCommand(this._schema, tableName, properties, keys.ToList(), includeVectors); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { - yield return this.GetRecord(dataReader, recordDefinition.Properties, includeVectors); + yield return this.GetRecord(dataReader, properties, includeVectors); } } } @@ -171,6 +172,26 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key } } + /// + public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( + string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, properties, vectorProperty, vectorValue, filter, skip, includeVectors, limit); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); + yield return (Row: this.GetRecord(dataReader, properties, includeVectors), Distance: distance); + } + } + } + /// public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs index 131036d7b0c7..c7959f950aaf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. namespace Microsoft.SemanticKernel.Connectors.Postgres; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 30201d0718ad..6737263d6f09 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -16,7 +16,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// The type of the key. /// The type of the record. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection, IVectorizableTextSearch +public sealed class PostgresVectorStoreRecordCollection : IVectorStoreRecordCollection #pragma warning restore CA1711 // Identifiers should not have incorrect suffix where TKey : notnull { @@ -38,6 +38,9 @@ public sealed class PostgresVectorStoreRecordCollection : IVector /// A mapper to use for converting between the data model and the Azure AI Search record. private readonly IVectorStoreRecordMapper> _mapper; + /// The default options for vector search. + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + /// /// Initializes a new instance of the class. /// @@ -79,9 +82,9 @@ public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, { this._mapper = this._options.DictionaryCustomMapper; } - else if (typeof(TRecord) == typeof(VectorStoreGenericDataModel) || typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + else if (typeof(TRecord).IsGenericType && typeof(TRecord).GetGenericTypeDefinition() == typeof(VectorStoreGenericDataModel<>)) { - this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; + this._mapper = (new PostgresGenericDataModelMapper(this._propertyReader) as IVectorStoreRecordMapper>)!; } else { @@ -161,7 +164,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record bool includeVectors = options?.IncludeVectors is true; - var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition, includeVectors, cancellationToken).ConfigureAwait(false); + var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); if (row is null) { return default; } @@ -181,7 +184,7 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get bool includeVectors = options?.IncludeVectors is true; - await foreach (var row in this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition, includeVectors, cancellationToken).ConfigureAwait(false)) + await foreach (var row in this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false)) { yield return VectorStoreErrorHandler.RunModelConversion( DatabaseName, @@ -203,15 +206,81 @@ public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? option return this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken); } - /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + /// + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + Verify.NotNull(vector); + + var vectorType = vector.GetType(); + + if (!PostgresConstants.SupportedVectorTypes.Contains(vectorType)) + { + throw new NotSupportedException( + $"The provided vector type {vectorType.FullName} is not supported by the SQLite connector. " + + $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); + } + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); + + if (vectorProperty is null) + { + throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); + } + + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + Verify.NotNull(pgVector); + + // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination + // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. + var limit = searchOptions.Top + searchOptions.Skip; + + var results = this._client.GetNearestMatchesAsync( + this.CollectionName, + this._propertyReader.RecordDefinition.Properties, + vectorProperty, + pgVector, + searchOptions.Top, + options.Filter, + searchOptions.Skip, + searchOptions.IncludeVectors, + cancellationToken + ).Select(result => + { + var record = this._mapper.MapFromStorageToDataModel( + result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }); + + return new VectorSearchResult(record, result.Distance); + }, cancellationToken); + + return new VectorSearchResults(results); } - /// - public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + /// + /// Get vector property to use for a search by using the storage name for the field name from options + /// if available, and falling back to the first vector property in if not. + /// + /// The vector field name. + /// Thrown if the provided field name is not a valid field name. + private VectorStoreRecordVectorProperty? GetVectorPropertyForSearch(string? vectorFieldName) { - throw new NotImplementedException(); + // If vector property name is provided in options, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorFieldName)) + { + // Check vector properties by data model property name. + var vectorProperty = this._propertyReader.VectorProperties + .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); + + if (vectorProperty is not null) + { + return vectorProperty; + } + + throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{vectorFieldName}'."); + } + + // If vector property is not provided in options, return first vector property from schema. + return this._propertyReader.VectorProperty; } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs index 373f01a25a99..753713d21b3f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollectionOptions.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; using Microsoft.Extensions.VectorData; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs index 8559d1218054..2bf87e11b645 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -1,9 +1,8 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; -using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -11,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// A mapper class that handles the conversion between data models and storage models for Postgres vector store. /// /// The type of the data model record. -internal class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> +internal sealed class PostgresVectorStoreRecordMapper : IVectorStoreRecordMapper> { /// with helpers for reading vector store model properties and their attributes. private readonly VectorStoreRecordPropertyReader _propertyReader; @@ -50,14 +49,8 @@ public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyR // Add vector properties foreach (var property in this._propertyReader.VectorPropertiesInfo) { - object? result = null; var propertyValue = property.GetValue(dataModel); - - if (propertyValue is not null) - { - var vector = (ReadOnlyMemory)propertyValue; - result = new Vector(PostgresVectorStoreRecordPropertyMapping.GetOrCreateArray(vector)); - } + var result = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(propertyValue); properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), result); } @@ -91,12 +84,14 @@ public TRecord MapFromStorageToDataModel(Dictionary storageMode this._propertyReader.VectorPropertiesInfo, this._propertyReader.StoragePropertyNamesMap, storageModel, - (object? vector, Type type) => vector is Vector pgVector ? - pgVector.ToArray() : null); + (object? vector, Type type) => + { + return PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(vector); + }); VectorStoreRecordMapping.SetPropertiesOnRecord(record, vectorPropertiesInfoWithValues); } return record; } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index bfedf8dd9424..c26c47f3127d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections; @@ -7,6 +7,7 @@ using Microsoft.Extensions.VectorData; using Npgsql; using NpgsqlTypes; +using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -18,6 +19,33 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => array.Array : memory.ToArray(); + public static Vector? MapVectorForStorageModel(TVector vector) + { + if (vector == null) + { + return null; + } + + if (vector is ReadOnlyMemory floatMemory) + { + var vecArray = MemoryMarshal.TryGetArray(floatMemory, out ArraySegment array) && + array.Count == array.Array!.Length ? + array.Array : + floatMemory.ToArray(); + return new Vector(vecArray); + } + + throw new NotSupportedException($"Mapping for type {typeof(TVector).FullName} to a vector is not supported."); + } + + public static ReadOnlyMemory? MapVectorForDataModel(object? vector) + { + var pgVector = vector is Vector pgv ? pgv : null; + if (pgVector == null) { return null; } + var vecArray = pgVector.ToArray(); + return vecArray != null && vecArray.Length != 0 ? (ReadOnlyMemory)vecArray : null; + } + public static TPropertyType? GetPropertyValue(NpgsqlDataReader reader, string propertyName) { int propertyIndex = reader.GetOrdinal(propertyName); @@ -50,16 +78,17 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => return propertyType switch { + Type t when t == typeof(bool) || t == typeof(bool?) => reader.GetBoolean(propertyIndex), + Type t when t == typeof(short) || t == typeof(short?) => reader.GetInt16(propertyIndex), Type t when t == typeof(int) || t == typeof(int?) => reader.GetInt32(propertyIndex), Type t when t == typeof(long) || t == typeof(long?) => reader.GetInt64(propertyIndex), - Type t when t == typeof(ulong) || t == typeof(ulong?) => (ulong)reader.GetInt64(propertyIndex), - Type t when t == typeof(short) || t == typeof(short?) => reader.GetInt16(propertyIndex), - Type t when t == typeof(ushort) || t == typeof(ushort?) => (ushort)reader.GetInt16(propertyIndex), - Type t when t == typeof(bool) || t == typeof(bool?) => reader.GetBoolean(propertyIndex), Type t when t == typeof(float) || t == typeof(float?) => reader.GetFloat(propertyIndex), Type t when t == typeof(double) || t == typeof(double?) => reader.GetDouble(propertyIndex), Type t when t == typeof(decimal) || t == typeof(decimal?) => reader.GetDecimal(propertyIndex), Type t when t == typeof(string) => reader.GetString(propertyIndex), + Type t when t == typeof(byte[]) => reader.GetFieldValue(propertyIndex), + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => reader.GetDateTime(propertyIndex), + Type t when t == typeof(Guid) => reader.GetFieldValue(propertyIndex), _ => reader.GetValue(propertyIndex) }; } @@ -67,18 +96,16 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => public static NpgsqlDbType? GetNpgsqlDbType(Type propertyType) => propertyType switch { + Type t when t == typeof(bool) || t == typeof(bool?) => NpgsqlDbType.Boolean, + Type t when t == typeof(short) || t == typeof(short?) => NpgsqlDbType.Smallint, Type t when t == typeof(int) || t == typeof(int?) => NpgsqlDbType.Integer, Type t when t == typeof(long) || t == typeof(long?) => NpgsqlDbType.Bigint, - Type t when t == typeof(ulong) || t == typeof(ulong?) => NpgsqlDbType.Bigint, - Type t when t == typeof(short) || t == typeof(short?) => NpgsqlDbType.Smallint, - Type t when t == typeof(ushort) || t == typeof(ushort?) => NpgsqlDbType.Smallint, - Type t when t == typeof(bool) || t == typeof(bool?) => NpgsqlDbType.Boolean, Type t when t == typeof(float) || t == typeof(float?) => NpgsqlDbType.Real, Type t when t == typeof(double) || t == typeof(double?) => NpgsqlDbType.Double, Type t when t == typeof(decimal) || t == typeof(decimal?) => NpgsqlDbType.Numeric, Type t when t == typeof(string) => NpgsqlDbType.Text, - Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, Type t when t == typeof(byte[]) => NpgsqlDbType.Bytea, + Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, Type t when t == typeof(Guid) => NpgsqlDbType.Uuid, _ => null }; @@ -92,17 +119,17 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property { var (pgType, isNullable) = propertyType switch { - Type t when t == typeof(int) => ("INTEGER", false), - Type t when t == typeof(string) => ("TEXT", true), Type t when t == typeof(bool) => ("BOOLEAN", false), - Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(short) => ("SMALLINT", false), + Type t when t == typeof(int) => ("INTEGER", false), + Type t when t == typeof(long) => ("BIGINT", false), + Type t when t == typeof(float) => ("REAL", false), Type t when t == typeof(double) => ("DOUBLE PRECISION", false), Type t when t == typeof(decimal) => ("NUMERIC", false), - Type t when t == typeof(float) => ("REAL", false), + Type t when t == typeof(string) => ("TEXT", true), Type t when t == typeof(byte[]) => ("BYTEA", true), + Type t when t == typeof(DateTime) => ("TIMESTAMP", false), Type t when t == typeof(Guid) => ("UUID", false), - Type t when t == typeof(short) => ("SMALLINT", false), - Type t when t == typeof(long) => ("BIGINT", false), _ => (null, false) }; @@ -111,14 +138,8 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property return (pgType, isNullable); } - // Handle lists and arrays (PostgreSQL supports array types for most types) - if (propertyType.IsArray) - { - Type elementType = propertyType.GetElementType() ?? throw new ArgumentException("Array type must have an element type."); - var underlyingPgType = GetPostgresTypeName(elementType); - return (underlyingPgType.PgType + "[]", true); - } - else if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + // Handle lists + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) { Type elementType = propertyType.GetGenericArguments()[0]; var underlyingPgType = GetPostgresTypeName(elementType); @@ -164,4 +185,4 @@ private static object ConvertList(IEnumerable list, Type elementType) return convertedList; } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 69373123431b..6ea820b89407 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -31,7 +31,7 @@ public async Task CreatesCollectionForGenericModelAsync() var recordDefinition = new VectorStoreRecordDefinition { Properties = [ - new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordKeyProperty("HotelId", typeof(int)), new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, @@ -41,11 +41,11 @@ public async Task CreatesCollectionForGenericModelAsync() new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 100, DistanceFunction = DistanceFunction.ManhattanDistance } ] }; - var options = new PostgresVectorStoreRecordCollectionOptions>() + var options = new PostgresVectorStoreRecordCollectionOptions>() { VectorStoreRecordDefinition = recordDefinition }; - var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options); this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TestCollectionName, this._testCancellationToken)).ReturnsAsync(false); // Act @@ -55,6 +55,26 @@ public async Task CreatesCollectionForGenericModelAsync() Assert.False(exists); } + [Fact] + public void ThrowsForUnsupportedType() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + ] + }; + var options = new PostgresVectorStoreRecordCollectionOptions>() + { + VectorStoreRecordDefinition = recordDefinition + }; + + // Act & Assert + Assert.Throws(() => new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName, options)); + } + [Fact] public async Task UpsertRecordAsyncProducesExpectedSqlAsync() { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs index 3510280397e0..61fd2e386990 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -44,7 +44,7 @@ public record PostgresHotel() public string Description { get; set; } /// A vector field. - [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index ab7afa0489c6..b87f17f198bb 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -90,7 +90,7 @@ public PostgresVectorStoreRecordCollection GetCollectionAn async task. public async Task InitializeAsync() { - this._containerId = await SetupPostgresContainerAsync(this._client); + //this._containerId = await SetupPostgresContainerAsync(this._client); this._connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; this._databaseName = $"sk_it_{Guid.NewGuid():N}"; @@ -144,106 +144,6 @@ public async Task InitializeAsync() // Create the table. await this.CreateTableAsync(); - - // await this.PostgresClient.CreateCollectionAsync( - // "singleVectorHotels", - // new VectorParams { Size = 4, Distance = Distance.Cosine }); - - // await this.PostgresClient.CreateCollectionAsync( - // "singleVectorGuidIdHotels", - // new VectorParams { Size = 4, Distance = Distance.Cosine }); - - // // Create test data common to both named and unnamed vectors. - // var tags = new ListValue(); - // tags.Values.Add("t1"); - // tags.Values.Add("t2"); - // var tagsValue = new Value(); - // tagsValue.ListValue = tags; - - // // Create some test data using named vectors. - // var embedding = new[] { 30f, 31f, 32f, 33f }; - - // var namedVectors1 = new NamedVectors(); - // var namedVectors2 = new NamedVectors(); - // var namedVectors3 = new NamedVectors(); - - // namedVectors1.Vectors.Add("DescriptionEmbedding", embedding); - // namedVectors2.Vectors.Add("DescriptionEmbedding", embedding); - // namedVectors3.Vectors.Add("DescriptionEmbedding", embedding); - - // List namedVectorPoints = - // [ - // new PointStruct - // { - // Id = 11, - // Vectors = new Vectors { Vectors_ = namedVectors1 }, - // Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = 12, - // Vectors = new Vectors { Vectors_ = namedVectors2 }, - // Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = 13, - // Vectors = new Vectors { Vectors_ = namedVectors3 }, - // Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } - // }, - // ]; - - // await this.PostgresClient.UpsertAsync("namedVectorsHotels", namedVectorPoints); - - // // Create some test data using a single unnamed vector. - // List unnamedVectorPoints = - // [ - // new PointStruct - // { - // Id = 11, - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 11", ["HotelCode"] = 11, ["parking_is_included"] = true, ["Tags"] = tagsValue, ["HotelRating"] = 4.5f, ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = 12, - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 12", ["HotelCode"] = 12, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = 13, - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 13", ["HotelCode"] = 13, ["parking_is_included"] = false, ["Description"] = "This is a great hotel." } - // }, - // ]; - - // await this.PostgresClient.UpsertAsync("singleVectorHotels", unnamedVectorPoints); - - // // Create some test data using a single unnamed vector and a guid id. - // List unnamedVectorGuidIdPoints = - // [ - // new PointStruct - // { - // Id = Guid.Parse("11111111-1111-1111-1111-111111111111"), - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 11", ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = Guid.Parse("22222222-2222-2222-2222-222222222222"), - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 12", ["Description"] = "This is a great hotel." } - // }, - // new PointStruct - // { - // Id = Guid.Parse("33333333-3333-3333-3333-333333333333"), - // Vectors = embedding, - // Payload = { ["HotelName"] = "My Hotel 13", ["Description"] = "This is a great hotel." } - // }, - // ]; - - // await this.PostgresClient.UpsertAsync("singleVectorGuidIdHotels", unnamedVectorGuidIdPoints); } private async Task CreateTableAsync() @@ -279,6 +179,8 @@ public async Task DisposeAsync() this._dataSource.Dispose(); } + await this.DropDatabaseAsync(); + if (this._containerId != null) { await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 7e1398b441ed..f6bfe584d452 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Linq; using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -161,4 +164,234 @@ public async Task ItCanGetUpsertDeleteBatchAsync() // Assert Assert.Empty(getResults); } + + [Fact] + public async Task ItCanUpsertExistingRecordAsync() + { + // Arrange + const int HotelId = 5; + var sut = fixture.GetCollection("UpsertRecord"); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + Assert.Null(getResult!.DescriptionEmbedding); + + // Act + record.HotelName = "Updated name"; + record.HotelRating = 10; + record.DescriptionEmbedding = new[] { 1f, 2f, 3f, 4f }; + + upsertResult = await sut.UpsertAsync(record); + getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal("Updated name", getResult.HotelName); + Assert.Equal(10, getResult.HotelRating); + + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), getResult.DescriptionEmbedding.Value.ToArray()); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + { + const int HotelId = 5; + + var options = new PostgresVectorStoreRecordCollectionOptions> + { + VectorStoreRecordDefinition = GetVectorStoreRecordDefinition() + }; + + var sut = fixture.GetCollection>("GenericMapperWithNumericKey", options); + + await sut.CreateCollectionAsync(); + + var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + + // Act + var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } + } + }); + + var localGetResult = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.Equal(HotelId, upsertResult); + + Assert.NotNull(localGetResult); + Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); + Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); + Assert.True((bool?)localGetResult.Data["ParkingIncluded"]); + Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); + Assert.Equal([30f, 31f, 32f, 33f], ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + + // Act - update with null embeddings + // Act + var upsertResult2 = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "HotelCode", 1 }, + { "ParkingIncluded", true }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", null } + } + }); + + var localGetResult2 = await sut.GetAsync(HotelId, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(localGetResult2); + Assert.Null(localGetResult2.Vectors["DescriptionEmbedding"]); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool includeVectors) + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection($"VectorizedSearch_{includeVectors}"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + { + IncludeVectors = includeVectors + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal(1, ids[0]); + Assert.Equal(4, ids[1]); + Assert.Equal(3, ids[2]); + + // Default limit is 3 + Assert.DoesNotContain(2, ids); + + Assert.Equal(0, results.First(l => l.Record.HotelId == 1).Score); + + Assert.Equal(includeVectors, results.All(result => result.Record.DescriptionEmbedding is not null)); + } + + [Fact] + public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection("VectorizedSearchWithEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new EqualToFilterClause("HotelRating", 2.5f) + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3, 2], ids); + } + + [Fact] + public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() + { + // Arrange + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag2", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + + var sut = fixture.GetCollection("VectorizedSearchWithEqualToFilter"); + + await sut.CreateCollectionAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 29f, 28f, 27f]), new() + { + IncludeVectors = false, + Top = 5, + Filter = new([ + new AnyTagEqualToFilterClause("Tags", "tag2") + ]) + }); + + var results = await searchResults.Results.ToListAsync(); + + // Assert + var ids = results.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal([1, 3], ids); + } + + #region private ================================================================================== + + private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition() => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(TKey)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)), + new VectorStoreRecordDataProperty("HotelCode", typeof(int)), + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + ] + }; + + #endregion + } \ No newline at end of file diff --git a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs index 844ae7e2f573..4b236fea2285 100644 --- a/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs +++ b/dotnet/src/InternalUtilities/src/Linq/AsyncEnumerable.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -135,6 +136,40 @@ static async ValueTask Core(IAsyncEnumerable source, Func + /// Projects each element of an into a new form by incorporating + /// an asynchronous transformation function. + /// + /// The type of the elements of the source sequence. + /// The type of the elements of the resulting sequence. + /// An to invoke a transform function on. + /// + /// A transform function to apply to each element. This function takes an element of + /// type TSource and returns an element of type TResult. + /// + /// + /// A CancellationToken to observe while iterating through the sequence. + /// + /// + /// An whose elements are the result of invoking the transform + /// function on each element of the original sequence. + /// + /// Thrown when the source or selector is null. + public static async IAsyncEnumerable Select( + this IAsyncEnumerable source, + Func selector, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return selector(item); + } + } + +#pragma warning restore IDE1006 // Naming rule violation: Missing suffix: 'Async' + private sealed class EmptyAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerator { public static readonly EmptyAsyncEnumerable Instance = new(); From 317f6af0b40b32fca1df0f8d488ab5f5c94e210a Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Wed, 23 Oct 2024 11:24:45 -0400 Subject: [PATCH 05/62] create index on collection creation --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 9 ++++ .../IPostgresVectorStoreDbClient.cs | 46 +++-------------- ...PostgresVectorStoreCollectionSqlBuilder.cs | 51 +++++++++++++++++++ .../PostgresVectorStoreDbClient.cs | 18 +++++++ .../PostgresVectorStoreRecordCollection.cs | 15 ++++-- 5 files changed, 96 insertions(+), 43 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index ec61286c595c..47aa135e9f93 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -38,6 +38,15 @@ public interface IPostgresVectorStoreCollectionSqlBuilder /// The built SQL command info. PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tableName, IReadOnlyList properties, bool ifNotExists = true); + /// + /// Builds a SQL command to create a vector index in the Postgres vector store. + /// + /// The schema of the table. + /// The name of the table. + /// The vector property to create an index for. + /// The built SQL command info. + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, VectorStoreRecordVectorProperty vectorProperty); + /// /// Builds a SQL command to drop a table in the Postgres vector store. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 2d3e92c6ab84..6da6ecc0ffc9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -37,6 +37,14 @@ public interface IPostgresVectorStoreDbClient /// Task CreateTableAsync(string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true, CancellationToken cancellationToken = default); + /// + /// Create a vector index. + /// + /// The name assigned to a table of entries. + /// The vector property to create an index for. + /// The to monitor for cancellation requests. The default is . + Task CreateVectorIndexAsync(string tableName, VectorStoreRecordVectorProperty vectorProperty, CancellationToken cancellationToken = default); + /// /// Drop a table. /// @@ -123,42 +131,4 @@ public interface IPostgresVectorStoreDbClient /// An asynchronous stream of objects that the nearest matches to the . IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); - - // /// - // /// Read a entry by its key. - // /// - // /// The name assigned to a table of entries. - // /// The key of the entry to read. - // /// If true, the embeddings will be returned in the entry. - // /// The to monitor for cancellation requests. The default is . - // /// - // Task ReadAsync(string tableName, string key, bool withEmbeddings = false, CancellationToken cancellationToken = default); - - // /// - // /// Read multiple entries by their keys. - // /// - // /// The name assigned to a table of entries. - // /// The keys of the entries to read. - // /// If true, the embeddings will be returned in the entries. - // /// The to monitor for cancellation requests. The default is . - // /// An asynchronous stream of objects that match the given keys. - // IAsyncEnumerable ReadBatchAsync(string tableName, IEnumerable keys, bool withEmbeddings = false, CancellationToken cancellationToken = default); - - // /// - // /// Delete a entry by its key. - // /// - // /// The name assigned to a table of entries. - // /// The key of the entry to delete. - // /// The to monitor for cancellation requests. The default is . - // /// - // Task DeleteAsync(string tableName, string key, CancellationToken cancellationToken = default); - - // /// - // /// Delete multiple entries by their key. - // /// - // /// The name assigned to a table of entries. - // /// The keys of the entries to delete. - // /// The to monitor for cancellation requests. The default is . - // /// - // Task DeleteBatchAsync(string tableName, IEnumerable keys, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 1d681ce11614..5cecaa0a9a60 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -116,6 +116,36 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl return new PostgresSqlCommandInfo(commandText: createTableCommand.ToString()); } + /// + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, VectorStoreRecordVectorProperty vectorProperty) + { + var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Only support creating HNSW index creation through the connector. + var indexTypeName = vectorProperty.IndexKind switch + { + IndexKind.Hnsw => "hnsw", + _ => throw new NotSupportedException($"Index kind '{vectorProperty.IndexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") + }; + + var indexOps = vectorProperty.DistanceFunction switch + { + DistanceFunction.CosineDistance => "vector_cosine_ops", + DistanceFunction.CosineSimilarity => "vector_cosine_ops", + DistanceFunction.DotProductSimilarity => "vector_ip_ops", + DistanceFunction.EuclideanDistance => "vector_l2_ops", + DistanceFunction.ManhattanDistance => "vector_l1_ops", + null => throw new ArgumentException("Distance function must be specified for HNSW index."), + _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + }; + + var indexName = $"{tableName}_{vectorColumnName}_index"; + + return new PostgresSqlCommandInfo( + commandText: $@" + CREATE INDEX {indexName} ON {schema}.""{tableName}"" USING {indexTypeName} (""{vectorColumnName}"" {indexOps});" + ); + } + /// public PostgresSqlCommandInfo BuildDropTableCommand(string schema, string tableName) { @@ -318,6 +348,7 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( var distanceOp = vectorProperty.DistanceFunction switch { + DistanceFunction.CosineDistance => "<=>", DistanceFunction.CosineSimilarity => "<=>", DistanceFunction.EuclideanDistance => "<->", DistanceFunction.ManhattanDistance => "<+>", @@ -337,6 +368,26 @@ ORDER BY {PostgresConstants.DistanceColumnName} if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } + // For cosine similarity, we need to take 1 - cosine distance. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.CosineSimilarity) + { + commandText = $@" + SELECT {columns}, 1 - ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + + // For inner product, we need to take -1 * inner product. + // However, we can't use an expression in the ORDER BY clause or else the index won't be used. + // Instead we'll wrap the query in a subquery and modify the distance in the outer query. + if (vectorProperty.DistanceFunction == DistanceFunction.DotProductSimilarity) + { + commandText = $@" + SELECT {columns}, -1 * ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" + FROM ({commandText}) AS subquery"; + } + return new PostgresSqlCommandInfo(commandText) { Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 161fbe1c74ff..84e72b467df2 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -83,6 +83,24 @@ public async Task CreateTableAsync(string tableName, VectorStoreRecordDefinition } } + /// + public async Task CreateVectorIndexAsync(string tableName, VectorStoreRecordVectorProperty vectorProperty, CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(vectorProperty.IndexKind)) + { + return; + } + + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + var commandInfo = this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, vectorProperty); + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + /// public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 6737263d6f09..b31d5007a51d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -99,9 +99,14 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke } /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, false, cancellationToken); + await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, false, cancellationToken).ConfigureAwait(false); + // Create indexes for vector properties. + foreach (var vectorProperty in this._propertyReader.VectorProperties) + { + await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); + } } /// @@ -207,7 +212,7 @@ public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? option } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -242,7 +247,7 @@ public async Task> VectorizedSearchAsync(T vectorProperty, pgVector, searchOptions.Top, - options.Filter, + searchOptions.Filter, searchOptions.Skip, searchOptions.IncludeVectors, cancellationToken @@ -254,7 +259,7 @@ public async Task> VectorizedSearchAsync(T return new VectorSearchResult(record, result.Distance); }, cancellationToken); - return new VectorSearchResults(results); + return Task.FromResult(new VectorSearchResults(results)); } /// From f4f5ba2df3cfae55bc590b45c373edafbab6ec59 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Wed, 23 Oct 2024 11:26:40 -0400 Subject: [PATCH 06/62] Support Guid, test distance functions --- .../PostgresConstants.cs | 7 +- ...resVectorStoreCollectionSqlBuilderTests.cs | 45 +------ .../Memory/Postgres/PostgresHotel.cs | 9 +- .../Postgres/PostgresVectorStoreFixture.cs | 2 +- ...ostgresVectorStoreRecordCollectionTests.cs | 127 ++++++++++++------ .../Postgres/PostgresVectorStoreTests.cs | 2 +- 6 files changed, 101 insertions(+), 91 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 27575247a73d..5a192ca8f268 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -10,10 +10,11 @@ internal static class PostgresConstants /// A of types that a key on the provided model may have. public static readonly HashSet SupportedKeyTypes = [ - typeof(string), + typeof(short), typeof(int), typeof(long), - typeof(short), + typeof(string), + typeof(Guid), ]; /// A of types that data properties on the provided model may have. @@ -36,6 +37,8 @@ internal static class PostgresConstants typeof(string), typeof(DateTimeOffset), typeof(DateTimeOffset?), + typeof(Guid), + typeof(Guid?), typeof(byte[]), typeof(List), typeof(List), diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs index 45d9391f62c3..2df54addf7d7 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -46,7 +46,7 @@ public void TestBuildCreateTableCommand() ] }; - var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition, ifNotExists: true); + var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: true); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); @@ -137,49 +137,12 @@ public void TestBuildGetCommand() var key = 123; // Act - var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition, key, includeVectors: true); - - // Assert - Assert.Contains("SELECT", cmdInfo.CommandText); - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } - - [Fact] - public void TestBuildGetBatchCommand() - { - // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var recordDefinition = new VectorStoreRecordDefinition() - { - Properties = [ - new VectorStoreRecordKeyProperty("id", typeof(long)), - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("code", typeof(int)), - new VectorStoreRecordDataProperty("rating", typeof(float?)), - new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), - new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { - Dimensions = 10, - IndexKind = "hnsw", - }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { - Dimensions = 10, - IndexKind = "hnsw", - } - ] - }; - - var keys = new List { 123, 456, 789 }; - - // Act - var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition, keys, includeVectors: true); + var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); // Assert Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); // Output this._output.WriteLine(cmdInfo.CommandText); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs index 61fd2e386990..f60429560a72 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -11,11 +11,11 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; /// /// A test model for the postgres vector store. /// -public record PostgresHotel() +public record PostgresHotel() { /// The key of the record. [VectorStoreRecordKey] - public int HotelId { get; init; } + public T HotelId { get; init; } /// A string metadata field. [VectorStoreRecordData()] @@ -46,6 +46,11 @@ public record PostgresHotel() /// A vector field. [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } + + public PostgresHotel(T key) : this() + { + this.HotelId = key; + } } #pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index b87f17f198bb..c44a6601e25a 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -18,7 +18,7 @@ public class PostgresVectorStoreFixture : IAsyncLifetime private readonly DockerClient _client; /// The id of the postgres container that we are testing with. - private string? _containerId = null; + private readonly string? _containerId = null; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index f6bfe584d452..ab6715e2b372 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -18,7 +19,7 @@ public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStore public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollection) { // Arrange - var sut = fixture.GetCollection("CollectionExists"); + var sut = fixture.GetCollection>("CollectionExists"); if (createCollection) { @@ -47,7 +48,7 @@ public async Task CollectionExistsReturnsCollectionStateAsync(bool createCollect public async Task CollectionCanUpsertAndGetAsync() { // Arrange - var sut = fixture.GetCollection("CollectionCanUpsertAndGet"); + var sut = fixture.GetCollection>("CollectionCanUpsertAndGet"); if (await sut.CollectionExistsAsync()) { await sut.DeleteCollectionAsync(); @@ -58,8 +59,8 @@ public async Task CollectionCanUpsertAndGetAsync() try { // Act - await sut.UpsertAsync(new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }); - await sut.UpsertAsync(new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }); + await sut.UpsertAsync(new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }); + await sut.UpsertAsync(new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }); var hotel1 = await sut.GetAsync(1); var hotel2 = await sut.GetAsync(2); @@ -97,29 +98,39 @@ public async Task CollectionCanUpsertAndGetAsync() } } - [Fact] - public async Task ItCanGetAndDeleteRecordAsync() + [Theory] + [InlineData(typeof(short), (short)3)] + [InlineData(typeof(int), 5)] + [InlineData(typeof(long), 7L)] + [InlineData(typeof(string), "key1")] + [InlineData(typeof(Guid), null)] + public async Task ItCanGetAndDeleteRecordAsync(Type idType, object? key) { + if (idType == typeof(Guid)) + { + key = Guid.NewGuid(); + } + // Arrange - const int HotelId = 5; - var sut = fixture.GetCollection("DeleteRecord"); + var collectionName = "DeleteRecord"; + dynamic sut = this.GetCollection(idType, collectionName); await sut.CreateCollectionAsync(); try { - var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; - + dynamic record = this.CreateRecord(idType, key!); + dynamic recordKey = record.HotelId; var upsertResult = await sut.UpsertAsync(record); - var getResult = await sut.GetAsync(HotelId); + var getResult = await sut.GetAsync(recordKey); - Assert.Equal(HotelId, upsertResult); + Assert.Equal(key, upsertResult); Assert.NotNull(getResult); // Act - await sut.DeleteAsync(HotelId); + await sut.DeleteAsync(recordKey); - getResult = await sut.GetAsync(HotelId); + getResult = await sut.GetAsync(recordKey); // Assert Assert.Null(getResult); @@ -139,13 +150,13 @@ public async Task ItCanGetUpsertDeleteBatchAsync() const int HotelId2 = 2; const int HotelId3 = 3; - var sut = fixture.GetCollection("GetUpsertDeleteBatch"); + var sut = fixture.GetCollection>("GetUpsertDeleteBatch"); await sut.CreateCollectionAsync(); - var record1 = new PostgresHotel { HotelId = HotelId1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; - var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; - var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; + var record1 = new PostgresHotel { HotelId = HotelId1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record2 = new PostgresHotel { HotelId = HotelId2, HotelName = "Hotel 2", HotelCode = 1, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag3"] }; + var record3 = new PostgresHotel { HotelId = HotelId3, HotelName = "Hotel 3", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"] }; var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); @@ -170,11 +181,11 @@ public async Task ItCanUpsertExistingRecordAsync() { // Arrange const int HotelId = 5; - var sut = fixture.GetCollection("UpsertRecord"); + var sut = fixture.GetCollection>("UpsertRecord"); await sut.CreateCollectionAsync(); - var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record = new PostgresHotel { HotelId = HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; var upsertResult = await sut.UpsertAsync(record); var getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = true }); @@ -214,7 +225,7 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() await sut.CreateCollectionAsync(); - var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var record = new PostgresHotel { HotelId = (int)HotelId, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; // Act var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel(HotelId) @@ -271,24 +282,31 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool includeVectors) + [InlineData(true, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineDistance)] + [InlineData(false, DistanceFunction.CosineSimilarity)] + [InlineData(false, DistanceFunction.EuclideanDistance)] + [InlineData(false, DistanceFunction.ManhattanDistance)] + [InlineData(false, DistanceFunction.DotProductSimilarity)] + public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool includeVectors, string distanceFunction) { // Arrange - var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; - var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; - var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; - var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 1f, 0f, 0f, 0f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 0f, 1f, 0f, 0f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 0f, 0f, 1f, 0f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 0f, 0f, 0f, 1f } }; - var sut = fixture.GetCollection($"VectorizedSearch_{includeVectors}"); + var sut = fixture.GetCollection>($"VectorizedSearch_{includeVectors}_{distanceFunction}", new() + { + VectorStoreRecordDefinition = GetVectorStoreRecordDefinition(distanceFunction) + }); await sut.CreateCollectionAsync(); await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); // Act - var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + var searchResults = await sut.VectorizedSearchAsync(new ReadOnlyMemory([0.9f, 0.1f, 0.5f, 0.8f]), new() { IncludeVectors = includeVectors }); @@ -305,7 +323,7 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include // Default limit is 3 Assert.DoesNotContain(2, ids); - Assert.Equal(0, results.First(l => l.Record.HotelId == 1).Score); + Assert.True(0 < results.First(l => l.Record.HotelId == 1).Score); Assert.Equal(includeVectors, results.All(result => result.Record.DescriptionEmbedding is not null)); } @@ -314,12 +332,12 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() { // Arrange - var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; - var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; - var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; - var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; - var sut = fixture.GetCollection("VectorizedSearchWithEqualToFilter"); + var sut = fixture.GetCollection>("VectorizedSearchWithEqualToFilter"); await sut.CreateCollectionAsync(); @@ -347,12 +365,12 @@ public async Task VectorizedSearchWithEqualToFilterReturnsValidResultsAsync() public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() { // Arrange - var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; - var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; - var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag2", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; - var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; + var hotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag1", "tag2"], DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f } }; + var hotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, Tags = ["tag1", "tag3"], DescriptionEmbedding = new[] { 10f, 10f, 10f, 10f } }; + var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 2.5f, Tags = ["tag2", "tag4"], DescriptionEmbedding = new[] { 20f, 20f, 20f, 20f } }; + var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 3.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 40f, 40f, 40f, 40f } }; - var sut = fixture.GetCollection("VectorizedSearchWithEqualToFilter"); + var sut = fixture.GetCollection>("VectorizedSearchWithAnyTagEqualToFilter"); await sut.CreateCollectionAsync(); @@ -378,20 +396,41 @@ public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() #region private ================================================================================== - private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition() => new() + private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition(string distanceFunction = DistanceFunction.CosineDistance) => new() { Properties = [ new VectorStoreRecordKeyProperty("HotelId", typeof(TKey)), new VectorStoreRecordDataProperty("HotelName", typeof(string)), new VectorStoreRecordDataProperty("HotelCode", typeof(int)), + new VectorStoreRecordDataProperty("HotelRating", typeof(float?)), new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("HotelRating", typeof(float)), + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("ListInts", typeof(List)), new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.Hnsw, DistanceFunction = distanceFunction } ] }; + private dynamic GetCollection(Type idType, string collectionName) + { + var method = typeof(PostgresVectorStoreFixture).GetMethod("GetCollection"); + var genericMethod = method!.MakeGenericMethod(idType, typeof(PostgresHotel<>).MakeGenericType(idType)); + return genericMethod.Invoke(fixture, [collectionName, null])!; + } + + private dynamic CreateRecord(Type idType, object key) + { + var recordType = typeof(PostgresHotel<>).MakeGenericType(idType); + dynamic record = Activator.CreateInstance(recordType, key)!; + record.HotelName = "Hotel 1"; + record.HotelCode = 1; + record.ParkingIncluded = true; + record.HotelRating = 4.5f; + record.Tags = new List { "tag1", "tag2" }; + return record; + } + #endregion } \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs index 0ae8fc8d5228..cd3d5781d8fc 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -17,7 +17,7 @@ public async Task ItCanGetAListOfExistingCollectionNamesAsync() var sut = new PostgresVectorStore(fixture.PostgresClient); // Setup - var collection = sut.GetCollection("VS_TEST_HOTELS"); + var collection = sut.GetCollection>("VS_TEST_HOTELS"); await collection.CreateCollectionIfNotExistsAsync(); // Act From 2acf11806d31ba02f39cd5c04155a37d4349b51e Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Wed, 23 Oct 2024 11:44:55 -0400 Subject: [PATCH 07/62] Format tests --- .../Memory/Postgres/PostgresHotel.cs | 4 ++-- ...stgresVectorStoreCollectionSqlBuilderTests.cs | 16 ++++++++++------ .../PostgresVectorStoreRecordCollectionTests.cs | 4 ++-- .../Memory/Postgres/PostgresVectorStoreTests.cs | 4 ++-- .../Connectors/Memory/Postgres/PostgresHotel.cs | 4 ++-- .../Postgres/PostgresVectorStoreFixture.cs | 6 +++--- .../PostgresVectorStoreRecordCollectionTests.cs | 4 ++-- .../Memory/Postgres/PostgresVectorStoreTests.cs | 2 +- 8 files changed, 24 insertions(+), 20 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs index b2357e302fda..353992333783 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -44,4 +44,4 @@ public record PostgresHotel() [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } } -#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. \ No newline at end of file +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs index 2df54addf7d7..c7649b8aef32 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -35,11 +35,13 @@ public void TestBuildCreateTableCommand() new VectorStoreRecordDataProperty("description", typeof(string)), new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { Dimensions = 10, IndexKind = "hnsw", }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { Dimensions = 10, IndexKind = "hnsw", } @@ -123,11 +125,13 @@ public void TestBuildGetCommand() new VectorStoreRecordDataProperty("description", typeof(string)), new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { Dimensions = 10, IndexKind = "hnsw", }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) { + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { Dimensions = 10, IndexKind = "hnsw", } @@ -147,4 +151,4 @@ public void TestBuildGetCommand() // Output this._output.WriteLine(cmdInfo.CommandText); } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 6ea820b89407..f4c518854a87 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -126,4 +126,4 @@ public async Task UpsertRecordAsyncProducesExpectedSqlAsync() Assert.Equal(3.0f, embedding[2]); Assert.Equal(4.0f, embedding[3]); } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs index 8ed97fce5f7c..9f6bbfd556f6 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs @@ -1,12 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Linq; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.SemanticKernel.Connectors.Postgres; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; using Moq; using Xunit; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs index f60429560a72..ff88638e8bdc 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -53,4 +53,4 @@ public PostgresHotel(T key) : this() } } -#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. \ No newline at end of file +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index c44a6601e25a..a7ddf78ac0d9 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -1,12 +1,12 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; using System.Threading.Tasks; using Docker.DotNet; using Docker.DotNet.Models; -using Microsoft.SemanticKernel.Connectors.Postgres; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; using Npgsql; using Xunit; @@ -260,4 +260,4 @@ private async Task DropDatabaseAsync() await using NpgsqlCommand command = new($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); await command.ExecuteNonQueryAsync(); } -} \ No newline at end of file +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index ab6715e2b372..95487c730772 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -433,4 +433,4 @@ private dynamic CreateRecord(Type idType, object key) #endregion -} \ No newline at end of file +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs index cd3d5781d8fc..2495a8558f22 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System.Linq; using System.Threading.Tasks; From f4b4dc5ddc57d17282ee959530f99d216f481f24 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:11:25 -0400 Subject: [PATCH 08/62] Add service and kernel extensions --- .../PostgresConstants.cs | 3 + .../PostgresKernelBuilderExtensions.cs | 84 +++++++ .../PostgresServiceCollectionExtensions.cs | 223 ++++++++++++++++++ .../PostgresVectorStoreDbClient.cs | 2 +- 4 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 5a192ca8f268..74dfc1e5cfb1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -58,6 +58,9 @@ internal static class PostgresConstants typeof(ReadOnlyMemory?) ]; + /// The default schema name. + public const string DefaultSchema = "public"; + /// The name of the column that returns distance value in the database. /// It is used in the similarity search query. Must not conflict with model property. public const string DistanceColumnName = "sk_pg_distance"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs new file mode 100644 index 000000000000..dda47d1a9930 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods to register the Postgres instances on the . +/// +public static class PostgresKernelBuilderExtensions +{ + /// + /// Register a Postgres with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The builder to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPostgresVectorStore(this IKernelBuilder builder, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddPostgresVectorStore(options, serviceId); + return builder; + } + /// + /// Register a Postgres with the specified service ID and where is constructed using the provided parameters. + /// + /// The builder to register the on. + /// The Postgres connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPostgresVectorStore(this IKernelBuilder builder, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + builder.Services.AddPostgresVectorStore(connectionString, options, serviceId); + return builder; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the Postgres is retrieved from the dependency injection container. + /// + /// The type of the key. + /// The type of the record. + /// The builder to register the on. + /// The name of the collection. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPostgresVectorStoreRecordCollection( + this IKernelBuilder builder, + string collectionName, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + builder.Services.AddPostgresVectorStoreRecordCollection(collectionName, options, serviceId); + return builder; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the Postgres is constructed using the provided parameters. + /// + /// The type of the key. + /// The type of the record. + /// The builder to register the on. + /// The name of the collection. + /// The Postgres connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The kernel builder. + public static IKernelBuilder AddPostgresVectorStoreRecordCollection( + this IKernelBuilder builder, + string collectionName, + string connectionString, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + builder.Services.AddPostgresVectorStoreRecordCollection(collectionName, connectionString, options, serviceId); + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs new file mode 100644 index 000000000000..23f1f00538f9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods to register Postgres instances on an . +/// +public static class PostgresServiceCollectionExtensions +{ + /// + /// Register a with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The to register the on. + /// The schema to use. + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) + { + // Since we are not constructing the client, add the IVectorStore as transient, since we + // cannot make assumptions about how client is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + return new PostgresVectorStoreDbClient(dataSource, schema); + }); + + return services; + } + + /// + /// Register a with the specified service ID and where NpgsqlDataSource is constructed using the provided parameters. + /// + /// The to register the on. + /// Postgres database connection string. + /// The schema to use. + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string connectionString, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) + { + // Register NpgsqlDataSource to ensure proper disposal. + services.AddSingleton( + sp => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + return new PostgresVectorStoreDbClient(dataSource, schema); + }); + + return services; + } + + /// + /// Register a Postgres with the specified service ID and where is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + // Since we are not constructing the client, add the IVectorStore as transient, since we + // cannot make assumptions about how client is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var client = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + client, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres with the specified service ID and where is constructed using the provided parameters. + /// + /// The to register the on. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + // Register NpgsqlDataSource to ensure proper disposal. + services.AddSingleton( + sp => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var client = new PostgresVectorStoreDbClient(dataSource); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + client, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the Postgres is retrieved from the dependency injection container. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + var PostgresClient = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService>(); + + return (new PostgresVectorStoreRecordCollection(PostgresClient, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the Postgres is constructed using the provided parameters. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Postgres database connection string. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + string connectionString, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + // Register NpgsqlDataSource to ensure proper disposal. + services.AddSingleton( + sp => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton>( + serviceId, + (sp, obj) => + { + var dataSource = sp.GetRequiredService(); + var client = new PostgresVectorStoreDbClient(dataSource); + var selectedOptions = options ?? sp.GetService>(); + + return (new PostgresVectorStoreRecordCollection(client, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Also register the with the given as a . + /// + /// The type of the key. + /// The type of the data model that the collection should contain. + /// The service collection to register on. + /// The service id that the registrations should use. + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + where TKey : notnull + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + return sp.GetRequiredKeyedService>(serviceId); + }); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 84e72b467df2..67de73fda489 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -32,7 +32,7 @@ public class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string sch /// /// Postgres data source. /// Schema of collection tables. - public PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = "public") : this(dataSource, schema, new PostgresVectorStoreCollectionSqlBuilder()) { } + public PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : this(dataSource, schema, new PostgresVectorStoreCollectionSqlBuilder()) { } /// public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) From 5c58400a07d048bc074028573e5439dd1f8f3f6b Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:11:45 -0400 Subject: [PATCH 09/62] Default to Euclidean distance if no distance function is specified --- .../PostgresVectorStoreCollectionSqlBuilder.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 5cecaa0a9a60..caa0af9e8fda 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -353,6 +353,7 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( DistanceFunction.EuclideanDistance => "<->", DistanceFunction.ManhattanDistance => "<+>", DistanceFunction.DotProductSimilarity => "<#>", + null or "" => "<->", // Default to Euclidean distance _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") }; From 8ea21cd2c23458f819aefb9be30b9dc251e567d8 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:12:13 -0400 Subject: [PATCH 10/62] Add Postgres sample to concepts --- .../VectorStoreFixtures/VectorStoreInfra.cs | 45 ++++++++++ .../VectorStorePostgresContainerFixture.cs | 67 +++++++++++++++ ...rStore_VectorSearch_MultiStore_Postgres.cs | 86 +++++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs create mode 100644 dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs index ea498f20c5ab..2681231c80d7 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStoreInfra.cs @@ -10,6 +10,51 @@ namespace Memory.VectorStoreFixtures; /// internal static class VectorStoreInfra { + /// + /// Setup the postgres pgvector container by pulling the image and running it. + /// + /// The docker client to create the container with. + /// The id of the container. + public static async Task SetupPostgresContainerAsync(DockerClient client) + { + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = "pgvector/pgvector", + Tag = "pg16", + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = "pgvector/pgvector:pg16", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + {"5432", new List {new() {HostPort = "5432" } }}, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "5432", default }, + }, + Env = new List + { + "POSTGRES_USER=postgres", + "POSTGRES_PASSWORD=example", + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + /// /// Setup the qdrant container by pulling the image and running it. /// diff --git a/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs new file mode 100644 index 000000000000..a8d4738fe133 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStoreFixtures/VectorStorePostgresContainerFixture.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet; +using Npgsql; + +namespace Memory.VectorStoreFixtures; + +/// +/// Fixture to use for creating a Postgres container before tests and delete it after tests. +/// +public class VectorStorePostgresContainerFixture : IAsyncLifetime +{ + private DockerClient? _dockerClient; + private string? _postgresContainerId; + + public async Task InitializeAsync() + { + } + + public async Task ManualInitializeAsync() + { + if (this._postgresContainerId == null) + { + // Connect to docker and start the docker container. + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._dockerClient = dockerClientConfiguration.CreateClient(); + this._postgresContainerId = await VectorStoreInfra.SetupPostgresContainerAsync(this._dockerClient); + + // Delay until the Postgres server is ready. + var connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; + var succeeded = false; + var attemptCount = 0; + while (!succeeded && attemptCount++ < 10) + { + try + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + using var dataSource = dataSourceBuilder.Build(); + NpgsqlConnection connection = await dataSource.OpenConnectionAsync().ConfigureAwait(false); + + await using (connection) + { + // Create extension vector if it doesn't exist + await using (NpgsqlCommand command = new("CREATE EXTENSION IF NOT EXISTS vector", connection)) + { + await command.ExecuteNonQueryAsync(); + } + } + } + catch (Exception) + { + await Task.Delay(1000); + } + } + } + } + + public async Task DisposeAsync() + { + if (this._dockerClient != null && this._postgresContainerId != null) + { + // Delete docker container. + await VectorStoreInfra.DeleteContainerAsync(this._dockerClient, this._postgresContainerId); + } + } +} diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs new file mode 100644 index 000000000000..29bb11cfe163 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.Identity; +using Memory.VectorStoreFixtures; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Connectors.Postgres; + +namespace Memory; + +/// +/// An example showing how to use common code, that can work with any vector database, with a Postgres database. +/// The common code is in the class. +/// The common code ingests data into the vector store and then searches over that data. +/// This example is part of a set of examples each showing a different vector database. +/// +/// For other databases, see the following classes: +/// +/// +/// +/// +/// To run this sample, you need a local instance of Docker running, since the associated fixture will try and start a Postgres container in the local docker instance. +/// +public class VectorStore_VectorSearch_MultiStore_Postgres(ITestOutputHelper output, VectorStorePostgresContainerFixture PostgresFixture) : BaseTest(output), IClassFixture +{ + /// + /// The connection string to the Postgres database hosted in the docker container. + /// + private const string ConnectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; + + [Fact] + public async Task ExampleWithDIAsync() + { + // Use the kernel for DI purposes. + var kernelBuilder = Kernel + .CreateBuilder(); + + // Register an embedding generation service with the DI container. + kernelBuilder.AddAzureOpenAITextEmbeddingGeneration( + deploymentName: TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + endpoint: TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + credential: new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and register the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + kernelBuilder.AddPostgresVectorStore(ConnectionString); + + // Register the test output helper common processor with the DI container. + kernelBuilder.Services.AddSingleton(this.Output); + kernelBuilder.Services.AddTransient(); + + // Build the kernel. + var kernel = kernelBuilder.Build(); + + // Build a common processor object using the DI container. + var processor = kernel.GetRequiredService(); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithDI", () => Guid.NewGuid()); + } + + [Fact] + public async Task ExampleWithoutDIAsync() + { + // Create an embedding generation service. + var textEmbeddingGenerationService = new AzureOpenAITextEmbeddingGenerationService( + TestConfiguration.AzureOpenAIEmbeddings.DeploymentName, + TestConfiguration.AzureOpenAIEmbeddings.Endpoint, + new AzureCliCredential()); + + // Initialize the Postgres docker container via the fixtures and construct the Postgres VectorStore. + await PostgresFixture.ManualInitializeAsync(); + var vectorStore = new PostgresVectorStore(ConnectionString); + + // Create the common processor that works for any vector store. + var processor = new VectorStore_VectorSearch_MultiStore_Common(vectorStore, textEmbeddingGenerationService, this.Output); + + // Run the process and pass a key generator function to it, to generate unique record keys. + // The key generator function is required, since different vector stores may require different key types. + // E.g. Postgres supports Guid and ulong keys, but others may support strings only. + await processor.IngestDataAndSearchAsync("skglossaryWithoutDI", () => Guid.NewGuid()); + } +} From 4dcd2225155a2a69b8494dcf89e3fe8346913ded Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:15:00 -0400 Subject: [PATCH 11/62] Add docs for setting configuration in samples\Concepts --- dotnet/samples/Concepts/Concepts.csproj | 3 + dotnet/samples/Concepts/README.md | 85 +++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index e35c2bef0dca..90ce6af1d1e0 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -102,6 +102,9 @@ + + Always + PreserveNewest diff --git a/dotnet/samples/Concepts/README.md b/dotnet/samples/Concepts/README.md index 15584b88685c..710ef6dfa5bb 100644 --- a/dotnet/samples/Concepts/README.md +++ b/dotnet/samples/Concepts/README.md @@ -204,3 +204,88 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom ### TextToImage - Using [`TextToImage`](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel.Abstractions/AI/TextToImage/ITextToImageService.cs) services to generate images - [OpenAI_TextToImage](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/TextToImage/OpenAI_TextToImageDalle3.cs) + +## Configuration + +### Option 1: Use Secret Manager + +Concept samples will require secrets and credentials, to access OpenAI, Azure OpenAI, +Bing and other resources. + +We suggest using .NET [Secret Manager](https://learn.microsoft.com/en-us/aspnet/core/security/app-secrets) +to avoid the risk of leaking secrets into the repository, branches and pull requests. +You can also use environment variables if you prefer. + +To set your secrets with Secret Manager: + +``` +cd dotnet/src/samples/Concepts + +dotnet user-secrets init + +dotnet user-secrets set "OpenAI:ServiceId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ModelId" "gpt-3.5-turbo-instruct" +dotnet user-secrets set "OpenAI:ChatModelId" "gpt-4" +dotnet user-secrets set "OpenAI:ApiKey" "..." + +... +``` + +### Option 2: Use Configuration File +1. Create a `appsettings.Development.json` file next to the `Concepts.csproj` file. This file will be ignored by git, + the content will not end up in pull requests, so it's safe for personal settings. Keep the file safe. +2. Edit `appsettings.Development.json` and set the appropriate configuration for the samples you are running. + +For example: + +```json +{ + "OpenAI": { + "ServiceId": "gpt-3.5-turbo-instruct", + "ModelId": "gpt-3.5-turbo-instruct", + "ChatModelId": "gpt-4", + "ApiKey": "sk-...." + }, + "AzureOpenAI": { + "ServiceId": "azure-gpt-35-turbo-instruct", + "DeploymentName": "gpt-35-turbo-instruct", + "ChatDeploymentName": "gpt-4", + "Endpoint": "https://contoso.openai.azure.com/", + "ApiKey": "...." + }, + // etc. +} +``` + +### Option 3: Use Environment Variables +You may also set the settings in your environment variables. The environment variables will override the settings in the `appsettings.Development.json` file. + +When setting environment variables, use a double underscore (i.e. "\_\_") to delineate between parent and child properties. For example: + +- bash: + + ```bash + export OpenAI__ApiKey="sk-...." + export AzureOpenAI__ApiKey="...." + export AzureOpenAI__DeploymentName="gpt-35-turbo-instruct" + export AzureOpenAI__ChatDeploymentName="gpt-4" + export AzureOpenAIEmbeddings__DeploymentName="azure-text-embedding-ada-002" + export AzureOpenAI__Endpoint="https://contoso.openai.azure.com/" + export HuggingFace__ApiKey="...." + export Bing__ApiKey="...." + export Postgres__ConnectionString="...." + ``` + +- PowerShell: + + ```ps + $env:OpenAI__ApiKey = "sk-...." + $env:AzureOpenAI__ApiKey = "...." + $env:AzureOpenAI__DeploymentName = "gpt-35-turbo-instruct" + $env:AzureOpenAI__ChatDeploymentName = "gpt-4" + $env:AzureOpenAIEmbeddings__DeploymentName = "azure-text-embedding-ada-002" + $env:AzureOpenAI__Endpoint = "https://contoso.openai.azure.com/" + $env:HuggingFace__ApiKey = "...." + $env:Bing__ApiKey = "...." + $env:Postgres__ConnectionString = "...." + ``` From 5c3e63f71e521f12859c62fe7809fb48ee1ef537 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:27:05 -0400 Subject: [PATCH 12/62] Enforce dimension size in index creation --- .../PostgresVectorStoreRecordCollection.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index b31d5007a51d..540317a62bad 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -105,6 +105,14 @@ public async Task CreateCollectionAsync(CancellationToken cancellationToken = de // Create indexes for vector properties. foreach (var vectorProperty in this._propertyReader.VectorProperties) { + // Ensure the dimensionality of the vector is supported for indexing. + if (vectorProperty.IndexKind == IndexKind.Hnsw) + { + if (vectorProperty.Dimensions > 2000) + { + throw new NotSupportedException($"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000."); + } + } await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); } } From 6d9f1fd4b1609c5f0e1860807ad7b1b75fd81d6a Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 11:32:05 -0400 Subject: [PATCH 13/62] Create index for CreateTableIfNotExistsAsyc --- .../PostgresVectorStoreRecordCollection.cs | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 540317a62bad..eb3bd1c04f5e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -101,26 +101,13 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke /// public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, false, cancellationToken).ConfigureAwait(false); - // Create indexes for vector properties. - foreach (var vectorProperty in this._propertyReader.VectorProperties) - { - // Ensure the dimensionality of the vector is supported for indexing. - if (vectorProperty.IndexKind == IndexKind.Hnsw) - { - if (vectorProperty.Dimensions > 2000) - { - throw new NotSupportedException($"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000."); - } - } - await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); - } + await this.InternalCreateCollectionAsync(false, cancellationToken).ConfigureAwait(false); } /// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { - return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, true, cancellationToken); + return this.InternalCreateCollectionAsync(true, cancellationToken); } /// @@ -270,6 +257,24 @@ public Task> VectorizedSearchAsync(TVector return Task.FromResult(new VectorSearchResults(results)); } + private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) + { + await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, ifNotExists, cancellationToken).ConfigureAwait(false); + // Create indexes for vector properties. + foreach (var vectorProperty in this._propertyReader.VectorProperties) + { + // Ensure the dimensionality of the vector is supported for indexing. + if (vectorProperty.IndexKind == IndexKind.Hnsw) + { + if (vectorProperty.Dimensions > 2000) + { + throw new NotSupportedException($"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000."); + } + } + await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); + } + } + /// /// Get vector property to use for a search by using the storage name for the field name from options /// if available, and falling back to the first vector property in if not. From b4266cc7ae26b8bc3526ec2e7a6c88d9efe1f617 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 12:44:50 -0400 Subject: [PATCH 14/62] Log warning when index not created due to dimensions --- .../PostgresVectorStoreRecordCollection.cs | 17 ++++++- ...ostgresVectorStoreRecordCollectionTests.cs | 48 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index eb3bd1c04f5e..e55d1645811c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -6,6 +6,8 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -32,6 +34,9 @@ public sealed class PostgresVectorStoreRecordCollection : IVector // Optional configuration options for this class. private readonly PostgresVectorStoreRecordCollectionOptions _options; + /// The logger to use for logging. + private readonly ILogger> _logger; + /// A helper to access property information for the current data model and record definition. private readonly VectorStoreRecordPropertyReader _propertyReader; @@ -47,7 +52,9 @@ public sealed class PostgresVectorStoreRecordCollection : IVector /// The Postgres client used to interact with the database. /// The name of the collection. /// Optional configuration options for this class. - public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + /// The logger to use for logging. + public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, + ILogger>? logger = null) { // Verify. Verify.NotNull(client); @@ -68,6 +75,7 @@ public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, SupportsMultipleKeys = false, SupportsMultipleVectors = true, }); + this._logger = logger ?? NullLogger>.Instance; // Validate property types. this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); @@ -268,7 +276,12 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT { if (vectorProperty.Dimensions > 2000) { - throw new NotSupportedException($"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000."); + this._logger.LogWarning( + "The provided vector property {VectorPropertyName} has {Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000. Index not created.", + vectorProperty.DataModelPropertyName, + vectorProperty.Dimensions + ); + continue; } } await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index f4c518854a87..52c14ecca77d 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; using Moq; @@ -126,4 +127,51 @@ public async Task UpsertRecordAsyncProducesExpectedSqlAsync() Assert.Equal(3.0f, embedding[2]); Assert.Equal(4.0f, embedding[3]); } + + [Fact] + public async Task CreateCollectionAsyncLogsWarningWhenDimensionsTooLargeAsync() + { + // Arrange + var recordDefinition = new VectorStoreRecordDefinition + { + Properties = [ + new VectorStoreRecordKeyProperty("HotelId", typeof(int)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, + new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 2001, IndexKind = IndexKind.Hnsw, DistanceFunction = DistanceFunction.ManhattanDistance } + ] + }; + var mockLogger = new Mock>>(); + mockLogger.Setup(x => x.Log( + LogLevel.Warning, + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>())); + var sut = new PostgresVectorStoreRecordCollection( + this._postgresClientMock.Object, + TestCollectionName, + logger: mockLogger.Object, + options: new PostgresVectorStoreRecordCollectionOptions { VectorStoreRecordDefinition = recordDefinition } + ); + + this._postgresClientMock.Setup(x => x.CreateTableAsync(TestCollectionName, It.IsAny(), It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); + + // Act + await sut.CreateCollectionAsync(cancellationToken: this._testCancellationToken); + + // Assert + mockLogger.Verify( + x => x.Log( + LogLevel.Warning, + It.IsAny(), + It.Is((v, t) => v.ToString()!.Contains("2001")), + It.IsAny(), + It.IsAny>()), + Times.Once); + } } From f86613a8699fde01146ecc10f74cd26e3c722362 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 24 Oct 2024 18:29:06 -0400 Subject: [PATCH 15/62] Refactor and tests; make SqlBuilder internal --- dotnet/SK-dotnet.sln | 9 + .../Connectors.Memory.Postgres.csproj | 5 + ...PostgresVectorStoreCollectionSqlBuilder.cs | 2 +- .../PostgresGenericDataModelMapper.cs | 2 +- .../PostgresServiceCollectionExtensions.cs | 2 +- .../PostgresVectorStore.cs | 10 +- ...PostgresVectorStoreCollectionSqlBuilder.cs | 9 +- .../PostgresVectorStoreDbClient.cs | 32 +- .../Connectors.Memory.Postgres/README.md | 15 +- .../Connectors.Postgres.UnitTests.csproj | 32 ++ .../PostgresGenericDataModelMapperTests.cs | 190 +++++++++ .../PostgresHotel.cs | 6 +- ...ostgresServiceCollectionExtensionsTests.cs | 89 +++++ ...resVectorStoreCollectionSqlBuilderTests.cs | 376 ++++++++++++++++++ ...ostgresVectorStoreRecordCollectionTests.cs | 94 ++++- .../PostgresVectorStoreRecordMapperTests.cs | 202 ++++++++++ ...esVectorStoreRecordPropertyMappingTests.cs | 101 +++++ .../PostgresVectorStoreTests.cs | 2 +- ...resVectorStoreCollectionSqlBuilderTests.cs | 154 ------- .../Postgres/PostgresVectorStoreFixture.cs | 34 +- ...ostgresVectorStoreRecordCollectionTests.cs | 13 +- .../Postgres/PostgresVectorStoreTests.cs | 4 +- 22 files changed, 1170 insertions(+), 213 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs rename dotnet/src/Connectors/{Connectors.UnitTests/Memory/Postgres => Connectors.Postgres.UnitTests}/PostgresHotel.cs (92%) create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs rename dotnet/src/Connectors/{Connectors.UnitTests/Memory/Postgres => Connectors.Postgres.UnitTests}/PostgresVectorStoreRecordCollectionTests.cs (73%) create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs rename dotnet/src/Connectors/{Connectors.UnitTests/Memory/Postgres => Connectors.Postgres.UnitTests}/PostgresVectorStoreTests.cs (98%) delete mode 100644 dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 98e20b1976a4..248b39d13d57 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -398,6 +398,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AotCompatibility", "samples EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.AotTests", "src\SemanticKernel.AotTests\SemanticKernel.AotTests.csproj", "{39EAB599-742F-417D-AF80-95F90376BB18}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Postgres.UnitTests", "src\Connectors\Connectors.Postgres.UnitTests\Connectors.Postgres.UnitTests.csproj", "{232E1153-6366-4175-A982-D66B30AAD610}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1029,6 +1031,12 @@ Global {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.Build.0 = Debug|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.ActiveCfg = Release|Any CPU {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.Build.0 = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Debug|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Publish|Any CPU.Build.0 = Debug|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.ActiveCfg = Release|Any CPU + {232E1153-6366-4175-A982-D66B30AAD610}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1171,6 +1179,7 @@ Global {AF7F68FD-ADB0-4941-90AE-88EAAB53BEEB} = {077928EA-2C61-4667-82FC-6A5120B7AC45} {6ECFDF04-2237-4A85-B114-DAA34923E9E6} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {39EAB599-742F-417D-AF80-95F90376BB18} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {232E1153-6366-4175-A982-D66B30AAD610} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj index ad132bde113d..7673d1e92a93 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -27,4 +27,9 @@ + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 47aa135e9f93..421b36a1ab05 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Interface for constructing SQL commands for Postgres vector store collections. /// -public interface IPostgresVectorStoreCollectionSqlBuilder +internal interface IPostgresVectorStoreCollectionSqlBuilder { /// /// Builds a SQL command to check if a table exists in the Postgres vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs index 6a60f3c056ce..268a3538b365 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -61,7 +61,7 @@ public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyRe return properties; } - VectorStoreGenericDataModel IVectorStoreRecordMapper, Dictionary>.MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) + public VectorStoreGenericDataModel MapFromStorageToDataModel(Dictionary storageModel, StorageToDataModelMapperOptions options) { TKey key; var dataProperties = new Dictionary(); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs index 23f1f00538f9..375fb709447b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -127,7 +127,7 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection /// /// Register a Postgres and with the specified service ID - /// and where the Postgres is retrieved from the dependency injection container. + /// and where the Postgres is retrieved from the dependency injection container. /// /// The type of the key. /// The type of the record. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 86f263a740ae..5a7395c8d40b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -69,10 +69,14 @@ public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancel public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull { - // Support int, long, Guid, and string keys - if (typeof(TKey) != typeof(int) && typeof(TKey) != typeof(long) && typeof(TKey) != typeof(Guid) && typeof(TKey) != typeof(string)) + // Support short, int, long, Guid, and string keys + if (typeof(TKey) != typeof(short) && + typeof(TKey) != typeof(int) && + typeof(TKey) != typeof(long) && + typeof(TKey) != typeof(Guid) && + typeof(TKey) != typeof(string)) { - throw new NotSupportedException($"Only int, long, {nameof(Guid)}, and {nameof(String)} keys are supported."); + throw new NotSupportedException($"Only short, int, long, {nameof(Guid)}, and {nameof(String)} keys are supported."); } if (this._options?.VectorStoreCollectionFactory is not null) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index caa0af9e8fda..0d5f2d15270e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -14,7 +14,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Provides methods to build SQL commands for managing vector store collections in PostgreSQL. /// -public class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder +internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCollectionSqlBuilder { /// public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) @@ -190,7 +190,10 @@ public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tabl // Generate column names and parameter placeholders var columnNames = string.Join(", ", columns.Select(c => $"\"{c}\"")); var valuePlaceholders = string.Join(", ", columns.Select((c, i) => $"${i + 1}")); - var valuesRows = string.Join(", ", rows.Select((row, rowIndex) => $"({string.Join(", ", columns.Select((c, colIndex) => $"${rowIndex * columns.Count + colIndex + 1}"))})")); + var valuesRows = string.Join(", ", + rows.Select((row, rowIndex) => + $"({string.Join(", ", + columns.Select((c, colIndex) => $"${rowIndex * columns.Count + colIndex + 1}"))})")); // Generate the update set clause var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); @@ -199,7 +202,7 @@ public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tabl var commandText = $@" INSERT INTO {schema}.""{tableName}"" ({columnNames}) VALUES {valuesRows} - ON CONFLICT(""{keyColumn}"") + ON CONFLICT (""{keyColumn}"") DO UPDATE SET {updateSetClause}; "; // Generate the parameters diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 67de73fda489..1d8d913687e7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -19,20 +19,13 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Postgres data source. /// Schema of collection tables. -/// Sql builder for collection tables. [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] -public class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema, IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) : IPostgresVectorStoreDbClient +public class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient { private readonly NpgsqlDataSource _dataSource = dataSource; - private readonly IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = sqlBuilder; private readonly string _schema = schema; - /// - /// Initializes a new instance of the class. - /// - /// Postgres data source. - /// Schema of collection tables. - public PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : this(dataSource, schema, new PostgresVectorStoreCollectionSqlBuilder()) { } + private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); /// public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) @@ -223,6 +216,25 @@ public async Task DeleteBatchAsync(string tableName, string keyColumn, IEn } } + #region internal =============================================================================== + + /// + /// Sets the SQL builder for the client. + /// + /// + /// + /// This method is used for other Semnatic Kernel connectors that may need to override the default SQL + /// used by this client. + /// + internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) + { + this._sqlBuilder = sqlBuilder; + } + + #endregion + + #region private ================================================================================ + private Dictionary GetRecord( NpgsqlDataReader reader, IEnumerable properties, @@ -243,4 +255,6 @@ public async Task DeleteBatchAsync(string tableName, string keyColumn, IEn return storageModel; } + + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index 35c80a45087a..ee1ebca74031 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -18,7 +18,7 @@ This extension is also available for **Azure Database for PostgreSQL - Flexible 1. To install pgvector using Docker: ```bash -docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword ankane/pgvector +docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword pgvector/pgvector ``` 2. Create a database and enable pgvector extension on this database @@ -33,8 +33,13 @@ sk_demo=# CREATE EXTENSION vector; > Note, "Azure Cosmos DB for PostgreSQL" uses `SELECT CREATE_EXTENSION('vector');` to enable the extension. -3. To use Postgres as a semantic memory store: - > See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. +### Using PostgresVectorStore + +See [this sample](../../../samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs) for an example of using the vector store. + +### Using PostgresMemoryStore + +> See [Example 14](../../../samples/Concepts/Memory/SemanticTextMemory_Building.cs) and [Example 15](../../../samples/Concepts/Memory/TextMemoryPlugin_MultipleMemoryStore.cs) for more memory usage examples with the kernel. ```csharp NpgsqlDataSourceBuilder dataSourceBuilder = new NpgsqlDataSourceBuilder("Host=localhost;Port=5432;Database=sk_demo;User Id=postgres;Password=mysecretpassword"); @@ -88,7 +93,9 @@ BEGIN END $$; ``` -## Migration from older versions +## Migration the MemoryStore from older versions + +> Note: The MemoryStore components are being deprecated in a future version in favor of the vector store components. This section is about migrating to the lates MemoryStore implementation, but users are encouraged to migrate to the vector store components if possible. Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj new file mode 100644 index 000000000000..5698a909022e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/Connectors.Postgres.UnitTests.csproj @@ -0,0 +1,32 @@ + + + + SemanticKernel.Connectors.Postgres.UnitTests + SemanticKernel.Connectors.Postgres.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs new file mode 100644 index 000000000000..77a6d955469f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresGenericDataModelMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel("key"); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetGenericDataModel("key"); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + IVectorStoreRecordMapper, Dictionary> mapper = new PostgresGenericDataModelMapper(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.Data["StringProperty"]); + Assert.Equal(5, result.Data["IntProperty"]); + + if (includeVectors) + { + Assert.NotNull(result.Vectors["FloatVector"]); + Assert.Equal(vector.ToArray(), ((ReadOnlyMemory)result.Vectors["FloatVector"]!).ToArray()); + } + else + { + Assert.False(result.Vectors.ContainsKey("FloatVector")); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static VectorStoreGenericDataModel GetGenericDataModel(TKey key) + { + return new VectorStoreGenericDataModel(key) + { + Data = new() + { + ["StringProperty"] = "Value1", + ["IntProperty"] = 5 + }, + Vectors = new() + { + ["FloatVector"] = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + } + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs similarity index 92% rename from dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs rename to dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs index 353992333783..daeb83fbb13c 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs @@ -4,18 +4,18 @@ using System.Collections.Generic; using Microsoft.Extensions.VectorData; -namespace SemanticKernel.Connectors.UnitTests.Postgres; +namespace SemanticKernel.Connectors.Postgres.UnitTests; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. /// /// A test model for the postgres vector store. /// -public record PostgresHotel() +public record PostgresHotel() { /// The key of the record. [VectorStoreRecordKey] - public int HotelId { get; init; } + public T HotelId { get; init; } /// A string metadata field. [VectorStoreRecordData()] diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..c6409c698792 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection = new ServiceCollection(); + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act + this._serviceCollection.AddPostgresVectorStore(); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + + // Assert + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } + + [Fact] + public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() + { + // Arrange + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act + this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + + // Assert + var collection = serviceProvider.GetRequiredService>(); + Assert.NotNull(collection); + Assert.IsType>(collection); + + var vectorizedSearch = serviceProvider.GetRequiredService>(); + Assert.NotNull(vectorizedSearch); + Assert.IsType>(vectorizedSearch); + } + + [Fact] + public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass() + { + // Arrange + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act + this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + + // Assert + var collection = serviceProvider.GetRequiredService>(); + Assert.NotNull(collection); + Assert.IsType>(collection); + + var vectorizedSearch = serviceProvider.GetRequiredService>(); + Assert.NotNull(vectorizedSearch); + Assert.IsType>(vectorizedSearch); + } + + #region private + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class TestRecord +#pragma warning restore CA1812 // Avoid uninstantiated internal classes + { + [VectorStoreRecordKey] + public string Id { get; set; } = string.Empty; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs new file mode 100644 index 000000000000..70f52ff7e630 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -0,0 +1,376 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +public class PostgresVectorStoreCollectionSqlBuilderTests +{ + private readonly ITestOutputHelper _output; + private static readonly float[] s_vector = new float[] { 1.0f, 2.0f, 3.0f }; + + public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) + { + this._output = output; + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestBuildCreateTableCommand(bool ifNotExists) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: ifNotExists); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("\"name\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"parking_is_included\" BOOLEAN NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); + Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); + Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"embedding2\" VECTOR(10)", cmdInfo.CommandText); + Assert.Contains("PRIMARY KEY (\"id\")", cmdInfo.CommandText); + + if (ifNotExists) + { + Assert.Contains("IF NOT EXISTS", cmdInfo.CommandText); + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDropTableCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var cmdInfo = builder.BuildDropTableCommand("public", "testcollection"); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("DROP TABLE IF EXISTS public.\"testcollection\"", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var row = new Dictionary() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }; + + var keyColumn = "id"; + + var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (key, index) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[key], cmdInfo.Parameters[index].Value); + // If the key is not the key column, it should be included in the update clause. + if (key != keyColumn) + { + Assert.Contains($"\"{key}\"=${index + 1}", cmdInfo.CommandText); + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildUpsertBatchCommand() + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var rows = new List>() + { + new() + { + ["id"] = 123, + ["name"] = "Hotel", + ["code"] = 456, + ["rating"] = 4.5f, + ["description"] = "Hotel description", + ["parking_is_included"] = true, + ["tags"] = new List { "tag1", "tag2" }, + ["embedding1"] = new Vector(s_vector), + }, + new() + { + ["id"] = 124, + ["name"] = "Motel", + ["code"] = 457, + ["rating"] = 4.6f, + ["description"] = "Motel description", + ["parking_is_included"] = false, + ["tags"] = new List { "tag3", "tag4" }, + ["embedding1"] = new Vector(s_vector), + }, + }; + + var keyColumn = "id"; + var columnCount = rows.First().Count; + + var cmdInfo = builder.BuildUpsertBatchCommand("public", "testcollection", keyColumn, rows); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); + Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); + Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + + foreach (var (row, rowIndex) in rows.Select((row, rowIndex) => (row, rowIndex))) + { + foreach (var (column, columnIndex) in row.Keys.Select((key, index) => (key, index))) + { + Assert.Equal(row[column], cmdInfo.Parameters[columnIndex + (rowIndex * columnCount)].Value); + // If the key is not the key column, it should be included in the update clause. + if (column != keyColumn) + { + Assert.Contains($"\"{column}\" = EXCLUDED.\"{column}\"", cmdInfo.CommandText); + } + } + } + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var key = 123; + + // Act + var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildGetBatchCommand("public", "testcollection", recordDefinition.Properties, keys, includeVectors: true); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var key = 123; + + // Act + var cmdInfo = builder.BuildDeleteCommand("public", "testcollection", "id", key); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildDeleteBatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var keys = new List { 123, 124 }; + + // Act + var cmdInfo = builder.BuildDeleteBatchCommand("public", "testcollection", "id", keys); + + // Assert + Assert.Contains("DELETE", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); + Assert.NotNull(cmdInfo.Parameters); + Assert.Single(cmdInfo.Parameters); + Assert.Equal(keys, cmdInfo.Parameters[0].Value); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + + [Fact] + public void TestBuildGetNearestMatchCommand() + { + // Arrange + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = "hnsw", + }; + + var recordDefinition = new VectorStoreRecordDefinition() + { + Properties = [ + new VectorStoreRecordKeyProperty("id", typeof(long)), + new VectorStoreRecordDataProperty("name", typeof(string)), + new VectorStoreRecordDataProperty("code", typeof(int)), + new VectorStoreRecordDataProperty("rating", typeof(float?)), + new VectorStoreRecordDataProperty("description", typeof(string)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("tags", typeof(List)), + vectorProperty, + new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) + { + Dimensions = 10, + IndexKind = "hnsw", + } + ] + }; + + var vector = new Vector(s_vector); + + // Act + var cmdInfo = builder.BuildGetNearestMatchCommand("public", "testcollection", + properties: recordDefinition.Properties, + vectorProperty: vectorProperty, + vectorValue: vector, + filter: null, + skip: null, + withEmbeddings: true, + limit: 10); + + // Assert + Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); + Assert.Contains("ORDER BY", cmdInfo.CommandText); + Assert.Contains("LIMIT 10", cmdInfo.CommandText); + + // Output + this._output.WriteLine(cmdInfo.CommandText); + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs similarity index 73% rename from dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs rename to dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs index 52c14ecca77d..48b7ce4d4a87 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -11,7 +11,7 @@ using Pgvector; using Xunit; -namespace SemanticKernel.Connectors.UnitTests.Postgres; +namespace SemanticKernel.Connectors.Postgres.UnitTests; public class PostgresVectorStoreRecordCollectionTests { @@ -77,13 +77,13 @@ public void ThrowsForUnsupportedType() } [Fact] - public async Task UpsertRecordAsyncProducesExpectedSqlAsync() + public async Task UpsertRecordAsyncProducesExpectedClientCallAsync() { // Arrange Dictionary? capturedArguments = null; - var sut = new PostgresVectorStoreRecordCollection(this._postgresClientMock.Object, TestCollectionName); - var record = new PostgresHotel + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TestCollectionName); + var record = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", @@ -92,7 +92,7 @@ public async Task UpsertRecordAsyncProducesExpectedSqlAsync() ParkingIncluded = true, Tags = ["tag1", "tag2"], Description = "A hotel", - DescriptionEmbedding = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }) + DescriptionEmbedding = new ReadOnlyMemory([1.0f, 2.0f, 3.0f, 4.0f]) }; this._postgresClientMock.Setup(x => x.UpsertAsync( @@ -145,18 +145,18 @@ public async Task CreateCollectionAsyncLogsWarningWhenDimensionsTooLargeAsync() new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 2001, IndexKind = IndexKind.Hnsw, DistanceFunction = DistanceFunction.ManhattanDistance } ] }; - var mockLogger = new Mock>>(); + var mockLogger = new Mock>>>(); mockLogger.Setup(x => x.Log( LogLevel.Warning, It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())); - var sut = new PostgresVectorStoreRecordCollection( + var sut = new PostgresVectorStoreRecordCollection>( this._postgresClientMock.Object, TestCollectionName, logger: mockLogger.Object, - options: new PostgresVectorStoreRecordCollectionOptions { VectorStoreRecordDefinition = recordDefinition } + options: new PostgresVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = recordDefinition } ); this._postgresClientMock.Setup(x => x.CreateTableAsync(TestCollectionName, It.IsAny(), It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); @@ -174,4 +174,82 @@ public async Task CreateCollectionAsyncLogsWarningWhenDimensionsTooLargeAsync() It.IsAny>()), Times.Once); } + + [Fact] + public async Task CollectionExistsReturnsValidResultAsync() + { + // Arrange + const string TableName = "CollectionExists"; + + this._postgresClientMock.Setup(x => x.DoesTableExistsAsync(TableName, this._testCancellationToken)).ReturnsAsync(true); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + var result = await sut.CollectionExistsAsync(); + + Assert.True(result); + } + + [Fact] + public async Task DeleteCollectionCallsClientDeleteAsync() + { + // Arrange + const string TableName = "DeleteCollection"; + + this._postgresClientMock.Setup(x => x.DeleteTableAsync(TableName, this._testCancellationToken)).Returns(Task.CompletedTask); + + var sut = new PostgresVectorStoreRecordCollection>(this._postgresClientMock.Object, TableName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + this._postgresClientMock.Verify(x => x.DeleteTableAsync(TableName, this._testCancellationToken), Times.Once); + } + + #region private + + private static void AssertRecord(TestRecord expectedRecord, TestRecord? actualRecord, bool includeVectors) + { + Assert.NotNull(actualRecord); + + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Data, actualRecord.Data); + + if (includeVectors) + { + Assert.NotNull(actualRecord.Vector); + Assert.Equal(expectedRecord.Vector!.Value.ToArray(), actualRecord.Vector.Value.Span.ToArray()); + } + else + { + Assert.Null(actualRecord.Vector); + } + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Vector { get; set; } + } + + private sealed class TestRecordWithoutVectorProperty + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? Data { get; set; } + } +#pragma warning restore CA1812 + + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs new file mode 100644 index 000000000000..201c93f53db8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordMapperTests +{ + [Fact] + public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel("key"); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + Vector? vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Fact] + public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() + { + // Arrange + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + var dataModel = GetDataModel(1); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal((ulong)1, result["Key"]); + Assert.Equal("Value1", result["StringProperty"]); + Assert.Equal(5, result["IntProperty"]); + + var vector = result["FloatVector"] as Vector; + + Assert.NotNull(vector); + Assert.True(vector.ToArray().Length > 0); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = "key", + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal("key", result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool includeVectors) + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var storageVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + var storageModel = new Dictionary + { + ["Key"] = (ulong)1, + ["StringProperty"] = "Value1", + ["IntProperty"] = 5, + ["FloatVector"] = storageVector + }; + + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); + + var mapper = new PostgresVectorStoreRecordMapper>(propertyReader); + + // Act + var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.Equal((ulong)1, result.Key); + Assert.Equal("Value1", result.StringProperty); + Assert.Equal(5, result.IntProperty); + + if (includeVectors) + { + Assert.NotNull(result.FloatVector); + Assert.Equal(vector.Span.ToArray(), result.FloatVector.Value.Span.ToArray()); + } + else + { + Assert.Null(result.FloatVector); + } + } + + #region private + + private static VectorStoreRecordDefinition GetRecordDefinition() + { + return new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(TKey)), + new VectorStoreRecordDataProperty("StringProperty", typeof(string)), + new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + } + }; + } + + private static TestRecord GetDataModel(TKey key) + { + return new TestRecord + { + Key = key, + StringProperty = "Value1", + IntProperty = 5, + FloatVector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) + }; + } + + private static VectorStoreRecordPropertyReader GetPropertyReader(VectorStoreRecordDefinition definition) + { + return new VectorStoreRecordPropertyReader(typeof(TRecord), definition, new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleKeys = false, + SupportsMultipleVectors = true + }); + } + +#pragma warning disable CA1812 + private sealed class TestRecord + { + [VectorStoreRecordKey] + public TKey? Key { get; set; } + + [VectorStoreRecordData] + public string? StringProperty { get; set; } + + [VectorStoreRecordData] + public int? IntProperty { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? FloatVector { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs new file mode 100644 index 000000000000..5005901c6ad6 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Pgvector; +using Xunit; + +namespace SemanticKernel.Connectors.Postgres.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class PostgresVectorStoreRecordPropertyMappingTests +{ + [Fact] + public void MapVectorForStorageModelWithInvalidVectorTypeThrowsException() + { + // Arrange + var vector = new float[] { 1f, 2f, 3f }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector)); + } + + [Fact] + public void MapVectorForStorageModelReturnsVector() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + + // Act + var storageModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Assert + Assert.IsType(storageModelVector); + Assert.True(storageModelVector.ToArray().Length > 0); + } + + [Fact] + public void MapVectorForDataModelReturnsReadOnlyMemory() + { + // Arrange + var vector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]); + var pgVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForStorageModel(vector); + + // Act + var dataModelVector = PostgresVectorStoreRecordPropertyMapping.MapVectorForDataModel(pgVector); + + // Assert + Assert.NotNull(dataModelVector); + Assert.Equal(vector.ToArray(), dataModelVector!.Value.ToArray()); + } + + [Fact] + public void GetPropertyValueReturnsCorrectValuesForLists() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(List), "INTEGER[]"), + (typeof(List), "REAL[]"), + (typeof(List), "DOUBLE PRECISION[]"), + (typeof(List), "TEXT[]"), + (typeof(List), "BOOLEAN[]"), + (typeof(List), "TIMESTAMP[]"), + (typeof(List), "UUID[]"), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (pgType, _) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, pgType); + } + } + + [Fact] + public void GetPropertyValueReturnsCorrectNullableValue() + { + // Arrange + var typesAndExpectedValues = new List<(Type, object)> + { + (typeof(short), false), + (typeof(short?), true), + (typeof(int?), true), + (typeof(long), false), + (typeof(string), true), + (typeof(bool?), true), + (typeof(DateTime?), true), + (typeof(Guid), false), + }; + + // Act & Assert + foreach (var (type, expectedValue) in typesAndExpectedValues) + { + var (_, isNullable) = PostgresVectorStoreRecordPropertyMapping.GetPostgresTypeName(type); + Assert.Equal(expectedValue, isNullable); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs similarity index 98% rename from dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs rename to dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs index 9f6bbfd556f6..3b3d407fc035 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -10,7 +10,7 @@ using Moq; using Xunit; -namespace SemanticKernel.Connectors.UnitTests.Postgres; +namespace SemanticKernel.Connectors.Postgres.UnitTests; /// /// Contains tests for the class. diff --git a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs deleted file mode 100644 index c7649b8aef32..000000000000 --- a/dotnet/src/Connectors/Connectors.UnitTests/Memory/Postgres/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; -using Pgvector; -using Xunit; -using Xunit.Abstractions; - -namespace SemanticKernel.Connectors.UnitTests.Postgres; - -public class PostgresVectorStoreCollectionSqlBuilderTests -{ - private readonly ITestOutputHelper _output; - - public PostgresVectorStoreCollectionSqlBuilderTests(ITestOutputHelper output) - { - this._output = output; - } - - [Fact] - public void TestBuildCreateTableCommand() - { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var recordDefinition = new VectorStoreRecordDefinition() - { - Properties = [ - new VectorStoreRecordKeyProperty("id", typeof(long)), - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("code", typeof(int)), - new VectorStoreRecordDataProperty("rating", typeof(float?)), - new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), - new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) - { - Dimensions = 10, - IndexKind = "hnsw", - }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) - { - Dimensions = 10, - IndexKind = "hnsw", - } - ] - }; - - var cmdInfo = builder.BuildCreateTableCommand("public", "testcollection", recordDefinition.Properties, ifNotExists: true); - - // Check for expected properties; integration tests will validate the actual SQL. - Assert.Contains("public.\"testcollection\" (", cmdInfo.CommandText); - Assert.Contains("IF NOT EXISTS", cmdInfo.CommandText); - Assert.Contains("\"name\" TEXT", cmdInfo.CommandText); - Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); - Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); - Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); - Assert.Contains("\"parking_is_included\" BOOLEAN NOT NULL", cmdInfo.CommandText); - Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); - Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); - Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); - Assert.Contains("\"embedding2\" VECTOR(10)", cmdInfo.CommandText); - Assert.Contains("PRIMARY KEY (\"id\")", cmdInfo.CommandText); - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } - - [Fact] - public void TestBuildUpsertCommand() - { - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var row = new Dictionary() - { - ["id"] = 123, - ["name"] = "Hotel", - ["code"] = 456, - ["rating"] = 4.5f, - ["description"] = "Hotel description", - ["parking_is_included"] = true, - ["tags"] = new List { "tag1", "tag2" }, - ["embedding1"] = new Vector(new float[] { 1.0f, 2.0f, 3.0f }), - }; - - var keyColumn = "id"; - - var cmdInfo = builder.BuildUpsertCommand("public", "testcollection", keyColumn, row); - - // Check for expected properties; integration tests will validate the actual SQL. - Assert.Contains("INSERT INTO public.\"testcollection\" (", cmdInfo.CommandText); - Assert.Contains("ON CONFLICT (\"id\")", cmdInfo.CommandText); - Assert.Contains("DO UPDATE SET", cmdInfo.CommandText); - Assert.NotNull(cmdInfo.Parameters); - - foreach (var (key, index) in row.Keys.Select((key, index) => (key, index))) - { - Assert.Equal(row[key], cmdInfo.Parameters[index].Value); - // If the key is not the key column, it should be included in the update clause. - if (key != keyColumn) - { - Assert.Contains($"\"{key}\"=${index + 1}", cmdInfo.CommandText); - } - } - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } - - [Fact] - public void TestBuildGetCommand() - { - // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var recordDefinition = new VectorStoreRecordDefinition() - { - Properties = [ - new VectorStoreRecordKeyProperty("id", typeof(long)), - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("code", typeof(int)), - new VectorStoreRecordDataProperty("rating", typeof(float?)), - new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), - new VectorStoreRecordDataProperty("tags", typeof(List)), - new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) - { - Dimensions = 10, - IndexKind = "hnsw", - }, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) - { - Dimensions = 10, - IndexKind = "hnsw", - } - ] - }; - - var key = 123; - - // Act - var cmdInfo = builder.BuildGetCommand("public", "testcollection", recordDefinition.Properties, key, includeVectors: true); - - // Assert - Assert.Contains("SELECT", cmdInfo.CommandText); - Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); - Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } -} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index a7ddf78ac0d9..751fe2cb3688 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -5,8 +5,9 @@ using System.Threading.Tasks; using Docker.DotNet; using Docker.DotNet.Models; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; +using Microsoft.SemanticKernel; using Npgsql; using Xunit; @@ -18,7 +19,7 @@ public class PostgresVectorStoreFixture : IAsyncLifetime private readonly DockerClient _client; /// The id of the postgres container that we are testing with. - private readonly string? _containerId = null; + private string? _containerId = null; #pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. @@ -63,8 +64,8 @@ public PostgresVectorStoreFixture() private string _connectionString = null!; private string _databaseName = null!; - /// Gets the postgres client connection to use for tests. - public PostgresVectorStoreDbClient PostgresClient { get; private set; } + /// Gets the Kernel that holds the vector store. + public Kernel Kernel { get; private set; } /// Gets the manually created vector store record definition for our test model. public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } @@ -72,16 +73,14 @@ public PostgresVectorStoreFixture() /// Gets the manually created vector store record definition for our test model. public VectorStoreRecordDefinition HotelWithGuidIdVectorStoreRecordDefinition { get; private set; } - public PostgresVectorStoreRecordCollection GetCollection( + public IVectorStoreRecordCollection GetCollection( string collectionName, - PostgresVectorStoreRecordCollectionOptions? options = default) + VectorStoreRecordDefinition? recordDefinition = default) where TKey : notnull where TRecord : class { - return new PostgresVectorStoreRecordCollection( - this.PostgresClient, - collectionName, - options); + var vectorStore = this.Kernel.GetRequiredService(); + return vectorStore.GetCollection(collectionName, recordDefinition); } /// @@ -90,7 +89,7 @@ public PostgresVectorStoreRecordCollection GetCollectionAn async task. public async Task InitializeAsync() { - //this._containerId = await SetupPostgresContainerAsync(this._client); + this._containerId = await SetupPostgresContainerAsync(this._client); this._connectionString = "Host=localhost;Port=5432;Username=postgres;Password=example;Database=postgres;"; this._databaseName = $"sk_it_{Guid.NewGuid():N}"; @@ -105,7 +104,9 @@ public async Task InitializeAsync() this._dataSource = dataSourceBuilder.Build(); - this.PostgresClient = new PostgresVectorStoreDbClient(this._dataSource); + this.Kernel = Kernel.CreateBuilder() + .AddPostgresVectorStore(connectionStringBuilder.ToString()) + .Build(); // Wait for the postgres container to be ready and create the test database using the initial data source. var initialDataSource = NpgsqlDataSource.Create(this._connectionString); @@ -174,6 +175,15 @@ DescriptionEmbedding VECTOR(4) NOT NULL, /// An async task. public async Task DisposeAsync() { + if (this.Kernel != null) + { + var dataSource = this.Kernel.Services.GetService(); + if (dataSource != null) + { + dataSource.Dispose(); + } + } + if (this._dataSource != null) { this._dataSource.Dispose(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 95487c730772..1e696093c139 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -216,12 +215,7 @@ public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() { const int HotelId = 5; - var options = new PostgresVectorStoreRecordCollectionOptions> - { - VectorStoreRecordDefinition = GetVectorStoreRecordDefinition() - }; - - var sut = fixture.GetCollection>("GenericMapperWithNumericKey", options); + var sut = fixture.GetCollection>("GenericMapperWithNumericKey", GetVectorStoreRecordDefinition()); await sut.CreateCollectionAsync(); @@ -296,10 +290,7 @@ public async Task VectorizedSearchReturnsValidResultsByDefaultAsync(bool include var hotel3 = new PostgresHotel { HotelId = 3, HotelName = "Hotel 3", HotelCode = 3, ParkingIncluded = true, HotelRating = 3.5f, Tags = ["tag1", "tag4"], DescriptionEmbedding = new[] { 0f, 0f, 1f, 0f } }; var hotel4 = new PostgresHotel { HotelId = 4, HotelName = "Hotel 4", HotelCode = 4, ParkingIncluded = false, HotelRating = 1.5f, Tags = ["tag1", "tag5"], DescriptionEmbedding = new[] { 0f, 0f, 0f, 1f } }; - var sut = fixture.GetCollection>($"VectorizedSearch_{includeVectors}_{distanceFunction}", new() - { - VectorStoreRecordDefinition = GetVectorStoreRecordDefinition(distanceFunction) - }); + var sut = fixture.GetCollection>($"VectorizedSearch_{includeVectors}_{distanceFunction}", GetVectorStoreRecordDefinition(distanceFunction)); await sut.CreateCollectionAsync(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs index 2495a8558f22..8291591872f7 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -2,7 +2,7 @@ using System.Linq; using System.Threading.Tasks; -using Microsoft.SemanticKernel.Connectors.Postgres; +using Microsoft.Extensions.VectorData; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -14,7 +14,7 @@ public class PostgresVectorStoreTests(PostgresVectorStoreFixture fixture) public async Task ItCanGetAListOfExistingCollectionNamesAsync() { // Arrange - var sut = new PostgresVectorStore(fixture.PostgresClient); + var sut = fixture.Kernel.GetRequiredService(); // Setup var collection = sut.GetCollection>("VS_TEST_HOTELS"); From 8d8283bac54031cb3372c15792b8179e6c04e984 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 10:24:26 -0400 Subject: [PATCH 16/62] Remove old migration note --- .../Connectors.Memory.Postgres/README.md | 66 ------------------- 1 file changed, 66 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md index ee1ebca74031..e9ed71109fbb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -92,69 +92,3 @@ BEGIN END IF; END $$; ``` - -## Migration the MemoryStore from older versions - -> Note: The MemoryStore components are being deprecated in a future version in favor of the vector store components. This section is about migrating to the lates MemoryStore implementation, but users are encouraged to migrate to the vector store components if possible. - -Since Postgres Memory connector has been re-implemented, the new implementation uses a separate table to store each Collection. - -We provide the following migration script to help you migrate to the new structure. However, please note that due to the use of collections as table names, you need to make sure that all Collections conform to the [Postgres naming convention](https://www.postgresql.org/docs/15/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS) before migrating. - -- Table names may only consist of ASCII letters, digits, and underscores. -- Table names must start with a letter or an underscore. -- Table names may not exceed 63 characters in length. -- Table names are case-insensitive, but it is recommended to use lowercase letters. - -```sql --- Create new tables, each with the name of the collection field value -DO $$ -DECLARE - r record; - c_count integer; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - - -- Drop Table (This will delete the table that already exists. Please consider carefully if you think you need to cancel this comment!) - -- EXECUTE format('DROP TABLE IF EXISTS %I;', r.collection); - - -- Create Table (Modify vector size on demand) - EXECUTE format('CREATE TABLE public.%I ( - key TEXT NOT NULL, - metadata JSONB, - embedding vector(1536), - timestamp TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (key) - );', r.collection); - - -- Get count of records in collection - SELECT count(*) INTO c_count FROM sk_memory_table WHERE collection = r.collection AND key <> ''; - - -- Create Index (https://github.com/pgvector/pgvector#indexing) - IF c_count > 10000000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, ROUND(sqrt(c_count))); - ELSIF c_count > 10000 THEN - EXECUTE format('CREATE INDEX %I - ON public.%I USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s);', - r.collection || '_ix', r.collection, c_count / 1000); - END IF; - END LOOP; -END $$; - --- Copy data from the old table to the new table -DO $$ -DECLARE - r record; -BEGIN - FOR r IN SELECT DISTINCT collection FROM sk_memory_table LOOP - EXECUTE format('INSERT INTO public.%I (key, metadata, embedding, timestamp) - SELECT key, metadata::JSONB, embedding, to_timestamp(timestamp / 1000.0) AT TIME ZONE ''UTC'' - FROM sk_memory_table WHERE collection = %L AND key <> '''';', r.collection, r.collection); - END LOOP; -END $$; - --- Drop old table (After ensuring successful execution, you can remove the following comments to remove sk_memory_table.) --- DROP TABLE IF EXISTS sk_memory_table; -``` From 89027fc3d9fd8185660373674d990b7fc0f73fc0 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 10:55:43 -0400 Subject: [PATCH 17/62] Fix docstring --- .../PostgresGenericDataModelMapper.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs index 268a3538b365..eb5b857185a4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -14,7 +14,7 @@ internal sealed class PostgresGenericDataModelMapper : IVectorStoreRecordM /// /// Initializes a new instance of the class. /// /// - /// A that defines the schema of the data in the database. + /// with helpers for reading vector store model properties and their attributes. public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyReader) { Verify.NotNull(propertyReader); From 8f45d9c902580886a7b246885ad44180bb5c995d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:07:19 -0400 Subject: [PATCH 18/62] Use parameter for tableName --- .../PostgresVectorStoreCollectionSqlBuilder.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 0d5f2d15270e..7ad4b0270328 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -20,13 +20,16 @@ internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCol public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) { return new PostgresSqlCommandInfo( - commandText: $@" + commandText: @" SELECT table_name FROM information_schema.tables WHERE table_schema = $1 AND table_type = 'BASE TABLE' - AND table_name = '{tableName}'", - parameters: [new NpgsqlParameter() { Value = schema }] + AND table_name = $2", + parameters: [ + new NpgsqlParameter() { Value = schema }, + new NpgsqlParameter() { Value = tableName } + ] ); } From 48811bdc39ec352ecce7e9a8f04a2e2b4e691737 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:08:09 -0400 Subject: [PATCH 19/62] Fix support for DateTime, DateTimeOffset --- .../PostgresConstants.cs | 2 + ...ostgresVectorStoreRecordPropertyMapping.cs | 3 + .../PostgresHotel.cs | 4 + .../Memory/Postgres/PostgresHotel.cs | 4 + ...ostgresVectorStoreRecordCollectionTests.cs | 74 ++++++++++++------- 5 files changed, 59 insertions(+), 28 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 74dfc1e5cfb1..b77c3eca35d0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -35,6 +35,8 @@ internal static class PostgresConstants typeof(decimal), typeof(decimal?), typeof(string), + typeof(DateTime), + typeof(DateTime?), typeof(DateTimeOffset), typeof(DateTimeOffset?), typeof(Guid), diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index c26c47f3127d..008a64cca73e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -88,6 +88,7 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => Type t when t == typeof(string) => reader.GetString(propertyIndex), Type t when t == typeof(byte[]) => reader.GetFieldValue(propertyIndex), Type t when t == typeof(DateTime) || t == typeof(DateTime?) => reader.GetDateTime(propertyIndex), + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => reader.GetFieldValue(propertyIndex), Type t when t == typeof(Guid) => reader.GetFieldValue(propertyIndex), _ => reader.GetValue(propertyIndex) }; @@ -106,6 +107,7 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => Type t when t == typeof(string) => NpgsqlDbType.Text, Type t when t == typeof(byte[]) => NpgsqlDbType.Bytea, Type t when t == typeof(DateTime) || t == typeof(DateTime?) => NpgsqlDbType.Timestamp, + Type t when t == typeof(DateTimeOffset) || t == typeof(DateTimeOffset?) => NpgsqlDbType.TimestampTz, Type t when t == typeof(Guid) => NpgsqlDbType.Uuid, _ => null }; @@ -129,6 +131,7 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property Type t when t == typeof(string) => ("TEXT", true), Type t when t == typeof(byte[]) => ("BYTEA", true), Type t when t == typeof(DateTime) => ("TIMESTAMP", false), + Type t when t == typeof(DateTimeOffset) => ("TIMESTAMPTZ", false), Type t when t == typeof(Guid) => ("UUID", false), _ => (null, false) }; diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs index daeb83fbb13c..e8e84badf292 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresHotel.cs @@ -40,6 +40,10 @@ public record PostgresHotel() [VectorStoreRecordData] public string Description { get; set; } + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + /// A vector field. [VectorStoreRecordVector(4, IndexKind.Hnsw, DistanceFunction.ManhattanDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs index ff88638e8bdc..48a8f5f36a41 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresHotel.cs @@ -47,6 +47,10 @@ public record PostgresHotel() [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.EuclideanDistance, IndexKind: IndexKind.Hnsw)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } + public DateTime CreatedAt { get; set; } = DateTime.UtcNow; + + public DateTimeOffset UpdatedAt { get; set; } = DateTimeOffset.UtcNow; + public PostgresHotel(T key) : this() { this.HotelId = key; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 1e696093c139..df45044aff54 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -55,40 +55,49 @@ public async Task CollectionCanUpsertAndGetAsync() await sut.CreateCollectionAsync(); + var writtenHotel1 = new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }; + var writtenHotel2 = new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }; + try { // Act - await sut.UpsertAsync(new PostgresHotel { HotelId = 1, HotelName = "Hotel 1", HotelCode = 1, ParkingIncluded = true, HotelRating = 4.5f, Tags = ["tag1", "tag2"] }); - await sut.UpsertAsync(new PostgresHotel { HotelId = 2, HotelName = "Hotel 2", HotelCode = 2, ParkingIncluded = false, HotelRating = 2.5f, ListInts = [1, 2] }); - var hotel1 = await sut.GetAsync(1); - var hotel2 = await sut.GetAsync(2); + await sut.UpsertAsync(writtenHotel1); + + await sut.UpsertAsync(writtenHotel2); + + var fetchedHotel1 = await sut.GetAsync(1); + var fetchedHotel2 = await sut.GetAsync(2); // Assert - Assert.NotNull(hotel1); - Assert.Equal(1, hotel1!.HotelId); - Assert.Equal("Hotel 1", hotel1!.HotelName); - Assert.Equal(1, hotel1!.HotelCode); - Assert.True(hotel1!.ParkingIncluded); - Assert.Equal(4.5f, hotel1!.HotelRating); - Assert.NotNull(hotel1!.Tags); - Assert.Equal(2, hotel1!.Tags!.Count); - Assert.Equal("tag1", hotel1!.Tags![0]); - Assert.Equal("tag2", hotel1!.Tags![1]); - Assert.Null(hotel1!.ListInts); - - Assert.NotNull(hotel2); - Assert.Equal(2, hotel2!.HotelId); - Assert.Equal("Hotel 2", hotel2!.HotelName); - Assert.Equal(2, hotel2!.HotelCode); - Assert.False(hotel2!.ParkingIncluded); - Assert.Equal(2.5f, hotel2!.HotelRating); - Assert.NotNull(hotel2!.Tags); - Assert.Empty(hotel2!.Tags); - Assert.NotNull(hotel2!.ListInts); - Assert.Equal(2, hotel2!.ListInts!.Count); - Assert.Equal(1, hotel2!.ListInts![0]); - Assert.Equal(2, hotel2!.ListInts![1]); + Assert.NotNull(fetchedHotel1); + Assert.Equal(1, fetchedHotel1!.HotelId); + Assert.Equal("Hotel 1", fetchedHotel1!.HotelName); + Assert.Equal(1, fetchedHotel1!.HotelCode); + Assert.True(fetchedHotel1!.ParkingIncluded); + Assert.Equal(4.5f, fetchedHotel1!.HotelRating); + Assert.NotNull(fetchedHotel1!.Tags); + Assert.Equal(2, fetchedHotel1!.Tags!.Count); + Assert.Equal("tag1", fetchedHotel1!.Tags![0]); + Assert.Equal("tag2", fetchedHotel1!.Tags![1]); + Assert.Null(fetchedHotel1!.ListInts); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.CreatedAt), TruncateMilliseconds(writtenHotel1.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel1.UpdatedAt), TruncateMilliseconds(writtenHotel1.UpdatedAt)); + + Assert.NotNull(fetchedHotel2); + Assert.Equal(2, fetchedHotel2!.HotelId); + Assert.Equal("Hotel 2", fetchedHotel2!.HotelName); + Assert.Equal(2, fetchedHotel2!.HotelCode); + Assert.False(fetchedHotel2!.ParkingIncluded); + Assert.Equal(2.5f, fetchedHotel2!.HotelRating); + Assert.NotNull(fetchedHotel2!.Tags); + Assert.Empty(fetchedHotel2!.Tags); + Assert.NotNull(fetchedHotel2!.ListInts); + Assert.Equal(2, fetchedHotel2!.ListInts!.Count); + Assert.Equal(1, fetchedHotel2!.ListInts![0]); + Assert.Equal(2, fetchedHotel2!.ListInts![1]); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.CreatedAt), TruncateMilliseconds(writtenHotel2.CreatedAt)); + Assert.Equal(TruncateMilliseconds(fetchedHotel2.UpdatedAt), TruncateMilliseconds(writtenHotel2.UpdatedAt)); } finally { @@ -421,6 +430,15 @@ private dynamic CreateRecord(Type idType, object key) record.Tags = new List { "tag1", "tag2" }; return record; } + private static DateTime TruncateMilliseconds(DateTime dateTime) + { + return new DateTime(dateTime.Ticks - (dateTime.Ticks % TimeSpan.TicksPerSecond), dateTime.Kind); + } + + private static DateTimeOffset TruncateMilliseconds(DateTimeOffset dateTimeOffset) + { + return new DateTimeOffset(dateTimeOffset.Ticks - (dateTimeOffset.Ticks % TimeSpan.TicksPerSecond), dateTimeOffset.Offset); + } #endregion From 1d6082d6a7f3375e25ad03f704a1ff7c4f01c99c Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:09:17 -0400 Subject: [PATCH 20/62] Fix warnings in test --- .../PostgresGenericDataModelMapperTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs index 77a6d955469f..99d9e6074428 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -10,7 +10,7 @@ namespace SemanticKernel.Connectors.Postgres.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class PostgresGenericDataModelMapperTests { @@ -123,7 +123,7 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool var definition = GetRecordDefinition(); var propertyReader = GetPropertyReader>(definition); - IVectorStoreRecordMapper, Dictionary> mapper = new PostgresGenericDataModelMapper(propertyReader); + var mapper = new PostgresGenericDataModelMapper(propertyReader); // Act var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); From eb0a683b10c483db3473a9e46b351c4db85ac824 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:09:40 -0400 Subject: [PATCH 21/62] Remove kernel extensions, improve service extensions --- ...rStore_VectorSearch_MultiStore_Postgres.cs | 2 +- .../PostgresKernelBuilderExtensions.cs | 84 ----------- .../PostgresServiceCollectionExtensions.cs | 136 +++++++++++++++--- .../Postgres/PostgresVectorStoreFixture.cs | 16 +-- 4 files changed, 118 insertions(+), 120 deletions(-) delete mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs index 29bb11cfe163..045e5b2bb5e2 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs @@ -44,7 +44,7 @@ public async Task ExampleWithDIAsync() // Initialize the Postgres docker container via the fixtures and register the Postgres VectorStore. await PostgresFixture.ManualInitializeAsync(); - kernelBuilder.AddPostgresVectorStore(ConnectionString); + kernelBuilder.Services.AddPostgresVectorStore(ConnectionString); // Register the test output helper common processor with the DI container. kernelBuilder.Services.AddSingleton(this.Output); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs deleted file mode 100644 index dda47d1a9930..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresKernelBuilderExtensions.cs +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.Postgres; - -namespace Microsoft.SemanticKernel; - -/// -/// Extension methods to register the Postgres instances on the . -/// -public static class PostgresKernelBuilderExtensions -{ - /// - /// Register a Postgres with the specified service ID and where is retrieved from the dependency injection container. - /// - /// The builder to register the on. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// The kernel builder. - public static IKernelBuilder AddPostgresVectorStore(this IKernelBuilder builder, PostgresVectorStoreOptions? options = default, string? serviceId = default) - { - builder.Services.AddPostgresVectorStore(options, serviceId); - return builder; - } - /// - /// Register a Postgres with the specified service ID and where is constructed using the provided parameters. - /// - /// The builder to register the on. - /// The Postgres connection string. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// The kernel builder. - public static IKernelBuilder AddPostgresVectorStore(this IKernelBuilder builder, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) - { - builder.Services.AddPostgresVectorStore(connectionString, options, serviceId); - return builder; - } - - /// - /// Register a Postgres and with the specified service ID - /// and where the Postgres is retrieved from the dependency injection container. - /// - /// The type of the key. - /// The type of the record. - /// The builder to register the on. - /// The name of the collection. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// The kernel builder. - public static IKernelBuilder AddPostgresVectorStoreRecordCollection( - this IKernelBuilder builder, - string collectionName, - PostgresVectorStoreRecordCollectionOptions? options = default, - string? serviceId = default) - where TKey : notnull - { - builder.Services.AddPostgresVectorStoreRecordCollection(collectionName, options, serviceId); - return builder; - } - - /// - /// Register a Postgres and with the specified service ID - /// and where the Postgres is constructed using the provided parameters. - /// - /// The type of the key. - /// The type of the record. - /// The builder to register the on. - /// The name of the collection. - /// The Postgres connection string. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// The kernel builder. - public static IKernelBuilder AddPostgresVectorStoreRecordCollection( - this IKernelBuilder builder, - string collectionName, - string connectionString, - PostgresVectorStoreRecordCollectionOptions? options = default, - string? serviceId = default) - where TKey : notnull - { - builder.Services.AddPostgresVectorStoreRecordCollection(collectionName, connectionString, options, serviceId); - return builder; - } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs index 375fb709447b..aa7274aae249 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -13,21 +13,31 @@ namespace Microsoft.SemanticKernel; public static class PostgresServiceCollectionExtensions { /// - /// Register a with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// Register a with the specified service ID and where NpgsqlDataSource is constructed using the provided parameters. /// /// The to register the on. + /// Postgres database connection string. /// The schema to use. /// An optional service id to use as the service key. /// The service collection. - public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) + public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string connectionString, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) { - // Since we are not constructing the client, add the IVectorStore as transient, since we - // cannot make assumptions about how client is being managed. - services.AddKeyedTransient( + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; + // Register NpgsqlDataSource to ensure proper disposal. + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => + { + NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); + dataSourceBuilder.UseVector(); + return dataSourceBuilder.Build(); + }); + + services.AddKeyedSingleton( serviceId, (sp, obj) => { - var dataSource = sp.GetRequiredService(); + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); return new PostgresVectorStoreDbClient(dataSource, schema); }); @@ -35,25 +45,39 @@ public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCol } /// - /// Register a with the specified service ID and where NpgsqlDataSource is constructed using the provided parameters. + /// Register a with the specified service ID and where NpgsqlDataSource is passed in as parameter. /// /// The to register the on. - /// Postgres database connection string. + /// The data source to use. /// The schema to use. /// An optional service id to use as the service key. /// The service collection. - public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string connectionString, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) + public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) { - // Register NpgsqlDataSource to ensure proper disposal. - services.AddSingleton( - sp => + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how client is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => { - NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); - dataSourceBuilder.UseVector(); - return dataSourceBuilder.Build(); + return new PostgresVectorStoreDbClient(dataSource, schema); }); - services.AddKeyedSingleton( + return services; + } + + /// + /// Register a with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. + /// + /// The to register the on. + /// The schema to use. + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) + { + // Since we are not constructing the client, add the IVectorStore as transient, since we + // cannot make assumptions about how client is being managed. + services.AddKeyedTransient( serviceId, (sp, obj) => { @@ -100,9 +124,11 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection /// The service collection. public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, string connectionString, PostgresVectorStoreOptions? options = default, string? serviceId = default) { + string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; // Register NpgsqlDataSource to ensure proper disposal. - services.AddSingleton( - sp => + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => { NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); dataSourceBuilder.UseVector(); @@ -113,7 +139,34 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection serviceId, (sp, obj) => { - var dataSource = sp.GetRequiredService(); + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + var client = new PostgresVectorStoreDbClient(dataSource); + var selectedOptions = options ?? sp.GetService(); + + return new PostgresVectorStore( + client, + selectedOptions); + }); + + return services; + } + + /// + /// Register a Postgres with the specified service ID and where is constructed using the NpgsqlDataSource. + /// + /// The to register the on. + /// The data source to use. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// The service collection. + public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, NpgsqlDataSource dataSource, PostgresVectorStoreOptions? options = default, string? serviceId = default) + { + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how data source is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { var client = new PostgresVectorStoreDbClient(dataSource); var selectedOptions = options ?? sp.GetService(); @@ -178,9 +231,11 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection( - sp => + services.AddKeyedSingleton( + npgsqlServiceId, + (sp, obj) => { NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); dataSourceBuilder.UseVector(); @@ -191,7 +246,44 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection { - var dataSource = sp.GetRequiredService(); + var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); + var client = new PostgresVectorStoreDbClient(dataSource); + var selectedOptions = options ?? sp.GetService>(); + + return (new PostgresVectorStoreRecordCollection(client, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Register a Postgres and with the specified service ID + /// and where the Postgres is constructed using the data source. + /// + /// The type of the key. + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// The data source to use. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddPostgresVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + NpgsqlDataSource dataSource, + PostgresVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + where TKey : notnull + { + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how data source is being managed. + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { var client = new PostgresVectorStoreDbClient(dataSource); var selectedOptions = options ?? sp.GetService>(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index 751fe2cb3688..2346a0eab9af 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -5,7 +5,6 @@ using System.Threading.Tasks; using Docker.DotNet; using Docker.DotNet.Models; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Npgsql; @@ -104,9 +103,9 @@ public async Task InitializeAsync() this._dataSource = dataSourceBuilder.Build(); - this.Kernel = Kernel.CreateBuilder() - .AddPostgresVectorStore(connectionStringBuilder.ToString()) - .Build(); + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.Services.AddPostgresVectorStore(this._dataSource); + this.Kernel = kernelBuilder.Build(); // Wait for the postgres container to be ready and create the test database using the initial data source. var initialDataSource = NpgsqlDataSource.Create(this._connectionString); @@ -175,15 +174,6 @@ DescriptionEmbedding VECTOR(4) NOT NULL, /// An async task. public async Task DisposeAsync() { - if (this.Kernel != null) - { - var dataSource = this.Kernel.Services.GetService(); - if (dataSource != null) - { - dataSource.Dispose(); - } - } - if (this._dataSource != null) { this._dataSource.Dispose(); From a66d835c4c1e5469ecad62d53484c004bd069e9a Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:16:27 -0400 Subject: [PATCH 22/62] Make PostgresSqlCommandInfo internal --- .../Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs index fb8c892d6bf1..68380d37ca2a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Represents a SQL command for Postgres. /// -public class PostgresSqlCommandInfo +internal class PostgresSqlCommandInfo { /// /// Gets or sets the SQL command text. From 53f1009246583a2a6998b5e42c73ccfb9d5c9208 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:23:25 -0400 Subject: [PATCH 23/62] Default to a Hnsw index --- .../Connectors.Memory.Postgres/PostgresConstants.cs | 4 ++++ .../PostgresVectorStoreRecordCollection.cs | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index b77c3eca35d0..c1ad24065f0a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -66,4 +67,7 @@ internal static class PostgresConstants /// The name of the column that returns distance value in the database. /// It is used in the similarity search query. Must not conflict with model property. public const string DistanceColumnName = "sk_pg_distance"; + + /// The default index kind. + public const string DefaultIndexKind = IndexKind.Hnsw; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index e55d1645811c..9d54a9f98c98 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -271,8 +271,10 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT // Create indexes for vector properties. foreach (var vectorProperty in this._propertyReader.VectorProperties) { + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; + // Ensure the dimensionality of the vector is supported for indexing. - if (vectorProperty.IndexKind == IndexKind.Hnsw) + if (indexKind == IndexKind.Hnsw) { if (vectorProperty.Dimensions > 2000) { From 08ea55f9d9370c921e8960bdaf824a3ba5a7af97 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:47:32 -0400 Subject: [PATCH 24/62] Default to cosine distance --- .../PostgresConstants.cs | 3 ++ ...PostgresVectorStoreCollectionSqlBuilder.cs | 8 ++-- ...resVectorStoreCollectionSqlBuilderTests.cs | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index c1ad24065f0a..6fa2d5a468d5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -70,4 +70,7 @@ internal static class PostgresConstants /// The default index kind. public const string DefaultIndexKind = IndexKind.Hnsw; + + /// The default distance function. + public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 7ad4b0270328..b7ba658e8081 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -130,14 +130,15 @@ public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, strin _ => throw new NotSupportedException($"Index kind '{vectorProperty.IndexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") }; - var indexOps = vectorProperty.DistanceFunction switch + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + + var indexOps = distanceFunction switch { DistanceFunction.CosineDistance => "vector_cosine_ops", DistanceFunction.CosineSimilarity => "vector_cosine_ops", DistanceFunction.DotProductSimilarity => "vector_ip_ops", DistanceFunction.EuclideanDistance => "vector_l2_ops", DistanceFunction.ManhattanDistance => "vector_l1_ops", - null => throw new ArgumentException("Distance function must be specified for HNSW index."), _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") }; @@ -352,7 +353,8 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( .Select(column => $"\"{column}\"") ); - var distanceOp = vectorProperty.DistanceFunction switch + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + var distanceOp = distanceFunction switch { DistanceFunction.CosineDistance => "<=>", DistanceFunction.CosineSimilarity => "<=>", diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index 70f52ff7e630..6d2ab6d4bf6c 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -75,6 +75,54 @@ public void TestBuildCreateTableCommand(bool ifNotExists) this._output.WriteLine(cmdInfo.CommandText); } + [Theory] + [InlineData(IndexKind.Hnsw, null)] + [InlineData(IndexKind.IvfFlat, null)] + [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] + [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] + public void TestBuildCreateIndexCommand(string indexKind, string? distanceFunction) + { + var builder = new PostgresVectorStoreCollectionSqlBuilder(); + + var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) + { + Dimensions = 10, + IndexKind = indexKind, + DistanceFunction = distanceFunction, + }; + + if (indexKind != IndexKind.Hnsw) + { + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorProperty)); + return; + } + + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorProperty); + + // Check for expected properties; integration tests will validate the actual SQL. + Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); + Assert.Contains("ON public.\"testcollection\" USING hnsw (\"embedding1\" ", cmdInfo.CommandText); + if (distanceFunction == null) + { + // Check for distance function defaults to cosine distance + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.CosineDistance) + { + Assert.Contains("vector_cosine_ops)", cmdInfo.CommandText); + } + else if (distanceFunction == DistanceFunction.EuclideanDistance) + { + Assert.Contains("vector_l2_ops)", cmdInfo.CommandText); + } + else + { + throw new NotImplementedException($"Test case for Distance function {distanceFunction} is not implemented."); + } + // Output + this._output.WriteLine(cmdInfo.CommandText); + } + [Fact] public void TestBuildDropTableCommand() { From 319648bdef11004c2f98756a974f2ddd2e4b2e31 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:48:59 -0400 Subject: [PATCH 25/62] Consistently use includeVectors --- .../IPostgresVectorStoreCollectionSqlBuilder.cs | 4 ++-- .../PostgresVectorStoreCollectionSqlBuilder.cs | 2 +- .../PostgresVectorStoreCollectionSqlBuilderTests.cs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 421b36a1ab05..8d8364eb903d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -127,8 +127,8 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// The vector to match. /// The filter conditions for the query. /// The number of records to skip. - /// Specifies whether to include embeddings in the result. + /// Specifies whether to include vectors in the result. /// The maximum number of records to return. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool withEmbeddings, int limit); + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool includeVectors, int limit); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index b7ba658e8081..375dfe97e14f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -345,7 +345,7 @@ DELETE FROM {schema}.""{tableName}"" /// public PostgresSqlCommandInfo BuildGetNearestMatchCommand( string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, - VectorSearchFilter? filter, int? skip, bool withEmbeddings, int limit) + VectorSearchFilter? filter, int? skip, bool includeVectors, int limit) { var columns = string.Join(" ,", properties diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index 6d2ab6d4bf6c..bfe15e025b8f 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -409,7 +409,7 @@ public void TestBuildGetNearestMatchCommand() vectorValue: vector, filter: null, skip: null, - withEmbeddings: true, + includeVectors: true, limit: 10); // Assert From 5b52bdc98b6b5e2d5bab1892d6fe529623f5b56f Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:50:33 -0400 Subject: [PATCH 26/62] Simplify AsyncEnumerable return --- .../Connectors.Memory.Postgres/PostgresVectorStore.cs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 5a7395c8d40b..f00f4c11ccf5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -57,12 +57,9 @@ public PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, Postgr } /// - public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) { - await foreach (string collection in this._postgresClient.GetTablesAsync(cancellationToken).ConfigureAwait(false)) - { - yield return collection; - } + return this._postgresClient.GetTablesAsync(cancellationToken); } /// From cd845ee5debfbc86e4bb3ddd3084cd6e160ccfdd Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 11:55:19 -0400 Subject: [PATCH 27/62] Pass properties instead of full definition --- .../IPostgresVectorStoreDbClient.cs | 4 ++-- .../Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs | 4 ++-- .../PostgresVectorStoreRecordCollection.cs | 2 +- .../PostgresVectorStoreRecordCollectionTests.cs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 6da6ecc0ffc9..a2db2157b84e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -31,11 +31,11 @@ public interface IPostgresVectorStoreDbClient /// Create a table. /// /// The name assigned to a table of entries. - /// The record definition of the table. + /// The properties of the record definition that define the table. /// Specifies whether to include IF NOT EXISTS in the command. /// The to monitor for cancellation requests. The default is . /// - Task CreateTableAsync(string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true, CancellationToken cancellationToken = default); + Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default); /// /// Create a vector index. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 1d8d913687e7..f13bd351f502 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -64,13 +64,13 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca } /// - public async Task CreateTableAsync(string tableName, VectorStoreRecordDefinition recordDefinition, bool ifNotExists = true, CancellationToken cancellationToken = default) + public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) { NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, recordDefinition.Properties, ifNotExists); + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 9d54a9f98c98..282528d9b3ce 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -267,7 +267,7 @@ public Task> VectorizedSearchAsync(TVector private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) { - await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition, ifNotExists, cancellationToken).ConfigureAwait(false); + await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken).ConfigureAwait(false); // Create indexes for vector properties. foreach (var vectorProperty in this._propertyReader.VectorProperties) { diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs index 48b7ce4d4a87..3ca268033aa0 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -159,7 +159,7 @@ public async Task CreateCollectionAsyncLogsWarningWhenDimensionsTooLargeAsync() options: new PostgresVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = recordDefinition } ); - this._postgresClientMock.Setup(x => x.CreateTableAsync(TestCollectionName, It.IsAny(), It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); + this._postgresClientMock.Setup(x => x.CreateTableAsync(TestCollectionName, It.IsAny>(), It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); // Act await sut.CreateCollectionAsync(cancellationToken: this._testCancellationToken); From 1d09a21ec25af333384fac174687cc0cb05c93c8 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:04:26 -0400 Subject: [PATCH 28/62] Throw instead of log for too high dimensionality --- .../PostgresVectorStoreRecordCollection.cs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 282528d9b3ce..b63879437f8f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -278,12 +278,9 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT { if (vectorProperty.Dimensions > 2000) { - this._logger.LogWarning( - "The provided vector property {VectorPropertyName} has {Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000. Index not created.", - vectorProperty.DataModelPropertyName, - vectorProperty.Dimensions + throw new NotSupportedException( + $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000. Index not created." ); - continue; } } await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); From 74e9757eccddc8efa7bc484524232c9522cfe613 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:04:50 -0400 Subject: [PATCH 29/62] Remove DefaultVectorSize --- .../Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs index c7959f950aaf..013f1810e146 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreOptions.cs @@ -7,11 +7,6 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// public sealed class PostgresVectorStoreOptions { - /// - /// Gets or sets the default vector size to use when creating a new vector. - /// - public int DefaultVectorSize { get; init; } = 100; - /// /// Gets or sets the database schema. /// From ad5628c8682fdfea05a51f2a18c8444e77be788c Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:05:02 -0400 Subject: [PATCH 30/62] Remove unused using statements --- .../Connectors.Memory.Postgres/PostgresVectorStore.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index f00f4c11ccf5..63e48c2b3c4b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -2,9 +2,7 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.VectorData; using Npgsql; From dbf1aefdcbe4a9e60526535a529346aad8042d24 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:10:12 -0400 Subject: [PATCH 31/62] Remove VectorStore constructor that creates datsaource --- ...VectorStore_VectorSearch_MultiStore_Postgres.cs | 6 +++++- .../PostgresVectorStore.cs | 14 -------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs index 045e5b2bb5e2..a15cc696e094 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Postgres.cs @@ -6,6 +6,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; namespace Memory; @@ -73,7 +74,10 @@ public async Task ExampleWithoutDIAsync() // Initialize the Postgres docker container via the fixtures and construct the Postgres VectorStore. await PostgresFixture.ManualInitializeAsync(); - var vectorStore = new PostgresVectorStore(ConnectionString); + var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString); + dataSourceBuilder.UseVector(); + await using var dataSource = dataSourceBuilder.Build(); + var vectorStore = new PostgresVectorStore(dataSource); // Create the common processor that works for any vector store. var processor = new VectorStore_VectorSearch_MultiStore_Common(vectorStore, textEmbeddingGenerationService, this.Output); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 63e48c2b3c4b..94bc7efb24cc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -17,20 +17,6 @@ public class PostgresVectorStore : IVectorStore private readonly NpgsqlDataSource? _dataSource; private readonly PostgresVectorStoreOptions? _options; - /// - /// Initializes a new instance of the class. - /// - /// Postgres database connection string. - /// Optional configuration options for this class - public PostgresVectorStore(string connectionString, PostgresVectorStoreOptions? options = default) - { - NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); - dataSourceBuilder.UseVector(); - this._dataSource = dataSourceBuilder.Build(); - this._options = options ?? new PostgresVectorStoreOptions(); - this._postgresClient = new PostgresVectorStoreDbClient(this._dataSource, this._options.Schema); - } - /// /// Initializes a new instance of the class. /// From a355bf787c8a3e386b78f8a8c0a2feaa02eedca9 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:12:22 -0400 Subject: [PATCH 32/62] Fix duplicate mapper call --- .../PostgresVectorStoreRecordCollection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index b63879437f8f..5d956a69e899 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -141,7 +141,7 @@ public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options Verify.NotNull(keyObj); TKey key = (TKey)keyObj!; - await this._client.UpsertAsync(this.CollectionName, this._mapper?.MapFromDataToStorageModel(record) ?? throw new InvalidOperationException("Failed to map record to storage model."), this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); return key; } From e499a802269ce376ff41bc5e5dc5e6ae690631af Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:13:03 -0400 Subject: [PATCH 33/62] Fix docstring typo --- .../PostgresVectorStoreRecordCollection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 5d956a69e899..c48f208d876c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -85,7 +85,7 @@ public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, // Resolve mapper. // First, if someone has provided a custom mapper, use that. // If they didn't provide a custom mapper, and the record type is the generic data model, use the built in mapper for that. - // Otherwise, don't set the mapper, and we'll default to just using Azure AI Search's built in json serialization and deserialization. + // Otherwise, use our own default mapper implementation for all other data models. if (this._options.DictionaryCustomMapper is not null) { this._mapper = this._options.DictionaryCustomMapper; From c95e2b32fabf092364b2cbe8abe877be87bca582 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:15:46 -0400 Subject: [PATCH 34/62] Comment clarifying that multiple keys should be previously validated --- .../PostgresVectorStoreCollectionSqlBuilder.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index 375dfe97e14f..d3119b5fce0a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -64,6 +64,8 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl { if (keyProperty != null) { + // Should be impossible, as property reader should have already validated that + // multiple key properties are not allowed. throw new ArgumentException("Record definition cannot have more than one key property."); } keyProperty = keyProp; From 9d972b3fe90827d9134bea41ab2569f2e23e81c5 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:34:25 -0400 Subject: [PATCH 35/62] Refactor ExecuteNonQueryAsync calls to reduce code dupe --- .../PostgresVectorStoreDbClient.cs | 81 ++++++------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index f13bd351f502..3f0761ebd6cc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -66,14 +66,8 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca /// public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// @@ -84,53 +78,29 @@ public async Task CreateVectorIndexAsync(string tableName, VectorStoreRecordVect return; } - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, vectorProperty); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, vectorProperty); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// public async Task DeleteTableAsync(string tableName, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildDropTableCommand(this._schema, tableName); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// public async Task UpsertAsync(string tableName, Dictionary row, string keyColumn, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildUpsertCommand(this._schema, tableName, keyColumn, row); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// public async Task UpsertBatchAsync(string tableName, IEnumerable> rows, string keyColumn, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildUpsertBatchCommand(this._schema, tableName, keyColumn, rows.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// @@ -173,14 +143,8 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable public async Task DeleteAsync(string tableName, string keyColumn, TKey key, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildDeleteCommand(this._schema, tableName, keyColumn, key); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } /// @@ -206,14 +170,8 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key /// public async Task DeleteBatchAsync(string tableName, string keyColumn, IEnumerable keys, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); - - await using (connection) - { - var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); - using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); - await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + var commandInfo = this._sqlBuilder.BuildDeleteBatchCommand(this._schema, tableName, keyColumn, keys.ToList()); + await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); } #region internal =============================================================================== @@ -256,5 +214,16 @@ internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) return storageModel; } + private async Task ExecuteNonQueryAsync(PostgresSqlCommandInfo commandInfo, CancellationToken cancellationToken) + { + NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + #endregion } From 6eb3793cd565e31f48128ec24f3ea6802a6b8472 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 12:40:28 -0400 Subject: [PATCH 36/62] Forward Schema option. Also _options shouldn't be nullable, and use PostgresConstants.SupportedKeyTypes for key type check. --- .../PostgresVectorStore.cs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 94bc7efb24cc..3224b5bccb9e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -15,7 +15,7 @@ public class PostgresVectorStore : IVectorStore { private readonly IPostgresVectorStoreDbClient _postgresClient; private readonly NpgsqlDataSource? _dataSource; - private readonly PostgresVectorStoreOptions? _options; + private readonly PostgresVectorStoreOptions _options; /// /// Initializes a new instance of the class. @@ -50,17 +50,12 @@ public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cance public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull { - // Support short, int, long, Guid, and string keys - if (typeof(TKey) != typeof(short) && - typeof(TKey) != typeof(int) && - typeof(TKey) != typeof(long) && - typeof(TKey) != typeof(Guid) && - typeof(TKey) != typeof(string)) + if (!PostgresConstants.SupportedKeyTypes.Contains(typeof(TKey))) { - throw new NotSupportedException($"Only short, int, long, {nameof(Guid)}, and {nameof(String)} keys are supported."); + throw new NotSupportedException($"Unsupported key type: {typeof(TKey)}"); } - if (this._options?.VectorStoreCollectionFactory is not null) + if (this._options.VectorStoreCollectionFactory is not null) { return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient, name, vectorStoreRecordDefinition); } @@ -68,7 +63,7 @@ public IVectorStoreRecordCollection GetCollection( var recordCollection = new PostgresVectorStoreRecordCollection( this._postgresClient, name, - new PostgresVectorStoreRecordCollectionOptions() { VectorStoreRecordDefinition = vectorStoreRecordDefinition } + new PostgresVectorStoreRecordCollectionOptions() { Schema = this._options.Schema, VectorStoreRecordDefinition = vectorStoreRecordDefinition } ); return recordCollection as IVectorStoreRecordCollection ?? throw new InvalidOperationException("Failed to cast record collection."); From ed59fedec49b545037bfbcdb7c745bf0f5bd26be Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 16:03:44 -0400 Subject: [PATCH 37/62] Make PostgresVectorStoreDbClient internal --- .../IPostgresVectorStoreDbClient.cs | 8 +++++++- ...tgresVectorStoreRecordCollectionFactory.cs | 5 +++-- .../PostgresVectorStore.cs | 4 ++-- .../PostgresVectorStoreDbClient.cs | 4 +++- .../PostgresVectorStoreRecordCollection.cs | 20 +++++++++++++++++-- .../PostgresVectorStoreTests.cs | 4 +++- 6 files changed, 36 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index a2db2157b84e..7f047d0d9f07 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -4,6 +4,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Npgsql; using Pgvector; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -11,8 +12,13 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// /// Internal interface for client managing postgres database operations. /// -public interface IPostgresVectorStoreDbClient +internal interface IPostgresVectorStoreDbClient { + /// + /// The used to connect to the database. + /// + public NpgsqlDataSource DataSource { get; } + /// /// Check if a table exists. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs index 98b1a344c194..5bf0d9cad789 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreRecordCollectionFactory.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.Extensions.VectorData; +using Npgsql; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -14,10 +15,10 @@ public interface IPostgresVectorStoreRecordCollectionFactory /// /// The data type of the record key. /// The data model to use for adding, updating and retrieving data from storage. - /// The Postgres client. + /// The Postgres data source. /// The name of the collection to connect to. /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . - IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(NpgsqlDataSource dataSource, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) where TKey : notnull; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index 3224b5bccb9e..d44b7e44d690 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -34,7 +34,7 @@ public PostgresVectorStore(NpgsqlDataSource dataSource, PostgresVectorStoreOptio /// /// An instance of . /// Optional configuration options for this class - public PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, PostgresVectorStoreOptions? options = default) + internal PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, PostgresVectorStoreOptions? options = default) { this._postgresClient = postgresDbClient; this._options = options ?? new PostgresVectorStoreOptions(); @@ -57,7 +57,7 @@ public IVectorStoreRecordCollection GetCollection( if (this._options.VectorStoreCollectionFactory is not null) { - return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient, name, vectorStoreRecordDefinition); + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._postgresClient.DataSource, name, vectorStoreRecordDefinition); } var recordCollection = new PostgresVectorStoreRecordCollection( diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 3f0761ebd6cc..095ab2d6aaa3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -20,13 +20,15 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; /// Postgres data source. /// Schema of collection tables. [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] -public class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient +internal class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient { private readonly NpgsqlDataSource _dataSource = dataSource; private readonly string _schema = schema; private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); + public NpgsqlDataSource DataSource => this._dataSource; + /// public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index c48f208d876c..5efa5ce771de 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.VectorData; +using Npgsql; namespace Microsoft.SemanticKernel.Connectors.Postgres; @@ -49,11 +50,26 @@ public sealed class PostgresVectorStoreRecordCollection : IVector /// /// Initializes a new instance of the class. /// - /// The Postgres client used to interact with the database. + /// The data source to use for connecting to the database. /// The name of the collection. /// Optional configuration options for this class. /// The logger to use for logging. - public PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, + public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, + ILogger>? logger = null) : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options, logger) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The client to use for interacting with the database. + /// The name of the collection. + /// Optional configuration options for this class. + /// The logger to use for logging. + /// + /// This constructor is internal. It allows internal code to create an instance of this class with a custom client. + /// + internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, ILogger>? logger = null) { // Verify. diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs index 3b3d407fc035..7feecff44f55 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; using Moq; +using Npgsql; using Xunit; namespace SemanticKernel.Connectors.Postgres.UnitTests; @@ -58,8 +59,9 @@ public void GetCollectionCallsFactoryIfProvided() var factoryMock = new Mock(MockBehavior.Strict); var collectionMock = new Mock>>(MockBehavior.Strict); var clientMock = new Mock(MockBehavior.Strict); + clientMock.Setup(x => x.DataSource).Returns(null); factoryMock - .Setup(x => x.CreateVectorStoreRecordCollection>(clientMock.Object, TestCollectionName, null)) + .Setup(x => x.CreateVectorStoreRecordCollection>(It.IsAny(), TestCollectionName, null)) .Returns(collectionMock.Object); var sut = new PostgresVectorStore(clientMock.Object, new() { VectorStoreCollectionFactory = factoryMock.Object }); From 1749adbca392223e4cb0df882810fa9a756fcb45 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 25 Oct 2024 16:04:48 -0400 Subject: [PATCH 38/62] Support more enumerable types --- .../PostgresConstants.cs | 25 ++-- .../PostgresGenericDataModelMapper.cs | 2 +- .../PostgresServiceCollectionExtensions.cs | 113 +++--------------- ...PostgresVectorStoreCollectionSqlBuilder.cs | 6 +- .../PostgresVectorStoreRecordCollection.cs | 2 +- .../PostgresVectorStoreRecordMapper.cs | 7 +- ...ostgresVectorStoreRecordPropertyMapping.cs | 68 ++++++++--- ...ostgresVectorStoreRecordCollectionTests.cs | 80 +++++++++++++ 8 files changed, 175 insertions(+), 128 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 6fa2d5a468d5..890beabd546b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -43,15 +43,22 @@ internal static class PostgresConstants typeof(Guid), typeof(Guid?), typeof(byte[]), - typeof(List), - typeof(List), - typeof(List), - typeof(List), - typeof(List), - typeof(List), - typeof(List), - typeof(List), - typeof(List), + ]; + + /// A of types that enumerable data properties on the provided model may use as their element types. + public static readonly HashSet SupportedEnumerableDataElementTypes = + [ + typeof(bool), + typeof(short), + typeof(int), + typeof(long), + typeof(float), + typeof(double), + typeof(decimal), + typeof(string), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(Guid), ]; /// A of types that vector properties on the provided model may have. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs index eb5b857185a4..9679bb22f44e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresGenericDataModelMapper.cs @@ -22,7 +22,7 @@ public PostgresGenericDataModelMapper(VectorStoreRecordPropertyReader propertyRe this._propertyReader = propertyReader; // Validate property types. - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: false); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); } public Dictionary MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs index aa7274aae249..8e46d21beb12 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -13,83 +13,7 @@ namespace Microsoft.SemanticKernel; public static class PostgresServiceCollectionExtensions { /// - /// Register a with the specified service ID and where NpgsqlDataSource is constructed using the provided parameters. - /// - /// The to register the on. - /// Postgres database connection string. - /// The schema to use. - /// An optional service id to use as the service key. - /// The service collection. - public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string connectionString, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) - { - string? npgsqlServiceId = serviceId == null ? default : $"{serviceId}_NpgsqlDataSource"; - // Register NpgsqlDataSource to ensure proper disposal. - services.AddKeyedSingleton( - npgsqlServiceId, - (sp, obj) => - { - NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionString); - dataSourceBuilder.UseVector(); - return dataSourceBuilder.Build(); - }); - - services.AddKeyedSingleton( - serviceId, - (sp, obj) => - { - var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); - return new PostgresVectorStoreDbClient(dataSource, schema); - }); - - return services; - } - - /// - /// Register a with the specified service ID and where NpgsqlDataSource is passed in as parameter. - /// - /// The to register the on. - /// The data source to use. - /// The schema to use. - /// An optional service id to use as the service key. - /// The service collection. - public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) - { - // Since we are not constructing the data source, add the IVectorStore as transient, since we - // cannot make assumptions about how client is being managed. - services.AddKeyedTransient( - serviceId, - (sp, obj) => - { - return new PostgresVectorStoreDbClient(dataSource, schema); - }); - - return services; - } - - /// - /// Register a with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. - /// - /// The to register the on. - /// The schema to use. - /// An optional service id to use as the service key. - /// The service collection. - public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCollection services, string schema = PostgresConstants.DefaultSchema, string? serviceId = default) - { - // Since we are not constructing the client, add the IVectorStore as transient, since we - // cannot make assumptions about how client is being managed. - services.AddKeyedTransient( - serviceId, - (sp, obj) => - { - var dataSource = sp.GetRequiredService(); - return new PostgresVectorStoreDbClient(dataSource, schema); - }); - - return services; - } - - /// - /// Register a Postgres with the specified service ID and where is retrieved from the dependency injection container. + /// Register a Postgres with the specified service ID and where the NpgsqlDataSource is retrieved from the dependency injection container. /// /// The to register the on. /// Optional options to further configure the . @@ -97,17 +21,17 @@ public static IServiceCollection AddPostgresVectorStoreDbClient(this IServiceCol /// The service collection. public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, PostgresVectorStoreOptions? options = default, string? serviceId = default) { - // Since we are not constructing the client, add the IVectorStore as transient, since we - // cannot make assumptions about how client is being managed. + // Since we are not constructing the data source, add the IVectorStore as transient, since we + // cannot make assumptions about how data source is being managed. services.AddKeyedTransient( serviceId, (sp, obj) => { - var client = sp.GetRequiredService(); + var dataSource = sp.GetRequiredService(); var selectedOptions = options ?? sp.GetService(); return new PostgresVectorStore( - client, + dataSource, selectedOptions); }); @@ -115,7 +39,7 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection } /// - /// Register a Postgres with the specified service ID and where is constructed using the provided parameters. + /// Register a Postgres with the specified service ID and where an NpgsqlDataSource is constructed using the provided parameters. /// /// The to register the on. /// Postgres database connection string. @@ -140,11 +64,10 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection (sp, obj) => { var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); - var client = new PostgresVectorStoreDbClient(dataSource); var selectedOptions = options ?? sp.GetService(); return new PostgresVectorStore( - client, + dataSource, selectedOptions); }); @@ -152,7 +75,7 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection } /// - /// Register a Postgres with the specified service ID and where is constructed using the NpgsqlDataSource. + /// Register a Postgres with the specified service ID and where an NpgsqlDataSource is passed in. /// /// The to register the on. /// The data source to use. @@ -167,11 +90,10 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection serviceId, (sp, obj) => { - var client = new PostgresVectorStoreDbClient(dataSource); var selectedOptions = options ?? sp.GetService(); return new PostgresVectorStore( - client, + dataSource, selectedOptions); }); @@ -180,7 +102,7 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection /// /// Register a Postgres and with the specified service ID - /// and where the Postgres is retrieved from the dependency injection container. + /// and where the NpgsqlDataSource is retrieved from the dependency injection container. /// /// The type of the key. /// The type of the record. @@ -200,10 +122,10 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection { - var PostgresClient = sp.GetRequiredService(); + var dataSource = sp.GetRequiredService(); var selectedOptions = options ?? sp.GetService>(); - return (new PostgresVectorStoreRecordCollection(PostgresClient, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); @@ -213,7 +135,7 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection /// Register a Postgres and with the specified service ID - /// and where the Postgres is constructed using the provided parameters. + /// and where the NpgsqlDataSource is constructed using the provided parameters. /// /// The type of the key. /// The type of the record. @@ -247,10 +169,8 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection { var dataSource = sp.GetRequiredKeyedService(npgsqlServiceId); - var client = new PostgresVectorStoreDbClient(dataSource); - var selectedOptions = options ?? sp.GetService>(); - return (new PostgresVectorStoreRecordCollection(client, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, options) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); @@ -260,7 +180,7 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection /// Register a Postgres and with the specified service ID - /// and where the Postgres is constructed using the data source. + /// and where the NpgsqlDataSource is passed in. /// /// The type of the key. /// The type of the record. @@ -284,10 +204,9 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection { - var client = new PostgresVectorStoreDbClient(dataSource); var selectedOptions = options ?? sp.GetService>(); - return (new PostgresVectorStoreRecordCollection(client, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; + return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; }); AddVectorizedSearch(services, serviceId); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index d3119b5fce0a..fda9cac233af 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -174,11 +174,11 @@ public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName ON CONFLICT (""{keyColumn}"") DO UPDATE SET {updateColumnsWithParams};"; - var parameters = row.ToDictionary(kvp => $"@{kvp.Key}", kvp => kvp.Value); - return new PostgresSqlCommandInfo(commandText) { - Parameters = columns.Select(c => new NpgsqlParameter() { Value = row[c] ?? DBNull.Value }).ToList() + Parameters = columns.Select(c => + PostgresVectorStoreRecordPropertyMapping.GetNpgsqlParameter(row[c]) + ).ToList() }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 5efa5ce771de..d01b5a63f678 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -95,7 +95,7 @@ internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client // Validate property types. this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: true); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); // Resolve mapper. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs index 2bf87e11b645..e656678413cc 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordMapper.cs @@ -28,7 +28,7 @@ public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyR this._propertyReader.VerifyHasParameterlessConstructor(); // Validate property types. - this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, supportEnumerable: false); + this._propertyReader.VerifyDataProperties(PostgresConstants.SupportedDataTypes, PostgresConstants.SupportedEnumerableDataElementTypes); this._propertyReader.VerifyVectorProperties(PostgresConstants.SupportedVectorTypes); } @@ -43,7 +43,10 @@ public PostgresVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyR // Add data properties foreach (var property in this._propertyReader.DataPropertiesInfo) { - properties.Add(this._propertyReader.GetStoragePropertyName(property.Name), property.GetValue(dataModel)); + properties.Add( + this._propertyReader.GetStoragePropertyName(property.Name), + property.GetValue(dataModel) + ); } // Add vector properties diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 008a64cca73e..235f3047e52a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; using Npgsql; @@ -67,13 +68,11 @@ internal static float[] GetOrCreateArray(ReadOnlyMemory memory) => return null; } - // Check if the type is a List - if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + // Check if the type implements IEnumerable + if (propertyType.IsGenericType && propertyType.GetInterfaces().Any(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>))) { - var elementType = propertyType.GetGenericArguments()[0]; - var list = (IEnumerable)reader.GetValue(propertyIndex); - // Convert list to the correct element type - return ConvertList(list, elementType); + var enumerable = (IEnumerable)reader.GetValue(propertyIndex); + return VectorStoreRecordMapping.CreateEnumerable(enumerable.Cast(), propertyType); } return propertyType switch @@ -141,8 +140,8 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property return (pgType, isNullable); } - // Handle lists - if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + // Handle enumerables + if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(propertyType)) { Type elementType = propertyType.GetGenericArguments()[0]; var underlyingPgType = GetPostgresTypeName(elementType); @@ -175,17 +174,56 @@ public static (string PgType, bool IsNullable) GetPgVectorTypeName(VectorStoreRe return ($"VECTOR({vectorProperty.Dimensions})", Nullable.GetUnderlyingType(vectorProperty.PropertyType) != null); } - // Helper method to convert lists - private static object ConvertList(IEnumerable list, Type elementType) + public static NpgsqlParameter GetNpgsqlParameter(object? value) { - var listType = typeof(List<>).MakeGenericType(elementType); - var convertedList = (IList)Activator.CreateInstance(listType)!; + if (value == null) + { + return new NpgsqlParameter() { Value = DBNull.Value }; + } + + // If it's already a List, return it directly + if (value is IList list) + { + return new NpgsqlParameter() { Value = list }; + } + + // If it's an IEnumerable, but not a List, convert it to a List + if (value is IEnumerable enumerable && !(value is string)) + { + // Use a helper method to convert to a List if possible + return new NpgsqlParameter() { Value = ConvertToListIfNecessary(enumerable) }; + } + + // Return the value directly if it's not IEnumerable + return new NpgsqlParameter() { Value = value }; + } + + // Helper method to convert an IEnumerable to a List if necessary + private static object ConvertToListIfNecessary(IEnumerable enumerable) + { + // Get an enumerator to manually iterate over the collection + var enumerator = enumerable.GetEnumerator(); + + // Check if the collection is empty by attempting to move to the first element + if (!enumerator.MoveNext()) + { + return enumerable; // Return the original enumerable if it's empty + } + + // Determine the type of the first element + var firstItem = enumerator.Current; + var itemType = firstItem?.GetType() ?? typeof(object); + + // Create a strongly-typed List based on the type of the first element + var typedList = Activator.CreateInstance(typeof(List<>).MakeGenericType(itemType)) as IList; + typedList!.Add(firstItem); // Add the first element to the typed list - foreach (var item in list) + // Continue iterating through the rest of the enumerable and add items to the list + while (enumerator.MoveNext()) { - convertedList.Add(Convert.ChangeType(item, elementType)); + typedList.Add(enumerator.Current); } - return convertedList; + return typedList; } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index df45044aff54..76eec311c584 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -394,6 +394,60 @@ public async Task VectorizedSearchWithAnyTagFilterReturnsValidResultsAsync() Assert.Equal([1, 3], ids); } + [Fact] + public async Task ItCanUpsertAndGetEnumerableTypesAsync() + { + // Arrange + var sut = fixture.GetCollection("UpsertAndGetEnumerableTypes"); + + await sut.CreateCollectionAsync(); + + var record = new RecordWithEnumerables + { + Id = 1, + ListInts = new() { 1, 2, 3 }, + CollectionInts = new HashSet() { 4, 5, 6 }, + EnumerableInts = [7, 8, 9], + ReadOnlyCollectionInts = new List { 10, 11, 12 }, + ReadOnlyListInts = new List { 13, 14, 15 } + }; + + // Act + await sut.UpsertAsync(record); + + var getResult = await sut.GetAsync(1); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(1, getResult!.Id); + Assert.NotNull(getResult.ListInts); + Assert.Equal(3, getResult.ListInts!.Count); + Assert.Equal(1, getResult.ListInts![0]); + Assert.Equal(2, getResult.ListInts![1]); + Assert.Equal(3, getResult.ListInts![2]); + Assert.NotNull(getResult.CollectionInts); + Assert.Equal(3, getResult.CollectionInts!.Count); + Assert.Contains(4, getResult.CollectionInts); + Assert.Contains(5, getResult.CollectionInts); + Assert.Contains(6, getResult.CollectionInts); + Assert.NotNull(getResult.EnumerableInts); + Assert.Equal(3, getResult.EnumerableInts!.Count()); + Assert.Equal(7, getResult.EnumerableInts.ElementAt(0)); + Assert.Equal(8, getResult.EnumerableInts.ElementAt(1)); + Assert.Equal(9, getResult.EnumerableInts.ElementAt(2)); + Assert.NotNull(getResult.ReadOnlyCollectionInts); + Assert.Equal(3, getResult.ReadOnlyCollectionInts!.Count); + var readOnlyCollectionIntsList = getResult.ReadOnlyCollectionInts.ToList(); + Assert.Equal(10, readOnlyCollectionIntsList[0]); + Assert.Equal(11, readOnlyCollectionIntsList[1]); + Assert.Equal(12, readOnlyCollectionIntsList[2]); + Assert.NotNull(getResult.ReadOnlyListInts); + Assert.Equal(3, getResult.ReadOnlyListInts!.Count); + Assert.Equal(13, getResult.ReadOnlyListInts[0]); + Assert.Equal(14, getResult.ReadOnlyListInts[1]); + Assert.Equal(15, getResult.ReadOnlyListInts[2]); + } + #region private ================================================================================== private static VectorStoreRecordDefinition GetVectorStoreRecordDefinition(string distanceFunction = DistanceFunction.CosineDistance) => new() @@ -440,6 +494,32 @@ private static DateTimeOffset TruncateMilliseconds(DateTimeOffset dateTimeOffset return new DateTimeOffset(dateTimeOffset.Ticks - (dateTimeOffset.Ticks % TimeSpan.TicksPerSecond), dateTimeOffset.Offset); } +#pragma warning disable CA1812 + private sealed class RecordWithEnumerables + { + [VectorStoreRecordKey] + public int Id { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] + public ReadOnlyMemory? Embedding { get; set; } + + [VectorStoreRecordData] + public List? ListInts { get; set; } + + [VectorStoreRecordData] + public ICollection? CollectionInts { get; set; } + + [VectorStoreRecordData] + public IEnumerable? EnumerableInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyCollection? ReadOnlyCollectionInts { get; set; } + + [VectorStoreRecordData] + public IReadOnlyList? ReadOnlyListInts { get; set; } + } +#pragma warning restore CA1812 + #endregion } From 86486d71d167000788271d2326c91ba93317fe40 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 13:02:09 -0400 Subject: [PATCH 39/62] Refactor to support default + transactions Defaults to Hnsw for index creation, but user can specify IndexKind.None to avoid index creation. Also wrap the table and index creation in a transaction so the entire operation is atomic. --- ...PostgresVectorStoreCollectionSqlBuilder.cs | 6 ++- .../IPostgresVectorStoreDbClient.cs | 10 +--- .../PostgresConstants.cs | 5 ++ .../PostgresSqlCommandInfo.cs | 6 ++- ...PostgresVectorStoreCollectionSqlBuilder.cs | 11 ++--- .../PostgresVectorStoreDbClient.cs | 49 ++++++++++++------- .../PostgresVectorStoreRecordCollection.cs | 17 ------- ...ostgresVectorStoreRecordPropertyMapping.cs | 39 +++++++++++++++ ...resVectorStoreCollectionSqlBuilderTests.cs | 16 ++---- ...esVectorStoreRecordPropertyMappingTests.cs | 46 +++++++++++++++++ .../RecordDefinition/IndexKind.cs | 5 ++ 11 files changed, 146 insertions(+), 64 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index 8d8364eb903d..d130d2f13b44 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -43,9 +43,11 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The vector property to create an index for. + /// The name of the vector column. + /// The kind of index to create. + /// The distance function to use for the index. /// The built SQL command info. - PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, VectorStoreRecordVectorProperty vectorProperty); + PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction); /// /// Builds a SQL command to drop a table in the Postgres vector store. diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 7f047d0d9f07..59aa9829c568 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -34,7 +34,7 @@ internal interface IPostgresVectorStoreDbClient /// A group of tables. IAsyncEnumerable GetTablesAsync(CancellationToken cancellationToken = default); /// - /// Create a table. + /// Create a table. Also creates an index on vector columns if the table has vector properties defined. /// /// The name assigned to a table of entries. /// The properties of the record definition that define the table. @@ -43,14 +43,6 @@ internal interface IPostgresVectorStoreDbClient /// Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default); - /// - /// Create a vector index. - /// - /// The name assigned to a table of entries. - /// The vector property to create an index for. - /// The to monitor for cancellation requests. The default is . - Task CreateVectorIndexAsync(string tableName, VectorStoreRecordVectorProperty vectorProperty, CancellationToken cancellationToken = default); - /// /// Drop a table. /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 890beabd546b..7f76ccef5857 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -80,4 +80,9 @@ internal static class PostgresConstants /// The default distance function. public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; + + public static readonly Dictionary IndexMaxDimensions = new() + { + { IndexKind.Hnsw, 2000 }, + }; } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs index 68380d37ca2a..fb520188b84b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresSqlCommandInfo.cs @@ -35,9 +35,13 @@ public PostgresSqlCommandInfo(string commandText, List? paramet /// Converts this instance to an . /// [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] - public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection) + public NpgsqlCommand ToNpgsqlCommand(NpgsqlConnection connection, NpgsqlTransaction? transaction = null) { NpgsqlCommand cmd = connection.CreateCommand(); + if (transaction != null) + { + cmd.Transaction = transaction; + } cmd.CommandText = this.CommandText; if (this.Parameters != null) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index fda9cac233af..eda009b6a1ae 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -122,17 +122,16 @@ public PostgresSqlCommandInfo BuildCreateTableCommand(string schema, string tabl } /// - public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, VectorStoreRecordVectorProperty vectorProperty) + public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, string tableName, string vectorColumnName, string indexKind, string distanceFunction) { - var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; // Only support creating HNSW index creation through the connector. - var indexTypeName = vectorProperty.IndexKind switch + var indexTypeName = indexKind switch { IndexKind.Hnsw => "hnsw", - _ => throw new NotSupportedException($"Index kind '{vectorProperty.IndexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") + _ => throw new NotSupportedException($"Index kind '{indexKind}' is not supported for table creation. If you need to create an index of this type, please do so manually. Only HNSW indexes are supported through the vector store.") }; - var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + distanceFunction ??= PostgresConstants.DefaultDistanceFunction; // Default to Cosine distance var indexOps = distanceFunction switch { @@ -141,7 +140,7 @@ public PostgresSqlCommandInfo BuildCreateVectorIndexCommand(string schema, strin DistanceFunction.DotProductSimilarity => "vector_ip_ops", DistanceFunction.EuclideanDistance => "vector_l2_ops", DistanceFunction.ManhattanDistance => "vector_l1_ops", - _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") + _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") }; var indexName = $"{tableName}_{vectorColumnName}_index"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 095ab2d6aaa3..a183a9f32de4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -22,17 +22,16 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; [System.Diagnostics.CodeAnalysis.SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "We need to build the full table name using schema and collection, it does not support parameterized passing.")] internal class PostgresVectorStoreDbClient(NpgsqlDataSource dataSource, string schema = PostgresConstants.DefaultSchema) : IPostgresVectorStoreDbClient { - private readonly NpgsqlDataSource _dataSource = dataSource; private readonly string _schema = schema; private IPostgresVectorStoreCollectionSqlBuilder _sqlBuilder = new PostgresVectorStoreCollectionSqlBuilder(); - public NpgsqlDataSource DataSource => this._dataSource; + public NpgsqlDataSource DataSource { get; } = dataSource; /// public async Task DoesTableExistsAsync(string tableName, CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { @@ -51,7 +50,7 @@ public async Task DoesTableExistsAsync(string tableName, CancellationToken /// public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { @@ -68,20 +67,34 @@ public async IAsyncEnumerable GetTablesAsync([EnumeratorCancellation] Ca /// public async Task CreateTableAsync(string tableName, IReadOnlyList properties, bool ifNotExists = true, CancellationToken cancellationToken = default) { + // Prepare the SQL commands. var commandInfo = this._sqlBuilder.BuildCreateTableCommand(this._schema, tableName, properties, ifNotExists); - await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); - } + var createIndexCommands = + PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(properties) + .Select(index => + this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, index.column, index.kind, index.function) + ); - /// - public async Task CreateVectorIndexAsync(string tableName, VectorStoreRecordVectorProperty vectorProperty, CancellationToken cancellationToken = default) - { - if (string.IsNullOrEmpty(vectorProperty.IndexKind)) + // Execute the commands in a transaction. + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + await using (connection) { - return; - } + var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false); + await using (transaction) + { + using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection, transaction); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - var commandInfo = this._sqlBuilder.BuildCreateVectorIndexCommand(this._schema, tableName, vectorProperty); - await this.ExecuteNonQueryAsync(commandInfo, cancellationToken).ConfigureAwait(false); + foreach (var createIndexCommand in createIndexCommands) + { + using NpgsqlCommand indexCmd = createIndexCommand.ToNpgsqlCommand(connection, transaction); + await indexCmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); + } + } } /// @@ -108,7 +121,7 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable public async Task?> GetAsync(string tableName, TKey key, IReadOnlyList properties, bool includeVectors = false, CancellationToken cancellationToken = default) where TKey : notnull { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { @@ -128,7 +141,7 @@ public async Task UpsertBatchAsync(string tableName, IEnumerable> GetBatchAsync(string tableName, IEnumerable keys, IReadOnlyList properties, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) where TKey : notnull { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { @@ -154,7 +167,7 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { @@ -218,7 +231,7 @@ internal void SetSqlBuilder(IPostgresVectorStoreCollectionSqlBuilder sqlBuilder) private async Task ExecuteNonQueryAsync(PostgresSqlCommandInfo commandInfo, CancellationToken cancellationToken) { - NpgsqlConnection connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index d01b5a63f678..dbf55a58130b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -284,23 +284,6 @@ public Task> VectorizedSearchAsync(TVector private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) { await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken).ConfigureAwait(false); - // Create indexes for vector properties. - foreach (var vectorProperty in this._propertyReader.VectorProperties) - { - var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; - - // Ensure the dimensionality of the vector is supported for indexing. - if (indexKind == IndexKind.Hnsw) - { - if (vectorProperty.Dimensions > 2000) - { - throw new NotSupportedException( - $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, which is not supported by the HNSW index. The maximum number of dimensions supported by the HNSW index is 2000. Index not created." - ); - } - } - await this._client.CreateVectorIndexAsync(this.CollectionName, vectorProperty, cancellationToken).ConfigureAwait(false); - } } /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 235f3047e52a..6f8e96ca2d1d 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -198,6 +198,45 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) return new NpgsqlParameter() { Value = value }; } + /// + /// Returns information about vector indexes to create, validating that the dimensions of the vector are supported. + /// + /// The properties of the vector store record. + /// A list of tuples containing the column name, index kind, and distance function for each vector property. + /// + /// The user can specify an index kind of "None" to prevent the creation of an index. Otherwise the default index kind is Hnsw. + /// + public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) + { + var vectorIndexesToCreate = new List<(string column, string kind, string function)>(); + foreach (var property in properties) + { + if (property is VectorStoreRecordVectorProperty vectorProperty) + { + var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; // Default to Hnsw + var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; + + // The user can specify an index kind of "None" to prevent the creation of an index. + if (indexKind != IndexKind.None) + { + // Ensure the dimensionality of the vector is supported for indexing. + if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) + { + throw new NotSupportedException( + $"The provided vector property {vectorProperty.DataModelPropertyName} has {vectorProperty.Dimensions} dimensions, " + + $"which is not supported by the {indexKind} index. The maximum number of dimensions supported by the {indexKind} index " + + $"is {maxDimensions}. Please reduce the number of dimensions or use a different index." + ); + } + + vectorIndexesToCreate.Add((vectorColumnName, indexKind, distanceFunction)); + } + } + } + return vectorIndexesToCreate; + } + // Helper method to convert an IEnumerable to a List if necessary private static object ConvertToListIfNecessary(IEnumerable enumerable) { diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index bfe15e025b8f..ef5fe4a32b16 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -76,28 +76,22 @@ public void TestBuildCreateTableCommand(bool ifNotExists) } [Theory] - [InlineData(IndexKind.Hnsw, null)] - [InlineData(IndexKind.IvfFlat, null)] [InlineData(IndexKind.Hnsw, DistanceFunction.EuclideanDistance)] + [InlineData(IndexKind.IvfFlat, DistanceFunction.DotProductSimilarity)] [InlineData(IndexKind.Hnsw, DistanceFunction.CosineDistance)] - public void TestBuildCreateIndexCommand(string indexKind, string? distanceFunction) + public void TestBuildCreateIndexCommand(string indexKind, string distanceFunction) { var builder = new PostgresVectorStoreCollectionSqlBuilder(); - var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) - { - Dimensions = 10, - IndexKind = indexKind, - DistanceFunction = distanceFunction, - }; + var vectorColumn = "embedding1"; if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorProperty)); + Assert.Throws(() => builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction)); return; } - var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorProperty); + var cmdInfo = builder.BuildCreateVectorIndexCommand("public", "testcollection", vectorColumn, indexKind, distanceFunction); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", cmdInfo.CommandText); diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs index 5005901c6ad6..315dc6483995 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; using Pgvector; using Xunit; @@ -98,4 +99,49 @@ public void GetPropertyValueReturnsCorrectNullableValue() Assert.Equal(expectedValue, isNullable); } } + + [Fact] + public void GetVectorIndexInfoReturnsCorrectValues() + { + // Arrange + List vectorProperties = [ + new VectorStoreRecordVectorProperty("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, + new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.None, Dimensions = 3000 }, + new VectorStoreRecordVectorProperty("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, + ]; + + // Act + var indexInfo = PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo(vectorProperties); + + // Assert + Assert.Equal(2, indexInfo.Count); + foreach (var (columnName, indexKind, distanceFunction) in indexInfo) + { + if (columnName == "vector1") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.CosineDistance, distanceFunction); + } + else if (columnName == "vector3") + { + Assert.Equal(IndexKind.Hnsw, indexKind); + Assert.Equal(DistanceFunction.ManhattanDistance, distanceFunction); + } + else + { + Assert.Fail("Unexpected column name"); + } + } + } + + [Theory] + [InlineData(IndexKind.Hnsw, 3000)] + public void GetVectorIndexInfoReturnsThrowsForInvalidDimensions(string indexKind, int dimensions) + { + // Arrange + var vectorProperty = new VectorStoreRecordVectorProperty("vector", typeof(ReadOnlyMemory?)) { IndexKind = indexKind, Dimensions = dimensions }; + + // Act & Assert + Assert.Throws(() => PostgresVectorStoreRecordPropertyMapping.GetVectorIndexInfo([vectorProperty])); + } } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs index 512b51e54c20..6afb24bcef5e 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs @@ -52,4 +52,9 @@ public static class IndexKind /// Dynamic index allows to automatically switch from to indexes. /// public const string Dynamic = nameof(Dynamic); + + /// + /// No index is used. + /// + public const string None = nameof(None); } From b9b4a44b0f97d8e9f111173716ea4e9796aa7e78 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 13:51:33 -0400 Subject: [PATCH 40/62] Fix issue with converting readonly array on upsert --- .../PostgresVectorStoreRecordPropertyMapping.cs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 6f8e96ca2d1d..62baa7db617b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -181,16 +181,16 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) return new NpgsqlParameter() { Value = DBNull.Value }; } - // If it's already a List, return it directly - if (value is IList list) - { - return new NpgsqlParameter() { Value = list }; - } - - // If it's an IEnumerable, but not a List, convert it to a List + // If it's an IEnumerable, use reflection to determine if it needs to be converted to a list if (value is IEnumerable enumerable && !(value is string)) { - // Use a helper method to convert to a List if possible + Type propertyType = value.GetType(); + if (propertyType.IsGenericType && propertyType.GetGenericTypeDefinition() == typeof(List<>)) + { + // If it's already a List, return it directly + return new NpgsqlParameter() { Value = value }; + } + return new NpgsqlParameter() { Value = ConvertToListIfNecessary(enumerable) }; } From 97ef60a0ff5b0d047c1e0f9ea93cc7b83aba724e Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 14:00:36 -0400 Subject: [PATCH 41/62] Fix SLN merge error --- dotnet/SK-dotnet.sln | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 8b05ef98c60c..a69b4e9cb8fb 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -401,6 +401,7 @@ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.AotTests", "src\SemanticKernel.AotTests\SemanticKernel.AotTests.csproj", "{39EAB599-742F-417D-AF80-95F90376BB18}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Postgres.UnitTests", "src\Connectors\Connectors.Postgres.UnitTests\Connectors.Postgres.UnitTests.csproj", "{232E1153-6366-4175-A982-D66B30AAD610}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Process.Utilities.UnitTests", "src\Experimental\Process.Utilities.UnitTests\Process.Utilities.UnitTests.csproj", "{DAC54048-A39A-4739-8307-EA5A291F2EA0}" EndProject Global @@ -1195,11 +1196,8 @@ Global {E82B640C-1704-430D-8D71-FD8ED3695468} = {5A7028A7-4DDF-4E4F-84A9-37CE8F8D7E89} {6ECFDF04-2237-4A85-B114-DAA34923E9E6} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {39EAB599-742F-417D-AF80-95F90376BB18} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} -<<<<<<< HEAD {232E1153-6366-4175-A982-D66B30AAD610} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} -======= {DAC54048-A39A-4739-8307-EA5A291F2EA0} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} ->>>>>>> upstream/main EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} From 81e180515f2c8e84277802d2f95ebdd8411a1015 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 14:55:30 -0400 Subject: [PATCH 42/62] Improve error handling --- .../PostgresVectorStoreRecordCollection.cs | 162 +++++++++++++----- 1 file changed, 116 insertions(+), 46 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index dbf55a58130b..b95e3c95c214 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -117,31 +117,43 @@ internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client } /// - public async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) { - return await this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken).ConfigureAwait(false); + const string OperationName = "DoesTableExists"; + return this.RunOperationAsync(OperationName, () => + this._client.DoesTableExistsAsync(this.CollectionName, cancellationToken) + ); } /// - public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - await this.InternalCreateCollectionAsync(false, cancellationToken).ConfigureAwait(false); + const string OperationName = "CreateCollection"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(false, cancellationToken) + ); } /// public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { - return this.InternalCreateCollectionAsync(true, cancellationToken); + const string OperationName = "CreateCollectionIfNotExists"; + return this.RunOperationAsync(OperationName, () => + this.InternalCreateCollectionAsync(true, cancellationToken) + ); } /// public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - return this._client.DeleteTableAsync(this.CollectionName, cancellationToken); + const string OperationName = "DeleteCollection"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteTableAsync(this.CollectionName, cancellationToken) + ); } /// - public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "Upsert"; @@ -157,8 +169,12 @@ public async Task UpsertAsync(TRecord record, UpsertRecordOptions? options Verify.NotNull(keyObj); TKey key = (TKey)keyObj!; - await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); - return key; + return this.RunOperationAsync(OperationName, async () => + { + await this._client.UpsertAsync(this.CollectionName, storageModel, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + return key; + } + ); } /// @@ -174,35 +190,39 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record var keys = storageModels.Select(model => model[this._propertyReader.KeyPropertyStoragePropertyName]!).ToList(); - await this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken).ConfigureAwait(false); + await this.RunOperationAsync(OperationName, () => + this._client.UpsertBatchAsync(this.CollectionName, storageModels, this._propertyReader.KeyPropertyStoragePropertyName, cancellationToken) + ).ConfigureAwait(false); foreach (var key in keys) { yield return (TKey)key!; } } /// - public async Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task GetAsync(TKey key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - var operationName = "Get"; + const string OperationName = "Get"; Verify.NotNull(key); bool includeVectors = options?.IncludeVectors is true; - var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); - - if (row is null) { return default; } + return this.RunOperationAsync(OperationName, async () => + { + var row = await this._client.GetAsync(this.CollectionName, key, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false); - return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - operationName, - () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + if (row is null) { return default; } + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); + }); } /// public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var operationName = "GetBatch"; + const string OperationName = "GetBatch"; Verify.NotNull(keys); @@ -213,26 +233,34 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get yield return VectorStoreErrorHandler.RunModelConversion( DatabaseName, this.CollectionName, - operationName, + OperationName, () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); } } /// - public async Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + public Task DeleteAsync(TKey key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) { - await this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken).ConfigureAwait(false); + const string OperationName = "Delete"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, key, cancellationToken) + ); } /// public Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) { - return this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken); + const string OperationName = "DeleteBatch"; + return this.RunOperationAsync(OperationName, () => + this._client.DeleteBatchAsync(this.CollectionName, this._propertyReader.KeyPropertyStoragePropertyName, keys, cancellationToken) + ); } /// public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { + const string OperationName = "VectorizedSearch"; + Verify.NotNull(vector); var vectorType = vector.GetType(); @@ -260,30 +288,38 @@ public Task> VectorizedSearchAsync(TVector // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. var limit = searchOptions.Top + searchOptions.Skip; - var results = this._client.GetNearestMatchesAsync( - this.CollectionName, - this._propertyReader.RecordDefinition.Properties, - vectorProperty, - pgVector, - searchOptions.Top, - searchOptions.Filter, - searchOptions.Skip, - searchOptions.IncludeVectors, - cancellationToken - ).Select(result => - { - var record = this._mapper.MapFromStorageToDataModel( - result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }); - - return new VectorSearchResult(record, result.Distance); - }, cancellationToken); - - return Task.FromResult(new VectorSearchResults(results)); + return this.RunOperationAsync(OperationName, () => + { + var results = this._client.GetNearestMatchesAsync( + this.CollectionName, + this._propertyReader.RecordDefinition.Properties, + vectorProperty, + pgVector, + searchOptions.Top, + searchOptions.Filter, + searchOptions.Skip, + searchOptions.IncludeVectors, + cancellationToken + ).Select(result => + { + var record = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel( + result.Row, new StorageToDataModelMapperOptions() { IncludeVectors = searchOptions.IncludeVectors }) + ); + + return new VectorSearchResult(record, result.Distance); + }, cancellationToken); + + return Task.FromResult(new VectorSearchResults(results)); + }); } - private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) + private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken = default) { - await this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken).ConfigureAwait(false); + return this._client.CreateTableAsync(this.CollectionName, this._propertyReader.RecordDefinition.Properties, ifNotExists, cancellationToken); } /// @@ -312,4 +348,38 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT // If vector property is not provided in options, return first vector property from schema. return this._propertyReader.VectorProperty; } + + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } } From a5872609b1aae61c7aa1018bb842f79489b6f535 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 15:08:25 -0400 Subject: [PATCH 43/62] Avoid CA1859 in test class e.g. Change type of property 'CollectionInts' from 'System.Collections.Generic.ICollection?' to 'System.Collections.Generic.HashSet?' for improved performance --- .../Postgres/PostgresVectorStoreRecordCollectionTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 76eec311c584..8b7bb9a02ff3 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -494,7 +494,7 @@ private static DateTimeOffset TruncateMilliseconds(DateTimeOffset dateTimeOffset return new DateTimeOffset(dateTimeOffset.Ticks - (dateTimeOffset.Ticks % TimeSpan.TicksPerSecond), dateTimeOffset.Offset); } -#pragma warning disable CA1812 +#pragma warning disable CA1812, CA1859 private sealed class RecordWithEnumerables { [VectorStoreRecordKey] @@ -518,7 +518,7 @@ private sealed class RecordWithEnumerables [VectorStoreRecordData] public IReadOnlyList? ReadOnlyListInts { get; set; } } -#pragma warning restore CA1812 +#pragma warning restore CA1812, CA1859 #endregion From e8fe800bc2f73e47474a4a2bd69411d04532d14b Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 15:49:29 -0400 Subject: [PATCH 44/62] Account for ngpsql missing func in .net std 2.0 --- .../PostgresVectorStoreDbClient.cs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index a183a9f32de4..5ef18cc88fdf 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -80,8 +80,13 @@ public async Task CreateTableAsync(string tableName, IReadOnlyList Date: Mon, 28 Oct 2024 16:21:31 -0400 Subject: [PATCH 45/62] Fix servicecollection tests --- .../PostgresServiceCollectionExtensionsTests.cs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs index c6409c698792..9104edd7c125 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -5,6 +5,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Postgres; using Moq; +using Npgsql; using Xunit; namespace SemanticKernel.Connectors.Postgres.UnitTests; @@ -20,7 +21,8 @@ public sealed class PostgresServiceCollectionExtensionsTests public void AddVectorStoreRegistersClass() { // Arrange - this._serviceCollection.AddSingleton(Mock.Of()); + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); // Act this._serviceCollection.AddPostgresVectorStore(); @@ -37,7 +39,8 @@ public void AddVectorStoreRegistersClass() public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() { // Arrange - this._serviceCollection.AddSingleton(Mock.Of()); + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); // Act this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); @@ -58,7 +61,8 @@ public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass() { // Arrange - this._serviceCollection.AddSingleton(Mock.Of()); + using var dataSource = NpgsqlDataSource.Create("Host=fake;"); + this._serviceCollection.AddSingleton(dataSource); // Act this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); From 0fc76f66df281b9228561117a677bffef612936a Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 16:21:58 -0400 Subject: [PATCH 46/62] Logic for dimension max moved and tested elsewhere --- ...ostgresVectorStoreRecordCollectionTests.cs | 48 ------------------- 1 file changed, 48 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs index 3ca268033aa0..0533ab28c3f3 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordCollectionTests.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Postgres; using Moq; @@ -128,53 +127,6 @@ public async Task UpsertRecordAsyncProducesExpectedClientCallAsync() Assert.Equal(4.0f, embedding[3]); } - [Fact] - public async Task CreateCollectionAsyncLogsWarningWhenDimensionsTooLargeAsync() - { - // Arrange - var recordDefinition = new VectorStoreRecordDefinition - { - Properties = [ - new VectorStoreRecordKeyProperty("HotelId", typeof(int)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Tags", typeof(List)), - new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 2001, IndexKind = IndexKind.Hnsw, DistanceFunction = DistanceFunction.ManhattanDistance } - ] - }; - var mockLogger = new Mock>>>(); - mockLogger.Setup(x => x.Log( - LogLevel.Warning, - It.IsAny(), - It.IsAny(), - It.IsAny(), - It.IsAny>())); - var sut = new PostgresVectorStoreRecordCollection>( - this._postgresClientMock.Object, - TestCollectionName, - logger: mockLogger.Object, - options: new PostgresVectorStoreRecordCollectionOptions> { VectorStoreRecordDefinition = recordDefinition } - ); - - this._postgresClientMock.Setup(x => x.CreateTableAsync(TestCollectionName, It.IsAny>(), It.IsAny(), It.IsAny())).Returns(Task.CompletedTask); - - // Act - await sut.CreateCollectionAsync(cancellationToken: this._testCancellationToken); - - // Assert - mockLogger.Verify( - x => x.Log( - LogLevel.Warning, - It.IsAny(), - It.Is((v, t) => v.ToString()!.Contains("2001")), - It.IsAny(), - It.IsAny>()), - Times.Once); - } - [Fact] public async Task CollectionExistsReturnsValidResultAsync() { From 266310b83fb0cbf6b5628554537e77a798bb36d6 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Mon, 28 Oct 2024 16:32:57 -0400 Subject: [PATCH 47/62] Remove unused using statement --- .../PostgresServiceCollectionExtensionsTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs index 9104edd7c125..159b7312927f 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -4,7 +4,6 @@ using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Postgres; -using Moq; using Npgsql; using Xunit; From 08f110ced41a0b2d47ee7f47d3f3e6eb8436b958 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Tue, 29 Oct 2024 10:13:40 -0400 Subject: [PATCH 48/62] Remove logger from PostgresVectorStoreRecordCollection --- .../PostgresVectorStoreRecordCollection.cs | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index b95e3c95c214..92d767764882 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -6,8 +6,6 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.VectorData; using Npgsql; @@ -35,9 +33,6 @@ public sealed class PostgresVectorStoreRecordCollection : IVector // Optional configuration options for this class. private readonly PostgresVectorStoreRecordCollectionOptions _options; - /// The logger to use for logging. - private readonly ILogger> _logger; - /// A helper to access property information for the current data model and record definition. private readonly VectorStoreRecordPropertyReader _propertyReader; @@ -53,9 +48,8 @@ public sealed class PostgresVectorStoreRecordCollection : IVector /// The data source to use for connecting to the database. /// The name of the collection. /// Optional configuration options for this class. - /// The logger to use for logging. - public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, - ILogger>? logger = null) : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options, logger) + public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) + : this(new PostgresVectorStoreDbClient(dataSource), collectionName, options) { } @@ -65,12 +59,10 @@ public PostgresVectorStoreRecordCollection(NpgsqlDataSource dataSource, string c /// The client to use for interacting with the database. /// The name of the collection. /// Optional configuration options for this class. - /// The logger to use for logging. /// /// This constructor is internal. It allows internal code to create an instance of this class with a custom client. /// - internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default, - ILogger>? logger = null) + internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client, string collectionName, PostgresVectorStoreRecordCollectionOptions? options = default) { // Verify. Verify.NotNull(client); @@ -91,7 +83,6 @@ internal PostgresVectorStoreRecordCollection(IPostgresVectorStoreDbClient client SupportsMultipleKeys = false, SupportsMultipleVectors = true, }); - this._logger = logger ?? NullLogger>.Instance; // Validate property types. this._propertyReader.VerifyKeyProperties(PostgresConstants.SupportedKeyTypes); From 5b44a80e327857743940aa37a06dd91948960eba Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Wed, 30 Oct 2024 13:38:28 -0400 Subject: [PATCH 49/62] Use Flat instead of None index kind --- .../PostgresVectorStoreRecordPropertyMapping.cs | 6 +++--- .../PostgresVectorStoreRecordPropertyMappingTests.cs | 2 +- .../VectorData.Abstractions/RecordDefinition/IndexKind.cs | 5 ----- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 62baa7db617b..3bee0d490575 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -204,7 +204,7 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) /// The properties of the vector store record. /// A list of tuples containing the column name, index kind, and distance function for each vector property. /// - /// The user can specify an index kind of "None" to prevent the creation of an index. Otherwise the default index kind is Hnsw. + /// The user can specify an index kind of "Flat" to prevent the creation of an index. Otherwise the default index kind is Hnsw. /// public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) { @@ -217,8 +217,8 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; // Default to Hnsw var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; - // The user can specify an index kind of "None" to prevent the creation of an index. - if (indexKind != IndexKind.None) + // The user can specify an index kind of "Flat" to prevent the creation of an index. + if (indexKind != IndexKind.Flat) { // Ensure the dimensionality of the vector is supported for indexing. if (PostgresConstants.IndexMaxDimensions.TryGetValue(indexKind, out int maxDimensions) && vectorProperty.Dimensions > maxDimensions) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs index 315dc6483995..0631cc2c0df4 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordPropertyMappingTests.cs @@ -106,7 +106,7 @@ public void GetVectorIndexInfoReturnsCorrectValues() // Arrange List vectorProperties = [ new VectorStoreRecordVectorProperty("vector1", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 1000 }, - new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.None, Dimensions = 3000 }, + new VectorStoreRecordVectorProperty("vector2", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Flat, Dimensions = 3000 }, new VectorStoreRecordVectorProperty("vector3", typeof(ReadOnlyMemory?)) { IndexKind = IndexKind.Hnsw, Dimensions = 900, DistanceFunction = DistanceFunction.ManhattanDistance }, ]; diff --git a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs index 6afb24bcef5e..512b51e54c20 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/RecordDefinition/IndexKind.cs @@ -52,9 +52,4 @@ public static class IndexKind /// Dynamic index allows to automatically switch from to indexes. /// public const string Dynamic = nameof(Dynamic); - - /// - /// No index is used. - /// - public const string None = nameof(None); } From 24577a0136a199641a20081ca8145d90180930ca Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 11:17:25 -0400 Subject: [PATCH 50/62] Remove unnecessary overloads --- .../PostgresServiceCollectionExtensions.cs | 62 ------------------- 1 file changed, 62 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs index 8e46d21beb12..983b8e7db443 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresServiceCollectionExtensions.cs @@ -74,32 +74,6 @@ public static IServiceCollection AddPostgresVectorStore(this IServiceCollection return services; } - /// - /// Register a Postgres with the specified service ID and where an NpgsqlDataSource is passed in. - /// - /// The to register the on. - /// The data source to use. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// The service collection. - public static IServiceCollection AddPostgresVectorStore(this IServiceCollection services, NpgsqlDataSource dataSource, PostgresVectorStoreOptions? options = default, string? serviceId = default) - { - // Since we are not constructing the data source, add the IVectorStore as transient, since we - // cannot make assumptions about how data source is being managed. - services.AddKeyedTransient( - serviceId, - (sp, obj) => - { - var selectedOptions = options ?? sp.GetService(); - - return new PostgresVectorStore( - dataSource, - selectedOptions); - }); - - return services; - } - /// /// Register a Postgres and with the specified service ID /// and where the NpgsqlDataSource is retrieved from the dependency injection container. @@ -178,42 +152,6 @@ public static IServiceCollection AddPostgresVectorStoreRecordCollection - /// Register a Postgres and with the specified service ID - /// and where the NpgsqlDataSource is passed in. - /// - /// The type of the key. - /// The type of the record. - /// The to register the on. - /// The name of the collection. - /// The data source to use. - /// Optional options to further configure the . - /// An optional service id to use as the service key. - /// Service collection. - public static IServiceCollection AddPostgresVectorStoreRecordCollection( - this IServiceCollection services, - string collectionName, - NpgsqlDataSource dataSource, - PostgresVectorStoreRecordCollectionOptions? options = default, - string? serviceId = default) - where TKey : notnull - { - // Since we are not constructing the data source, add the IVectorStore as transient, since we - // cannot make assumptions about how data source is being managed. - services.AddKeyedTransient>( - serviceId, - (sp, obj) => - { - var selectedOptions = options ?? sp.GetService>(); - - return (new PostgresVectorStoreRecordCollection(dataSource, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; - }); - - AddVectorizedSearch(services, serviceId); - - return services; - } - /// /// Also register the with the given as a . /// From 60d65120b29424a7618829583d947e61d24f29c1 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 11:54:22 -0400 Subject: [PATCH 51/62] Change tests to be true to name --- .../PostgresGenericDataModelMapperTests.cs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs index 99d9e6074428..79b7eeae82eb 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -42,11 +42,11 @@ public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() { // Arrange - var definition = GetRecordDefinition(); + var definition = GetRecordDefinition(); var propertyReader = GetPropertyReader>(definition); - var dataModel = GetGenericDataModel("key"); + var dataModel = GetGenericDataModel(1); - var mapper = new PostgresGenericDataModelMapper(propertyReader); + var mapper = new PostgresGenericDataModelMapper(propertyReader); // Act var result = mapper.MapFromDataToStorageModel(dataModel); @@ -114,22 +114,22 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool var storageModel = new Dictionary { - ["Key"] = "key", + ["Key"] = 1, ["StringProperty"] = "Value1", ["IntProperty"] = 5, ["FloatVector"] = storageVector }; - var definition = GetRecordDefinition(); - var propertyReader = GetPropertyReader>(definition); + var definition = GetRecordDefinition(); + var propertyReader = GetPropertyReader>(definition); - var mapper = new PostgresGenericDataModelMapper(propertyReader); + var mapper = new PostgresGenericDataModelMapper(propertyReader); // Act var result = mapper.MapFromStorageToDataModel(storageModel, new() { IncludeVectors = includeVectors }); // Assert - Assert.Equal("key", result.Key); + Assert.Equal(1, result.Key); Assert.Equal("Value1", result.Data["StringProperty"]); Assert.Equal(5, result.Data["IntProperty"]); From 5a66a13ad9658df9316fbbde1713e3cb60daa510 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 11:58:30 -0400 Subject: [PATCH 52/62] Remove reduntant key type based test --- ...ostgresServiceCollectionExtensionsTests.cs | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs index 159b7312927f..f667d86eee30 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresServiceCollectionExtensionsTests.cs @@ -35,29 +35,7 @@ public void AddVectorStoreRegistersClass() } [Fact] - public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() - { - // Arrange - using var dataSource = NpgsqlDataSource.Create("Host=fake;"); - this._serviceCollection.AddSingleton(dataSource); - - // Act - this._serviceCollection.AddPostgresVectorStoreRecordCollection("testcollection"); - - var serviceProvider = this._serviceCollection.BuildServiceProvider(); - - // Assert - var collection = serviceProvider.GetRequiredService>(); - Assert.NotNull(collection); - Assert.IsType>(collection); - - var vectorizedSearch = serviceProvider.GetRequiredService>(); - Assert.NotNull(vectorizedSearch); - Assert.IsType>(vectorizedSearch); - } - - [Fact] - public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass() + public void AddVectorStoreRecordCollectionRegistersClass() { // Arrange using var dataSource = NpgsqlDataSource.Create("Host=fake;"); From 581b6ab90d9b9bea944651dfbf40e03d1b7d2a3b Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 13:46:11 -0400 Subject: [PATCH 53/62] Remove unnecessary overloads --- .../Postgres/PostgresVectorStoreFixture.cs | 16 +++++++--------- .../Memory/Postgres/PostgresVectorStoreTests.cs | 3 +-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index 2346a0eab9af..6251af573dc0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -6,7 +6,7 @@ using Docker.DotNet; using Docker.DotNet.Models; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Postgres; using Npgsql; using Xunit; @@ -63,22 +63,24 @@ public PostgresVectorStoreFixture() private string _connectionString = null!; private string _databaseName = null!; - /// Gets the Kernel that holds the vector store. - public Kernel Kernel { get; private set; } - /// Gets the manually created vector store record definition for our test model. public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } /// Gets the manually created vector store record definition for our test model. public VectorStoreRecordDefinition HotelWithGuidIdVectorStoreRecordDefinition { get; private set; } + /// + /// Gets a vector store to use for tests. + /// + public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!); + public IVectorStoreRecordCollection GetCollection( string collectionName, VectorStoreRecordDefinition? recordDefinition = default) where TKey : notnull where TRecord : class { - var vectorStore = this.Kernel.GetRequiredService(); + var vectorStore = this.VectorStore; return vectorStore.GetCollection(collectionName, recordDefinition); } @@ -103,10 +105,6 @@ public async Task InitializeAsync() this._dataSource = dataSourceBuilder.Build(); - var kernelBuilder = Kernel.CreateBuilder(); - kernelBuilder.Services.AddPostgresVectorStore(this._dataSource); - this.Kernel = kernelBuilder.Build(); - // Wait for the postgres container to be ready and create the test database using the initial data source. var initialDataSource = NpgsqlDataSource.Create(this._connectionString); using (initialDataSource) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs index 8291591872f7..3eb2c02d54c6 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreTests.cs @@ -2,7 +2,6 @@ using System.Linq; using System.Threading.Tasks; -using Microsoft.Extensions.VectorData; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -14,7 +13,7 @@ public class PostgresVectorStoreTests(PostgresVectorStoreFixture fixture) public async Task ItCanGetAListOfExistingCollectionNamesAsync() { // Arrange - var sut = fixture.Kernel.GetRequiredService(); + var sut = fixture.VectorStore; // Setup var collection = sut.GetCollection>("VS_TEST_HOTELS"); From 494a0d433b3c94fece5375ff20f403436b65a8bf Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 14:03:11 -0400 Subject: [PATCH 54/62] Better error handling for IAsyncEnumerable --- .../PostgresConstants.cs | 3 + .../PostgresVectorStore.cs | 6 +- .../PostgresVectorStoreRecordCollection.cs | 38 ++++++------ .../PostgresVectorStoreUtils.cs | 59 +++++++++++++++++++ .../PostgresVectorStoreTests.cs | 34 +++++++++++ 5 files changed, 121 insertions(+), 19 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 7f76ccef5857..0cb73c2f3ade 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -8,6 +8,9 @@ namespace Microsoft.SemanticKernel.Connectors.Postgres; internal static class PostgresConstants { + /// The name of this database for telemetry purposes. + public const string DatabaseName = "Postgres"; + /// A of types that a key on the provided model may have. public static readonly HashSet SupportedKeyTypes = [ diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs index d44b7e44d690..99bbc8e320b5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStore.cs @@ -43,7 +43,11 @@ internal PostgresVectorStore(IPostgresVectorStoreDbClient postgresDbClient, Post /// public IAsyncEnumerable ListCollectionNamesAsync(CancellationToken cancellationToken = default) { - return this._postgresClient.GetTablesAsync(cancellationToken); + const string OperationName = "ListCollectionNames"; + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._postgresClient.GetTablesAsync(cancellationToken), + OperationName + ); } /// diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index 92d767764882..f93dcf482fb8 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -21,9 +21,6 @@ public sealed class PostgresVectorStoreRecordCollection : IVector #pragma warning restore CA1711 // Identifiers should not have incorrect suffix where TKey : notnull { - /// The name of this database for telemetry purposes. - private const string DatabaseName = "Postgres"; - /// public string CollectionName { get; } @@ -149,7 +146,7 @@ public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = nul const string OperationName = "Upsert"; var storageModel = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + PostgresConstants.DatabaseName, this.CollectionName, OperationName, () => this._mapper.MapFromDataToStorageModel(record)); @@ -174,7 +171,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record const string OperationName = "UpsertBatch"; var storageModels = records.Select(record => VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + PostgresConstants.DatabaseName, this.CollectionName, OperationName, () => this._mapper.MapFromDataToStorageModel(record))).ToList(); @@ -203,7 +200,7 @@ await this.RunOperationAsync(OperationName, () => if (row is null) { return default; } return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + PostgresConstants.DatabaseName, this.CollectionName, OperationName, () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); @@ -211,7 +208,7 @@ await this.RunOperationAsync(OperationName, () => } /// - public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "GetBatch"; @@ -219,14 +216,19 @@ public async IAsyncEnumerable GetBatchAsync(IEnumerable keys, Get bool includeVectors = options?.IncludeVectors is true; - await foreach (var row in this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken).ConfigureAwait(false)) - { - yield return VectorStoreErrorHandler.RunModelConversion( - DatabaseName, - this.CollectionName, - OperationName, - () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })); - } + return PostgresVectorStoreUtils.WrapAsyncEnumerableAsync( + this._client.GetBatchAsync(this.CollectionName, keys, this._propertyReader.RecordDefinition.Properties, includeVectors, cancellationToken) + .Select(row => + VectorStoreErrorHandler.RunModelConversion( + PostgresConstants.DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(row, new() { IncludeVectors = includeVectors })), + cancellationToken + ), + OperationName, + this.CollectionName + ); } /// @@ -294,7 +296,7 @@ public Task> VectorizedSearchAsync(TVector ).Select(result => { var record = VectorStoreErrorHandler.RunModelConversion( - DatabaseName, + PostgresConstants.DatabaseName, this.CollectionName, OperationName, () => this._mapper.MapFromStorageToDataModel( @@ -350,7 +352,7 @@ private async Task RunOperationAsync(string operationName, Func operation) { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreType = PostgresConstants.DatabaseName, CollectionName = this.CollectionName, OperationName = operationName }; @@ -367,7 +369,7 @@ private async Task RunOperationAsync(string operationName, Func> o { throw new VectorStoreOperationException("Call to vector store failed.", ex) { - VectorStoreType = DatabaseName, + VectorStoreType = PostgresConstants.DatabaseName, CollectionName = this.CollectionName, OperationName = operationName }; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs new file mode 100644 index 000000000000..a89f5fab2c12 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal static class PostgresVectorStoreUtils +{ + /// + /// Wraps an in an that will throw a + /// if an exception is thrown while iterating over the original enumerator. + /// + /// The type of the items in the async enumerable. + /// The async enumerable to wrap. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// An async enumerable that will throw a if an exception is thrown while iterating over the original enumerator. + public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumerable asyncEnumerable, string operationName, string? collectionName = null) + { + var enumerator = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator(); + + var nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + while (nextResult.more) + { + yield return nextResult.item; + nextResult = await GetNextAsync(enumerator, operationName, collectionName).ConfigureAwait(false); + } + } + + /// + /// Helper method to get the next index name from the enumerator with a try catch around the move next call to convert + /// exceptions to . + /// + /// The enumerator to get the next result from. + /// The name of the operation being performed. + /// The name of the collection being operated on. + /// A value indicating whether there are more results and the current string if true. + public static async Task<(T item, bool more)> GetNextAsync(ConfiguredCancelableAsyncEnumerable.Enumerator enumerator, string operationName, string? collectionName = null) + { + try + { + var more = await enumerator.MoveNextAsync(); + return (enumerator.Current, more); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = PostgresConstants.DatabaseName, + CollectionName = collectionName, + OperationName = operationName + }; + } + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs index 7feecff44f55..e83e43f6963d 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -93,6 +93,40 @@ public async Task ListCollectionNamesCallsSDKAsync() Assert.Equal(expectedCollections, actualList); } + [Fact] + public async Task ListCollectionNamesThrowsCorrectExcpetionAsync() + { + // Arrange + var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; + + this._postgresClientMock + .Setup(client => client.GetTablesAsync(CancellationToken.None)) + .Returns(this.ThrowingAsyncEnumerableAsync); + + var sut = new PostgresVectorStore(this._postgresClientMock.Object); + + // Act. + var actual = sut.ListCollectionNamesAsync(this._testCancellationToken); + + // Assert + Assert.NotNull(actual); + await Assert.ThrowsAsync(async () => await actual.ToListAsync()); + } + + private async IAsyncEnumerable ThrowingAsyncEnumerableAsync() + { + int itemIndex = 0; + await foreach (var item in new List { "item1", "item2", "item3" }.ToAsyncEnumerable()) + { + if (itemIndex == 1) + { + throw new InvalidOperationException("Test exception"); + } + yield return item; + itemIndex++; + } + } + public sealed class SinglePropsModel { [VectorStoreRecordKey] From 5f1988936cb482339b4ddce5b9017b36aa268428 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 14:23:15 -0400 Subject: [PATCH 55/62] Default to Flat (no index) instead of Hnsw --- .../Connectors.Memory.Postgres/PostgresConstants.cs | 3 ++- .../PostgresVectorStoreRecordPropertyMapping.cs | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs index 0cb73c2f3ade..f8784890e83a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresConstants.cs @@ -79,7 +79,8 @@ internal static class PostgresConstants public const string DistanceColumnName = "sk_pg_distance"; /// The default index kind. - public const string DefaultIndexKind = IndexKind.Hnsw; + /// Defaults to "Flat", which means no indexing. + public const string DefaultIndexKind = IndexKind.Flat; /// The default distance function. public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 3bee0d490575..0b36f2003bf5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -204,7 +204,7 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) /// The properties of the vector store record. /// A list of tuples containing the column name, index kind, and distance function for each vector property. /// - /// The user can specify an index kind of "Flat" to prevent the creation of an index. Otherwise the default index kind is Hnsw. + /// The default index kind is "Flat", which prevents the creation of an index. /// public static List<(string column, string kind, string function)> GetVectorIndexInfo(IReadOnlyList properties) { @@ -214,10 +214,11 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) if (property is VectorStoreRecordVectorProperty vectorProperty) { var vectorColumnName = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; - var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; // Default to Hnsw + var indexKind = vectorProperty.IndexKind ?? PostgresConstants.DefaultIndexKind; var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; - // The user can specify an index kind of "Flat" to prevent the creation of an index. + // Index kind of "Flat" to prevent the creation of an index. This is the default behavior. + // Otherwise, the index will be created with the specified index kind and distance function, if supported. if (indexKind != IndexKind.Flat) { // Ensure the dimensionality of the vector is supported for indexing. From 62ac8ebb577b738594ced7b17431d017344ed89c Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 14:40:29 -0400 Subject: [PATCH 56/62] Add enumerable to record mapper test --- .../PostgresVectorStoreRecordMapperTests.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs index 201c93f53db8..11dfd2ecd564 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreRecordMapperTests.cs @@ -31,6 +31,7 @@ public void MapFromDataToStorageModelWithStringKeyReturnsValidStorageModel() Assert.Equal("key", result["Key"]); Assert.Equal("Value1", result["StringProperty"]); Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); Vector? vector = result["FloatVector"] as Vector; @@ -55,6 +56,7 @@ public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() Assert.Equal((ulong)1, result["Key"]); Assert.Equal("Value1", result["StringProperty"]); Assert.Equal(5, result["IntProperty"]); + Assert.Equal(new List { "Value2", "Value3" }, result["StringArray"]); var vector = result["FloatVector"] as Vector; @@ -76,7 +78,8 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool ["Key"] = "key", ["StringProperty"] = "Value1", ["IntProperty"] = 5, - ["FloatVector"] = storageVector + ["StringArray"] = new List { "Value2", "Value3" }, + ["FloatVector"] = storageVector, }; var definition = GetRecordDefinition(); @@ -91,6 +94,7 @@ public void MapFromStorageToDataModelWithStringKeyReturnsValidGenericModel(bool Assert.Equal("key", result.Key); Assert.Equal("Value1", result.StringProperty); Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); if (includeVectors) { @@ -117,6 +121,7 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool ["Key"] = (ulong)1, ["StringProperty"] = "Value1", ["IntProperty"] = 5, + ["StringArray"] = new List { "Value2", "Value3" }, ["FloatVector"] = storageVector }; @@ -132,6 +137,7 @@ public void MapFromStorageToDataModelWithNumericKeyReturnsValidGenericModel(bool Assert.Equal((ulong)1, result.Key); Assert.Equal("Value1", result.StringProperty); Assert.Equal(5, result.IntProperty); + Assert.Equal(new List { "Value2", "Value3" }, result.StringArray); if (includeVectors) { @@ -155,6 +161,7 @@ private static VectorStoreRecordDefinition GetRecordDefinition() new VectorStoreRecordKeyProperty("Key", typeof(TKey)), new VectorStoreRecordDataProperty("StringProperty", typeof(string)), new VectorStoreRecordDataProperty("IntProperty", typeof(int)), + new VectorStoreRecordDataProperty("StringArray", typeof(IEnumerable)), new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), } }; @@ -167,6 +174,7 @@ private static TestRecord GetDataModel(TKey key) Key = key, StringProperty = "Value1", IntProperty = 5, + StringArray = new List { "Value2", "Value3" }, FloatVector = new ReadOnlyMemory([1.1f, 2.2f, 3.3f, 4.4f]) }; } @@ -193,6 +201,9 @@ private sealed class TestRecord [VectorStoreRecordData] public int? IntProperty { get; set; } + [VectorStoreRecordData] + public IEnumerable? StringArray { get; set; } + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance)] public ReadOnlyMemory? FloatVector { get; set; } } From 364b592afc45c79fe3f45a4066858220b453fbcc Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 15:24:20 -0400 Subject: [PATCH 57/62] Remove unused fixture properties --- .../Postgres/PostgresVectorStoreFixture.cs | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index 6251af573dc0..c3b5d5b89f72 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -29,30 +29,6 @@ public PostgresVectorStoreFixture() { using var dockerClientConfiguration = new DockerClientConfiguration(); this._client = dockerClientConfiguration.CreateClient(); - this.HotelVectorStoreRecordDefinition = new VectorStoreRecordDefinition - { - Properties = new List - { - new VectorStoreRecordKeyProperty("HotelId", typeof(ulong)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("HotelCode", typeof(int)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { IsFilterable = true, StoragePropertyName = "parking_is_included" }, - new VectorStoreRecordDataProperty("HotelRating", typeof(float)) { IsFilterable = true }, - new VectorStoreRecordDataProperty("Tags", typeof(List)), - new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } - } - }; - this.HotelWithGuidIdVectorStoreRecordDefinition = new VectorStoreRecordDefinition - { - Properties = new List - { - new VectorStoreRecordKeyProperty("HotelId", typeof(Guid)), - new VectorStoreRecordDataProperty("HotelName", typeof(string)) { IsFilterable = true, IsFullTextSearchable = true }, - new VectorStoreRecordDataProperty("Description", typeof(string)), - new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, DistanceFunction = DistanceFunction.ManhattanDistance } - } - }; } /// @@ -63,12 +39,6 @@ public PostgresVectorStoreFixture() private string _connectionString = null!; private string _databaseName = null!; - /// Gets the manually created vector store record definition for our test model. - public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } - - /// Gets the manually created vector store record definition for our test model. - public VectorStoreRecordDefinition HotelWithGuidIdVectorStoreRecordDefinition { get; private set; } - /// /// Gets a vector store to use for tests. /// From bf58caba5363ac0470d2dcdb6f067945a65edf7d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Thu, 31 Oct 2024 15:24:37 -0400 Subject: [PATCH 58/62] Test StoragePropertyName in sql builder tests --- .../PostgresVectorStoreCollectionSqlBuilderTests.cs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index ef5fe4a32b16..675843a78c18 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -36,7 +36,7 @@ public void TestBuildCreateTableCommand(bool ifNotExists) new VectorStoreRecordDataProperty("code", typeof(int)), new VectorStoreRecordDataProperty("rating", typeof(float?)), new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { @@ -59,7 +59,7 @@ public void TestBuildCreateTableCommand(bool ifNotExists) Assert.Contains("\"code\" INTEGER NOT NULL", cmdInfo.CommandText); Assert.Contains("\"rating\" REAL", cmdInfo.CommandText); Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); - Assert.Contains("\"parking_is_included\" BOOLEAN NOT NULL", cmdInfo.CommandText); + Assert.Contains("\"free_parking\" BOOLEAN NOT NULL", cmdInfo.CommandText); Assert.Contains("\"tags\" TEXT[]", cmdInfo.CommandText); Assert.Contains("\"description\" TEXT", cmdInfo.CommandText); Assert.Contains("\"embedding1\" VECTOR(10) NOT NULL", cmdInfo.CommandText); @@ -245,7 +245,7 @@ public void TestBuildGetCommand() new VectorStoreRecordDataProperty("code", typeof(int)), new VectorStoreRecordDataProperty("rating", typeof(float?)), new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { @@ -267,6 +267,8 @@ public void TestBuildGetCommand() // Assert Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); + Assert.Contains("\"embedding1\"", cmdInfo.CommandText); Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); Assert.Contains("WHERE \"id\" = $1", cmdInfo.CommandText); @@ -288,7 +290,7 @@ public void TestBuildGetBatchCommand() new VectorStoreRecordDataProperty("code", typeof(int)), new VectorStoreRecordDataProperty("rating", typeof(float?)), new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), + new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)) { StoragePropertyName = "free_parking" }, new VectorStoreRecordDataProperty("tags", typeof(List)), new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) { @@ -310,6 +312,8 @@ public void TestBuildGetBatchCommand() // Assert Assert.Contains("SELECT", cmdInfo.CommandText); + Assert.Contains("\"code\"", cmdInfo.CommandText); + Assert.Contains("\"free_parking\"", cmdInfo.CommandText); Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); Assert.Contains("WHERE \"id\" = ANY($1)", cmdInfo.CommandText); Assert.NotNull(cmdInfo.Parameters); From aa592de5ef1bff66fd74be4dd6253d9e7f055612 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 1 Nov 2024 11:27:53 -0400 Subject: [PATCH 59/62] Remove dynamic from integration test --- ...ostgresVectorStoreRecordCollectionTests.cs | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 8b7bb9a02ff3..55935b1961d2 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -106,29 +106,30 @@ public async Task CollectionCanUpsertAndGetAsync() } } - [Theory] - [InlineData(typeof(short), (short)3)] - [InlineData(typeof(int), 5)] - [InlineData(typeof(long), 7L)] - [InlineData(typeof(string), "key1")] - [InlineData(typeof(Guid), null)] - public async Task ItCanGetAndDeleteRecordAsync(Type idType, object? key) - { - if (idType == typeof(Guid)) + public static IEnumerable ItCanGetAndDeleteRecordParameters => + new List { - key = Guid.NewGuid(); - } + new object[] { typeof(short), (short)3 }, + new object[] { typeof(int), 5 }, + new object[] { typeof(long), 7L }, + new object[] { typeof(string), "key1" }, + new object[] { typeof(Guid), Guid.NewGuid() } + }; + [Theory] + [MemberData(nameof(ItCanGetAndDeleteRecordParameters))] + public async Task ItCanGetAndDeleteRecordAsync(Type idType, TKey? key) + { // Arrange var collectionName = "DeleteRecord"; - dynamic sut = this.GetCollection(idType, collectionName); + var sut = this.GetCollection(idType, collectionName); await sut.CreateCollectionAsync(); try { - dynamic record = this.CreateRecord(idType, key!); - dynamic recordKey = record.HotelId; + var record = this.CreateRecord(idType, key!); + var recordKey = record.HotelId; var upsertResult = await sut.UpsertAsync(record); var getResult = await sut.GetAsync(recordKey); @@ -473,10 +474,10 @@ private dynamic GetCollection(Type idType, string collectionName) return genericMethod.Invoke(fixture, [collectionName, null])!; } - private dynamic CreateRecord(Type idType, object key) + private PostgresHotel CreateRecord(Type idType, TKey key) { var recordType = typeof(PostgresHotel<>).MakeGenericType(idType); - dynamic record = Activator.CreateInstance(recordType, key)!; + var record = (PostgresHotel)Activator.CreateInstance(recordType, key)!; record.HotelName = "Hotel 1"; record.HotelCode = 1; record.ParkingIncluded = true; From 9a3b216b86f059d1d527b167a73f132461759bdb Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 1 Nov 2024 11:58:39 -0400 Subject: [PATCH 60/62] Add test to read from manually inserted record --- .../Postgres/PostgresVectorStoreFixture.cs | 8 +++++ ...ostgresVectorStoreRecordCollectionTests.cs | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs index c3b5d5b89f72..5888a513ace0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreFixture.cs @@ -44,6 +44,14 @@ public PostgresVectorStoreFixture() /// public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!); + /// + /// Get a database connection + /// + public NpgsqlConnection GetConnection() + { + return this._dataSource!.OpenConnection(); + } + public IVectorStoreRecordCollection GetCollection( string collectionName, VectorStoreRecordDefinition? recordDefinition = default) diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 55935b1961d2..ef25dfd6d439 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Npgsql; using Xunit; namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; @@ -220,6 +221,41 @@ public async Task ItCanUpsertExistingRecordAsync() Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), getResult.DescriptionEmbedding.Value.ToArray()); } + [Fact] + public async Task ItCanReadManuallyInsertedRecordAsync() + { + const string CollectionName = "ItCanReadManuallyInsertedRecordAsync"; + // Arrange + var sut = fixture.GetCollection>(CollectionName); + await sut.CreateCollectionAsync().ConfigureAwait(true); + Assert.True(await sut.CollectionExistsAsync().ConfigureAwait(true)); + await using (var connection = fixture.GetConnection()) + { + using NpgsqlCommand cmd = connection.CreateCommand(); + cmd.CommandText = @$" + INSERT INTO public.""{CollectionName}"" ( + ""HotelId"", ""HotelName"", ""HotelCode"", ""HotelRating"", ""parking_is_included"", ""Tags"", ""Description"", ""DescriptionEmbedding"" + ) VALUES ( + 215, 'Devine Lorraine', 215, 5, false, ARRAY['historic', 'philly'], 'An iconic building on broad street', '[10,20,30,40]' + );"; + await cmd.ExecuteNonQueryAsync().ConfigureAwait(true); + } + + // Act + var getResult = await sut.GetAsync(215, new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(getResult); + Assert.Equal(215, getResult!.HotelId); + Assert.Equal("Devine Lorraine", getResult.HotelName); + Assert.Equal(215, getResult.HotelCode); + Assert.Equal(5, getResult.HotelRating); + Assert.False(getResult.ParkingIncluded); + Assert.Equal(new List { "historic", "philly" }, getResult.Tags); + Assert.Equal("An iconic building on broad street", getResult.Description); + Assert.Equal([10f, 20f, 30f, 40f], getResult.DescriptionEmbedding!.Value.ToArray()); + } + [Fact] public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() { From b0370751512481c0441a59dbf582117c86e9194d Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 1 Nov 2024 12:58:11 -0400 Subject: [PATCH 61/62] Formatting, spelling --- .../Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs | 4 ++-- .../Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs | 2 +- .../Postgres/PostgresVectorStoreRecordCollectionTests.cs | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs index a89f5fab2c12..27fa7181bdc5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreUtils.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Generic; @@ -56,4 +56,4 @@ public static async IAsyncEnumerable WrapAsyncEnumerableAsync(IAsyncEnumer }; } } -} \ No newline at end of file +} diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs index e83e43f6963d..b11d6a81963f 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreTests.cs @@ -94,7 +94,7 @@ public async Task ListCollectionNamesCallsSDKAsync() } [Fact] - public async Task ListCollectionNamesThrowsCorrectExcpetionAsync() + public async Task ListCollectionNamesThrowsCorrectExceptionAsync() { // Arrange var expectedCollections = new List { "fake-collection-1", "fake-collection-2", "fake-collection-3" }; diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index ef25dfd6d439..7e3ae3ad9392 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -236,7 +236,7 @@ public async Task ItCanReadManuallyInsertedRecordAsync() INSERT INTO public.""{CollectionName}"" ( ""HotelId"", ""HotelName"", ""HotelCode"", ""HotelRating"", ""parking_is_included"", ""Tags"", ""Description"", ""DescriptionEmbedding"" ) VALUES ( - 215, 'Devine Lorraine', 215, 5, false, ARRAY['historic', 'philly'], 'An iconic building on broad street', '[10,20,30,40]' + 215, 'Divine Lorraine', 215, 5, false, ARRAY['historic', 'philly'], 'An iconic building on broad street', '[10,20,30,40]' );"; await cmd.ExecuteNonQueryAsync().ConfigureAwait(true); } @@ -247,7 +247,7 @@ public async Task ItCanReadManuallyInsertedRecordAsync() // Assert Assert.NotNull(getResult); Assert.Equal(215, getResult!.HotelId); - Assert.Equal("Devine Lorraine", getResult.HotelName); + Assert.Equal("Divine Lorraine", getResult.HotelName); Assert.Equal(215, getResult.HotelCode); Assert.Equal(5, getResult.HotelRating); Assert.False(getResult.ParkingIncluded); From c2937e0cba65ee5e9a88bb8e0e57af0946cb0b63 Mon Sep 17 00:00:00 2001 From: Rob Emanuele Date: Fri, 1 Nov 2024 17:04:26 -0400 Subject: [PATCH 62/62] Fix test. --- .../PostgresGenericDataModelMapperTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs index 79b7eeae82eb..d9e97fc6b855 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresGenericDataModelMapperTests.cs @@ -52,7 +52,7 @@ public void MapFromDataToStorageModelWithNumericKeyReturnsValidStorageModel() var result = mapper.MapFromDataToStorageModel(dataModel); // Assert - Assert.Equal("key", result["Key"]); + Assert.Equal(1, result["Key"]); Assert.Equal("Value1", result["StringProperty"]); Assert.Equal(5, result["IntProperty"]);