Skip to content

Commit

Permalink
patched error from graph connect
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Solender committed Jun 13, 2022
1 parent df20169 commit 86167b8
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 45 deletions.
2 changes: 1 addition & 1 deletion container_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func setupNeo4jContainer(ctx context.Context, version neo4jContainerVersion) (*n
return nil, err
}

req.Image = "neo4j:4.3.2-enterprise"
req.Image = "neo4j:4.4-enterprise"

req.Env["NEO4J_dbms_default__listen__address"] = "0.0.0.0"

Expand Down
78 changes: 52 additions & 26 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func traverseResultRecordValues(values []interface{}) ([]neo4j.Path, []neo4j.Rel
return paths, strictRels, isolatedNodes
}

func ptrToBool(b bool) *bool {
return &b
}

//decodes raw path response from driver
//example query `match p=(n)-[*0..5]-() return p`
func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
Expand All @@ -76,16 +80,23 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
return fmt.Errorf("response object can not be nil - %w", ErrInvalidParams)
}

rv := reflect.ValueOf(respObj)
rt := reflect.TypeOf(respObj)

primaryLabel := getPrimaryLabel(rt)
returnValue := reflect.ValueOf(respObj)
returnType := reflect.TypeOf(respObj)

if rv.Kind() != reflect.Ptr || rv.IsNil() {
// check type is valid
if returnValue.Kind() != reflect.Ptr || returnValue.IsNil() {
return fmt.Errorf("invalid resp type %T - %w", respObj, ErrInvalidParams)
}

//todo optimize with set array size
primaryLabel, err := getPrimaryLabel(returnType)
if err != nil {
return fmt.Errorf("failed to get primary label from returnType: %w", err)
}

if primaryLabel == "" {
return errors.New("label was empty for returnType")
}

var paths []neo4j.Path
var strictRels []neo4j.Relationship
var isolatedNodes []neo4j.Node
Expand All @@ -102,16 +113,18 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
var pks []int64
rels := make(map[int64]*neoEdgeConfig)
labelLookup := map[int64]string{}
returnIsSingle := returnType.Elem().Kind() != reflect.Slice
returnUsed := ptrToBool(false)

if len(paths) != 0 {
err = sortPaths(gogm, paths, nodeLookup, rels, &pks, primaryLabel, relMaps)
err = sortPaths(gogm, paths, nodeLookup, rels, &pks, returnIsSingle, returnUsed, &returnValue, primaryLabel, relMaps)
if err != nil {
return err
}
}

if len(isolatedNodes) != 0 {
err = sortIsolatedNodes(gogm, isolatedNodes, labelLookup, nodeLookup, &pks, primaryLabel, relMaps)
err = sortIsolatedNodes(gogm, isolatedNodes, labelLookup, nodeLookup, &pks, returnIsSingle, returnUsed, &returnValue, primaryLabel, relMaps)
if err != nil {
return err
}
Expand Down Expand Up @@ -229,7 +242,7 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
}

//create value
specialEdgeValue, err := convertToValue(gogm, relationConfig.Id, typeConfig, relationConfig.Obj, specialEdgeType)
specialEdgeValue, err := convertToValue(gogm, relationConfig.Id, typeConfig, relationConfig.Obj, specialEdgeType, false, returnUsed, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -302,8 +315,8 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
}

//handle if its returning a slice -- validation has been done at an earlier step
if rt.Elem().Kind() == reflect.Slice {
reflection := reflect.MakeSlice(rt.Elem(), 0, cap(pks))
if !returnIsSingle {
reflection := reflect.MakeSlice(returnType.Elem(), 0, cap(pks))

reflectionValue := reflect.New(reflection.Type())
reflectionValue.Elem().Set(reflection)
Expand All @@ -312,7 +325,7 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {

sliceValuePtr := slicePtr.Elem()

sliceType := rt.Elem().Elem()
sliceType := returnType.Elem().Elem()

for _, id := range pks {
val, ok := nodeLookup[id]
Expand All @@ -328,19 +341,19 @@ func decode(gogm *Gogm, result neo4j.Result, respObj interface{}) (err error) {
}
}

reflect.Indirect(rv).Set(sliceValuePtr)
reflect.Indirect(returnValue).Set(sliceValuePtr)

return err
} else {
//handles single -- already checked to make sure p2 is at least 1
reflect.Indirect(rv).Set(*nodeLookup[pks[0]])
// reflect.Indirect(returnValue).Set(*nodeLookup[pks[0]])

return err
}
}

// getPrimaryLabel gets the label from a reflect type
func getPrimaryLabel(rt reflect.Type) string {
func getPrimaryLabel(rt reflect.Type) (string, error) {
//assume its already a pointer
rt = rt.Elem()

Expand All @@ -349,13 +362,15 @@ func getPrimaryLabel(rt reflect.Type) string {
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
}
} else if rt.Kind() == reflect.Ptr {
return "", errors.New("must be pointer to struct of slice, can not be a pointer to another pointer")
}

return rt.Name()
return rt.Name(), nil
}

// sortIsolatedNodes process nodes that are returned individually from bolt driver
func sortIsolatedNodes(gogm *Gogm, isolatedNodes []neo4j.Node, labelLookup map[int64]string, nodeLookup map[int64]*reflect.Value, pks *[]int64, pkLabel string, relMaps map[int64]map[string]*RelationConfig) error {
func sortIsolatedNodes(gogm *Gogm, isolatedNodes []neo4j.Node, labelLookup map[int64]string, nodeLookup map[int64]*reflect.Value, pks *[]int64, pkSingle bool, passTypeUsed *bool, passValue *reflect.Value, pkLabel string, relMaps map[int64]map[string]*RelationConfig) error {
if isolatedNodes == nil {
return fmt.Errorf("isolatedNodes can not be nil, %w", ErrInternal)
}
Expand All @@ -364,7 +379,7 @@ func sortIsolatedNodes(gogm *Gogm, isolatedNodes []neo4j.Node, labelLookup map[i
//check if node has already been found by another process
if _, ok := nodeLookup[node.Id]; !ok {
//if it hasn't, map it
val, err := convertNodeToValue(gogm, node)
val, err := convertNodeToValue(gogm, node, pkSingle, passTypeUsed, passValue)
if err != nil {
return err
}
Expand Down Expand Up @@ -421,7 +436,7 @@ func sortStrictRels(strictRels []neo4j.Relationship, labelLookup map[int64]strin
}

// sortPaths sorts nodes and relationships from bolt driver that dont specify the direction explicitly, instead uses the bolt spec to determine direction
func sortPaths(gogm *Gogm, paths []neo4j.Path, nodeLookup map[int64]*reflect.Value, rels map[int64]*neoEdgeConfig, pks *[]int64, pkLabel string, relMaps map[int64]map[string]*RelationConfig) error {
func sortPaths(gogm *Gogm, paths []neo4j.Path, nodeLookup map[int64]*reflect.Value, rels map[int64]*neoEdgeConfig, pks *[]int64, pkSingle bool, passTypeUsed *bool, passValue *reflect.Value, pkLabel string, relMaps map[int64]map[string]*RelationConfig) error {
if paths == nil {
return fmt.Errorf("paths is empty, that shouldn't have happened, %w", ErrInternal)
}
Expand All @@ -439,9 +454,9 @@ func sortPaths(gogm *Gogm, paths []neo4j.Path, nodeLookup map[int64]*reflect.Val
}
if _, ok := nodeLookup[node.Id]; !ok {
//we haven't parsed this one yet, lets do that now
val, err := convertNodeToValue(gogm, node)
val, err := convertNodeToValue(gogm, node, pkSingle, passTypeUsed, passValue)
if err != nil {
return err
return fmt.Errorf("failed to convert node to value, %w", err)
}

nodeLookup[node.Id] = val
Expand Down Expand Up @@ -508,10 +523,10 @@ var sliceOfEmptyInterface []interface{}
var emptyInterfaceType = reflect.TypeOf(sliceOfEmptyInterface).Elem()

// convertToValue converts properties map from neo4j to golang reflect value
func convertToValue(gogm *Gogm, graphId int64, conf structDecoratorConfig, props map[string]interface{}, rtype reflect.Type) (valss *reflect.Value, err error) {
func convertToValue(gogm *Gogm, graphId int64, conf structDecoratorConfig, props map[string]interface{}, rtype reflect.Type, pkSingle bool, passValueUsed *bool, passValue *reflect.Value) (valPtr *reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
err = fmt.Errorf("recovered converToValue: %v", r)
}
}()

Expand All @@ -524,7 +539,6 @@ func convertToValue(gogm *Gogm, graphId int64, conf structDecoratorConfig, props
isPtr = true
rtype = rtype.Elem()
}

val := reflect.New(rtype)

if graphId >= 0 {
Expand Down Expand Up @@ -646,6 +660,18 @@ func convertToValue(gogm *Gogm, graphId int64, conf structDecoratorConfig, props
}
}

// handle pk case
if pkSingle && !*passValueUsed {
// pass value will always be a pointer
if isPtr {
reflect.Indirect(*passValue).Set(val)
} else {
reflect.Indirect(*passValue).Set(val.Elem())
}
*passValueUsed = true
return passValue, nil
}

//if its not a pointer, dereference it
if !isPtr {
retV := reflect.Indirect(val)
Expand All @@ -656,7 +682,7 @@ func convertToValue(gogm *Gogm, graphId int64, conf structDecoratorConfig, props
}

// convertNodeToValue converts raw bolt node to reflect value
func convertNodeToValue(gogm *Gogm, boltNode neo4j.Node) (*reflect.Value, error) {
func convertNodeToValue(gogm *Gogm, boltNode neo4j.Node, pkSingle bool, passTypeUsed *bool, passValue *reflect.Value) (*reflect.Value, error) {

if boltNode.Labels == nil || len(boltNode.Labels) == 0 {
return nil, errors.New("boltNode has no labels")
Expand All @@ -674,5 +700,5 @@ func convertNodeToValue(gogm *Gogm, boltNode neo4j.Node) (*reflect.Value, error)
return nil, errors.New("unable to cast to struct decorator config")
}

return convertToValue(gogm, boltNode.Id, typeConfig, boltNode.Props, typeConfig.Type)
return convertToValue(gogm, boltNode.Id, typeConfig, boltNode.Props, typeConfig.Type, pkSingle, passTypeUsed, passValue)
}
4 changes: 2 additions & 2 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestConvertNodeToValue(t *testing.T) {
Labels: []string{"TestStruct"},
}

val, err := convertNodeToValue(gogm, bn)
val, err := convertNodeToValue(gogm, bn, false, ptrToBool(false), nil)
req.Nil(err)
req.NotNil(val)
req.EqualValues(TestStruct{
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestConvertNodeToValue(t *testing.T) {
Name: "test",
}
mappedTypes.Set("TestStruct", te)
val, err = convertNodeToValue(gogm, bn)
val, err = convertNodeToValue(gogm, bn, false, ptrToBool(false), nil)
req.Nil(err)
req.NotNil(val)
}
Expand Down
47 changes: 47 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"context"
"fmt"
"log"
"reflect"
"sync"

uuid2 "github.com/google/uuid"
Expand Down Expand Up @@ -160,6 +161,52 @@ func (integrationTest *IntegrationTestSuite) TestSecureConnection() {
integrationTest.Require().Nil(err)
}

func (integrationTest *IntegrationTestSuite) TestFirstLayerSpecialEdgeLoad() {
req := integrationTest.Require()

/*
a
/ \
EdgeC EdgeC
/ \
b b
verifying that loading from A places pointers correctly so that a change in edge is working
*/

// build base graph
_a := a{}
b1, b2 := b{}, b{}
c1, c2 := c{Start: &_a, End: &b1}, c{Start: &_a, End: &b2}
_a.MultiSpecA = []*c{&c1, &c2}
b1.SingleSpec = &c1
b2.SingleSpec = &c2
sess1, err := integrationTest.gogm.NewSessionV2(SessionConfig{AccessMode: neo4j.AccessModeWrite})
req.NoError(err)
req.NotNil(sess1)

req.NoError(sess1.SaveDepth(context.Background(), &_a, 1))

req.NoError(sess1.Close())

// now load stuff and verify that it loaded correctly
sess2, err := integrationTest.gogm.NewSessionV2(SessionConfig{AccessMode: neo4j.AccessModeWrite})
req.NoError(err)
req.NotNil(sess2)
copyA := a{}
req.NoError(sess2.LoadDepth(context.Background(), &copyA, _a.UUID, 1))
// now we need to verify that the pointers match
// check that the slice length is correct
req.Equal(2, len(copyA.MultiSpecA), "length of C edges")
copyC1, copyC2 := copyA.MultiSpecA[0], copyA.MultiSpecA[1]
// check pointers are correct
req.Equal(reflect.ValueOf(&copyA).Pointer(), reflect.ValueOf(copyC1.Start).Pointer(), "edge pointer C1 Start back to A should match A pointer")
req.Equal(reflect.ValueOf(&copyA).Pointer(), reflect.ValueOf(copyC2.Start).Pointer(), "edge pointer C2 Start back to A should match A pointer")

// now to replicate the original error, clear the edges and try to save, this should pass if the issue is fixed
copyA.MultiSpecA = []*c{}
req.NoError(sess2.SaveDepth(context.Background(), &copyA, 1))
}

func (integrationTest *IntegrationTestSuite) TestManagedTx() {
//req := integrationTest.Require()
if integrationTest.gogm.boltMajorVersion < 4 {
Expand Down
19 changes: 3 additions & 16 deletions save.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ func generateCurRels(gogm *Gogm, parentPtr uintptr, current *reflect.Value, curr
for i := 0; i < slLen; i++ {
relVal := relField.Index(i)

fmt.Println("calling process struct many from generate cur rels")
newParentId, _, _, _, _, followVal, err := processStruct(gogm, conf, &relVal, curPtr)
if err != nil {
return err
Expand Down Expand Up @@ -472,7 +471,6 @@ func generateCurRels(gogm *Gogm, parentPtr uintptr, current *reflect.Value, curr
}
}
} else {
fmt.Println("calling processStruct many from generateCurRels")
newParentId, _, _, _, _, followVal, err := processStruct(gogm, conf, &relField, curPtr)
if err != nil {
return err
Expand Down Expand Up @@ -784,7 +782,6 @@ func parseStruct(gogm *Gogm, parentPtr uintptr, edgeLabel string, parentIsStart

for i := 0; i < slLen; i++ {
relVal := relField.Index(i)
fmt.Printf("calling process struct many from parse struct\ntype=%v\n", current.Type().String())
newParentId, newEdgeLabel, newParentIsStart, newDirection, newEdgeParams, followVal, err := processStruct(gogm, conf, &relVal, curPtr)
if err != nil {
return err
Expand All @@ -796,7 +793,6 @@ func parseStruct(gogm *Gogm, parentPtr uintptr, edgeLabel string, parentIsStart
}
}
} else {
fmt.Printf("calling process struct single from parse struct\n")
newParentId, newEdgeLabel, newParentIsStart, newDirection, newEdgeParams, followVal, err := processStruct(gogm, conf, &relField, curPtr)
if err != nil {
return err
Expand Down Expand Up @@ -853,20 +849,11 @@ func processStruct(gogm *Gogm, fieldConf decoratorConfig, relValue *reflect.Valu
params = map[string]interface{}{}
}

startPtr, endPtr, relPTR := startVal.Pointer(), endVal.Pointer(), relValue.Pointer()
fmt.Printf("startPTR=%v\nendPTR=%v\nrelPTR=%v\ncurPTR=%v\n", startPtr, endPtr, relPTR, curPtr)
fmt.Printf("startType=%v\nendType=%v\ncurType=%v\n", startVal.Type().String(), endVal.Type().String(), relValue.Type().String())
startPtr, endPtr := startVal.Pointer(), endVal.Pointer()
if startPtr == curPtr {

//follow the end
retVal := endValSlice[0].Elem()

return curPtr, edgeLabel, true, fieldConf.Direction, params, &retVal, nil
return startPtr, edgeLabel, true, fieldConf.Direction, params, &endVal, nil
} else if endPtr == curPtr {
///follow the start
retVal := startValSlice[0].Elem()

return curPtr, edgeLabel, false, fieldConf.Direction, params, &retVal, nil
return endPtr, edgeLabel, false, fieldConf.Direction, params, &startVal, nil
} else {
return 0, "", false, 0, nil, nil, errors.New("edge is invalid, doesn't point to parent vertex")
}
Expand Down

0 comments on commit 86167b8

Please sign in to comment.