From 2d9bec5da97aa561ac68358c8205ab7a28e70297 Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Tue, 9 Jul 2024 10:15:26 -0700 Subject: [PATCH] [BUG] Make go sysdb return created flag. Respect created flag (#2476) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - In #1835 the sysdb client in python was made to not respect the `created` flag. This fixes the python code to do that, as well as updates the go code to appropriately return that flag. - New functionality - None ## Test plan *How are these changes tested?* Added a test in the go side to check that created is false when get_or_create'ing a collection that already exists - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- chromadb/db/impl/grpc/client.py | 2 +- go/pkg/coordinator/apis.go | 10 +++--- go/pkg/coordinator/apis_test.go | 31 ++++++++++++++----- go/pkg/coordinator/grpc/collection_service.go | 3 +- go/pkg/metastore/catalog.go | 2 +- go/pkg/metastore/coordinator/table_catalog.go | 9 ++++-- .../coordinator/table_catalog_test.go | 2 +- 7 files changed, 40 insertions(+), 19 deletions(-) diff --git a/chromadb/db/impl/grpc/client.py b/chromadb/db/impl/grpc/client.py index 8ff3069ddd2..4b0fdecd877 100644 --- a/chromadb/db/impl/grpc/client.py +++ b/chromadb/db/impl/grpc/client.py @@ -222,7 +222,7 @@ def create_collection( if response.status.code == 409: raise UniqueConstraintError() collection = from_proto_collection(response.collection) - return collection, response.status.code == 200 + return collection, response.created @overrides def delete_collection( diff --git a/go/pkg/coordinator/apis.go b/go/pkg/coordinator/apis.go index 40ec49f5d50..a8340e43536 100644 --- a/go/pkg/coordinator/apis.go +++ b/go/pkg/coordinator/apis.go @@ -17,7 +17,7 @@ import ( type ICoordinator interface { common.Component ResetState(ctx context.Context) error - CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, error) + CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, bool, error) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, dataName string, limit *int32, offset *int32) ([]*model.Collection, error) DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection) error UpdateCollection(ctx context.Context, updateCollection *model.UpdateCollection) (*model.Collection, error) @@ -70,13 +70,13 @@ func (s *Coordinator) GetTenant(ctx context.Context, getTenant *model.GetTenant) return tenant, nil } -func (s *Coordinator) CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, error) { +func (s *Coordinator) CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, bool, error) { log.Info("create collection", zap.Any("createCollection", createCollection)) - collection, err := s.catalog.CreateCollection(ctx, createCollection, createCollection.Ts) + collection, created, err := s.catalog.CreateCollection(ctx, createCollection, createCollection.Ts) if err != nil { - return nil, err + return nil, false, err } - return collection, nil + return collection, created, nil } func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error) { diff --git a/go/pkg/coordinator/apis_test.go b/go/pkg/coordinator/apis_test.go index f45a2fe6208..4803ce11a5a 100644 --- a/go/pkg/coordinator/apis_test.go +++ b/go/pkg/coordinator/apis_test.go @@ -56,7 +56,7 @@ func (suite *APIsTestSuite) SetupTest() { } suite.coordinator = c for _, collection := range suite.sampleCollections { - _, errCollectionCreation := c.CreateCollection(ctx, &model.CreateCollection{ + _, _, errCollectionCreation := c.CreateCollection(ctx, &model.CreateCollection{ ID: collection.ID, Name: collection.Name, Metadata: collection.Metadata, @@ -104,7 +104,7 @@ func testCollection(t *rapid.T) { } }).Draw(t, "collection") - _, err := c.CreateCollection(ctx, collection) + _, _, err := c.CreateCollection(ctx, collection) if err != nil { if err == common.ErrCollectionNameEmpty && collection.Name == "" { t.Logf("expected error for empty collection name") @@ -265,7 +265,7 @@ func (suite *APIsTestSuite) TestCreateGetDeleteCollections() { suite.Equal(suite.sampleCollections, results) // Duplicate create fails - _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + _, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ ID: suite.sampleCollections[0].ID, Name: suite.sampleCollections[0].Name, TenantID: suite.tenantName, @@ -363,6 +363,23 @@ func (suite *APIsTestSuite) TestUpdateCollections() { suite.Equal([]*model.Collection{coll}, resultList) } +func (suite *APIsTestSuite) TestGetOrCreateCollectionsTwice() { + // GetOrCreateCollection already existing collection returns false for created + ctx := context.Background() + coll := suite.sampleCollections[0] + _, created, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + ID: coll.ID, + Name: coll.Name, + Metadata: coll.Metadata, + Dimension: coll.Dimension, + GetOrCreate: true, + TenantID: coll.TenantID, + DatabaseName: coll.DatabaseName, + }) + suite.NoError(err) + suite.False(created) +} + func (suite *APIsTestSuite) TestCreateUpdateWithDatabase() { ctx := context.Background() newDatabaseName := "test_apis_CreateUpdateWithDatabase" @@ -376,7 +393,7 @@ func (suite *APIsTestSuite) TestCreateUpdateWithDatabase() { suite.sampleCollections[0].ID = types.NewUniqueID() suite.sampleCollections[0].Name = suite.sampleCollections[0].Name + "1" - _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + _, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ ID: suite.sampleCollections[0].ID, Name: suite.sampleCollections[0].Name, Metadata: suite.sampleCollections[0].Metadata, @@ -430,7 +447,7 @@ func (suite *APIsTestSuite) TestGetMultipleWithDatabase() { collection.Name = collection.Name + "1" collection.TenantID = suite.tenantName collection.DatabaseName = newDatabaseName - _, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + _, _, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ ID: collection.ID, Name: collection.Name, Metadata: collection.Metadata, @@ -499,7 +516,7 @@ func (suite *APIsTestSuite) TestCreateDatabaseWithTenants() { // Create a new collection in the new tenant suite.sampleCollections[0].ID = types.NewUniqueID() suite.sampleCollections[0].Name = suite.sampleCollections[0].Name + "1" - _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + _, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ ID: suite.sampleCollections[0].ID, Name: suite.sampleCollections[0].Name, Metadata: suite.sampleCollections[0].Metadata, @@ -512,7 +529,7 @@ func (suite *APIsTestSuite) TestCreateDatabaseWithTenants() { // Create a new collection in the default tenant suite.sampleCollections[1].ID = types.NewUniqueID() suite.sampleCollections[1].Name = suite.sampleCollections[1].Name + "2" - _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ + _, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{ ID: suite.sampleCollections[1].ID, Name: suite.sampleCollections[1].Name, Metadata: suite.sampleCollections[1].Metadata, diff --git a/go/pkg/coordinator/grpc/collection_service.go b/go/pkg/coordinator/grpc/collection_service.go index 0ea3bbe5ecf..19fd8591f4a 100644 --- a/go/pkg/coordinator/grpc/collection_service.go +++ b/go/pkg/coordinator/grpc/collection_service.go @@ -72,7 +72,7 @@ func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.Create res.Status = failResponseWithError(err, successCode) return res, nil } - collection, err := s.coordinator.CreateCollection(ctx, createCollection) + collection, created, err := s.coordinator.CreateCollection(ctx, createCollection) if err != nil { log.Error("error creating collection", zap.Error(err)) res.Collection = &coordinatorpb.Collection{ @@ -94,6 +94,7 @@ func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.Create } res.Collection = convertCollectionToProto(collection) res.Status = setResponseStatus(successCode) + res.Created = created return res, nil } diff --git a/go/pkg/metastore/catalog.go b/go/pkg/metastore/catalog.go index 3c0958974f5..961b13b1b86 100644 --- a/go/pkg/metastore/catalog.go +++ b/go/pkg/metastore/catalog.go @@ -14,7 +14,7 @@ import ( //go:generate mockery --name=Catalog type Catalog interface { ResetState(ctx context.Context) error - CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, error) + CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, bool, error) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error) DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection) error UpdateCollection(ctx context.Context, updateCollection *model.UpdateCollection, ts types.Timestamp) (*model.Collection, error) diff --git a/go/pkg/metastore/coordinator/table_catalog.go b/go/pkg/metastore/coordinator/table_catalog.go index 3da3ac6dc78..3dac3e89e14 100644 --- a/go/pkg/metastore/coordinator/table_catalog.go +++ b/go/pkg/metastore/coordinator/table_catalog.go @@ -210,9 +210,10 @@ func (tc *Catalog) GetAllTenants(ctx context.Context, ts types.Timestamp) ([]*mo return result, nil } -func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, error) { +func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, bool, error) { var result *model.Collection + created := false err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error { // insert collection databaseName := createCollection.DatabaseName @@ -254,6 +255,8 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model } else { return common.ErrCollectionUniqueConstraintViolation } + } else { + created = true } dbCollection := &dbmodel.Collection{ @@ -301,10 +304,10 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model }) if err != nil { log.Error("error creating collection", zap.Error(err)) - return nil, err + return nil, false, err } log.Info("collection created", zap.Any("collection", result)) - return result, nil + return result, created, nil } func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error) { diff --git a/go/pkg/metastore/coordinator/table_catalog_test.go b/go/pkg/metastore/coordinator/table_catalog_test.go index 7b2f54d8eac..e0359eb58ea 100644 --- a/go/pkg/metastore/coordinator/table_catalog_test.go +++ b/go/pkg/metastore/coordinator/table_catalog_test.go @@ -64,7 +64,7 @@ func TestCatalog_CreateCollection(t *testing.T) { }).Return(nil) // call the CreateCollection method - _, err := catalog.CreateCollection(context.Background(), collection, ts) + _, _, err := catalog.CreateCollection(context.Background(), collection, ts) // assert that the method returned no error assert.NoError(t, err)