Skip to content

Commit

Permalink
mutate procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Dec 6, 2024
1 parent 4365a49 commit 344a991
Show file tree
Hide file tree
Showing 17 changed files with 588 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public MemoryEstimateResult pcst(
return runEstimation(configuration, graphNameOrConfiguration, memoryEstimation);
}

MemoryEstimation pcst() {
public MemoryEstimation pcst() {
return new PrizeSteinerTreeMemoryEstimateDefinition().memoryEstimation();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import org.neo4j.gds.paths.traverse.BfsMutateConfig;
import org.neo4j.gds.paths.traverse.DfsMutateConfig;
import org.neo4j.gds.paths.yens.config.ShortestPathYensMutateConfig;
import org.neo4j.gds.pcst.PCSTMutateConfig;
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;
import org.neo4j.gds.spanningtree.SpanningTree;
import org.neo4j.gds.spanningtree.SpanningTreeMutateConfig;
import org.neo4j.gds.steiner.SteinerTreeMutateConfig;
Expand Down Expand Up @@ -152,6 +154,24 @@ public <RESULT> RESULT depthFirstSearch(
);
}

public <RESULT> RESULT pcst(
GraphName graphName,
PCSTMutateConfig configuration,
ResultBuilder<PCSTMutateConfig, PrizeSteinerTreeResult, RESULT, RelationshipsWritten> resultBuilder
) {
var mutateStep = new PrizeCollectingSteinerTreeMutateStep(configuration);

return algorithmProcessingTemplateConvenience.processRegularAlgorithmInMutateMode(
graphName,
configuration,
SteinerTree,
estimationFacade::pcst,
(graph, __) -> pathFindingAlgorithms.pcst(graph, configuration),
mutateStep,
resultBuilder
);
}

public <RESULT> RESULT randomWalk(
GraphName graphName,
RandomWalkMutateConfig configuration,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.pathfinding;

import org.neo4j.gds.Orientation;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.applications.algorithms.machinery.MutateStep;
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.pcst.PCSTMutateConfig;
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;

import java.util.stream.LongStream;

class PrizeCollectingSteinerTreeMutateStep implements MutateStep<PrizeSteinerTreeResult, RelationshipsWritten> {
private final PCSTMutateConfig configuration;

PrizeCollectingSteinerTreeMutateStep(PCSTMutateConfig configuration) {
this.configuration = configuration;
}

@Override
public RelationshipsWritten execute(
Graph graph,
GraphStore graphStore,
PrizeSteinerTreeResult treeResult
) {
var mutateRelationshipType = RelationshipType.of(configuration.mutateRelationshipType());

var relationshipsBuilder = GraphFactory.initRelationshipsBuilder()
.nodes(graph)
.relationshipType(mutateRelationshipType)
.addPropertyConfig(GraphFactory.PropertyConfig.of(configuration.mutateProperty()))
.orientation(Orientation.NATURAL)
.build();

var parentArray = treeResult.parentArray();
var costArray = treeResult.relationshipToParentCost();
LongStream.range(0, graph.nodeCount())
.filter(nodeId -> parentArray.get(nodeId) != PrizeSteinerTreeResult.PRUNED )
.filter(nodeId -> parentArray.get(nodeId) != PrizeSteinerTreeResult.ROOT )
.forEach(nodeId -> {
var parentId = parentArray.get(nodeId);
relationshipsBuilder.addFromInternal(parentId, nodeId, costArray.get(nodeId));

});

var relationships = relationshipsBuilder.build();

// the effect
graphStore.addRelationshipType(relationships);

// the reporting
return new RelationshipsWritten(treeResult.effectiveNodeCount() - 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.neo4j.gds.paths.traverse.BfsMutateConfig;
import org.neo4j.gds.paths.traverse.DfsMutateConfig;
import org.neo4j.gds.paths.yens.config.ShortestPathYensMutateConfig;
import org.neo4j.gds.pcst.PCSTMutateConfig;
import org.neo4j.gds.scaleproperties.ScalePropertiesMutateConfig;
import org.neo4j.gds.scc.SccMutateConfig;
import org.neo4j.gds.similarity.filteredknn.FilteredKnnMutateConfig;
Expand Down Expand Up @@ -126,7 +127,7 @@ public Function<CypherMapWrapper, AlgoBaseConfig> lookup(Algorithm algorithm) {
case NodeSimilarity -> NodeSimilarityMutateConfig::of;
case Node2Vec -> Node2VecMutateConfig::of;
case PageRank -> PageRankMutateConfig::of;
case PCST -> null;
case PCST -> PCSTMutateConfig::of;
case RandomWalk -> null;
case ScaleProperties -> ScalePropertiesMutateConfig::of;
case SCC -> SccMutateConfig::of;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static CanonicalProcedureName algorithmToName(Algorithm algorithm) {
case NodeSimilarity -> CanonicalProcedureName.parse("gds.nodeSimilarity");
case Node2Vec -> CanonicalProcedureName.parse("gds.node2vec");
case PageRank -> CanonicalProcedureName.parse("gds.pageRank");
case PCST -> null;
case PCST -> CanonicalProcedureName.parse("gds.prizeSteinerTree");
case RandomWalk -> null;
case ScaleProperties -> CanonicalProcedureName.parse("gds.scaleProperties");
case SCC -> CanonicalProcedureName.parse("gds.scc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.neo4j.gds.ml.pipeline.stubs.Node2VecStub;
import org.neo4j.gds.ml.pipeline.stubs.NodeSimilarityStub;
import org.neo4j.gds.ml.pipeline.stubs.PageRankStub;
import org.neo4j.gds.ml.pipeline.stubs.PrizeCollectingSteinerTreeStub;
import org.neo4j.gds.ml.pipeline.stubs.RandomWalkStub;
import org.neo4j.gds.ml.pipeline.stubs.ScalePropertiesStub;
import org.neo4j.gds.ml.pipeline.stubs.SccStub;
Expand Down Expand Up @@ -125,7 +126,7 @@ Stub get(Algorithm algorithm) {
case NodeSimilarity -> new NodeSimilarityStub();
case Node2Vec -> new Node2VecStub();
case PageRank -> new PageRankStub();
case PCST -> null;
case PCST -> new PrizeCollectingSteinerTreeStub();
case RandomWalk -> new RandomWalkStub();
case ScaleProperties -> new ScalePropertiesStub();
case SCC -> new SccStub();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.pipeline.stubs;

import org.neo4j.gds.pcst.PCSTMutateConfig;
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
import org.neo4j.gds.procedures.algorithms.pathfinding.PrizeCollectingSteinerTreeMutateResult;
import org.neo4j.gds.procedures.algorithms.stubs.MutateStub;

public class PrizeCollectingSteinerTreeStub extends AbstractStub<PCSTMutateConfig, PrizeCollectingSteinerTreeMutateResult> {
protected MutateStub<PCSTMutateConfig, PrizeCollectingSteinerTreeMutateResult> stub(AlgorithmsProcedureFacade facade) {
return facade.pathFinding().prizeCollectingSteinerTreeMutateStub();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.paths.prizesteiner;

import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.algorithms.pathfinding.PrizeCollectingSteinerTreeMutateResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.Map;
import java.util.stream.Stream;

import static org.neo4j.gds.procedures.ProcedureConstants.MEMORY_ESTIMATION_DESCRIPTION;
import static org.neo4j.procedure.Mode.READ;

public class PrizeCollectingSteinerTreeMutateProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(value = "gds.prizeSteinerTree.mutate", mode = READ)
@Description(Constants.PRIZE_STEINER_DESCRIPTION)
public Stream<PrizeCollectingSteinerTreeMutateResult> steinerTree(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeMutateStub().execute(graphName, configuration);
}

@Procedure(value = "gds.prizeSteinerTree.mutate.estimate", mode = READ)
@Description(MEMORY_ESTIMATION_DESCRIPTION)
public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeMutateStub().estimate(graphNameOrConfiguration, algoConfiguration);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.paths.prizesteiner;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.neo4j.gds.BaseProcTest;
import org.neo4j.gds.GdsCypher;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.catalog.GraphProjectProc;
import org.neo4j.gds.core.Username;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.Neo4jGraph;

import java.util.Optional;
import java.util.concurrent.atomic.LongAdder;

import static org.assertj.core.api.Assertions.assertThat;

class PrizeCollectingSteinerTreeMutateProcTest extends BaseProcTest {

@Neo4jGraph
static final String DB_CYPHER =
"CREATE(a:Node{p:100.0}) " +
"CREATE(b:Node{p:100.0}) " +
"CREATE(c:Node{p:169.0}) " +
"CREATE (a)-[:TYPE {cost:60.0}]->(b) ";

@Inject
private IdFunction idFunction;

@BeforeEach
void setup() throws Exception {

registerProcedures(PrizeCollectingSteinerTreeMutateProc.class, GraphProjectProc.class);
var createQuery = GdsCypher.call(DEFAULT_GRAPH_NAME)
.graphProject()
.withAnyLabel()
.withNodeProperty("p")
.withRelationshipType("TYPE", Orientation.UNDIRECTED)
.withRelationshipProperty("cost")
.yields();
runQuery(createQuery);
}

@Test
void testMutate() {
String query = GdsCypher.call(DEFAULT_GRAPH_NAME)
.algo("gds.prizeSteinerTree")
.mutateMode()
.addParameter("prizeProperty","p")
.addParameter("relationshipWeightProperty", "cost")
.addParameter("mutateRelationshipType", "PCST")
.addParameter("mutateProperty", "cost")
.yields(
"effectiveNodeCount", "totalWeight", "sumOfPrizes", "relationshipsWritten"
);

var rowCount = runQueryWithRowConsumer(query,
resultRow -> {

assertThat(resultRow.get("effectiveNodeCount")).isInstanceOf(Long.class);
assertThat(resultRow.get("totalWeight")).isInstanceOf(Double.class);
assertThat(resultRow.get("sumOfPrizes")).isInstanceOf(Double.class);
assertThat(resultRow.get("relationshipsWritten")).isInstanceOf(Long.class);

assertThat((long) resultRow.get("effectiveNodeCount")).isEqualTo(2L);
assertThat((double) resultRow.get("totalWeight")).isEqualTo(60);
assertThat((double) resultRow.get("sumOfPrizes")).isEqualTo(200.0);
assertThat((long) resultRow.get("relationshipsWritten")).isEqualTo(1L);

});
assertThat(rowCount).isEqualTo(1L);

var mutatedGraph = GraphStoreCatalog
.get(Username.EMPTY_USERNAME.username(), db.databaseName(), DEFAULT_GRAPH_NAME)
.graphStore()
.getGraph(RelationshipType.of("PCST"), Optional.of("cost"));

assertThat(mutatedGraph.relationshipCount()).isEqualTo(1L);

var relationshipCounter = new LongAdder();
mutatedGraph.forEachRelationship(mutatedGraph.toMappedNodeId(idFunction.of("b")), -1, (s, t, w) -> {
assertThat(t).isEqualTo(mutatedGraph.toMappedNodeId(idFunction.of("a")));
assertThat(w).isEqualTo(60);
relationshipCounter.increment();
return true;
});
assertThat(relationshipCounter.longValue()).isEqualTo(1L);

}
}
Loading

0 comments on commit 344a991

Please sign in to comment.