Skip to content

Commit

Permalink
write procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Dec 6, 2024
1 parent 4c9a986 commit 4365a49
Show file tree
Hide file tree
Showing 11 changed files with 460 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import org.neo4j.gds.api.properties.nodes.LongNodePropertyValues;
import org.neo4j.gds.spanningtree.SpanningTree;

public class SpanningTreeBackedNodePropertyValues implements LongNodePropertyValues {
public class KSpanningTreeBackedNodePropertyValues implements LongNodePropertyValues {
private final SpanningTree spanningTree;
private final long nodeCount;

public SpanningTreeBackedNodePropertyValues(
public KSpanningTreeBackedNodePropertyValues(
SpanningTree spanningTree,
long nodeCount
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Void execute(
SpanningTree spanningTree,
JobId jobId
) {
var properties = new SpanningTreeBackedNodePropertyValues(spanningTree, graph.nodeCount());
var properties = new KSpanningTreeBackedNodePropertyValues(spanningTree, graph.nodeCount());

var progressTracker = new TaskProgressTracker(
NodePropertyExporter.baseTask(AlgorithmLabel.KSpanningTree.asString(), graph.nodeCount()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import org.neo4j.gds.api.GraphName;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
import org.neo4j.gds.applications.algorithms.machinery.Computation;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplateConvenience;
import org.neo4j.gds.applications.algorithms.machinery.Computation;
import org.neo4j.gds.applications.algorithms.machinery.Label;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.applications.algorithms.machinery.ResultBuilder;
Expand All @@ -44,6 +44,8 @@
import org.neo4j.gds.paths.dijkstra.config.AllShortestPathsDijkstraWriteConfig;
import org.neo4j.gds.paths.dijkstra.config.ShortestPathDijkstraWriteConfig;
import org.neo4j.gds.paths.yens.config.ShortestPathYensWriteConfig;
import org.neo4j.gds.pcst.PCSTWriteConfig;
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;
import org.neo4j.gds.spanningtree.SpanningTree;
import org.neo4j.gds.spanningtree.SpanningTreeWriteConfig;
import org.neo4j.gds.steiner.SteinerTreeResult;
Expand All @@ -56,6 +58,7 @@
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.DeltaStepping;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.Dijkstra;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.KSpanningTree;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.PCST;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.SingleSourceDijkstra;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.SteinerTree;
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.Yens;
Expand Down Expand Up @@ -150,6 +153,24 @@ public <RESULT> RESULT kSpanningTree(
);
}

public <RESULT> RESULT pcst(
GraphName graphName,
PCSTWriteConfig configuration,
ResultBuilder<PCSTWriteConfig, PrizeSteinerTreeResult, RESULT, RelationshipsWritten> resultBuilder
) {
var writeStep = new PrizeCollectingSteinerTreeWriteStep(requestScopedDependencies, writeContext, configuration);

return runAlgorithmAndWrite(
graphName,
configuration,
PCST,
estimationFacade::pcst,
(graph, __) -> pathFindingAlgorithms.pcst(graph, configuration),
writeStep,
resultBuilder
);
}

public <RESULT> RESULT singlePairShortestPathAStar(
GraphName graphName,
ShortestPathAStarWriteConfig configuration,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.applications.algorithms.pathfinding;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.ResultStore;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.applications.algorithms.machinery.WriteContext;
import org.neo4j.gds.applications.algorithms.machinery.WriteStep;
import org.neo4j.gds.applications.algorithms.metadata.RelationshipsWritten;
import org.neo4j.gds.core.utils.progress.JobId;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.pcst.PCSTWriteConfig;
import org.neo4j.gds.pricesteiner.PrizeSteinerTreeResult;
import org.neo4j.gds.spanningtree.SpanningGraph;
import org.neo4j.gds.spanningtree.SpanningTree;

class PrizeCollectingSteinerTreeWriteStep implements WriteStep<PrizeSteinerTreeResult, RelationshipsWritten> {
private final RequestScopedDependencies requestScopedDependencies;
private final PCSTWriteConfig configuration;
private final WriteContext writeContext;

PrizeCollectingSteinerTreeWriteStep(
RequestScopedDependencies requestScopedDependencies,
WriteContext writeContext,
PCSTWriteConfig configuration
) {
this.requestScopedDependencies = requestScopedDependencies;
this.configuration = configuration;
this.writeContext = writeContext;
}

@Override
public RelationshipsWritten execute(
Graph graph,
GraphStore graphStore,
ResultStore resultStore,
PrizeSteinerTreeResult steinerTreeResult,
JobId jobId
) {

var spanningTree = new SpanningTree(
-1,
graph.nodeCount(),
steinerTreeResult.effectiveNodeCount(),
steinerTreeResult.parentArray(),
nodeId -> steinerTreeResult.relationshipToParentCost().get(nodeId),
steinerTreeResult.totalWeight()
);
var spanningGraph = new SpanningGraph(graph, spanningTree);

var relationshipExporter = writeContext.relationshipExporterBuilder()
.withGraph(spanningGraph)
.withIdMappingOperator(spanningGraph::toOriginalNodeId)
.withTerminationFlag(requestScopedDependencies.terminationFlag())
.withProgressTracker(ProgressTracker.NULL_TRACKER)
.withResultStore(configuration.resolveResultStore(resultStore))
.withJobId(configuration.jobId())
.build();

relationshipExporter.write(configuration.writeRelationshipType(), configuration.writeProperty());

return new RelationshipsWritten(steinerTreeResult.effectiveNodeCount() - 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.paths.prizesteiner;

import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.algorithms.pathfinding.PrizeCollectingSteinerTreeWriteResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.Map;
import java.util.stream.Stream;

import static org.neo4j.gds.procedures.ProcedureConstants.MEMORY_ESTIMATION_DESCRIPTION;
import static org.neo4j.procedure.Mode.READ;
import static org.neo4j.procedure.Mode.WRITE;

public class PrizeCollectingSteinerTreeWriteProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(value = "gds.prizeSteinerTree.write", mode = WRITE)
@Description(Constants.PRIZE_STEINER_DESCRIPTION)
public Stream<PrizeCollectingSteinerTreeWriteResult> steinerTree(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeWrite(graphName, configuration);
}

@Procedure(value = "gds.prizeSteinerTree.write.estimate", mode = READ)
@Description(MEMORY_ESTIMATION_DESCRIPTION)
public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
return facade.algorithms().pathFinding().prizeCollectingSteinerTreeWriteEstimate(graphNameOrConfiguration, algoConfiguration);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.paths.prizesteiner;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.neo4j.gds.BaseProcTest;
import org.neo4j.gds.GdsCypher;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.catalog.GraphProjectProc;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.Neo4jGraph;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

class PrizeCollectingSteinerTreeWriteProcTest extends BaseProcTest {

@Neo4jGraph
static final String DB_CYPHER =
"CREATE(a:Node{p:100.0}) " +
"CREATE(b:Node{p:100.0}) " +
"CREATE(c:Node{p:169.0}) " +
"CREATE (a)-[:TYPE {cost:60.0}]->(b) ";

@Inject
private IdFunction idFunction;

@BeforeEach
void setup() throws Exception {

registerProcedures(PrizeCollectingSteinerTreeWriteProc.class, GraphProjectProc.class);
var createQuery = GdsCypher.call(DEFAULT_GRAPH_NAME)
.graphProject()
.withAnyLabel()
.withNodeProperty("p")
.withRelationshipType("TYPE", Orientation.UNDIRECTED)
.withRelationshipProperty("cost")
.yields();
runQuery(createQuery);
}

@Test
void testWrite() {
String query = GdsCypher.call(DEFAULT_GRAPH_NAME)
.algo("gds.prizeSteinerTree")
.writeMode()
.addParameter("prizeProperty","p")
.addParameter("relationshipWeightProperty", "cost")
.addParameter("writeRelationshipType", "PCST")
.addParameter("writeProperty", "cost")
.yields(
"effectiveNodeCount", "totalWeight", "sumOfPrizes", "relationshipsWritten"
);

var rowCount = runQueryWithRowConsumer(query,
resultRow -> {

assertThat(resultRow.get("effectiveNodeCount")).isInstanceOf(Long.class);
assertThat(resultRow.get("totalWeight")).isInstanceOf(Double.class);
assertThat(resultRow.get("sumOfPrizes")).isInstanceOf(Double.class);
assertThat(resultRow.get("relationshipsWritten")).isInstanceOf(Long.class);

assertThat((long) resultRow.get("effectiveNodeCount")).isEqualTo(2L);
assertThat((double) resultRow.get("totalWeight")).isEqualTo(60);
assertThat((double) resultRow.get("sumOfPrizes")).isEqualTo(200.0);
assertThat((long) resultRow.get("relationshipsWritten")).isEqualTo(1L);

});
assertThat(rowCount).isEqualTo(1L);

var sourceNode = idFunction.of("a");
var terminalNode = idFunction.of("b");

var rowCountCheck = runQueryWithRowConsumer(
"MATCH (a)-[r:PCST]->(b) RETURN id(a) AS a, id(b) AS b, r.cost AS cost",
row -> {
var a = row.getNumber("a").longValue();
var b = row.getNumber("b").longValue();

assertThat(a).isNotEqualTo(b);

assertThat(a).isIn(List.of(sourceNode,terminalNode));
assertThat(b).isIn(List.of(sourceNode,terminalNode));

var writtenCost = row.getNumber("cost").doubleValue();
assertThat(writtenCost)
.isEqualTo(60.0);
}
);

assertThat(rowCountCheck).isEqualTo(1L);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.neo4j.gds.paths.yens.config.ShortestPathYensWriteConfig;
import org.neo4j.gds.pcst.PCSTStatsConfig;
import org.neo4j.gds.pcst.PCSTStreamConfig;
import org.neo4j.gds.pcst.PCSTWriteConfig;
import org.neo4j.gds.procedures.algorithms.configuration.UserSpecificConfigurationParser;
import org.neo4j.gds.procedures.algorithms.pathfinding.stubs.BFSMutateStub;
import org.neo4j.gds.procedures.algorithms.pathfinding.stubs.BellmanFordMutateStub;
Expand Down Expand Up @@ -634,6 +635,37 @@ public Stream<MemoryEstimateResult> prizeCollectingSteinerTreeStatsEstimate(
);
}

@Override
public Stream<PrizeCollectingSteinerTreeWriteResult> prizeCollectingSteinerTreeWrite(
String graphName,
Map<String, Object> configuration
) {
var config = configurationParser.parseConfiguration(configuration, PCSTWriteConfig::of);
var resultBuilder = new PrizeCollectingSteinerTreeResultBuilderForWriteMode(config);

return Stream.of(
writeModeBusinessFacade.pcst(
GraphName.parse(graphName),
config,
resultBuilder
)
);
}

@Override
public Stream<MemoryEstimateResult> prizeCollectingSteinerTreeWriteEstimate(
Object graphNameOrConfiguration,
Map<String, Object> algorithmConfiguration
) {
return
Stream.of(
estimationModeBusinessFacade.pcst(
configurationParser.parseConfiguration(algorithmConfiguration, PCSTWriteConfig::of),
graphNameOrConfiguration
)
);
}

@Override
public Stream<StandardModeResult> randomWalkStats(String graphName, Map<String, Object> rawConfiguration) {
var configuration = configurationParser.parseConfiguration(
Expand Down
Loading

0 comments on commit 4365a49

Please sign in to comment.