Skip to content

Commit

Permalink
migrate estimation cli hashgnn to application layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Dec 4, 2024
1 parent 0be494c commit 5228383
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,16 @@
*/
package org.neo4j.gds.embeddings.hashgnn;

import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.ArrayList;
import java.util.List;

public class HashGNNFactory<CONFIG extends HashGNNConfig> extends GraphAlgorithmFactory<HashGNN, CONFIG> {

@Override
public String taskName() {
return "HashGNN";
}

public HashGNN build(
Graph graph,
HashGNNParameters parameters,
ProgressTracker progressTracker
) {
return new HashGNN(
graph,
parameters,
progressTracker,
TerminationFlag.RUNNING_TRUE
);
}

@Override
public HashGNN build(
Graph graph,
CONFIG configuration,
ProgressTracker progressTracker
) {
return build(graph, HashGNNConfigTransformer.toParameters(configuration), progressTracker);
}

@Override
public Task progressTask(Graph graph, CONFIG config) {
public class HashGNNTask {
public static Task create(Graph graph, HashGNNConfig config) {
var tasks = new ArrayList<Task>();

if (config.generateFeatures().isPresent()) {
Expand Down Expand Up @@ -93,17 +61,8 @@ public Task progressTask(Graph graph, CONFIG config) {
}

return Tasks.task(
taskName(),
AlgorithmLabel.HashGNN.asString(),
tasks
);
}

public MemoryEstimation memoryEstimation(HashGNNParameters parameters) {
return new HashGNNMemoryEstimateDefinition(parameters).memoryEstimation();
}

@Override
public MemoryEstimation memoryEstimation(CONFIG config) {
return memoryEstimation(HashGNNConfigTransformer.toParameters(config));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ void estimationShouldUseGeneratedDimensionIfOutputIsMissing() {
Optional.of(GenerateFeaturesConfigImpl.builder().dimension(inputDimension).densityLevel(1).build()),
Optional.empty()
);
var bigEstimation = new HashGNNFactory<>()
.memoryEstimation(bigParameters)
var bigEstimation = new HashGNNMemoryEstimateDefinition(bigParameters)
.memoryEstimation()
.estimate(graphDims, concurrency)
.memoryUsage();

Expand All @@ -128,8 +128,8 @@ void estimationShouldUseGeneratedDimensionIfOutputIsMissing() {
Optional.of(GenerateFeaturesConfigImpl.builder().dimension((int) (inputRatio * inputDimension)).densityLevel(1).build()),
Optional.empty()
);
var smallEstimation = new HashGNNFactory<>()
.memoryEstimation(smallParameters)
var smallEstimation = new HashGNNMemoryEstimateDefinition(smallParameters)
.memoryEstimation()
.estimate(graphDims, concurrency)
.memoryUsage();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.ResourceUtil;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.applications.algorithms.embeddings.NodeEmbeddingAlgorithms;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.hsa.HugeSparseLongArray;
import org.neo4j.gds.compat.TestLog;
Expand Down Expand Up @@ -272,6 +275,13 @@ void outputDimensionIsApplied() {
@ParameterizedTest
@CsvSource(value = {"true", "false"})
void shouldLogProgress(boolean dense) {
var log = new GdsTestLog();
var requestScopedDependencies = RequestScopedDependencies.builder()
.terminationFlag(TerminationFlag.RUNNING_TRUE)
.build();
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, requestScopedDependencies.terminationFlag());

var g = dense ? doubleGraph : binaryGraph;

int embeddingDensity = 200;
Expand All @@ -290,12 +300,10 @@ void shouldLogProgress(boolean dense) {
}
var config = configBuilder.build();

var factory = new HashGNNFactory<>();
var progressTask = factory.progressTask(g, config);
var log = new GdsTestLog();
var progressTask = HashGNNTask.create(g, config);
var progressTracker = new TaskProgressTracker(progressTask, log, new Concurrency(4), EmptyTaskRegistryFactory.INSTANCE);

factory.build(g, config, progressTracker).compute();
nodeEmbeddingAlgorithms.hashGnn(g, config, progressTracker);

String logResource;
if (dense) {
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfigTransformer;
import org.neo4j.gds.embeddings.hashgnn.HashGNNResult;
import org.neo4j.gds.embeddings.hashgnn.HashGNNTask;
import org.neo4j.gds.embeddings.node2vec.Node2Vec;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecConfigTransformer;
Expand Down Expand Up @@ -196,10 +197,16 @@ private static GraphSageTrain constructGraphSageTrainAlgorithm(
}

HashGNNResult hashGnn(Graph graph, HashGNNConfig configuration) {
var task = createHashGnnTask(graph, configuration);
var task = HashGNNTask.create(graph, configuration);
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);

var algorithm = new HashGNN(graph, HashGNNConfigTransformer.toParameters(configuration), progressTracker, terminationFlag);
return hashGnn(graph, configuration, progressTracker);
}

public HashGNNResult hashGnn(Graph graph, HashGNNConfig configuration, ProgressTracker progressTracker) {
var parameters = HashGNNConfigTransformer.toParameters(configuration);

var algorithm = new HashGNN(graph, parameters, progressTracker, terminationFlag);

return algorithmMachinery.runAlgorithmsAndManageProgressTracker(
algorithm,
Expand Down Expand Up @@ -246,41 +253,6 @@ private Task createFastRPTask(Graph graph, Number nodeSelfInfluence, int iterati
return Tasks.task(AlgorithmLabel.FastRP.asString(), tasks);
}

private static Task createHashGnnTask(Graph graph, HashGNNConfig configuration) {
var tasks = new ArrayList<Task>();

if (configuration.generateFeatures().isPresent()) {
tasks.add(Tasks.leaf("Generate base node property features", graph.nodeCount()));
} else if (configuration.binarizeFeatures().isPresent()) {
tasks.add(Tasks.leaf("Binarize node property features", graph.nodeCount()));
} else {
tasks.add(Tasks.leaf("Extract raw node property features", graph.nodeCount()));
}

int numRelTypes = configuration.heterogeneous() ? configuration.relationshipTypes().size() : 1;

tasks.add(Tasks.iterativeFixed(
"Propagate embeddings",
() -> List.of(
Tasks.leaf(
"Precompute hashes",
configuration.embeddingDensity() * (1L + 1 + numRelTypes)
),
Tasks.leaf(
"Perform min-hashing",
(2 * graph.nodeCount() + graph.relationshipCount()) * configuration.embeddingDensity()
)
),
configuration.iterations()
));

if (configuration.outputDimension().isPresent()) {
tasks.add(Tasks.leaf("Densify output embeddings", graph.nodeCount()));
}

return Tasks.task(AlgorithmLabel.HashGNN.asString(), tasks);
}

private Task createNode2VecTask(Graph graph, Node2VecBaseConfig configuration) {
var randomWalkTasks = new ArrayList<Task>();
if (graph.hasRelationshipProperty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfigTransformer;
import org.neo4j.gds.embeddings.hashgnn.HashGNNMemoryEstimateDefinition;
import org.neo4j.gds.embeddings.hashgnn.HashGNNParameters;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.Node2VecConfigTransformer;
import org.neo4j.gds.embeddings.node2vec.Node2VecMemoryEstimateDefinition;
Expand Down Expand Up @@ -97,7 +98,13 @@ public MemoryEstimateResult graphSageTrain(GraphSageTrainConfig configuration, O
}

public MemoryEstimation hashGnn(HashGNNConfig configuration) {
return new HashGNNMemoryEstimateDefinition(HashGNNConfigTransformer.toParameters(configuration)).memoryEstimation();
var parameters = HashGNNConfigTransformer.toParameters(configuration);

return hashGnn(parameters);
}

private MemoryEstimation hashGnn(HashGNNParameters parameters) {
return new HashGNNMemoryEstimateDefinition(parameters).memoryEstimation();
}

public MemoryEstimateResult hashGnn(HashGNNConfig configuration, Object graphNameOrConfiguration) {
Expand Down

0 comments on commit 5228383

Please sign in to comment.