Skip to content

Adding JGNN to maven #13

@Mohammed-Ryiad-Eiadeh

Description

@Mohammed-Ryiad-Eiadeh

Does this generate the embedding of the graph:

  package MainPackage;
  
  import mklab.JGNN.adhoc.ModelBuilder;
  import mklab.JGNN.adhoc.parsers.LayeredBuilder;
  import mklab.JGNN.core.Matrix;
  import mklab.JGNN.core.Tensor;
  import mklab.JGNN.core.matrix.DenseMatrix;
  import mklab.JGNN.core.tensor.DenseTensor;
  import mklab.JGNN.nn.Loss;
  import mklab.JGNN.nn.Model;
  import mklab.JGNN.nn.initializers.XavierNormal;
  import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
  import mklab.JGNN.nn.optimizers.BatchOptimizer;
  import mklab.JGNN.nn.optimizers.GradientDescent;
  
  import java.util.ArrayList;
  import java.util.Arrays;
  import java.util.List;
  
  public class main {
      public static void main(String[] args) {
          // Create adjacency matrix for a simple graph (A-B, B-C)
          Matrix adjacencyMatrix = new DenseMatrix(10, 10);
          adjacencyMatrix.put(0, 1, 1);  // A-B
          adjacencyMatrix.put(1, 0, 1);  // B-A
          adjacencyMatrix.put(1, 2, 1);  // B-C
          adjacencyMatrix.put(2, 1, 1);  // C-B
          adjacencyMatrix.put(2, 3, 1);  // C-D
          adjacencyMatrix.put(3, 2, 1);  // D-C
          adjacencyMatrix.put(3, 4, 1);  // D-E
          adjacencyMatrix.put(4, 3, 1);  // E-D
          adjacencyMatrix.put(4, 5, 1);  // E-F
          adjacencyMatrix.put(5, 4, 1);  // F-E
          adjacencyMatrix.put(5, 6, 1);  // F-G
          adjacencyMatrix.put(6, 5, 1);  // G-F
          adjacencyMatrix.put(6, 7, 1);  // G-H
          adjacencyMatrix.put(7, 6, 1);  // H-G
          adjacencyMatrix.put(7, 8, 1);  // H-I
          adjacencyMatrix.put(8, 7, 1);  // I-H
          adjacencyMatrix.put(8, 9, 1);  // I-J
          adjacencyMatrix.put(9, 8, 1);  // J-I
          adjacencyMatrix.put(0, 9, 1);  // A-J
          adjacencyMatrix.put(9, 0, 1);  // J-A
  
          // Create a feature set with 128 features per node
          int numNodes = 10;   // Number of nodes in the graph
          int numFeatures = 128;  // Number of features per node
          Matrix nodeFeatures = new DenseMatrix(numNodes, numFeatures);
  
          // Manually populate the feature matrix or generate random values for the features
          for (int i = 0; i < numNodes; i++) {
              for (int j = 0; j < numFeatures; j++) {
                  // You can generate the features randomly, or input them manually
                  nodeFeatures.put(i, j, Math.random());  // Random feature value between 0 and 1
              }
          }
          // Graph label (binary classification task)
          Tensor graphLabel = new DenseTensor(1);
          graphLabel.put(0, 1.0); // Label for the graph (e.g., 1 for positive class)
  
          // Prepare the data for training
          ArrayList<Matrix> adjacencyMatrices = new ArrayList<>();
          ArrayList<Matrix> nodeFeaturesList = new ArrayList<>();
          ArrayList<Tensor> graphLabels = new ArrayList<>();
  
          adjacencyMatrices.add(adjacencyMatrix);
          nodeFeaturesList.add(nodeFeatures);
          graphLabels.add(graphLabel);
  
          // Define the model using JGNN LayeredBuilder
          ModelBuilder builder = new LayeredBuilder()
                  .var("A")  // Sparse adjacency matrix input
                  .config("features", numFeatures)  // Number of features per node
                  .config("hidden", 5)  // Hidden layer size
                  .config("classes", 1)  // Number of output classes (binary classification)
                  .layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden)))")  // Hidden layer computation
                  .layer("h{l+1}=relu(A@(h{l}@matrix(hidden, classes)))")  // Output layer for classification
                  .layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))")  // Pooling layer (averaging over nodes)
                  .out("h{l}");  // Final output layer to get graph-level embedding
  
          // Initialize the model
          Model model = builder.getModel();
          model.init(new XavierNormal());
  
          // Loss function and optimizer
          Loss loss = new CategoricalCrossEntropy();
          BatchOptimizer optimizer = new BatchOptimizer(new GradientDescent(0.5));
  
          // Since we have one graph, we can directly use the adjacency and feature matrices
          Matrix adjacency = adjacencyMatrices.get(0);   // Only one graph
          Matrix features = nodeFeaturesList.get(0);     // Features for the single graph
          Tensor label = graphLabels.get(0);             // Label for the graph (e.g., binary classification)
  
          // Train the model using the adjacency matrix and node features
          model.train(loss, optimizer,
                  Arrays.asList(features, adjacency),
                  List.of(label));
  
          // Train the model for 300 epochs
          for (int epoch = 0; epoch < 300; epoch++) {
              optimizer.updateAll();  // Update model parameters after each epoch
          }
  
          for (var i : model.getParameters()) {
              System.out.println(i.getPrediction());
          }
      }
  }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions