Skip to content

Commit 7c106c4

Browse files
authored
[BUG]: Allow repairing partially created attached functions (#5981)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - The attach_function codepath did not allow repairing a nonready function that was a result of a previously failed attach. This diff fixes that. - This diff also adds code to the FinishCreateAttachedFunction method to ensure the invariant that there can only be one function attached to a collection. - New functionality - ... Many changes and tests here were ported from a change @rescrv had made. ## Test plan _How are these changes tested?_ test_task_api.py has been edited for this change. - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the_ [_docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 54d8c8c commit 7c106c4

File tree

6 files changed

+202
-27
lines changed

6 files changed

+202
-27
lines changed

chromadb/test/distributed/test_task_api.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,127 @@ def test_delete_orphaned_output_collection(basic_http_client: System) -> None:
384384
with pytest.raises(NotFoundError):
385385
# Try to use the function - it should fail since it's detached
386386
client.get_collection("output_collection")
387+
388+
389+
def test_partial_attach_function_repair(
390+
basic_http_client: System,
391+
) -> None:
392+
"""Test creating and removing a function with the record_counter operator"""
393+
client = ClientCreator.from_system(basic_http_client)
394+
client.reset()
395+
396+
# Create a collection
397+
collection = client.get_or_create_collection(
398+
name="my_document",
399+
)
400+
401+
# Create a task that counts records in the collection
402+
attached_fn = collection.attach_function(
403+
name="count_my_docs",
404+
function=RECORD_COUNTER_FUNCTION,
405+
output_collection="my_documents_counts",
406+
params=None,
407+
)
408+
409+
# Verify task creation succeeded
410+
assert attached_fn is not None
411+
412+
collection2 = client.get_or_create_collection(
413+
name="my_document2",
414+
)
415+
416+
# Create a task that counts records in the collection
417+
# This should fail
418+
with pytest.raises(
419+
ChromaError, match=r"Output collection \[my_documents_counts\] already exists"
420+
):
421+
attached_fn = collection2.attach_function(
422+
name="count_my_docs",
423+
function=RECORD_COUNTER_FUNCTION,
424+
output_collection="my_documents_counts",
425+
params=None,
426+
)
427+
428+
# Detach the function
429+
assert (
430+
collection.detach_function(attached_fn.name, delete_output_collection=True)
431+
is True
432+
)
433+
434+
# Create a task that counts records in the collection
435+
attached_fn = collection2.attach_function(
436+
name="count_my_docs",
437+
function=RECORD_COUNTER_FUNCTION,
438+
output_collection="my_documents_counts",
439+
params=None,
440+
)
441+
assert attached_fn is not None
442+
443+
444+
def test_count_function_attach_and_detach_attach_attach(
445+
basic_http_client: System,
446+
) -> None:
447+
"""Test creating and removing a function with the record_counter operator"""
448+
client = ClientCreator.from_system(basic_http_client)
449+
client.reset()
450+
451+
# Create a collection
452+
collection = client.get_or_create_collection(
453+
name="my_document",
454+
metadata={"description": "Sample documents for task processing"},
455+
)
456+
457+
# Create a task that counts records in the collection
458+
attached_fn = collection.attach_function(
459+
name="count_my_docs",
460+
function=RECORD_COUNTER_FUNCTION,
461+
output_collection="my_documents_counts",
462+
params=None,
463+
)
464+
465+
# Verify task creation succeeded
466+
assert attached_fn is not None
467+
initial_version = get_collection_version(client, collection.name)
468+
469+
# Add documents
470+
collection.add(
471+
ids=["doc_{}".format(i) for i in range(0, 300)],
472+
documents=["test document"] * 300,
473+
)
474+
475+
# Verify documents were added
476+
assert collection.count() == 300
477+
478+
wait_for_version_increase(client, collection.name, initial_version)
479+
# Give some time to invalidate the frontend query cache
480+
sleep(60)
481+
482+
result = client.get_collection("my_documents_counts").get("function_output")
483+
assert result["metadatas"] is not None
484+
assert result["metadatas"][0]["total_count"] == 300
485+
486+
# Remove the task
487+
success = collection.detach_function(
488+
attached_fn.name, delete_output_collection=True
489+
)
490+
491+
# Verify task removal succeeded
492+
assert success is True
493+
494+
# Create a task that counts records in the collection
495+
attached_fn = collection.attach_function(
496+
name="count_my_docs",
497+
function=RECORD_COUNTER_FUNCTION,
498+
output_collection="my_documents_counts",
499+
params=None,
500+
)
501+
assert attached_fn is not None
502+
503+
# Create a task that counts records in the collection
504+
attached_fn = collection.attach_function(
505+
name="count_my_docs",
506+
function=RECORD_COUNTER_FUNCTION,
507+
output_collection="my_documents_counts",
508+
params=None,
509+
)
510+
assert attached_fn is not None

go/pkg/sysdb/coordinator/create_task_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (suite *AttachFunctionTestSuite) setupAttachFunctionMocks(ctx context.Conte
6969
// Phase 1: Create attached function in transaction
7070
// Check if any attached function exists for this collection
7171
suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once()
72-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
72+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
7373
Return([]*dbmodel.AttachedFunction{}, nil).Once()
7474

7575
suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once()
@@ -165,7 +165,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() {
165165
// Setup mocks that will be called within the transaction (using mock.Anything for context)
166166
// Check if any attached function exists for this collection
167167
suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once()
168-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
168+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
169169
Return([]*dbmodel.AttachedFunction{}, nil).Once()
170170

171171
// Look up database
@@ -283,7 +283,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea
283283

284284
// Inside transaction: check for existing attached functions
285285
suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once()
286-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
286+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
287287
Return([]*dbmodel.AttachedFunction{existingAttachedFunction}, nil).Once()
288288

289289
// Note: validateAttachedFunctionMatchesRequest uses dbmodel.GetFunctionNameByID (static lookup),
@@ -352,7 +352,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() {
352352

353353
// Phase 1: Create attached function in transaction
354354
suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once()
355-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
355+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
356356
Return([]*dbmodel.AttachedFunction{}, nil).Once()
357357

358358
suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once()
@@ -408,7 +408,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() {
408408

409409
// Inside transaction: check for existing attached functions
410410
suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once()
411-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
411+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
412412
Return([]*dbmodel.AttachedFunction{incompleteAttachedFunction}, nil).Once()
413413

414414
// Note: validateAttachedFunctionMatchesRequest uses dbmodel.GetFunctionNameByID (static lookup),
@@ -484,6 +484,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param
484484
OutputCollectionName: outputCollectionName,
485485
FunctionID: existingOperatorID,
486486
MinRecordsForInvocation: int64(MinRecordsForInvocation),
487+
IsReady: true,
487488
CreatedAt: now,
488489
UpdatedAt: now,
489490
}
@@ -497,7 +498,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param
497498

498499
// Inside transaction: check for existing attached functions
499500
suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once()
500-
suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID).
501+
suite.mockAttachedFunctionDb.On("GetAnyByCollectionID", inputCollectionID).
501502
Return([]*dbmodel.AttachedFunction{existingAttachedFunction}, nil).Once()
502503

503504
// Note: validateAttachedFunctionMatchesRequest uses dbmodel.GetFunctionNameByID (static lookup)

go/pkg/sysdb/coordinator/task.go

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,43 +99,36 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att
9999
err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error {
100100
// Check if there's any active (ready, non-deleted) attached function for this collection
101101
// We only allow one active attached function per collection
102-
existingAttachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetByCollectionID(req.InputCollectionId)
102+
existingAttachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetAnyByCollectionID(req.InputCollectionId)
103103
if err != nil {
104104
log.Error("AttachFunction: failed to check for existing attached function", zap.Error(err))
105105
return err
106106
}
107-
if len(existingAttachedFunctions) > 0 {
108-
if len(existingAttachedFunctions) > 1 {
109-
log.Error("AttachFunction: collection has multiple attached functions")
110-
return status.Errorf(codes.Internal, "data inconsistency: collection has %d attached functions, expected at most 1", len(existingAttachedFunctions))
111-
}
112-
113-
existingAttachedFunction := existingAttachedFunctions[0]
114-
115-
// Same name - validate it matches our request (idempotency)
116-
log.Info("AttachFunction: attached function exists with same name, validating parameters",
117-
zap.String("attached_function_id", existingAttachedFunction.ID.String()))
118107

119-
matches, err := s.validateAttachedFunctionMatchesRequest(txCtx, existingAttachedFunction, req)
108+
for _, attachedFunction := range existingAttachedFunctions {
109+
matches, err := s.validateAttachedFunctionMatchesRequest(txCtx, attachedFunction, req)
120110
if err != nil {
121111
return err
122112
}
123-
if !matches {
124-
functionName, err := dbmodel.GetFunctionNameByID(existingAttachedFunction.FunctionID)
113+
if matches {
114+
// If the attached function matches the request, use it
115+
attachedFunctionID = attachedFunction.ID
116+
return nil
117+
}
118+
119+
if attachedFunction.IsReady {
120+
log.Error("AttachFunction: collection already has an attached function", zap.String("name", attachedFunction.Name))
121+
functionName, err := dbmodel.GetFunctionNameByID(attachedFunction.FunctionID)
125122
if err != nil {
126123
log.Error("AttachFunction: unknown function ID", zap.Error(err))
127124
return err
128125
}
129126
return status.Errorf(codes.AlreadyExists,
130127
"collection already has an attached function: name=%s, function=%s, output_collection=%s",
131-
existingAttachedFunction.Name,
128+
attachedFunction.Name,
132129
functionName,
133-
existingAttachedFunction.OutputCollectionName)
130+
attachedFunction.OutputCollectionName)
134131
}
135-
136-
// Validation passed, reuse the existing attached function ID (idempotent)
137-
attachedFunctionID = existingAttachedFunction.ID
138-
return nil
139132
}
140133

141134
// Look up database_id
@@ -626,6 +619,17 @@ func (s *Coordinator) FinishCreateAttachedFunction(ctx context.Context, req *coo
626619
return err
627620
}
628621

622+
// 7. Validate that there is only one ready attached function for this collection
623+
existingAttachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetByCollectionID(attachedFunction.InputCollectionID)
624+
if err != nil {
625+
log.Error("FinishCreateAttachedFunction: failed to get attached functions", zap.Error(err))
626+
return err
627+
}
628+
if len(existingAttachedFunctions) > 1 {
629+
log.Error("FinishCreateAttachedFunction: multiple attached functions found for collection", zap.String("collection_id", attachedFunction.InputCollectionID))
630+
return common.ErrAttachedFunctionAlreadyExists
631+
}
632+
629633
log.Info("FinishCreateAttachedFunction: successfully created output collection and set is_ready=true",
630634
zap.String("attached_function_id", attachedFunctionID.String()),
631635
zap.String("output_collection_id", collectionID.String()))

go/pkg/sysdb/metastore/db/dao/task.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,21 @@ func (s *attachedFunctionDb) GetByCollectionID(inputCollectionID string) ([]*dbm
153153
return attachedFunctions, nil
154154
}
155155

156+
func (s *attachedFunctionDb) GetAnyByCollectionID(inputCollectionID string) ([]*dbmodel.AttachedFunction, error) {
157+
var attachedFunctions []*dbmodel.AttachedFunction
158+
err := s.db.
159+
Where("input_collection_id = ?", inputCollectionID).
160+
Where("is_deleted = ?", false).
161+
Find(&attachedFunctions).Error
162+
163+
if err != nil {
164+
log.Error("GetAnyByCollectionID failed", zap.Error(err), zap.String("input_collection_id", inputCollectionID))
165+
return nil, err
166+
}
167+
168+
return attachedFunctions, nil
169+
}
170+
156171
func (s *attachedFunctionDb) SoftDelete(inputCollectionID string, name string) error {
157172
// Update name and is_deleted in a single query
158173
// Format: _deleted_<original_name>_<id>

go/pkg/sysdb/metastore/db/dbmodel/mocks/IAttachedFunctionDb.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/pkg/sysdb/metastore/db/dbmodel/task.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type IAttachedFunctionDb interface {
4242
GetByID(id uuid.UUID) (*AttachedFunction, error)
4343
GetAnyByID(id uuid.UUID) (*AttachedFunction, error) // TODO(tanujnay112): Consolidate all the getters.
4444
GetByCollectionID(inputCollectionID string) ([]*AttachedFunction, error)
45+
GetAnyByCollectionID(inputCollectionID string) ([]*AttachedFunction, error)
4546
Update(attachedFunction *AttachedFunction) error
4647
Finish(id uuid.UUID) error
4748
SoftDelete(inputCollectionID string, name string) error

0 commit comments

Comments
 (0)