Skip to content

Commit 344a991

Browse files
mutate procedures
1 parent 4365a49 commit 344a991

File tree

17 files changed

+588
-12
lines changed

17 files changed

+588
-12
lines changed

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsEstimationModeBusinessFacade.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public MemoryEstimateResult pcst(
132132
return runEstimation(configuration, graphNameOrConfiguration, memoryEstimation);
133133
}
134134

135-
MemoryEstimation pcst() {
135+
public MemoryEstimation pcst() {
136136
return new PrizeSteinerTreeMemoryEstimateDefinition().memoryEstimation();
137137
}
138138

applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithmsMutateModeBusinessFacade.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.neo4j.gds.paths.traverse.BfsMutateConfig;
4040
import org.neo4j.gds.paths.traverse.DfsMutateConfig;
4141
import org.neo4j.gds.paths.yens.config.ShortestPathYensMutateConfig;
42+
import org.neo4j.gds.pcst.PCSTMutateConfig;
43+
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;
4244
import org.neo4j.gds.spanningtree.SpanningTree;
4345
import org.neo4j.gds.spanningtree.SpanningTreeMutateConfig;
4446
import org.neo4j.gds.steiner.SteinerTreeMutateConfig;
@@ -152,6 +154,24 @@ public <RESULT> RESULT depthFirstSearch(
152154
);
153155
}
154156

157+
public <RESULT> RESULT pcst(
158+
GraphName graphName,
159+
PCSTMutateConfig configuration,
160+
ResultBuilder<PCSTMutateConfig, PrizeSteinerTreeResult, RESULT, RelationshipsWritten> resultBuilder
161+
) {
162+
var mutateStep = new PrizeCollectingSteinerTreeMutateStep(configuration);
163+
164+
return algorithmProcessingTemplateConvenience.processRegularAlgorithmInMutateMode(
165+
graphName,
166+
configuration,
167+
SteinerTree,
168+
estimationFacade::pcst,
169+
(graph, __) -> pathFindingAlgorithms.pcst(graph, configuration),
170+
mutateStep,
171+
resultBuilder
172+
);
173+
}
174+
155175
public <RESULT> RESULT randomWalk(
156176
GraphName graphName,
157177
RandomWalkMutateConfig configuration,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.applications.algorithms.pathfinding;
21+
22+
import org.neo4j.gds.Orientation;
23+
import org.neo4j.gds.RelationshipType;
24+
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.api.GraphStore;
26+
import org.neo4j.gds.applications.algorithms.machinery.MutateStep;
27+
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
28+
import org.neo4j.gds.core.loading.construction.GraphFactory;
29+
import org.neo4j.gds.pcst.PCSTMutateConfig;
30+
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;
31+
32+
import java.util.stream.LongStream;
33+
34+
class PrizeCollectingSteinerTreeMutateStep implements MutateStep<PrizeSteinerTreeResult, RelationshipsWritten> {
35+
private final PCSTMutateConfig configuration;
36+
37+
PrizeCollectingSteinerTreeMutateStep(PCSTMutateConfig configuration) {
38+
this.configuration = configuration;
39+
}
40+
41+
@Override
42+
public RelationshipsWritten execute(
43+
Graph graph,
44+
GraphStore graphStore,
45+
PrizeSteinerTreeResult treeResult
46+
) {
47+
var mutateRelationshipType = RelationshipType.of(configuration.mutateRelationshipType());
48+
49+
var relationshipsBuilder = GraphFactory.initRelationshipsBuilder()
50+
.nodes(graph)
51+
.relationshipType(mutateRelationshipType)
52+
.addPropertyConfig(GraphFactory.PropertyConfig.of(configuration.mutateProperty()))
53+
.orientation(Orientation.NATURAL)
54+
.build();
55+
56+
var parentArray = treeResult.parentArray();
57+
var costArray = treeResult.relationshipToParentCost();
58+
LongStream.range(0, graph.nodeCount())
59+
.filter(nodeId -> parentArray.get(nodeId) != PrizeSteinerTreeResult.PRUNED )
60+
.filter(nodeId -> parentArray.get(nodeId) != PrizeSteinerTreeResult.ROOT )
61+
.forEach(nodeId -> {
62+
var parentId = parentArray.get(nodeId);
63+
relationshipsBuilder.addFromInternal(parentId, nodeId, costArray.get(nodeId));
64+
65+
});
66+
67+
var relationships = relationshipsBuilder.build();
68+
69+
// the effect
70+
graphStore.addRelationshipType(relationships);
71+
72+
// the reporting
73+
return new RelationshipsWritten(treeResult.effectiveNodeCount() - 1);
74+
}
75+
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/ConfigurationParsersForMutateMode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.neo4j.gds.paths.traverse.BfsMutateConfig;
5757
import org.neo4j.gds.paths.traverse.DfsMutateConfig;
5858
import org.neo4j.gds.paths.yens.config.ShortestPathYensMutateConfig;
59+
import org.neo4j.gds.pcst.PCSTMutateConfig;
5960
import org.neo4j.gds.scaleproperties.ScalePropertiesMutateConfig;
6061
import org.neo4j.gds.scc.SccMutateConfig;
6162
import org.neo4j.gds.similarity.filteredknn.FilteredKnnMutateConfig;
@@ -126,7 +127,7 @@ public Function<CypherMapWrapper, AlgoBaseConfig> lookup(Algorithm algorithm) {
126127
case NodeSimilarity -> NodeSimilarityMutateConfig::of;
127128
case Node2Vec -> Node2VecMutateConfig::of;
128129
case PageRank -> PageRankMutateConfig::of;
129-
case PCST -> null;
130+
case PCST -> PCSTMutateConfig::of;
130131
case RandomWalk -> null;
131132
case ScaleProperties -> ScalePropertiesMutateConfig::of;
132133
case SCC -> SccMutateConfig::of;

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/MutateModeAlgorithmLibrary.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static CanonicalProcedureName algorithmToName(Algorithm algorithm) {
101101
case NodeSimilarity -> CanonicalProcedureName.parse("gds.nodeSimilarity");
102102
case Node2Vec -> CanonicalProcedureName.parse("gds.node2vec");
103103
case PageRank -> CanonicalProcedureName.parse("gds.pageRank");
104-
case PCST -> null;
104+
case PCST -> CanonicalProcedureName.parse("gds.prizeSteinerTree");
105105
case RandomWalk -> null;
106106
case ScaleProperties -> CanonicalProcedureName.parse("gds.scaleProperties");
107107
case SCC -> CanonicalProcedureName.parse("gds.scc");

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/StubbyHolder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.neo4j.gds.ml.pipeline.stubs.Node2VecStub;
5555
import org.neo4j.gds.ml.pipeline.stubs.NodeSimilarityStub;
5656
import org.neo4j.gds.ml.pipeline.stubs.PageRankStub;
57+
import org.neo4j.gds.ml.pipeline.stubs.PrizeCollectingSteinerTreeStub;
5758
import org.neo4j.gds.ml.pipeline.stubs.RandomWalkStub;
5859
import org.neo4j.gds.ml.pipeline.stubs.ScalePropertiesStub;
5960
import org.neo4j.gds.ml.pipeline.stubs.SccStub;
@@ -125,7 +126,7 @@ Stub get(Algorithm algorithm) {
125126
case NodeSimilarity -> new NodeSimilarityStub();
126127
case Node2Vec -> new Node2VecStub();
127128
case PageRank -> new PageRankStub();
128-
case PCST -> null;
129+
case PCST -> new PrizeCollectingSteinerTreeStub();
129130
case RandomWalk -> new RandomWalkStub();
130131
case ScaleProperties -> new ScalePropertiesStub();
131132
case SCC -> new SccStub();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.ml.pipeline.stubs;
21+
22+
import org.neo4j.gds.pcst.PCSTMutateConfig;
23+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
24+
import org.neo4j.gds.procedures.algorithms.pathfinding.PrizeCollectingSteinerTreeMutateResult;
25+
import org.neo4j.gds.procedures.algorithms.stubs.MutateStub;
26+
27+
public class PrizeCollectingSteinerTreeStub extends AbstractStub<PCSTMutateConfig, PrizeCollectingSteinerTreeMutateResult> {
28+
protected MutateStub<PCSTMutateConfig, PrizeCollectingSteinerTreeMutateResult> stub(AlgorithmsProcedureFacade facade) {
29+
return facade.pathFinding().prizeCollectingSteinerTreeMutateStub();
30+
}
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.prizesteiner;
21+
22+
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
23+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
24+
import org.neo4j.gds.procedures.algorithms.pathfinding.PrizeCollectingSteinerTreeMutateResult;
25+
import org.neo4j.procedure.Context;
26+
import org.neo4j.procedure.Description;
27+
import org.neo4j.procedure.Name;
28+
import org.neo4j.procedure.Procedure;
29+
30+
import java.util.Map;
31+
import java.util.stream.Stream;
32+
33+
import static org.neo4j.gds.procedures.ProcedureConstants.MEMORY_ESTIMATION_DESCRIPTION;
34+
import static org.neo4j.procedure.Mode.READ;
35+
36+
public class PrizeCollectingSteinerTreeMutateProc {
37+
@Context
38+
public GraphDataScienceProcedures facade;
39+
40+
@Procedure(value = "gds.prizeSteinerTree.mutate", mode = READ)
41+
@Description(Constants.PRIZE_STEINER_DESCRIPTION)
42+
public Stream<PrizeCollectingSteinerTreeMutateResult> steinerTree(
43+
@Name(value = "graphName") String graphName,
44+
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
45+
) {
46+
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeMutateStub().execute(graphName, configuration);
47+
}
48+
49+
@Procedure(value = "gds.prizeSteinerTree.mutate.estimate", mode = READ)
50+
@Description(MEMORY_ESTIMATION_DESCRIPTION)
51+
public Stream<MemoryEstimateResult> estimate(
52+
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
53+
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
54+
) {
55+
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeMutateStub().estimate(graphNameOrConfiguration, algoConfiguration);
56+
}
57+
58+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.paths.prizesteiner;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.BaseProcTest;
25+
import org.neo4j.gds.GdsCypher;
26+
import org.neo4j.gds.Orientation;
27+
import org.neo4j.gds.RelationshipType;
28+
import org.neo4j.gds.catalog.GraphProjectProc;
29+
import org.neo4j.gds.core.Username;
30+
import org.neo4j.gds.core.loading.GraphStoreCatalog;
31+
import org.neo4j.gds.extension.IdFunction;
32+
import org.neo4j.gds.extension.Inject;
33+
import org.neo4j.gds.extension.Neo4jGraph;
34+
35+
import java.util.Optional;
36+
import java.util.concurrent.atomic.LongAdder;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
40+
class PrizeCollectingSteinerTreeMutateProcTest extends BaseProcTest {
41+
42+
@Neo4jGraph
43+
static final String DB_CYPHER =
44+
"CREATE(a:Node{p:100.0}) " +
45+
"CREATE(b:Node{p:100.0}) " +
46+
"CREATE(c:Node{p:169.0}) " +
47+
"CREATE (a)-[:TYPE {cost:60.0}]->(b) ";
48+
49+
@Inject
50+
private IdFunction idFunction;
51+
52+
@BeforeEach
53+
void setup() throws Exception {
54+
55+
registerProcedures(PrizeCollectingSteinerTreeMutateProc.class, GraphProjectProc.class);
56+
var createQuery = GdsCypher.call(DEFAULT_GRAPH_NAME)
57+
.graphProject()
58+
.withAnyLabel()
59+
.withNodeProperty("p")
60+
.withRelationshipType("TYPE", Orientation.UNDIRECTED)
61+
.withRelationshipProperty("cost")
62+
.yields();
63+
runQuery(createQuery);
64+
}
65+
66+
@Test
67+
void testMutate() {
68+
String query = GdsCypher.call(DEFAULT_GRAPH_NAME)
69+
.algo("gds.prizeSteinerTree")
70+
.mutateMode()
71+
.addParameter("prizeProperty","p")
72+
.addParameter("relationshipWeightProperty", "cost")
73+
.addParameter("mutateRelationshipType", "PCST")
74+
.addParameter("mutateProperty", "cost")
75+
.yields(
76+
"effectiveNodeCount", "totalWeight", "sumOfPrizes", "relationshipsWritten"
77+
);
78+
79+
var rowCount = runQueryWithRowConsumer(query,
80+
resultRow -> {
81+
82+
assertThat(resultRow.get("effectiveNodeCount")).isInstanceOf(Long.class);
83+
assertThat(resultRow.get("totalWeight")).isInstanceOf(Double.class);
84+
assertThat(resultRow.get("sumOfPrizes")).isInstanceOf(Double.class);
85+
assertThat(resultRow.get("relationshipsWritten")).isInstanceOf(Long.class);
86+
87+
assertThat((long) resultRow.get("effectiveNodeCount")).isEqualTo(2L);
88+
assertThat((double) resultRow.get("totalWeight")).isEqualTo(60);
89+
assertThat((double) resultRow.get("sumOfPrizes")).isEqualTo(200.0);
90+
assertThat((long) resultRow.get("relationshipsWritten")).isEqualTo(1L);
91+
92+
});
93+
assertThat(rowCount).isEqualTo(1L);
94+
95+
var mutatedGraph = GraphStoreCatalog
96+
.get(Username.EMPTY_USERNAME.username(), db.databaseName(), DEFAULT_GRAPH_NAME)
97+
.graphStore()
98+
.getGraph(RelationshipType.of("PCST"), Optional.of("cost"));
99+
100+
assertThat(mutatedGraph.relationshipCount()).isEqualTo(1L);
101+
102+
var relationshipCounter = new LongAdder();
103+
mutatedGraph.forEachRelationship(mutatedGraph.toMappedNodeId(idFunction.of("b")), -1, (s, t, w) -> {
104+
assertThat(t).isEqualTo(mutatedGraph.toMappedNodeId(idFunction.of("a")));
105+
assertThat(w).isEqualTo(60);
106+
relationshipCounter.increment();
107+
return true;
108+
});
109+
assertThat(relationshipCounter.longValue()).isEqualTo(1L);
110+
111+
}
112+
}

0 commit comments

Comments
 (0)