Skip to content

Commit

Permalink
[BUG] Make go sysdb return created flag. Respect created flag (chroma…
Browse files Browse the repository at this point in the history
…-core#2476)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- In chroma-core#1835 the sysdb client in python was made to not respect the
`created` flag. This fixes the python code to do that, as well as
updates the go code to appropriately return that flag.
 - New functionality
	 - None

## Test plan
*How are these changes tested?*
Added a test in the go side to check that created is false when
get_or_create'ing a collection that already exists
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Jul 9, 2024
1 parent f097ba1 commit 2d9bec5
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def create_collection(
if response.status.code == 409:
raise UniqueConstraintError()
collection = from_proto_collection(response.collection)
return collection, response.status.code == 200
return collection, response.created

@overrides
def delete_collection(
Expand Down
10 changes: 5 additions & 5 deletions go/pkg/coordinator/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
type ICoordinator interface {
common.Component
ResetState(ctx context.Context) error
CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, error)
CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, bool, error)
GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, dataName string, limit *int32, offset *int32) ([]*model.Collection, error)
DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection) error
UpdateCollection(ctx context.Context, updateCollection *model.UpdateCollection) (*model.Collection, error)
Expand Down Expand Up @@ -70,13 +70,13 @@ func (s *Coordinator) GetTenant(ctx context.Context, getTenant *model.GetTenant)
return tenant, nil
}

func (s *Coordinator) CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, error) {
func (s *Coordinator) CreateCollection(ctx context.Context, createCollection *model.CreateCollection) (*model.Collection, bool, error) {
log.Info("create collection", zap.Any("createCollection", createCollection))
collection, err := s.catalog.CreateCollection(ctx, createCollection, createCollection.Ts)
collection, created, err := s.catalog.CreateCollection(ctx, createCollection, createCollection.Ts)
if err != nil {
return nil, err
return nil, false, err
}
return collection, nil
return collection, created, nil
}

func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error) {
Expand Down
31 changes: 24 additions & 7 deletions go/pkg/coordinator/apis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (suite *APIsTestSuite) SetupTest() {
}
suite.coordinator = c
for _, collection := range suite.sampleCollections {
_, errCollectionCreation := c.CreateCollection(ctx, &model.CreateCollection{
_, _, errCollectionCreation := c.CreateCollection(ctx, &model.CreateCollection{
ID: collection.ID,
Name: collection.Name,
Metadata: collection.Metadata,
Expand Down Expand Up @@ -104,7 +104,7 @@ func testCollection(t *rapid.T) {
}
}).Draw(t, "collection")

_, err := c.CreateCollection(ctx, collection)
_, _, err := c.CreateCollection(ctx, collection)
if err != nil {
if err == common.ErrCollectionNameEmpty && collection.Name == "" {
t.Logf("expected error for empty collection name")
Expand Down Expand Up @@ -265,7 +265,7 @@ func (suite *APIsTestSuite) TestCreateGetDeleteCollections() {
suite.Equal(suite.sampleCollections, results)

// Duplicate create fails
_, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
_, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: suite.sampleCollections[0].ID,
Name: suite.sampleCollections[0].Name,
TenantID: suite.tenantName,
Expand Down Expand Up @@ -363,6 +363,23 @@ func (suite *APIsTestSuite) TestUpdateCollections() {
suite.Equal([]*model.Collection{coll}, resultList)
}

func (suite *APIsTestSuite) TestGetOrCreateCollectionsTwice() {
// GetOrCreateCollection already existing collection returns false for created
ctx := context.Background()
coll := suite.sampleCollections[0]
_, created, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: coll.ID,
Name: coll.Name,
Metadata: coll.Metadata,
Dimension: coll.Dimension,
GetOrCreate: true,
TenantID: coll.TenantID,
DatabaseName: coll.DatabaseName,
})
suite.NoError(err)
suite.False(created)
}

func (suite *APIsTestSuite) TestCreateUpdateWithDatabase() {
ctx := context.Background()
newDatabaseName := "test_apis_CreateUpdateWithDatabase"
Expand All @@ -376,7 +393,7 @@ func (suite *APIsTestSuite) TestCreateUpdateWithDatabase() {

suite.sampleCollections[0].ID = types.NewUniqueID()
suite.sampleCollections[0].Name = suite.sampleCollections[0].Name + "1"
_, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
_, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: suite.sampleCollections[0].ID,
Name: suite.sampleCollections[0].Name,
Metadata: suite.sampleCollections[0].Metadata,
Expand Down Expand Up @@ -430,7 +447,7 @@ func (suite *APIsTestSuite) TestGetMultipleWithDatabase() {
collection.Name = collection.Name + "1"
collection.TenantID = suite.tenantName
collection.DatabaseName = newDatabaseName
_, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
_, _, err := suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: collection.ID,
Name: collection.Name,
Metadata: collection.Metadata,
Expand Down Expand Up @@ -499,7 +516,7 @@ func (suite *APIsTestSuite) TestCreateDatabaseWithTenants() {
// Create a new collection in the new tenant
suite.sampleCollections[0].ID = types.NewUniqueID()
suite.sampleCollections[0].Name = suite.sampleCollections[0].Name + "1"
_, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
_, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: suite.sampleCollections[0].ID,
Name: suite.sampleCollections[0].Name,
Metadata: suite.sampleCollections[0].Metadata,
Expand All @@ -512,7 +529,7 @@ func (suite *APIsTestSuite) TestCreateDatabaseWithTenants() {
// Create a new collection in the default tenant
suite.sampleCollections[1].ID = types.NewUniqueID()
suite.sampleCollections[1].Name = suite.sampleCollections[1].Name + "2"
_, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
_, _, err = suite.coordinator.CreateCollection(ctx, &model.CreateCollection{
ID: suite.sampleCollections[1].ID,
Name: suite.sampleCollections[1].Name,
Metadata: suite.sampleCollections[1].Metadata,
Expand Down
3 changes: 2 additions & 1 deletion go/pkg/coordinator/grpc/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.Create
res.Status = failResponseWithError(err, successCode)
return res, nil
}
collection, err := s.coordinator.CreateCollection(ctx, createCollection)
collection, created, err := s.coordinator.CreateCollection(ctx, createCollection)
if err != nil {
log.Error("error creating collection", zap.Error(err))
res.Collection = &coordinatorpb.Collection{
Expand All @@ -94,6 +94,7 @@ func (s *Server) CreateCollection(ctx context.Context, req *coordinatorpb.Create
}
res.Collection = convertCollectionToProto(collection)
res.Status = setResponseStatus(successCode)
res.Created = created
return res, nil
}

Expand Down
2 changes: 1 addition & 1 deletion go/pkg/metastore/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
//go:generate mockery --name=Catalog
type Catalog interface {
ResetState(ctx context.Context) error
CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, error)
CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, bool, error)
GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error)
DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection) error
UpdateCollection(ctx context.Context, updateCollection *model.UpdateCollection, ts types.Timestamp) (*model.Collection, error)
Expand Down
9 changes: 6 additions & 3 deletions go/pkg/metastore/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ func (tc *Catalog) GetAllTenants(ctx context.Context, ts types.Timestamp) ([]*mo
return result, nil
}

func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, error) {
func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model.CreateCollection, ts types.Timestamp) (*model.Collection, bool, error) {
var result *model.Collection

created := false
err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error {
// insert collection
databaseName := createCollection.DatabaseName
Expand Down Expand Up @@ -254,6 +255,8 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model
} else {
return common.ErrCollectionUniqueConstraintViolation
}
} else {
created = true
}

dbCollection := &dbmodel.Collection{
Expand Down Expand Up @@ -301,10 +304,10 @@ func (tc *Catalog) CreateCollection(ctx context.Context, createCollection *model
})
if err != nil {
log.Error("error creating collection", zap.Error(err))
return nil, err
return nil, false, err
}
log.Info("collection created", zap.Any("collection", result))
return result, nil
return result, created, nil
}

func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.UniqueID, collectionName *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*model.Collection, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/metastore/coordinator/table_catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestCatalog_CreateCollection(t *testing.T) {
}).Return(nil)

// call the CreateCollection method
_, err := catalog.CreateCollection(context.Background(), collection, ts)
_, _, err := catalog.CreateCollection(context.Background(), collection, ts)

// assert that the method returned no error
assert.NoError(t, err)
Expand Down

0 comments on commit 2d9bec5

Please sign in to comment.