Skip to content

Commit 23f4843

Browse files
committed
fix: Reuse parents when creating children
Closes: #510
1 parent e9f8098 commit 23f4843

File tree

6 files changed

+233
-97
lines changed

6 files changed

+233
-97
lines changed

gen/templates/factory/bobfactory_context.bob.go.tpl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ var (
1212
{{- end}}
1313

1414
{{end}}
15+
16+
modelsInCreationCtx = newContextual[map[string]any]("modelsInCreation")
1517
)
1618

1719
// Contextual is a convienience wrapper around context.WithValue and context.Value
@@ -31,4 +33,3 @@ func (k contextual[V]) Value(ctx context.Context) (V, bool) {
3133
v, ok := ctx.Value(k.key).(V)
3234
return v, ok
3335
}
34-

gen/templates/factory/bobfactory_main.bob_test.go.tpl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1+
{{- $needsModels := false -}}
2+
{{- range $table := .Tables -}}
3+
{{- range $rel := $.Relationships.Get $table.Key -}}
4+
{{- $bridgeRels := $.Tables.NeededBridgeRels $rel -}}
5+
{{- if and $rel.IsToMany (ne $rel.Foreign $table.Key) (not $bridgeRels) (eq (len $rel.Sides) 1) -}}
6+
{{- $needsModels = true -}}
7+
{{- end -}}
8+
{{- end -}}
9+
{{- end -}}
110
{{- $.Importer.Import "context" -}}
211
{{- $.Importer.Import "testing" -}}
12+
{{- if $needsModels -}}
13+
{{- $.Importer.Import "models" (index $.OutputPackages "models") -}}
14+
{{- end -}}
315

416
{{range $table := .Tables}}{{if not $table.Constraints.Primary}}{{continue}}{{end}}
517
{{ $tAlias := $.Aliases.Table $table.Key -}}
@@ -27,4 +39,50 @@ func TestCreate{{$tAlias.UpSingular}}(t *testing.T) {
2739
}
2840
}
2941

42+
{{range $rel := $.Relationships.Get $table.Key -}}
43+
{{- if not .IsToMany -}}{{continue}}{{end -}}
44+
{{- if eq .Foreign $table.Key -}}{{continue}}{{end -}}
45+
{{- if $.Tables.NeededBridgeRels . -}}{{continue}}{{end -}}
46+
{{- if gt (len .Sides) 1 -}}{{continue}}{{end -}}
47+
{{- $relAlias := $tAlias.Relationship .Name -}}
48+
func TestCreate{{$tAlias.UpSingular}}With{{$relAlias}}DoesNotDuplicateParent(t *testing.T) {
49+
if testDB == nil {
50+
t.Skip("skipping test, no DSN provided")
51+
}
52+
53+
ctx, cancel := context.WithCancel(t.Context())
54+
t.Cleanup(cancel)
55+
56+
tx, err := testDB.Begin(ctx)
57+
if err != nil {
58+
t.Fatalf("Error starting transaction: %v", err)
59+
}
60+
61+
defer func() {
62+
if err := tx.Rollback(ctx); err != nil {
63+
t.Fatalf("Error rolling back transaction: %v", err)
64+
}
65+
}()
66+
67+
before, err := models.{{$tAlias.UpPlural}}.Query().Count(ctx, tx)
68+
if err != nil {
69+
t.Fatalf("Error counting {{$tAlias.UpPlural}}: %v", err)
70+
}
71+
72+
if _, err := New().New{{$tAlias.UpSingular}}WithContext(ctx, {{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}(2)).Create(ctx, tx); err != nil {
73+
t.Fatalf("Error creating {{$tAlias.UpSingular}} with {{$relAlias}}: %v", err)
74+
}
75+
76+
after, err := models.{{$tAlias.UpPlural}}.Query().Count(ctx, tx)
77+
if err != nil {
78+
t.Fatalf("Error counting {{$tAlias.UpPlural}}: %v", err)
79+
}
80+
81+
if got := after - before; got != 1 {
82+
t.Fatalf("Expected {{$tAlias.UpPlural}} to increase by 1, got %d", got)
83+
}
84+
}
85+
86+
{{end}}
87+
3088
{{end}}

gen/templates/factory/table/003_create.go.tpl

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,26 +110,41 @@ func (o *{{$tAlias.UpSingular}}Template) Create(ctx context.Context, exec bob.Ex
110110
opt := o.BuildSetter()
111111
ensureCreatable{{$tAlias.UpSingular}}(opt)
112112

113+
// Retrieve ancestor models from context to avoid duplicate parent creation.
114+
// Parents are keyed by "parent_table:child_table:child_rel_name".
115+
// This works regardless of NoBackReferencing since it only uses child-side metadata.
116+
mInCreation, _ := modelsInCreationCtx.Value(ctx)
117+
113118
{{range $index, $rel := $.Relationships.Get $table.Key -}}
114119
{{- if not ($table.RelIsRequired $rel)}}{{continue}}{{end -}}
115120
{{- $ftable := $.Aliases.Table .Foreign -}}
116121
{{- $relAlias := $tAlias.Relationship .Name -}}
122+
123+
var rel{{$index}} *models.{{$ftable.UpSingular}}
124+
117125
if o.r.{{$relAlias}} == nil {
118-
{{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}().Apply(ctx, o)
126+
if parentModel, found := mInCreation["{{$rel.Foreign}}:{{$table.Key}}:{{$rel.Name}}"]; found {
127+
if pModel, ok := parentModel.(*models.{{$ftable.UpSingular}}); ok {
128+
rel{{$index}} = pModel
129+
}
130+
}
119131
}
120132

121-
var rel{{$index}} *models.{{$ftable.UpSingular}}
133+
if rel{{$index}} == nil {
134+
if o.r.{{$relAlias}} == nil {
135+
{{$tAlias.UpSingular}}Mods.WithNew{{$relAlias}}().Apply(ctx, o)
136+
}
122137

123-
if o.r.{{$relAlias}}.o.alreadyPersisted {
124-
rel{{$index}} = o.r.{{$relAlias}}.o.Build()
125-
} else {
126-
rel{{$index}}, err = o.r.{{$relAlias}}.o.Create(ctx, exec)
127-
if err != nil {
128-
return nil, err
138+
if o.r.{{$relAlias}}.o.alreadyPersisted {
139+
rel{{$index}} = o.r.{{$relAlias}}.o.Build()
140+
} else {
141+
rel{{$index}}, err = o.r.{{$relAlias}}.o.Create(ctx, exec)
142+
if err != nil {
143+
return nil, err
144+
}
129145
}
130146
}
131147

132-
133148
{{range $rel.ValuedSides -}}
134149
{{- if ne .TableName $table.Key}}{{continue}}{{end -}}
135150
{{range .Mapped}}
@@ -146,6 +161,23 @@ func (o *{{$tAlias.UpSingular}}Template) Create(ctx context.Context, exec bob.Ex
146161
return nil, err
147162
}
148163

164+
// Store this model in context for child creates.
165+
// Key format: "parent_table:child_table:child_rel_name" where child_rel_name is the FK name.
166+
// We store an entry for every relationship pointing TO this table (where this table is the parent).
167+
// This works regardless of NoBackReferencing since keys use child-side metadata.
168+
newMInCreation := make(map[string]any, len(mInCreation)+1)
169+
for k, v := range mInCreation {
170+
newMInCreation[k] = v
171+
}
172+
{{range $childTable := $.Tables -}}
173+
{{- range $childRel := $.Relationships.Get $childTable.Key -}}
174+
{{- if $childRel.IsToMany -}}{{continue}}{{end -}}
175+
{{- if ne $childRel.Foreign $table.Key -}}{{continue}}{{end -}}
176+
newMInCreation["{{$table.Key}}:{{$childTable.Key}}:{{$childRel.Name}}"] = m
177+
{{end -}}
178+
{{end}}
179+
ctx = modelsInCreationCtx.WithValue(ctx, newMInCreation)
180+
149181
{{range $index, $rel := $.Relationships.Get $table.Key -}}
150182
{{- if not ($table.RelIsRequired $rel) -}}{{continue}}{{end -}}
151183
{{- $ftable := $.Aliases.Table .Foreign -}}

test/gen/templates/factory/mysql/bobfactory_runtime.bob_test.go.tpl

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,10 @@ func TestPreloadCountUserVideos(t *testing.T) {
490490
defer tx.Rollback(ctx)
491491

492492
// Create users with different numbers of videos
493-
expectedCounts := []int{2, 3, 1}
494-
for _, count := range expectedCounts {
493+
expectedCounts := map[int64]int{}
494+
for _, count := range []int{2, 3, 1} {
495495
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
496+
expectedCounts[int64(user.ID)] = count
496497
for j := 0; j < count; j++ {
497498
New().NewVideoWithContext(ctx,
498499
VideoMods.WithExistingUser(user),
@@ -508,20 +509,24 @@ func TestPreloadCountUserVideos(t *testing.T) {
508509
t.Fatalf("Error querying users with PreloadCount: %v", err)
509510
}
510511

511-
if len(users) != 3 {
512-
t.Fatalf("Expected 3 users, got %d", len(users))
513-
}
514-
515-
// Verify counts are loaded
512+
found := 0
516513
for i, user := range users {
517514
if user.C.Videos == nil {
518515
t.Fatalf("Expected Videos count to be set for user %d, got nil", i)
519516
}
520-
expectedCount := int64(expectedCounts[i])
521-
if *user.C.Videos != expectedCount {
522-
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, expectedCount, *user.C.Videos)
517+
expectedCount, ok := expectedCounts[int64(user.ID)]
518+
if !ok {
519+
continue
520+
}
521+
found++
522+
count := int64(expectedCount)
523+
if *user.C.Videos != count {
524+
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, count, *user.C.Videos)
523525
}
524526
}
527+
if found != len(expectedCounts) {
528+
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
529+
}
525530
}
526531

527532
// TestThenLoadCountUserVideos tests using ThenLoadCount to load counts in a separate query
@@ -538,9 +543,10 @@ func TestThenLoadCountUserVideos(t *testing.T) {
538543
defer tx.Rollback(ctx)
539544

540545
// Create users with different numbers of videos
541-
expectedCounts := []int{2, 3, 1}
542-
for _, count := range expectedCounts {
546+
expectedCounts := map[int64]int{}
547+
for _, count := range []int{2, 3, 1} {
543548
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
549+
expectedCounts[int64(user.ID)] = count
544550
for j := 0; j < count; j++ {
545551
New().NewVideoWithContext(ctx,
546552
VideoMods.WithExistingUser(user),
@@ -556,20 +562,24 @@ func TestThenLoadCountUserVideos(t *testing.T) {
556562
t.Fatalf("Error querying users with ThenLoadCount: %v", err)
557563
}
558564

559-
if len(users) != 3 {
560-
t.Fatalf("Expected 3 users, got %d", len(users))
561-
}
562-
563-
// Verify counts are loaded
565+
found := 0
564566
for i, user := range users {
565567
if user.C.Videos == nil {
566568
t.Fatalf("Expected Videos count to be set for user %d, got nil", i)
567569
}
568-
expectedCount := int64(expectedCounts[i])
569-
if *user.C.Videos != expectedCount {
570-
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, expectedCount, *user.C.Videos)
570+
expectedCount, ok := expectedCounts[int64(user.ID)]
571+
if !ok {
572+
continue
573+
}
574+
found++
575+
count := int64(expectedCount)
576+
if *user.C.Videos != count {
577+
t.Fatalf("Expected Videos count for user %d to be %d, got %d", i, count, *user.C.Videos)
571578
}
572579
}
580+
if found != len(expectedCounts) {
581+
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
582+
}
573583
}
574584

575585
// TestPreloadCountWithFilter tests PreloadCount with filtering mods
@@ -1459,9 +1469,10 @@ func TestThenLoadToMany(t *testing.T) {
14591469

14601470
// Create users with videos
14611471
var users models.UserSlice
1462-
expectedCounts := []int{2, 3, 1}
1463-
for _, count := range expectedCounts {
1472+
expectedCounts := map[int64]int{}
1473+
for _, count := range []int{2, 3, 1} {
14641474
user := New().NewUserWithContext(ctx).CreateOrFail(ctx, t, tx)
1475+
expectedCounts[int64(user.ID)] = count
14651476
for j := 0; j < count; j++ {
14661477
New().NewVideoWithContext(ctx,
14671478
VideoMods.WithExistingUser(user),
@@ -1478,18 +1489,22 @@ func TestThenLoadToMany(t *testing.T) {
14781489
t.Fatalf("Error querying users with ThenLoad: %v", err)
14791490
}
14801491

1481-
if len(users) != 3 {
1482-
t.Fatalf("Expected 3 users, got %d", len(users))
1483-
}
1484-
1485-
// Verify Videos are preloaded for each user
1492+
found := 0
14861493
for i, user := range users {
14871494
if user.R.Videos == nil {
14881495
t.Fatalf("Expected Videos to be preloaded for user %d, got nil", i)
14891496
}
1490-
if len(user.R.Videos) != expectedCounts[i] {
1491-
t.Fatalf("Expected user %d to have %d videos, got %d", i, expectedCounts[i], len(user.R.Videos))
1497+
expectedCount, ok := expectedCounts[int64(user.ID)]
1498+
if !ok {
1499+
continue
14921500
}
1501+
found++
1502+
if len(user.R.Videos) != expectedCount {
1503+
t.Fatalf("Expected user %d to have %d videos, got %d", i, expectedCount, len(user.R.Videos))
1504+
}
1505+
}
1506+
if found != len(expectedCounts) {
1507+
t.Fatalf("Expected %d users, got %d", len(expectedCounts), found)
14931508
}
14941509
}
14951510
{{- end }}

0 commit comments

Comments
 (0)