< Summary

Information
Class: GistBackend.Handlers.ChromaDbHandler.ChromaDbHandler
Assembly: GistBackend
File(s): /home/runner/work/the-gist-of-it-sec/the-gist-of-it-sec/backend/GistBackend/Handlers/ChromaDbHandler/ChromaDbHandler.cs
Line coverage
91%
Covered lines: 166
Uncovered lines: 15
Coverable lines: 181
Total lines: 322
Line coverage: 91.7%
Branch coverage
73%
Covered branches: 53
Total branches: 72
Branch coverage: 73.6%
Method coverage

Feature is only available for sponsors

Upgrade to PRO version

Metrics

File(s)

/home/runner/work/the-gist-of-it-sec/the-gist-of-it-sec/backend/GistBackend/Handlers/ChromaDbHandler/ChromaDbHandler.cs

#LineLine coverage
 1using System.Net;
 2using System.Text;
 3using System.Text.Json;
 4using GistBackend.Exceptions;
 5using GistBackend.Handlers.AIHandler;
 6using GistBackend.Types;
 7using GistBackend.Utils;
 8using Microsoft.Extensions.Logging;
 9using Microsoft.Extensions.Options;
 10using static GistBackend.Utils.LogEvents;
 11
 12namespace GistBackend.Handlers.ChromaDbHandler;
 13
 14public interface IChromaDbHandler
 15{
 16    Task UpsertEntryAsync(RssEntry entry, string summary, CancellationToken ct);
 17    Task<bool> EnsureGistHasCorrectMetadataAsync(Gist gist, bool disabled, CancellationToken ct);
 18    Task<List<SimilarDocument>> GetReferenceAndScoreOfSimilarEntriesAsync(
 19        string reference, int nResults, IEnumerable<int> disabledFeedIds, CancellationToken ct);
 20}
 21
 22public class ChromaDbHandler : IChromaDbHandler
 23{
 24    private readonly Uri _chromaDbUri;
 25    private readonly string _tenantName;
 26    private readonly string _databaseName;
 27    private readonly string _collectionName;
 28    private readonly IAIHandler _aiHandler;
 29    private readonly HttpClient _httpClient;
 30    private readonly string _credentialsHeaderName;
 31    private readonly string _serverAuthnCredentials;
 32    private readonly ILogger<ChromaDbHandler>? _logger;
 33
 2834    public ChromaDbHandler(IAIHandler aiHandler,
 2835        HttpClient httpClient,
 2836        IOptions<ChromaDbHandlerOptions> options,
 2837        ILogger<ChromaDbHandler>? logger)
 38    {
 2839        if (string.IsNullOrWhiteSpace(options.Value.Server))
 040            throw new ArgumentException("Server is not set in the options.");
 2841        if (string.IsNullOrWhiteSpace(options.Value.ServerAuthnCredentials))
 042            throw new ArgumentException("Server authentication credentials are not set in the options.");
 2843        _aiHandler = aiHandler;
 2844        _httpClient = httpClient;
 2845        _logger = logger;
 2846        _chromaDbUri = new Uri($"http://{options.Value.Server}:{options.Value.Port}/");
 2847        _credentialsHeaderName = options.Value.CredentialsHeaderName;
 2848        _serverAuthnCredentials = options.Value.ServerAuthnCredentials;
 2849        _tenantName = options.Value.GistsTenantName;
 2850        _databaseName = options.Value.GistsDatabaseName;
 2851        _collectionName = options.Value.GistsCollectionName;
 2852    }
 53
 154    private static readonly string[] IncludeOnGet = ["metadatas", "distances"];
 55
 56    public async Task<List<SimilarDocument>> GetReferenceAndScoreOfSimilarEntriesAsync(string reference,
 57        int nResults, IEnumerable<int> disabledFeedIds, CancellationToken ct)
 58    {
 659        ValidateReference(reference);
 660        var collectionId = await GetOrCreateCollectionAsync(ct);
 661        if (!await EntryExistsByReferenceAsync(reference, ct, collectionId))
 62        {
 063            throw new DatabaseOperationException("Entry does not exist in database");
 64        }
 65
 666        var document = await GetDocumentByReferenceAsync(reference, collectionId, true, false, ct);
 667        var content = CreateStringContent(new {
 668            QueryEmbeddings = new[] {document.Embeddings!.Single()},
 669            NResults = nResults+1, // +1 to exclude the original entry
 670            Where = GenerateWhere(disabledFeedIds),
 671            Include = IncludeOnGet
 672        });
 673        var response = await SendPostRequestAsync(
 674            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/query", content, ct);
 675        if (response.StatusCode != HttpStatusCode.OK)
 76        {
 077            throw await CreateDatabaseOperationExceptionAsync("Could not query similar entries", response, ct);
 78        }
 79
 680        var responseContent = await response.Content.ReadAsStreamAsync(ct);
 681        var queryResponse =
 682            await JsonSerializer.DeserializeAsync<QueryResponse>(responseContent, SerializerDefaults.JsonOptions, ct);
 683        if (queryResponse is null) throw new DatabaseOperationException("Could not get similar entries");
 684        var referencesAndScores = ExtractReferencesAndScores(queryResponse);
 85
 86        // Exclude the original entry from the results
 2087        return referencesAndScores.Where(referenceAndScore => referenceAndScore.Reference != reference).ToList();
 688    }
 89
 90    private static Dictionary<string, object> GenerateWhere(IEnumerable<int> disabledFeedIds)
 91    {
 692        var whereNotDisabled = new Dictionary<string, object> {
 693            { "disabled", new Dictionary<string, object> { { "$ne", true } } }
 694        };
 695        var disabledFeedIdsArray = disabledFeedIds.ToArray();
 696        if (disabledFeedIdsArray.Length == 0)
 97        {
 498            return whereNotDisabled;
 99        }
 100
 2101        var whereNotInDisabledFeeds = new Dictionary<string, object> {
 2102            { "feed_id", new Dictionary<string, object> { { "$nin", disabledFeedIdsArray } } }
 2103        };
 2104        return new Dictionary<string, object> {
 2105            { "$and", new[] {
 2106                whereNotDisabled,
 2107                whereNotInDisabledFeeds
 2108            } }
 2109        };
 110    }
 111
 112    private static List<SimilarDocument> ExtractReferencesAndScores(QueryResponse queryResponse) =>
 6113        Enumerable.Range(0, queryResponse.Ids.First().Length).Select(i =>
 14114            new SimilarDocument(
 14115                queryResponse.Metadatas.First()[i].Reference,
 14116                ConvertCosineDistanceToSimilarity(queryResponse.Distances.First()[i])
 14117            ))
 6118            .ToList();
 119
 14120    private static float ConvertCosineDistanceToSimilarity(float distance) => float.Clamp(1 - distance/2, 0, 1);
 121
 122    public async Task UpsertEntryAsync(RssEntry entry, string summary, CancellationToken ct)
 123    {
 27124        ValidateReference(entry.Reference);
 25125        var collectionId = await GetOrCreateCollectionAsync(ct);
 25126        var mode = "add";
 25127        if (await EntryExistsByReferenceAsync(entry.Reference, ct, collectionId))
 128        {
 1129            _logger?.LogInformation(EntryAlreadyExistsInChromaDb,
 1130                "Entry with reference {Reference} already exists in database", entry.Reference);
 1131            mode = "update";
 132        }
 133
 25134        var metadata = new Metadata(entry.Reference, entry.FeedId);
 25135        var embedding = await _aiHandler.GenerateEmbeddingAsync(summary, ct);
 25136        var content = CreateStringContent(new Document([entry.Reference], [metadata], [embedding]));
 25137        var response = await SendPostRequestAsync(
 25138            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/{mode}", content, ct);
 139
 25140        if (mode == "add" && response.StatusCode != HttpStatusCode.Created ||
 25141            mode == "update" && response.StatusCode != HttpStatusCode.OK)
 142        {
 0143            throw await CreateDatabaseOperationExceptionAsync($"Could not {mode} entry", response, ct);
 144        }
 25145        _logger?.LogInformation(DocumentInserted,
 25146            "Upserted ({Mode}) document with metadata {Metadata} for entry with reference {Reference}",
 25147            mode, metadata, entry.Reference);
 25148    }
 149
 150    public async Task<bool> EnsureGistHasCorrectMetadataAsync(Gist gist, bool disabled, CancellationToken ct)
 151    {
 13152        ValidateReference(gist.Reference);
 9153        var collectionId = await GetOrCreateCollectionAsync(ct);
 9154        var document = await GetDocumentByReferenceAsync(gist.Reference, collectionId, false, true, ct);
 9155        var oldMetadata = document.Metadatas.FirstOrDefault();
 9156        if (oldMetadata is null)
 157        {
 2158            throw new DatabaseOperationException($"Entry with reference {gist.Reference} does not exist in ChromaDb");
 159        }
 9160        if (oldMetadata.Disabled == disabled && oldMetadata.FeedId == gist.FeedId) return true;
 5161        var newMetaData = new Metadata(gist.Reference, gist.FeedId, disabled);
 5162        await UpdateMetadataAsync(gist.Reference, newMetaData, ct);
 5163        _logger?.LogInformation(ChangedMetadataOfGistInChromaDb,
 5164            "Changed metadata from {OldMetadata} to {NewMetadata} for gist with reference {GistReference}",
 5165            oldMetadata, newMetaData, gist.Reference);
 5166        return false;
 7167    }
 168
 169    private async Task UpdateMetadataAsync(string reference, Metadata metadata, CancellationToken ct)
 170    {
 5171        ValidateReference(reference);
 5172        var collectionId = await GetOrCreateCollectionAsync(ct);
 5173        if (!await EntryExistsByReferenceAsync(reference, ct, collectionId))
 174        {
 0175            throw new DatabaseOperationException("Entry to update does not exist");
 176        }
 5177        var content = CreateStringContent(new Document([reference], [metadata]));
 5178        var response = await SendPostRequestAsync(
 5179            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/update", content, ct);
 180
 5181        if (response.StatusCode != HttpStatusCode.OK)
 182        {
 0183            throw await CreateDatabaseOperationExceptionAsync("Could not update entry", response, ct);
 184        }
 5185    }
 186
 187    public async Task<bool> EntryExistsByReferenceAsync(string reference, CancellationToken ct, string? collectionId = n
 188    {
 36189        collectionId ??= await GetOrCreateCollectionAsync(ct);
 36190        var document = await GetDocumentByReferenceAsync(reference, collectionId, false, false, ct);
 36191        return document.Ids.Length != 0;
 36192    }
 193
 194    private async Task<Document> GetDocumentByReferenceAsync(string reference, string collectionId,
 195        bool includeEmbeddings, bool includeMetadata, CancellationToken ct)
 196    {
 51197        var include = new List<string>();
 57198        if (includeEmbeddings) include.Add("embeddings");
 60199        if (includeMetadata) include.Add("metadatas");
 51200        var content = CreateStringContent(new { Ids = new[] { reference }, Include = include });
 51201        var response = await SendPostRequestAsync(
 51202            $"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionId}/get", content, ct);
 51203        var responseContent = await response.Content.ReadAsStreamAsync(ct);
 51204        var document =
 51205            await JsonSerializer.DeserializeAsync<Document>(responseContent, SerializerDefaults.JsonOptions, ct);
 51206        if (document is null || (includeEmbeddings && document.Embeddings is null))
 207        {
 0208            throw await CreateDatabaseOperationExceptionAsync("Could not get entry", response, ct);
 209        }
 51210        return document;
 51211    }
 212
 213    private async Task<string> GetOrCreateCollectionAsync(CancellationToken ct)
 214    {
 45215        await CreateDatabaseIfNotExistsAsync(ct);
 45216        var existingCollectionId = await GetCollectionIdAsync(_collectionName, ct);
 83217        if (existingCollectionId is not null) return existingCollectionId;
 218
 7219        var requestContent = CreateStringContent(new CollectionDefinition(_collectionName));
 7220        var response =
 7221            await SendPostRequestAsync($"/api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections",
 7222                requestContent, ct);
 223
 7224        if (response.StatusCode != HttpStatusCode.OK)
 225        {
 0226            throw await CreateDatabaseOperationExceptionAsync("Could not create collection", response, ct);
 227        }
 7228        return await ExtractCollectionIdAsync(response, ct);
 45229    }
 230
 231    private async Task<string?> GetCollectionIdAsync(string collectionName, CancellationToken ct)
 232    {
 45233        var response = await SendGetRequestAsync(
 45234            $"api/v2/tenants/{_tenantName}/databases/{_databaseName}/collections/{collectionName}", ct);
 45235        return response.StatusCode == HttpStatusCode.NotFound ? null : await ExtractCollectionIdAsync(response, ct);
 45236    }
 237
 238    private static async Task<string> ExtractCollectionIdAsync(HttpResponseMessage response, CancellationToken ct)
 239    {
 45240        var content = await response.Content.ReadAsStreamAsync(ct);
 45241        var collection =
 45242            await JsonSerializer.DeserializeAsync<Collection>(content, SerializerDefaults.JsonOptions, ct);
 45243        if (collection is null)
 244        {
 0245            throw await CreateDatabaseOperationExceptionAsync("Could not extract collection ID", response, ct);
 246        }
 45247        return collection.Id;
 45248    }
 249
 250    private async Task CreateDatabaseIfNotExistsAsync(CancellationToken ct)
 251    {
 45252        await CreateTenantIfNotExistsAsync(ct);
 87253        if (await DatabaseExistsAsync(ct)) return;
 3254        var content = CreateStringContent(new { Name = _databaseName });
 3255        var response = await SendPostRequestAsync($"/api/v2/tenants/{_tenantName}/databases", content, ct);
 3256        if (response.StatusCode != HttpStatusCode.OK)
 257        {
 0258            throw await CreateDatabaseOperationExceptionAsync("Could not create database", response, ct);
 259        }
 45260    }
 261
 262    private async Task<bool> DatabaseExistsAsync(CancellationToken ct)
 263    {
 45264        var response = await SendGetRequestAsync($"api/v2/tenants/{_tenantName}/databases/{_databaseName}", ct);
 45265        return response.StatusCode == HttpStatusCode.OK;
 45266    }
 267
 268    private async Task CreateTenantIfNotExistsAsync(CancellationToken ct)
 269    {
 87270        if (await TenantExistsAsync(ct)) return;
 3271        var content = CreateStringContent(new { Name = _tenantName });
 3272        var response = await SendPostRequestAsync("/api/v2/tenants", content, ct);
 3273        if (response.StatusCode != HttpStatusCode.OK)
 274        {
 0275            throw await CreateDatabaseOperationExceptionAsync("Could not create tenant", response, ct);
 276        }
 45277    }
 278
 279    private async Task<bool> TenantExistsAsync(CancellationToken ct)
 280    {
 45281        var response = await SendGetRequestAsync($"api/v2/tenants/{_tenantName}", ct);
 45282        return response.StatusCode == HttpStatusCode.OK;
 45283    }
 284
 285    private Task<HttpResponseMessage> SendGetRequestAsync(string relativeUri, CancellationToken ct) =>
 135286        SendRequestAsync(HttpMethod.Get, relativeUri, ct);
 287
 288    private Task<HttpResponseMessage> SendPostRequestAsync(string relativeUri, HttpContent content,
 100289        CancellationToken ct) => SendRequestAsync(HttpMethod.Post, relativeUri, ct, content);
 290
 291    private async Task<HttpResponseMessage> SendRequestAsync(HttpMethod method, string relativeUri,
 292        CancellationToken ct, HttpContent? content = null)
 293    {
 235294        var uri = new Uri(_chromaDbUri, relativeUri);
 235295        var request = CreateHttpRequestMessage(method, uri, content);
 235296        return await _httpClient.SendAsync(request, ct);
 235297    }
 298
 299    private static StringContent CreateStringContent(object objectToSerialize) =>
 100300        new(JsonSerializer.Serialize(objectToSerialize, SerializerDefaults.JsonOptions), Encoding.UTF8,
 100301            "application/json");
 302
 303    private HttpRequestMessage CreateHttpRequestMessage(HttpMethod method, Uri uri, HttpContent? content = null)
 304    {
 235305        var request = new HttpRequestMessage(method, uri);
 235306        request.Headers.Add(_credentialsHeaderName, _serverAuthnCredentials);
 235307        request.Content = content;
 235308        return request;
 309    }
 310
 311    private static async Task<DatabaseOperationException> CreateDatabaseOperationExceptionAsync(string message,
 312        HttpResponseMessage response, CancellationToken ct)
 313    {
 0314        var responseContent = await response.Content.ReadAsStringAsync(ct);
 0315        return new DatabaseOperationException($"{message}. Code: {response.StatusCode}. Response: {responseContent}");
 0316    }
 317
 318    private static void ValidateReference(string reference)
 319    {
 57320        if (reference.Length is 0 or >= 1000000) throw new ArgumentException("Reference is invalid.");
 45321    }
 322}