Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gen/templates/factory/bobfactory_context.bob.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ var (
{{- end}}

{{end}}

modelsInCreationCtx = newContextual[map[string]any]("modelsInCreation")
)

// Contextual is a convienience wrapper around context.WithValue and context.Value
Expand All @@ -31,4 +33,3 @@ func (k contextual[V]) Value(ctx context.Context) (V, bool) {
v, ok := ctx.Value(k.key).(V)
return v, ok
}

58 changes: 58 additions & 0 deletions gen/templates/factory/bobfactory_main.bob_test.go.tpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
{{- $needsModels := false -}}
{{- range $table := .Tables -}}
{{- range $rel := $.Relationships.Get $table.Key -}}
{{- $bridgeRels := $.Tables.NeededBridgeRels $rel -}}
{{- if and $rel.IsToMany (ne $rel.Foreign $table.Key) (not $bridgeRels) (eq (len $rel.Sides) 1) -}}
{{- $needsModels = true -}}
{{- end -}}
{{- end -}}
{{- end -}}
{{- $.Importer.Import "context" -}}
{{- $.Importer.Import "testing" -}}
{{- if $needsModels -}}
{{- $.Importer.Import "models" (index $.OutputPackages "models") -}}
{{- end -}}

{{range $table := .Tables}}{{if not $table.Constraints.Primary}}{{continue}}{{end}}
{{ $tAlias := $.Aliases.Table $table.Key -}}
Expand Down Expand Up @@ -27,4 +39,50 @@ func TestCreate{{$tAlias.UpSingular}}(t *testing.T) {
}
}

{{range $rel := $.Relationships.Get $table.Key -}}
{{- if not .IsToMany -}}{{continue}}{{end -}}
{{- if eq .Foreign $table.Key -}}{{continue}}{{end -}}
{{- if $.Tables.NeededBridgeRels . -}}{{continue}}{{end -}}
{{- if gt (len .Sides) 1 -}}{{continue}}{{end -}}
{{- $relAlias := $tAlias.Relationship .Name -}}
func TestCreate{{$tAlias.UpSingular}}With{{$relAlias}}DoesNotDuplicateParent(t *testing.T) {
if testDB == nil {
t.Skip("skipping test, no DSN provided")
}

ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)

tx, err := testDB.Begin(ctx)
if err != nil {
t.Fatalf("Error starting transaction: %v", err)
}

defer func() {
if err := tx.Rollback(ctx); err != nil {
t.Fatalf("Error rolling back transaction: %v", err)
}
}()

before, err := models.{{$tAlias.UpPlural}}.Query().Count(ctx, tx)
if err != nil {
t.Fatalf("Error counting {{$tAlias.UpPlural}}: %v", err)
}

if _, err := New().New{{$tAlias.UpSingular}}WithContext(ctx, {{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}(2)).Create(ctx, tx); err != nil {
t.Fatalf("Error creating {{$tAlias.UpSingular}} with {{$relAlias}}: %v", err)
}

after, err := models.{{$tAlias.UpPlural}}.Query().Count(ctx, tx)
if err != nil {
t.Fatalf("Error counting {{$tAlias.UpPlural}}: %v", err)
}

if got := after - before; got != 1 {
t.Fatalf("Expected {{$tAlias.UpPlural}} to increase by 1, got %d", got)
}
}

{{end}}

{{end}}
50 changes: 41 additions & 9 deletions gen/templates/factory/table/003_create.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,41 @@ func (o *{{$tAlias.UpSingular}}Template) Create(ctx context.Context, exec bob.Ex
opt := o.BuildSetter()
ensureCreatable{{$tAlias.UpSingular}}(opt)

// Retrieve ancestor models from context to avoid duplicate parent creation.
// Parents are keyed by "parent_table:child_table:child_rel_name".
// This works regardless of NoBackReferencing since it only uses child-side metadata.
mInCreation, _ := modelsInCreationCtx.Value(ctx)

{{range $index, $rel := $.Relationships.Get $table.Key -}}
{{- if not ($table.RelIsRequired $rel)}}{{continue}}{{end -}}
{{- $ftable := $.Aliases.Table .Foreign -}}
{{- $relAlias := $tAlias.Relationship .Name -}}

var rel{{$index}} *models.{{$ftable.UpSingular}}

if o.r.{{$relAlias}} == nil {
{{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}().Apply(ctx, o)
if parentModel, found := mInCreation["{{$rel.Foreign}}:{{$table.Key}}:{{$rel.Name}}"]; found {
if pModel, ok := parentModel.(*models.{{$ftable.UpSingular}}); ok {
rel{{$index}} = pModel
}
}
}

var rel{{$index}} *models.{{$ftable.UpSingular}}
if rel{{$index}} == nil {
if o.r.{{$relAlias}} == nil {
{{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}().Apply(ctx, o)
}

if o.r.{{$relAlias}}.o.alreadyPersisted {
rel{{$index}} = o.r.{{$relAlias}}.o.Build()
} else {
rel{{$index}}, err = o.r.{{$relAlias}}.o.Create(ctx, exec)
if err != nil {
return nil, err
if o.r.{{$relAlias}}.o.alreadyPersisted {
rel{{$index}} = o.r.{{$relAlias}}.o.Build()
} else {
rel{{$index}}, err = o.r.{{$relAlias}}.o.Create(ctx, exec)
if err != nil {
return nil, err
}
}
}


{{range $rel.ValuedSides -}}
{{- if ne .TableName $table.Key}}{{continue}}{{end -}}
{{range .Mapped}}
Expand All @@ -146,6 +161,23 @@ func (o *{{$tAlias.UpSingular}}Template) Create(ctx context.Context, exec bob.Ex
return nil, err
}

// Store this model in context for child creates.
// Key format: "parent_table:child_table:child_rel_name" where child_rel_name is the FK name.
// We store an entry for every relationship pointing TO this table (where this table is the parent).
// This works regardless of NoBackReferencing since keys use child-side metadata.
newMInCreation := make(map[string]any, len(mInCreation)+1)
for k, v := range mInCreation {
newMInCreation[k] = v
}
{{range $childTable := $.Tables -}}
{{- range $childRel := $.Relationships.Get $childTable.Key -}}
{{- if $childRel.IsToMany -}}{{continue}}{{end -}}
{{- if ne $childRel.Foreign $table.Key -}}{{continue}}{{end -}}
newMInCreation["{{$table.Key}}:{{$childTable.Key}}:{{$childRel.Name}}"] = m
{{end -}}
{{end}}
ctx = modelsInCreationCtx.WithValue(ctx, newMInCreation)

{{range $index, $rel := $.Relationships.Get $table.Key -}}
{{- if not ($table.RelIsRequired $rel) -}}{{continue}}{{end -}}
{{- $ftable := $.Aliases.Table .Foreign -}}
Expand Down
73 changes: 44 additions & 29 deletions test/gen/templates/factory/mysql/bobfactory_runtime.bob_test.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,10 @@ func TestPreloadCountUserVideos(t *testing.T) {
defer tx.Rollback(ctx)

// Create users with different numbers of videos
expectedCounts := []int{2, 3, 1}
for _, count := range expectedCounts {
expectedCounts := map[int64]int{}
for _, count := range []int{2, 3, 1} {
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
expectedCounts[int64(user.ID)] = count
for j := 0; j < count; j++ {
New().NewVideoWithContext(ctx,
VideoMods.WithExistingUser(user),
Expand All @@ -508,20 +509,24 @@ func TestPreloadCountUserVideos(t *testing.T) {
t.Fatalf("Error querying users with PreloadCount: %v", err)
}

if len(users) != 3 {
t.Fatalf("Expected 3 users, got %d", len(users))
}

// Verify counts are loaded
found := 0
for i, user := range users {
if user.C.Videos == nil {
t.Fatalf("Expected Videos count to be set for user %d, got nil", i)
}
expectedCount := int64(expectedCounts[i])
if *user.C.Videos != expectedCount {
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, expectedCount, *user.C.Videos)
expectedCount, ok := expectedCounts[int64(user.ID)]
if !ok {
continue
}
found++
count := int64(expectedCount)
if *user.C.Videos != count {
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, count, *user.C.Videos)
}
}
if found != len(expectedCounts) {
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
}
}

// TestThenLoadCountUserVideos tests using ThenLoadCount to load counts in a separate query
Expand All @@ -538,9 +543,10 @@ func TestThenLoadCountUserVideos(t *testing.T) {
defer tx.Rollback(ctx)

// Create users with different numbers of videos
expectedCounts := []int{2, 3, 1}
for _, count := range expectedCounts {
expectedCounts := map[int64]int{}
for _, count := range []int{2, 3, 1} {
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
expectedCounts[int64(user.ID)] = count
for j := 0; j < count; j++ {
New().NewVideoWithContext(ctx,
VideoMods.WithExistingUser(user),
Expand All @@ -556,20 +562,24 @@ func TestThenLoadCountUserVideos(t *testing.T) {
t.Fatalf("Error querying users with ThenLoadCount: %v", err)
}

if len(users) != 3 {
t.Fatalf("Expected 3 users, got %d", len(users))
}

// Verify counts are loaded
found := 0
for i, user := range users {
if user.C.Videos == nil {
t.Fatalf("Expected Videos count to be set for user %d, got nil", i)
}
expectedCount := int64(expectedCounts[i])
if *user.C.Videos != expectedCount {
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, expectedCount, *user.C.Videos)
expectedCount, ok := expectedCounts[int64(user.ID)]
if !ok {
continue
}
found++
count := int64(expectedCount)
if *user.C.Videos != count {
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, count, *user.C.Videos)
}
}
if found != len(expectedCounts) {
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
}
}

// TestPreloadCountWithFilter tests PreloadCount with filtering mods
Expand Down Expand Up @@ -1459,9 +1469,10 @@ func TestThenLoadToMany(t *testing.T) {

// Create users with videos
var users models.UserSlice
expectedCounts := []int{2, 3, 1}
for _, count := range expectedCounts {
expectedCounts := map[int64]int{}
for _, count := range []int{2, 3, 1} {
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
expectedCounts[int64(user.ID)] = count
for j := 0; j < count; j++ {
New().NewVideoWithContext(ctx,
VideoMods.WithExistingUser(user),
Expand All @@ -1478,18 +1489,22 @@ func TestThenLoadToMany(t *testing.T) {
t.Fatalf("Error querying users with ThenLoad: %v", err)
}

if len(users) != 3 {
t.Fatalf("Expected 3 users, got %d", len(users))
}

// Verify Videos are preloaded for each user
found := 0
for i, user := range users {
if user.R.Videos == nil {
t.Fatalf("Expected Videos to be preloaded for user %d, got nil", i)
}
if len(user.R.Videos) != expectedCounts[i] {
t.Fatalf("Expected user %d to have %d videos, got %d", i, expectedCounts[i], len(user.R.Videos))
expectedCount, ok := expectedCounts[int64(user.ID)]
if !ok {
continue
}
found++
if len(user.R.Videos) != expectedCount {
t.Fatalf("Expected user %d to have %d videos, got %d", i, expectedCount, len(user.R.Videos))
}
}
if found != len(expectedCounts) {
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
}
}
{{- end }}
Expand Down
Loading