< Summary

Information
Class: GistBackend.Handlers.ChromaDbHandler.ChromaDbHandler
Assembly: GistBackend
File(s): /home/runner/work/the-gist-of-it-sec/the-gist-of-it-sec/backend/GistBackend/Handlers/ChromaDbHandler/ChromaDbHandler.cs
Line coverage
89%
Covered lines: 175
Uncovered lines: 21
Coverable lines: 196
Total lines: 350
Line coverage: 89.2%
Branch coverage
67%
Covered branches: 53
Total branches: 78
Branch coverage: 67.9%
Method coverage

Feature is only available for sponsors

Upgrade to PRO version

Metrics

File(s)

/home/runner/work/the-gist-of-it-sec/the-gist-of-it-sec/backend/GistBackend/Handlers/ChromaDbHandler/ChromaDbHandler.cs

#LineLine coverage
 1using System.Net;
 2using System.Text;
 3using System.Text.Json;
 4using GistBackend.Exceptions;
 5using GistBackend.Handlers.AIHandler;
 6using GistBackend.Types;
 7using GistBackend.Utils;
 8using Microsoft.Extensions.Logging;
 9using Microsoft.Extensions.Options;
 10using static GistBackend.Utils.LogEvents;
 11
 12namespace GistBackend.Handlers.ChromaDbHandler;
 13
 14public interface IChromaDbHandler
 15{
 16    Task UpsertEntryAsync(RssEntry entry, string summary, CancellationToken ct);
 17    Task<bool> EnsureGistHasCorrectMetadataAsync(Gist gist, bool disabled, CancellationToken ct);
 18    Task<List<SimilarDocument>> GetReferenceAndScoreOfSimilarEntriesAsync(
 19        string reference, int nResults, IEnumerable<int> disabledFeedIds, CancellationToken ct);
 20    Task<List<SimilarDocument>> SearchSimilarEntriesByQueryAsync(string query, int nResults,
 21        IEnumerable<int> disabledFeedIds, CancellationToken ct);
 22}
 23
 24public class ChromaDbHandler : IChromaDbHandler
 25{
 26    private readonly Uri _chromaDbUri;
 27    private readonly string _tenantName;
 28    private readonly string _databaseName;
 29    private readonly string _collectionName;
 30    private readonly IAIHandler _aiHandler;
 31    private readonly HttpClient _httpClient;
 32    private readonly string _credentialsHeaderName;
 33    private readonly string _serverAuthnCredentials;
 34    private readonly ILogger<ChromaDbHandler>? _logger;
 35
 2836    public ChromaDbHandler(IAIHandler aiHandler,
 2837        HttpClient httpClient,
 2838        IOptions<ChromaDbHandlerOptions> options,
 2839        ILogger<ChromaDbHandler>? logger)
 40    {
 2841        if (string.IsNullOrWhiteSpace(options.Value.Server))
 042            throw new ArgumentException("Server is not set in the options.");
 2843        if (string.IsNullOrWhiteSpace(options.Value.ServerAuthnCredentials))
 044            throw new ArgumentException("Server authentication credentials are not set in the options.");
 2845        _aiHandler = aiHandler;
 2846        _httpClient = httpClient;
 2847        _logger = logger;
 2848        _chromaDbUri = new Uri($"http://{options.Value.Server}:{options.Value.Port}/");
 2849        _credentialsHeaderName = options.Value.CredentialsHeaderName;
 2850        _serverAuthnCredentials = options.Value.ServerAuthnCredentials;
 2851        _tenantName = options.Value.GistsTenantName;
 2852        _databaseName = options.Value.GistsDatabaseName;
 2853        _collectionName = options.Value.GistsCollectionName;
 2854    }
 55
 156    private static readonly string[] IncludeOnGet = ["metadatas", "distances"];
 57
 58    public async Task<List<SimilarDocument>> GetReferenceAndScoreOfSimilarEntriesAsync(string reference,
 59        int nResults, IEnumerable<int> disabledFeedIds, CancellationToken ct)
 60    {
 661        ValidateReference(reference);
 662        var collectionId = await GetOrCreateCollectionAsync(ct);
 663        if (!await EntryExistsByReferenceAsync(reference, ct, collectionId))
 64        {
 065            throw new DatabaseOperationException("Entry does not exist in database");
 66        }
 67
 668        var document = await GetDocumentByReferenceAsync(reference, collectionId, true, false, ct);
 669        var similarDocuments = await GetSimilarDocumentsByEmbeddingsAsync(
 670            document.Embeddings!.Single(),
 671            nResults + 1,  // +1 to exclude the original entry
 672            disabledFeedIds,
 673            ct
 674        );
 75
 76        // Exclude the original entry from the results
 2077        return similarDocuments.Where(referenceAndScore => referenceAndScore.Reference != reference).ToList();
 678    }
 79
 80    public async Task<List<SimilarDocument>> SearchSimilarEntriesByQueryAsync(string query, int nResults,
 81        IEnumerable<int> disabledFeedIds, CancellationToken ct)
 82    {
 083        ValidateSearchQuery(query);
 084        var embeddings = await _aiHandler.GenerateEmbeddingAsync(query, ct);
 085        return await GetSimilarDocumentsByEmbeddingsAsync(embeddings, nResults, disabledFeedIds, ct);
 086    }
 87
 88    private async Task<List<SimilarDocument>> GetSimilarDocumentsByEmbeddingsAsync(float[] embeddings, int nResults,
 89        IEnumerable<int> disabledFeedIds, CancellationToken ct)
 90    {
 691        var collectionId = await GetOrCreateCollectionAsync(ct);
 692        var content = CreateStringContent(new {
 693            QueryEmbeddings = new[] {embeddings},
 694            NResults = nResults,
 695            Where = GenerateWhere(disabledFeedIds),
 696            Include = IncludeOnGet
 697        });
 698        var response = await SendPostRequestAsync(
 699            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/query", content, ct);
 6100        if (response.StatusCode != HttpStatusCode.OK)
 101        {
 0102            throw await CreateDatabaseOperationExceptionAsync("Could not query similar entries", response, ct);
 103        }
 104
 6105        var responseContent = await response.Content.ReadAsStreamAsync(ct);
 6106        var queryResponse =
 6107            await JsonSerializer.DeserializeAsync<QueryResponse>(responseContent, SerializerDefaults.JsonOptions, ct);
 6108        return queryResponse is null
 6109            ? throw new DatabaseOperationException("Could not get similar entries")
 6110            : ExtractReferencesAndScores(queryResponse);
 6111    }
 112
 113    private static Dictionary<string, object> GenerateWhere(IEnumerable<int> disabledFeedIds)
 114    {
 6115        var whereNotDisabled = new Dictionary<string, object> {
 6116            { "disabled", new Dictionary<string, object> { { "$ne", true } } }
 6117        };
 6118        var disabledFeedIdsArray = disabledFeedIds.ToArray();
 6119        if (disabledFeedIdsArray.Length == 0)
 120        {
 4121            return whereNotDisabled;
 122        }
 123
 2124        var whereNotInDisabledFeeds = new Dictionary<string, object> {
 2125            { "feed_id", new Dictionary<string, object> { { "$nin", disabledFeedIdsArray } } }
 2126        };
 2127        return new Dictionary<string, object> {
 2128            { "$and", new[] {
 2129                whereNotDisabled,
 2130                whereNotInDisabledFeeds
 2131            } }
 2132        };
 133    }
 134
 135    private static List<SimilarDocument> ExtractReferencesAndScores(QueryResponse queryResponse) =>
 6136        Enumerable.Range(0, queryResponse.Ids.First().Length).Select(i =>
 14137            new SimilarDocument(
 14138                queryResponse.Metadatas.First()[i].Reference,
 14139                ConvertCosineDistanceToSimilarity(queryResponse.Distances.First()[i])
 14140            ))
 6141            .ToList();
 142
 14143    private static float ConvertCosineDistanceToSimilarity(float distance) => float.Clamp(1 - distance/2, 0, 1);
 144
 145    public async Task UpsertEntryAsync(RssEntry entry, string summary, CancellationToken ct)
 146    {
 27147        ValidateReference(entry.Reference);
 25148        var collectionId = await GetOrCreateCollectionAsync(ct);
 25149        var mode = "add";
 25150        if (await EntryExistsByReferenceAsync(entry.Reference, ct, collectionId))
 151        {
 1152            _logger?.LogInformation(EntryAlreadyExistsInChromaDb,
 1153                "Entry with reference {Reference} already exists in database", entry.Reference);
 1154            mode = "update";
 155        }
 156
 25157        var metadata = new Metadata(entry.Reference, entry.FeedId);
 25158        var embedding = await _aiHandler.GenerateEmbeddingAsync(summary, ct);
 25159        var content = CreateStringContent(new Document([entry.Reference], [metadata], [embedding]));
 25160        var response = await SendPostRequestAsync(
 25161            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/{mode}", content, ct);
 162
 25163        if (mode == "add" && response.StatusCode != HttpStatusCode.Created ||
 25164            mode == "update" && response.StatusCode != HttpStatusCode.OK)
 165        {
 0166            throw await CreateDatabaseOperationExceptionAsync($"Could not {mode} entry", response, ct);
 167        }
 25168        _logger?.LogInformation(DocumentInserted,
 25169            "Upserted ({Mode}) document with metadata {Metadata} for entry with reference {Reference}",
 25170            mode, metadata, entry.Reference);
 25171    }
 172
 173    public async Task<bool> EnsureGistHasCorrectMetadataAsync(Gist gist, bool disabled, CancellationToken ct)
 174    {
 13175        ValidateReference(gist.Reference);
 9176        var collectionId = await GetOrCreateCollectionAsync(ct);
 9177        var document = await GetDocumentByReferenceAsync(gist.Reference, collectionId, false, true, ct);
 9178        var oldMetadata = document.Metadatas.FirstOrDefault();
 9179        if (oldMetadata is null)
 180        {
 2181            throw new DatabaseOperationException($"Entry with reference {gist.Reference} does not exist in ChromaDb");
 182        }
 9183        if (oldMetadata.Disabled == disabled && oldMetadata.FeedId == gist.FeedId) return true;
 5184        var newMetaData = new Metadata(gist.Reference, gist.FeedId, disabled);
 5185        await UpdateMetadataAsync(gist.Reference, newMetaData, ct);
 5186        _logger?.LogInformation(ChangedMetadataOfGistInChromaDb,
 5187            "Changed metadata from {OldMetadata} to {NewMetadata} for gist with reference {GistReference}",
 5188            oldMetadata, newMetaData, gist.Reference);
 5189        return false;
 7190    }
 191
 192    private async Task UpdateMetadataAsync(string reference, Metadata metadata, CancellationToken ct)
 193    {
 5194        ValidateReference(reference);
 5195        var collectionId = await GetOrCreateCollectionAsync(ct);
 5196        if (!await EntryExistsByReferenceAsync(reference, ct, collectionId))
 197        {
 0198            throw new DatabaseOperationException("Entry to update does not exist");
 199        }
 5200        var content = CreateStringContent(new Document([reference], [metadata]));
 5201        var response = await SendPostRequestAsync(
 5202            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/update", content, ct);
 203
 5204        if (response.StatusCode != HttpStatusCode.OK)
 205        {
 0206            throw await CreateDatabaseOperationExceptionAsync("Could not update entry", response, ct);
 207        }
 5208    }
 209
 210    public async Task<bool> EntryExistsByReferenceAsync(string reference, CancellationToken ct, string? collectionId = n
 211    {
 36212        collectionId ??= await GetOrCreateCollectionAsync(ct);
 36213        var document = await GetDocumentByReferenceAsync(reference, collectionId, false, false, ct);
 36214        return document.Ids.Length != 0;
 36215    }
 216
 217    private async Task<Document> GetDocumentByReferenceAsync(string reference, string collectionId,
 218        bool includeEmbeddings, bool includeMetadata, CancellationToken ct)
 219    {
 51220        var include = new List<string>();
 57221        if (includeEmbeddings) include.Add("embeddings");
 60222        if (includeMetadata) include.Add("metadatas");
 51223        var content = CreateStringContent(new { Ids = new[] { reference }, Include = include });
 51224        var response = await SendPostRequestAsync(
 51225            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/get", content, ct);
 51226        var responseContent = await response.Content.ReadAsStreamAsync(ct);
 51227        var document =
 51228            await JsonSerializer.DeserializeAsync<Document>(responseContent, SerializerDefaults.JsonOptions, ct);
 51229        if (document is null || (includeEmbeddings && document.Embeddings is null))
 230        {
 0231            throw await CreateDatabaseOperationExceptionAsync("Could not get entry", response, ct);
 232        }
 51233        return document;
 51234    }
 235
 236    private async Task<string> GetOrCreateCollectionAsync(CancellationToken ct)
 237    {
 51238        await CreateDatabaseIfNotExistsAsync(ct);
 51239        var existingCollectionId = await GetCollectionIdAsync(_collectionName, ct);
 95240        if (existingCollectionId is not null) return existingCollectionId;
 241
 7242        var requestContent = CreateStringContent(new CollectionDefinition(_collectionName));
 7243        var response =
 7244            await SendPostRequestAsync($"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections",
 7245                requestContent, ct);
 246
 7247        if (response.StatusCode != HttpStatusCode.OK)
 248        {
 0249            throw await CreateDatabaseOperationExceptionAsync("Could not create collection", response, ct);
 250        }
 7251        return await ExtractCollectionIdAsync(response, ct);
 51252    }
 253
 254    private async Task<string?> GetCollectionIdAsync(string collectionName, CancellationToken ct)
 255    {
 51256        var response = await SendGetRequestAsync(
 51257            $"api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionName}", ct);
 51258        return response.StatusCode == HttpStatusCode.NotFound ? null : await ExtractCollectionIdAsync(response, ct);
 51259    }
 260
 261    private static async Task<string> ExtractCollectionIdAsync(HttpResponseMessage response, CancellationToken ct)
 262    {
 51263        var content = await response.Content.ReadAsStreamAsync(ct);
 51264        var collection =
 51265            await JsonSerializer.DeserializeAsync<Collection>(content, SerializerDefaults.JsonOptions, ct);
 51266        if (collection is null)
 267        {
 0268            throw await CreateDatabaseOperationExceptionAsync("Could not extract collection ID", response, ct);
 269        }
 51270        return collection.Id;
 51271    }
 272
 273    private async Task CreateDatabaseIfNotExistsAsync(CancellationToken ct)
 274    {
 51275        await CreateTenantIfNotExistsAsync(ct);
 99276        if (await DatabaseExistsAsync(ct)) return;
 3277        var content = CreateStringContent(new { Name = _databaseName });
 3278        var response = await SendPostRequestAsync($"/api/v2/tenants/{_tenantName}/databases", content, ct);
 3279        if (response.StatusCode != HttpStatusCode.OK)
 280        {
 0281            throw await CreateDatabaseOperationExceptionAsync("Could not create database", response, ct);
 282        }
 51283    }
 284
 285    private async Task<bool> DatabaseExistsAsync(CancellationToken ct)
 286    {
 51287        var response = await SendGetRequestAsync($"api/v2/tenants/{_tenantName}/databases/{_databaseName}", ct);
 51288        return response.StatusCode == HttpStatusCode.OK;
 51289    }
 290
 291    private async Task CreateTenantIfNotExistsAsync(CancellationToken ct)
 292    {
 99293        if (await TenantExistsAsync(ct)) return;
 3294        var content = CreateStringContent(new { Name = _tenantName });
 3295        var response = await SendPostRequestAsync("/api/v2/tenants", content, ct);
 3296        if (response.StatusCode != HttpStatusCode.OK)
 297        {
 0298            throw await CreateDatabaseOperationExceptionAsync("Could not create tenant", response, ct);
 299        }
 51300    }
 301
 302    private async Task<bool> TenantExistsAsync(CancellationToken ct)
 303    {
 51304        var response = await SendGetRequestAsync($"api/v2/tenants/{_tenantName}", ct);
 51305        return response.StatusCode == HttpStatusCode.OK;
 51306    }
 307
 308    private Task<HttpResponseMessage> SendGetRequestAsync(string relativeUri, CancellationToken ct) =>
 153309        SendRequestAsync(HttpMethod.Get, relativeUri, ct);
 310
 311    private Task<HttpResponseMessage> SendPostRequestAsync(string relativeUri, HttpContent content,
 100312        CancellationToken ct) => SendRequestAsync(HttpMethod.Post, relativeUri, ct, content);
 313
 314    private async Task<HttpResponseMessage> SendRequestAsync(HttpMethod method, string relativeUri,
 315        CancellationToken ct, HttpContent? content = null)
 316    {
 253317        var uri = new Uri(_chromaDbUri, relativeUri);
 253318        var request = CreateHttpRequestMessage(method, uri, content);
 253319        return await _httpClient.SendAsync(request, ct);
 253320    }
 321
 322    private static StringContent CreateStringContent(object objectToSerialize) =>
 100323        new(JsonSerializer.Serialize(objectToSerialize, SerializerDefaults.JsonOptions), Encoding.UTF8,
 100324            "application/json");
 325
 326    private HttpRequestMessage CreateHttpRequestMessage(HttpMethod method, Uri uri, HttpContent? content = null)
 327    {
 253328        var request = new HttpRequestMessage(method, uri);
 253329        request.Headers.Add(_credentialsHeaderName, _serverAuthnCredentials);
 253330        request.Content = content;
 253331        return request;
 332    }
 333
 334    private static async Task<DatabaseOperationException> CreateDatabaseOperationExceptionAsync(string message,
 335        HttpResponseMessage response, CancellationToken ct)
 336    {
 0337        var responseContent = await response.Content.ReadAsStringAsync(ct);
 0338        return new DatabaseOperationException($"{message}. Code: {response.StatusCode}. Response: {responseContent}");
 0339    }
 340
 341    private static void ValidateReference(string reference)
 342    {
 57343        if (reference.Length is 0 or >= 1000000) throw new ArgumentException("Reference is invalid.");
 45344    }
 345
 346    private static void ValidateSearchQuery(string query)
 347    {
 0348        if (query.Length is 0 or >= 1000000) throw new ArgumentException("Query is invalid.");
 0349    }
 350}