-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Does this generate the embedding of the graph:
package MainPackage;
import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.parsers.LayeredBuilder;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.matrix.DenseMatrix;
import mklab.JGNN.core.tensor.DenseTensor;
import mklab.JGNN.nn.Loss;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.optimizers.BatchOptimizer;
import mklab.JGNN.nn.optimizers.GradientDescent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class main {
public static void main(String[] args) {
// Create adjacency matrix for a simple graph (A-B, B-C)
Matrix adjacencyMatrix = new DenseMatrix(10, 10);
adjacencyMatrix.put(0, 1, 1); // A-B
adjacencyMatrix.put(1, 0, 1); // B-A
adjacencyMatrix.put(1, 2, 1); // B-C
adjacencyMatrix.put(2, 1, 1); // C-B
adjacencyMatrix.put(2, 3, 1); // C-D
adjacencyMatrix.put(3, 2, 1); // D-C
adjacencyMatrix.put(3, 4, 1); // D-E
adjacencyMatrix.put(4, 3, 1); // E-D
adjacencyMatrix.put(4, 5, 1); // E-F
adjacencyMatrix.put(5, 4, 1); // F-E
adjacencyMatrix.put(5, 6, 1); // F-G
adjacencyMatrix.put(6, 5, 1); // G-F
adjacencyMatrix.put(6, 7, 1); // G-H
adjacencyMatrix.put(7, 6, 1); // H-G
adjacencyMatrix.put(7, 8, 1); // H-I
adjacencyMatrix.put(8, 7, 1); // I-H
adjacencyMatrix.put(8, 9, 1); // I-J
adjacencyMatrix.put(9, 8, 1); // J-I
adjacencyMatrix.put(0, 9, 1); // A-J
adjacencyMatrix.put(9, 0, 1); // J-A
// Create a feature set with 128 features per node
int numNodes = 10; // Number of nodes in the graph
int numFeatures = 128; // Number of features per node
Matrix nodeFeatures = new DenseMatrix(numNodes, numFeatures);
// Manually populate the feature matrix or generate random values for the features
for (int i = 0; i < numNodes; i++) {
for (int j = 0; j < numFeatures; j++) {
// You can generate the features randomly, or input them manually
nodeFeatures.put(i, j, Math.random()); // Random feature value between 0 and 1
}
}
// Graph label (binary classification task)
Tensor graphLabel = new DenseTensor(1);
graphLabel.put(0, 1.0); // Label for the graph (e.g., 1 for positive class)
// Prepare the data for training
ArrayList<Matrix> adjacencyMatrices = new ArrayList<>();
ArrayList<Matrix> nodeFeaturesList = new ArrayList<>();
ArrayList<Tensor> graphLabels = new ArrayList<>();
adjacencyMatrices.add(adjacencyMatrix);
nodeFeaturesList.add(nodeFeatures);
graphLabels.add(graphLabel);
// Define the model using JGNN LayeredBuilder
ModelBuilder builder = new LayeredBuilder()
.var("A") // Sparse adjacency matrix input
.config("features", numFeatures) // Number of features per node
.config("hidden", 5) // Hidden layer size
.config("classes", 1) // Number of output classes (binary classification)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden)))") // Hidden layer computation
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, classes)))") // Output layer for classification
.layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))") // Pooling layer (averaging over nodes)
.out("h{l}"); // Final output layer to get graph-level embedding
// Initialize the model
Model model = builder.getModel();
model.init(new XavierNormal());
// Loss function and optimizer
Loss loss = new CategoricalCrossEntropy();
BatchOptimizer optimizer = new BatchOptimizer(new GradientDescent(0.5));
// Since we have one graph, we can directly use the adjacency and feature matrices
Matrix adjacency = adjacencyMatrices.get(0); // Only one graph
Matrix features = nodeFeaturesList.get(0); // Features for the single graph
Tensor label = graphLabels.get(0); // Label for the graph (e.g., binary classification)
// Train the model using the adjacency matrix and node features
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
List.of(label));
// Train the model for 300 epochs
for (int epoch = 0; epoch < 300; epoch++) {
optimizer.updateAll(); // Update model parameters after each epoch
}
for (var i : model.getParameters()) {
System.out.println(i.getPrediction());
}
}
}
Metadata
Metadata
Assignees
Labels
No labels