From 648521deec24d736fe8f74dccbfa2184b5c9063d Mon Sep 17 00:00:00 2001 From: Ayoub Aarrasse Date: Tue, 12 Aug 2025 15:42:15 +0100 Subject: [PATCH 1/5] Fixing null time returned value --- oracle/common.go | 137 +++++++++++++++++++++++++++++++---------------- oracle/create.go | 4 +- oracle/delete.go | 2 +- oracle/update.go | 2 +- 4 files changed, 96 insertions(+), 49 deletions(-) diff --git a/oracle/common.go b/oracle/common.go index 3758324..5d37a23 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -100,18 +100,30 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field { } // Create typed destination for OUT parameters -func createTypedDestination(fieldType reflect.Type) interface{} { - // Handle pointer types - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() +func createTypedDestination(f *schema.Field) interface{} { + if f == nil { + var s string + return &s } - // Type-safe handling for known GORM types and SQL null types - switch fieldType { - case reflect.TypeOf(gorm.DeletedAt{}): + ft := f.FieldType + for ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + if ft == reflect.TypeOf(gorm.DeletedAt{}) { return new(sql.NullTime) - case reflect.TypeOf(time.Time{}): + } + if ft == reflect.TypeOf(time.Time{}) { + if !f.NotNull { // nullable column => keep NULLs + return new(sql.NullTime) + } return new(time.Time) + } + + switch ft { + case reflect.TypeOf(sql.NullTime{}): + return new(sql.NullTime) case reflect.TypeOf(sql.NullInt64{}): return new(sql.NullInt64) case reflect.TypeOf(sql.NullInt32{}): @@ -120,33 +132,28 @@ func createTypedDestination(fieldType reflect.Type) interface{} { return new(sql.NullFloat64) case reflect.TypeOf(sql.NullBool{}): return new(sql.NullBool) - case reflect.TypeOf(sql.NullTime{}): - return new(sql.NullTime) } - // Handle primitive types by Kind - switch fieldType.Kind() { + switch ft.Kind() { + case reflect.String: + return new(string) + + case reflect.Bool: + return new(int64) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return new(int64) // Oracle returns NUMBER as int64 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return new(int64) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return new(uint64) + case reflect.Float32, reflect.Float64: - return new(float64) // Oracle returns FLOAT as float64 - case reflect.Bool: - return new(int64) // Oracle NUMBER(1) for boolean - case reflect.String: - return new(string) - case reflect.Struct: - // For time.Time specifically - if fieldType == reflect.TypeOf(time.Time{}) { - return new(time.Time) - } - // For other structs, use string as safe fallback - return new(string) - default: - // For unknown types, use string as safe fallback - return new(string) + return new(float64) } + + // Fallback + var s string + return &s } // Convert values for Oracle-specific types @@ -182,7 +189,7 @@ func convertValue(val interface{}) interface{} { // Convert Oracle values back to Go types func convertFromOracleToField(value interface{}, field *schema.Field) interface{} { - if value == nil { + if value == nil || field == nil { return nil } @@ -194,7 +201,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ var converted interface{} - // Handle special types first using type-safe comparisons switch targetType { case reflect.TypeOf(gorm.DeletedAt{}): if nullTime, ok := value.(sql.NullTime); ok { @@ -203,7 +209,31 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ converted = gorm.DeletedAt{} } case reflect.TypeOf(time.Time{}): - converted = value + switch vv := value.(type) { + case time.Time: + converted = vv + case sql.NullTime: + if vv.Valid { + converted = vv.Time + } else { + // DB returned NULL + if isPtr { + return nil // -> *time.Time(nil) + } + // non-pointer time.Time: represent NULL as zero time + return time.Time{} + } + default: + converted = value + } + + case reflect.TypeOf(sql.NullTime{}): + if nullTime, ok := value.(sql.NullTime); ok { + converted = nullTime + } else { + converted = sql.NullTime{} + } + case reflect.TypeOf(sql.NullInt64{}): if nullInt, ok := value.(sql.NullInt64); ok { converted = nullInt @@ -228,25 +258,19 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ } else { converted = sql.NullBool{} } - case reflect.TypeOf(sql.NullTime{}): - if nullTime, ok := value.(sql.NullTime); ok { - converted = nullTime - } else { - converted = sql.NullTime{} - } default: - // Handle primitive types + // primitives and everything else converted = convertPrimitiveType(value, targetType) } - // Handle pointer types - if isPtr && converted != nil { - if isZeroValueForPointer(converted, targetType) { + // Pointer targets: nil for "zero-ish", else allocate and set. + if isPtr { + if isZeroFor(targetType, converted) { return nil } ptr := reflect.New(targetType) ptr.Elem().Set(reflect.ValueOf(converted)) - converted = ptr.Interface() + return ptr.Interface() } return converted @@ -426,8 +450,6 @@ func isNullValue(value interface{}) bool { // Check for different NULL types switch v := value.(type) { - case sql.NullString: - return !v.Valid case sql.NullInt64: return !v.Valid case sql.NullInt32: @@ -442,3 +464,28 @@ func isNullValue(value interface{}) bool { return false } } + +func isZeroFor(t reflect.Type, v interface{}) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + if !rv.IsValid() { + return true + } + // exact type match? + if rv.Type() == t { + // special-case time.Time + if t == reflect.TypeOf(time.Time{}) { + return rv.Interface().(time.Time).IsZero() + } + // generic zero check + z := reflect.Zero(t) + return reflect.DeepEqual(rv.Interface(), z.Interface()) + } + // If types differ (e.g., sql.NullTime), treat invalid as zero + if nt, ok := v.(sql.NullTime); ok { + return !nt.Valid + } + return false +} diff --git a/oracle/create.go b/oracle/create.go index b8999c1..eac8670 100644 --- a/oracle/create.go +++ b/oracle/create.go @@ -474,7 +474,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ { for _, column := range allColumns { if field := findFieldByDBName(schema, column); field != nil { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1)) writeQuotedIdentifier(&plsqlBuilder, column) plsqlBuilder.WriteString("; END IF;\n") @@ -586,7 +586,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { quotedColumn := columnBuilder.String() if field := findFieldByDBName(schema, column); field != nil { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)}) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", rowIdx, outParamIndex+1, rowIdx+1, quotedColumn)) outParamIndex++ diff --git a/oracle/delete.go b/oracle/delete.go index aaa29aa..d3ab888 100644 --- a/oracle/delete.go +++ b/oracle/delete.go @@ -254,7 +254,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) { for _, column := range allColumns { field := findFieldByDBName(schema, column) if field != nil { - dest := createTypedDestination(field.FieldType) + dest := createTypedDestination(field) stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest}) plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx)) diff --git a/oracle/update.go b/oracle/update.go index 5b498b1..1389fdc 100644 --- a/oracle/update.go +++ b/oracle/update.go @@ -522,7 +522,7 @@ func buildUpdatePLSQL(db *gorm.DB) { for _, column := range allColumns { field := findFieldByDBName(schema, column) if field != nil { - dest := createTypedDestination(field.FieldType) + dest := createTypedDestination(field) stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest}) } } From 6ba9abe6aedee0aa22e1ce524b9ca59f9e29d94d Mon Sep 17 00:00:00 2001 From: Ayoub Aarrasse Date: Tue, 12 Aug 2025 15:51:01 +0100 Subject: [PATCH 2/5] missing changes --- oracle/clause_builder.go | 2 +- tests/joins_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/oracle/clause_builder.go b/oracle/clause_builder.go index 23cda2b..2036c03 100644 --- a/oracle/clause_builder.go +++ b/oracle/clause_builder.go @@ -161,7 +161,7 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) { var dest interface{} if stmt.Schema != nil { if field := findFieldByDBName(stmt.Schema, column.Name); field != nil { - dest = createTypedDestination(field.FieldType) + dest = createTypedDestination(field) } else { dest = new(string) // Default to string for unknown fields } diff --git a/tests/joins_test.go b/tests/joins_test.go index 22d28e1..e62f44e 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -410,7 +410,7 @@ func TestNestedJoins(t *testing.T) { Joins("Manager.NamedPet.Toy"). Joins("NamedPet"). Joins("NamedPet.Toy"). - Find(&users2, "users.id IN ?", userIDs).Error; err != nil { + Find(&users2, "\"users\".\"id\" IN ?", userIDs).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } else if len(users2) != len(users) { t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) From 06a49b423b23ef1f32011d53b161f41b6f0cef11 Mon Sep 17 00:00:00 2001 From: Ayoub Aarrasse Date: Tue, 12 Aug 2025 23:48:29 +0100 Subject: [PATCH 3/5] fixed tests --- oracle/common.go | 2 ++ tests/passed-tests.txt | 10 +++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/oracle/common.go b/oracle/common.go index 5d37a23..aa276a0 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -450,6 +450,8 @@ func isNullValue(value interface{}) bool { // Check for different NULL types switch v := value.(type) { + case sql.NullString: + return !v.Valid case sql.NullInt64: return !v.Valid case sql.NullInt32: diff --git a/tests/passed-tests.txt b/tests/passed-tests.txt index 047c79c..09f98f6 100644 --- a/tests/passed-tests.txt +++ b/tests/passed-tests.txt @@ -21,9 +21,9 @@ TestMany2ManyOmitAssociations TestMany2ManyAssociationForSlice #TestSingleTableMany2ManyAssociation #TestSingleTableMany2ManyAssociationForSlice -#TestDuplicateMany2ManyAssociation +TestDuplicateMany2ManyAssociation TestConcurrentMany2ManyAssociation -#TestMany2ManyDuplicateBelongsToAssociation +TestMany2ManyDuplicateBelongsToAssociation TestInvalidAssociation TestAssociationNotNullClear #TestForeignKeyConstraints @@ -112,7 +112,7 @@ TestGenericsDelete TestGenericsFindInBatches TestGenericsScopes #TestGenericsJoins -#TestGenericsNestedJoins +TestGenericsNestedJoins #TestGenericsPreloads #TestGenericsNestedPreloads TestGenericsDistinct @@ -146,7 +146,7 @@ TestJoinCount #TestInnerJoins TestJoinWithSameColumnName #TestJoinArgsWithDB -#TestNestedJoins +TestNestedJoins TestJoinsPreload_Issue7013 TestJoinsPreload_Issue7013_RelationEmpty TestJoinsPreload_Issue7013_NoEntries @@ -267,7 +267,7 @@ TestSubQueryWithHaving TestQueryWithTableAndConditions TestQueryWithTableAndConditionsAndAllFields #TestQueryScannerWithSingleColumn -#TestQueryResetNullValue +TestQueryResetNullValue TestQueryError TestQueryScanToArray TestRownum From a7f410d4fb7f03854f797a9e6fdf64d7fdb0aad5 Mon Sep 17 00:00:00 2001 From: Ayoub Aarrasse Date: Wed, 13 Aug 2025 15:23:06 +0100 Subject: [PATCH 4/5] Removing unnecessary function --- oracle/common.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/oracle/common.go b/oracle/common.go index aa276a0..347c016 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -276,24 +276,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ return converted } -// Helper function to check if a value should be treated as nil for pointer fields -func isZeroValueForPointer(value interface{}, targetType reflect.Type) bool { - v := reflect.ValueOf(value) - if !v.IsValid() || v.Kind() != targetType.Kind() { - return false - } - - switch targetType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0.0 - } - return false -} - // Helper function to handle primitive type conversions func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} { switch targetType.Kind() { From 488bbae98f460a03494c21e42b2c66c4d699de88 Mon Sep 17 00:00:00 2001 From: Ayoub Aarrasse Date: Fri, 15 Aug 2025 10:24:13 +0100 Subject: [PATCH 5/5] removing "skip" from passing tests --- tests/associations_many2many_test.go | 2 -- tests/generics_test.go | 1 - tests/joins_test.go | 2 -- tests/query_test.go | 1 - 4 files changed, 6 deletions(-) diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 86dcd5e..3bfd1e2 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -372,7 +372,6 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { } func TestDuplicateMany2ManyAssociation(t *testing.T) { - t.Skip() user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ {Code: "TestDuplicateMany2ManyAssociation-language-1"}, {Code: "TestDuplicateMany2ManyAssociation-language-2"}, @@ -436,7 +435,6 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) { } func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { - t.Skip() user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ ID: 1, diff --git a/tests/generics_test.go b/tests/generics_test.go index bd53acb..3900eff 100644 --- a/tests/generics_test.go +++ b/tests/generics_test.go @@ -422,7 +422,6 @@ func TestGenericsJoins(t *testing.T) { } func TestGenericsNestedJoins(t *testing.T) { - t.Skip() users := []User{ { Name: "generics-nested-joins-1", diff --git a/tests/joins_test.go b/tests/joins_test.go index 8662133..a48b8f9 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -383,8 +383,6 @@ func TestJoinArgsWithDB(t *testing.T) { } func TestNestedJoins(t *testing.T) { - t.Skip() - users := []User{ { Name: "nested-joins-1", diff --git a/tests/query_test.go b/tests/query_test.go index 7119e87..31d5fe8 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1393,7 +1393,6 @@ func TestQueryScannerWithSingleColumn(t *testing.T) { } func TestQueryResetNullValue(t *testing.T) { - t.Skip() type QueryResetItem struct { ID string `gorm:"type:varchar(5)"` Name string