Skip to content

Commit 2a7ee65

Browse files
committed
Fix Issue #116
1 parent 7e0365a commit 2a7ee65

File tree

5 files changed

+100
-39
lines changed

5 files changed

+100
-39
lines changed

oracle/common.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,6 @@ func getOracleArrayType(values []any) string {
9797
return arrayType
9898
}
9999

100-
// Helper function to get all column names for a table
101-
func getAllTableColumns(schema *schema.Schema) []string {
102-
var columns []string
103-
for _, field := range schema.Fields {
104-
if field.DBName != "" {
105-
columns = append(columns, field.DBName)
106-
}
107-
}
108-
return columns
109-
}
110-
111100
// Helper to check if a variable is an OUT parameter
112101
func isOutParam(v interface{}) bool {
113102
_, ok := v.(sql.Out)
@@ -713,3 +702,32 @@ func isZeroFor(t reflect.Type, v interface{}) bool {
713702
}
714703
return false
715704
}
705+
706+
// generic helper that filters fields based on a predicate
707+
func filterFields(s *schema.Schema, predicate func(f *schema.Field) bool) []string {
708+
var fields []string
709+
for _, f := range s.Fields {
710+
if predicate(f) {
711+
fields = append(fields, f.DBName)
712+
}
713+
}
714+
return fields
715+
}
716+
717+
func getCreatableFields(s *schema.Schema) []string {
718+
return filterFields(s, func(f *schema.Field) bool {
719+
return f.Creatable
720+
})
721+
}
722+
723+
func getUpdatableFields(s *schema.Schema) []string {
724+
return filterFields(s, func(f *schema.Field) bool {
725+
return f.Updatable
726+
})
727+
}
728+
729+
func getMergableFields(s *schema.Schema) []string {
730+
return filterFields(s, func(f *schema.Field) bool {
731+
return f.Creatable && f.Updatable
732+
})
733+
}

oracle/create.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
311311
sanitizeCreateValuesForBulkArrays(db.Statement, &createValues)
312312

313313
stmt := db.Statement
314-
schema := stmt.Schema
314+
sch := stmt.Schema
315315

316316
onConflict, ok := onConflictClause.Expression.(clause.OnConflict)
317317
if !ok {
@@ -322,10 +322,10 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
322322
// Determine conflict columns (use primary key if not specified)
323323
conflictColumns := onConflict.Columns
324324
if len(conflictColumns) == 0 {
325-
if schema == nil || len(schema.PrimaryFields) == 0 {
325+
if sch == nil || len(sch.PrimaryFields) == 0 {
326326
return
327327
}
328-
for _, primaryField := range schema.PrimaryFields {
328+
for _, primaryField := range sch.PrimaryFields {
329329
conflictColumns = append(conflictColumns, clause.Column{Name: primaryField.DBName})
330330
}
331331
}
@@ -340,7 +340,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
340340
var filteredConflictColumns []clause.Column
341341
for _, conflictCol := range conflictColumns {
342342
field := stmt.Schema.LookUpField(conflictCol.Name)
343-
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && fieldCanConflict(field, schema) {
343+
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && fieldCanConflict(field, sch) {
344344
filteredConflictColumns = append(filteredConflictColumns, conflictCol)
345345
}
346346
}
@@ -358,7 +358,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
358358

359359
// Start PL/SQL block
360360
plsqlBuilder.WriteString("DECLARE\n")
361-
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
361+
writeTableRecordCollectionDecl(db, &plsqlBuilder, getCreatableFields(stmt.Schema), stmt.Table)
362362
plsqlBuilder.WriteString(" l_affected_records t_records;\n")
363363

364364
// Create array types and variables for each column
@@ -457,9 +457,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
457457
}
458458

459459
isAutoIncrement := false
460-
if schema.PrioritizedPrimaryField != nil &&
461-
schema.PrioritizedPrimaryField.AutoIncrement &&
462-
strings.EqualFold(schema.PrioritizedPrimaryField.DBName, column.Name) {
460+
if sch.PrioritizedPrimaryField != nil &&
461+
sch.PrioritizedPrimaryField.AutoIncrement &&
462+
strings.EqualFold(sch.PrioritizedPrimaryField.DBName, column.Name) {
463463
isAutoIncrement = true
464464
} else if stmt.Schema.LookUpField(column.Name).AutoIncrement {
465465
isAutoIncrement = true
@@ -563,7 +563,8 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
563563

564564
// Add RETURNING clause with BULK COLLECT INTO
565565
plsqlBuilder.WriteString(" RETURNING ")
566-
allColumns := getAllTableColumns(schema)
566+
allColumns := getMergableFields(sch)
567+
567568
for i, column := range allColumns {
568569
if i > 0 {
569570
plsqlBuilder.WriteString(", ")
@@ -576,7 +577,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
576577
outParamIndex := len(stmt.Vars)
577578
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
578579
for _, column := range allColumns {
579-
if field := findFieldByDBName(schema, column); field != nil {
580+
if field := findFieldByDBName(sch, column); field != nil {
580581
if isJSONField(field) {
581582
if isRawMessageField(field) {
582583
// Column is a BLOB, return raw bytes; no JSON_SERIALIZE
@@ -638,13 +639,13 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
638639
// Build PL/SQL block for bulk INSERT only (no conflict handling)
639640
func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap plsqlBindVariableMap) {
640641
stmt := db.Statement
641-
schema := stmt.Schema
642+
sch := stmt.Schema
642643

643644
var plsqlBuilder strings.Builder
644645

645646
// Start PL/SQL block
646647
plsqlBuilder.WriteString("DECLARE\n")
647-
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
648+
writeTableRecordCollectionDecl(db, &plsqlBuilder, getCreatableFields(stmt.Schema), stmt.Table)
648649
plsqlBuilder.WriteString(" l_inserted_records t_records;\n")
649650

650651
// Create array types and variables for each column
@@ -694,7 +695,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap p
694695

695696
// Add RETURNING clause with BULK COLLECT INTO
696697
plsqlBuilder.WriteString(" RETURNING ")
697-
allColumns := getAllTableColumns(schema)
698+
allColumns := getCreatableFields(sch)
698699
for i, column := range allColumns {
699700
if i > 0 {
700701
plsqlBuilder.WriteString(", ")
@@ -711,7 +712,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values, bindMap p
711712
db.QuoteTo(&columnBuilder, column)
712713
quotedColumn := columnBuilder.String()
713714

714-
if field := findFieldByDBName(schema, column); field != nil {
715+
if field := findFieldByDBName(sch, column); field != nil {
715716
if isJSONField(field) {
716717
if isRawMessageField(field) {
717718
// Column is a BLOB, return raw bytes; no JSON_SERIALIZE
@@ -959,7 +960,7 @@ func getBulkReturningValues(db *gorm.DB, rowCount int) {
959960
}
960961

961962
// Get all table columns
962-
allColumns := getAllTableColumns(db.Statement.Schema)
963+
allColumns := db.Statement.Schema.DBNames
963964

964965
// Find the actual starting index of OUT parameters
965966
actualStartIndex := -1

oracle/delete.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,15 @@ func buildStandardDeleteSQL(db *gorm.DB) {
236236
// Build PL/SQL block for bulk DELETE with RETURNING
237237
func buildBulkDeletePLSQL(db *gorm.DB) {
238238
stmt := db.Statement
239-
schema := stmt.Schema
239+
sch := stmt.Schema
240240

241-
if schema == nil {
241+
if sch == nil {
242242
db.AddError(fmt.Errorf("schema required for bulk delete with returning"))
243243
return
244244
}
245245

246246
// Check if this is a soft delete model and we're not using Unscoped
247-
if deletedAtField := schema.LookUpField("deleted_at"); deletedAtField != nil && !stmt.Unscoped {
247+
if deletedAtField := sch.LookUpField("deleted_at"); deletedAtField != nil && !stmt.Unscoped {
248248
// For soft delete with RETURNING, let GORM handle it normally
249249
stmt.Build(stmt.BuildClauses...)
250250
return
@@ -273,7 +273,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
273273

274274
// Add RETURNING clause
275275
plsqlBuilder.WriteString("\n RETURNING ")
276-
allColumns := getAllTableColumns(schema)
276+
allColumns := sch.DBNames
277277
for i, column := range allColumns {
278278
if i > 0 {
279279
plsqlBuilder.WriteString(", ")
@@ -290,7 +290,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
290290

291291
for rowIdx := 0; rowIdx < estimatedRows; rowIdx++ {
292292
for _, column := range allColumns {
293-
if field := findFieldByDBName(schema, column); field != nil {
293+
if field := findFieldByDBName(sch, column); field != nil {
294294
if isJSONField(field) {
295295
if isRawMessageField(field) {
296296
// Column is a BLOB, return raw bytes; no JSON_SERIALIZE
@@ -527,7 +527,7 @@ func getDeleteBulkReturningValues(db *gorm.DB) {
527527
return
528528
}
529529

530-
allColumns := getAllTableColumns(db.Statement.Schema)
530+
allColumns := db.Statement.Schema.DBNames
531531

532532
// Count OUT parameters and calculate max rows
533533
outParamCount := 0

oracle/update.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,9 @@ func addPrimaryKeyWhereClauseForUpdate(stmt *gorm.Statement) {
451451
// Build PL/SQL block for UPDATE with RETURNING
452452
func buildUpdatePLSQL(db *gorm.DB) {
453453
stmt := db.Statement
454-
schema := stmt.Schema
454+
sch := stmt.Schema
455455

456-
if schema == nil {
456+
if sch == nil {
457457
db.AddError(fmt.Errorf("schema required for update with returning"))
458458
return
459459
}
@@ -477,7 +477,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
477477

478478
// Start PL/SQL block
479479
plsqlBuilder.WriteString("DECLARE\n")
480-
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
480+
writeTableRecordCollectionDecl(db, &plsqlBuilder, getUpdatableFields(stmt.Schema), stmt.Table)
481481
plsqlBuilder.WriteString(" l_updated_records t_records;\n")
482482
plsqlBuilder.WriteString("BEGIN\n")
483483

@@ -524,7 +524,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
524524

525525
// Add RETURNING clause
526526
plsqlBuilder.WriteString("\n RETURNING ")
527-
allColumns := getAllTableColumns(schema)
527+
allColumns := getUpdatableFields(sch)
528528
for i, column := range allColumns {
529529
if i > 0 {
530530
plsqlBuilder.WriteString(", ")
@@ -541,7 +541,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
541541
// First, create all OUT parameters
542542
for rowIdx := 0; rowIdx < estimatedRows; rowIdx++ {
543543
for _, column := range allColumns {
544-
field := findFieldByDBName(schema, column)
544+
field := findFieldByDBName(sch, column)
545545
if field != nil {
546546
var dest interface{}
547547
if isJSONField(field) {
@@ -563,7 +563,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
563563
// Then, generate PL/SQL assignments with correct parameter indices
564564
for rowIdx := 0; rowIdx < estimatedRows; rowIdx++ {
565565
for colIdx, column := range allColumns {
566-
field := findFieldByDBName(schema, column)
566+
field := findFieldByDBName(sch, column)
567567
if field != nil {
568568
paramIndex := outParamStartIndex + (rowIdx * len(allColumns)) + colIdx + 1
569569

@@ -671,7 +671,7 @@ func getUpdateReturningValues(db *gorm.DB) {
671671
return
672672
}
673673

674-
allColumns := getAllTableColumns(db.Statement.Schema)
674+
allColumns := getUpdatableFields(db.Statement.Schema)
675675

676676
if len(allColumns) == 0 {
677677
return

tests/create_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import (
5050

5151
"time"
5252

53+
"gorm.io/datatypes"
5354
"gorm.io/gorm"
5455
"gorm.io/gorm/clause"
5556
"gorm.io/gorm/utils/tests"
@@ -1002,3 +1003,44 @@ func TestCreateChildrenWithMixedPointers(t *testing.T) {
10021003
}
10031004

10041005
}
1006+
1007+
func TestCreateReadOnlyJson(t *testing.T) {
1008+
type record struct {
1009+
ID string `gorm:"column:child_id;primaryKey;type:varchar(36)"`
1010+
ReadOnlyField string `gorm:"->;-:migration;column:read_only_field;type:varchar(100)"`
1011+
JsonTypeField datatypes.JSON `gorm:"column:json_type_field;type:json"`
1012+
}
1013+
1014+
json := datatypes.JSON(fmt.Sprintf(`{"key":"%s"}`, strings.Repeat("x", 4000)))
1015+
records := []record{
1016+
{
1017+
ID: "1",
1018+
JsonTypeField: json,
1019+
},
1020+
{
1021+
ID: "2",
1022+
},
1023+
}
1024+
1025+
DB.Migrator().DropTable(&record{})
1026+
err := DB.AutoMigrate(&record{})
1027+
if err != nil {
1028+
t.Fatalf("errors happened with migrate: %v", err)
1029+
}
1030+
1031+
err = DB.Model(&records).Create(&records).Error
1032+
if err != nil {
1033+
t.Fatalf("errors happened when create: %v", err)
1034+
}
1035+
1036+
// verify records are created
1037+
var results []record
1038+
err = DB.Find(&results).Error
1039+
if err != nil {
1040+
t.Fatalf("errors happened when querying after create: %v", err)
1041+
}
1042+
1043+
if len(results) != 2 {
1044+
t.Fatalf("expected 1 parent, got %d", len(results))
1045+
}
1046+
}

0 commit comments

Comments
 (0)