Graph Neural Network models
In this section, we'll explore the architecture of the Graph Neural Network (GNN) models used in QuantumGravPy. These models are designed to learn meaningful representations from causal set graphs, which can then be used for various downstream tasks such as classification or regression.
Structure
The GNN models in QuantumGravPy follow a modular architecture that consists of three main components:
- Backbone: A sequence of GNN blocks that processes node features and graph connectivity to produce node embeddings.
- Optional addition: A network for processing additional graph-level features.
- Frontend: A linear block that takes the embeddings from the backbone (and optional graph features) and produces predictions for specific tasks. This can be replaced with another task specific frontend by overwriting the respective parts of
GNNModel
with your own logic.
This modular design allows for flexibility in model configuration, making it easy to adapt the architecture to different types of causal set data and analysis tasks. The separation of the backbone from the frontend enables transfer learning approaches, where a pre-trained backbone can be reused for multiple different downstream tasks:
Input Graph (Node features + topology)
│
▼
┌──────────────┐ ┌─────────────────┐
│ Backbone │ │ Graph Features │
│ (GNN Blocks)│ │ Network │
└──────┬───────┘ └────────┬────────┘
│ │
▼ ▼
┌────────────────────────────────────┐
│ Concatenate │
└──────────────────┬─────────────────┘
│
▼
┌────────────────────────────────────┐
│ Frontend │
└──────────────────┬─────────────────┘
│
▼
Output
Backbone
The backbone consists of a sequence of GNNBlock
modules that transform the input node features through multiple graph convolutional layers. Each GNNBlock includes:
- A graph convolution layer (GCN, GraphConv, SageConv, etc.)
- A batch normalization layer
- An activation function
- A residual connection
- Dropout for regularization
Input graph (Node features + topology)
│ │
▼ │
┌───────────────┐ │
│ GNN layer │ │
│ (GCNConv,...)│ │
└───────┬───────┘ │
│ │
▼ │
┌───────────────┐ │
│ Normalize │ │
│ (BatchNorm..) │ │
└───────┬───────┘ │
│ │
▼ │
┌───────────────┐ │
│ Normalize │ │
│ (BatchNorm..) │ │
└───────┬───────┘ │
│ │
▼ │
┌────────────────┐ │
│ Activation │ │
│ (ReLu...) │ │
└───────┬────────┘ │
│ │
▼ ▼
┌─────────────────────┐
│ + │ residual connection
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ Dropout │ only active during training
└──────────┬──────────┘
│
▼
Output
This implementation follows modern deep learning practices, with residual connections helping to mitigate the vanishing gradient problem and enabling the training of deeper networks. Each block is set up with a set of parameters:
# Example GNNBlock configuration
{
"in_dim": 64,
"out_dim": 64,
"dropout": 0.3,
"gnn_layer_type": "GCNConv",
"normalizer": "BatchNorm",
"activation": "ReLU",
"gnn_layer_args": [arg1, arg2, ...]
"gnn_layer_kwargs":{
key1: kwarg1,
key2: kwarg2,
key3: ...
}
"norm_args": [arg1, arg2, ...]
"norm_kwargs": {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
"activation_args": [arg1, arg2, ...]
"activation_kwargs": {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
}
The backbone processes the node features and graph connectivity through these blocks sequentially, and the final node embeddings are pooled (using a configurable pooling operation such as mean, sum, or max pooling) to obtain a graph-level representation. It makes sense to familiarize yourself with the pytorch_geometric
documentation if you don't have already, to learn how the GNN layers, activation functions etc work.
The gnn_layer_type
, activation
and normalizer
can be chosen from a set of predefined layers:
gnn_layers: dict[str, torch.nn.Module] = {
"gcn": tgnn.conv.GCNConv,
"gat": tgnn.conv.GATConv,
"sage": tgnn.conv.SAGEConv,
"gco": tgnn.conv.GraphConv,
}
normalizer_layers: dict[str, torch.nn.Module] = {
"identity": torch.nn.Identity,
"batch_norm": torch.nn.BatchNorm1d,
"layer_norm": torch.nn.LayerNorm,
}
activation_layers: dict[str, torch.nn.Module] = {
"relu": torch.nn.ReLU,
"leaky_relu": torch.nn.LeakyReLU,
"sigmoid": torch.nn.Sigmoid,
"tanh": torch.nn.Tanh,
"identity": torch.nn.Identity,
}
pooling_layers: dict[str, torch.nn.Module] = {
"mean": tgnn.global_mean_pool,
"max": tgnn.global_max_pool,
"sum": tgnn.global_add_pool,
}
register_
functions in the utils
module. Consult the API documentation for more details.
Optional graph-level features
In many real-world scenarios, we have additional graph-level features that can't be naturally represented as node features. The GraphFeaturesBlock
allows for processing these auxiliary features and integrating them with the learned graph representation from the backbone.
The GraphFeaturesBlock
is a sequence of linear layers with activation functions that processes these global graph features. The outputs of the GNN backbone and the graph features network are concatenated before being passed to the frontend classifier. It's construction parameters are similar to the GNNBlock
:
# Example GraphFeaturesBlock configuration
{
"input_dim": 10,
"output_dim": 32,
"hidden_dims": [64, 64],
"activation": "ReLU",
"norm_args": [arg1, arg2, ...]
"norm_kwargs": {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
"activation_args": [arg1, arg2, ...]
"activation_kwargs": {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
}
This approach allows the model to incorporate both local (node-level) and global (graph-level) information when making predictions. By customizing the hidden_dims
you can make the network shallower or deeper or wider or narrower.
Edges features are currently not supported out of the box.
Frontends
The frontend model takes the combined embeddings from the backbone and optional graph features network and produces predictions for specific tasks. The ClassifierBlock
is based on the LinearSequential
module, which supports multi-objective classification by allowing multiple output layers, each corresponding to a different task. It's construction parameters are once more similar to the GNNBlock
and GraphFeaturesBlock
models, e.g.:
# Example ClassifierBlock configuration
{
"input_dim": 96, # Combined dimension from backbone and graph features
"output_dims": [2, 3], # Two tasks: binary classification and 3-class classification
"hidden_dims": [128, 64],
"activation": "ReLU",
"backbone_kwargs":
- {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
- {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
- {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
output_kwargs:
- {
key1: kwarg1,
key2: kwarg2,
key3: ...
},
- {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
"activation_args": [arg1, arg2, ...]
"activation_kwargs": {
key1: kwarg1,
key2: kwarg2,
key3: ...
}
}
hidden_dims
you can make the network shallower or deeper or wider or narrower.
In this way, transfer learning is supported - we can change the frontend depending on the task, while keeping the model that produces embeddings.
The output layers are implemneted as one linear layer per task, and each of them can be given a list of kwargs. The same holds for the input- and hidden layers: for each of them, the backbone_kwargs
config node can contain a set of kwargs, too. you can leave out some of them if desired, but for skipping one you need to put in an empty dictionary because the code maps kwargs to layers by index in the list.
Configuration-driven model creation
All model components can be created from configuration dictionaries, making it easy to experiment with different architectures without changing the code. The from_config
class methods in each component allow for declarative model definition through YAML configuration files.
# Example full model configuration
{
"encoder": [
{
"in_dim": 8,
"out_dim": 64,
"gnn_layer_type": "GCNConv",
"normalizer": "BatchNorm",
"activation": "ReLU"
},
{
"in_dim": 64,
"out_dim": 64,
"gnn_layer_type": "GCNConv",
"normalizer": "BatchNorm",
"activation": "ReLU"
}
],
"pooling_layer": "mean",
"graph_features_net": {
"input_dim": 10,
"output_dim": 32,
"hidden_dims": [64, 64],
"activation": "ReLU"
},
"classifier": {
"input_dim": 96, # 64 from backbone + 32 from graph features
"output_dims": [2],
"hidden_dims": [128, 64],
"activation": "ReLU"
}
}
This configuration-based approach aligns with the overall philosophy of QuantumGravPy, which emphasizes a separation between configuration and code to facilitate rapid experimentation and reproducibility. Note how the number of blocks in the encoder
part of the config determines the architecture of the backbone model - each block will create a GNNBlock
instance, and the data is processed through them sequentially. In the same way, the number of hidden_dims
will make the linear models deeper or shallower, and the supplied numbers determine the model width.
The output dim in classifier
determines the number of classification tasks.
Note that right now, only the ClassifierBlock
is explicitly supported.
Transfer learning and model reuse
The separation of backbone and frontend enables effective transfer learning strategies. A backbone trained on a large dataset of causal sets can capture general properties of causal set graphs, which can then be reused for multiple downstream tasks by attaching different frontend classifiers. Currently, only the classifier block is supproted, but later on, we plan to generalize this such that the system can accomodate generative models as well.
To extract the embeddings from a trained model for reuse, you can use the get_embeddings
method of the GNNModel
class, which returns the output of the backbone after pooling, ready to be used for other tasks.
Full config example:
A full configuration that defines a model as it would appear in a YAML config file could read:
encoder:
- in_dim: 12
out_dim: 128
dropout: 0.3
gnn_layer_type: "sage"
normalizer: "batch_norm"
activation: "relu"
norm_args:
- 128
gnn_layer_kwargs:
normalize: False
bias: True
project: False
root_weight: False
aggr: "mean"
- in_dim: 128
out_dim: 256
dropout: 0.3
gnn_layer_type: "sage"
normalizer: "batch_norm"
activation: "relu"
norm_args:
- 256
gnn_layer_kwargs:
normalize: False
bias: True
project: False
root_weight: False
aggr: "mean"
- in_dim: 256
out_dim: 128
dropout: 0.3
gnn_layer_type: "sage"
normalizer: "batch_norm"
activation: "relu"
norm_args:
- 128
gnn_layer_kwargs:
normalize: False
bias: True
project: False
root_weight: False
aggr: "mean"
pooling_layer: mean
classifier:
input_dim: 128
output_dims:
- 2
hidden_dims:
- 48
- 18
activation: "relu"
backbone_kwargs: [{}, {}]
output_kwargs: [{}]
activation_kwargs: [{ "inplace": False }]
Here, we use a chain of three GraphSage
-based GNNBlocks
, using batch normalization, a dropout probability of 0.3 and a ReLU
activation function, followed by a global mean
pooling layer and a classifer that has an input layer of size 128, two hidden layers of sizes 48 and 18, and an output layer of size 2, i.e., 2 classification tasks. We also make add some arguments for the constructor of the activation (activation_kwargs
). The first GNNBlock
must have an input dimension that corresponds to the dimensionality of the node features per node. We have no GraphfeaturesBlock
in this case, but if we had it would look similar to the classifier
block, the config node would just be called graph_features_net
instead of classifier
.
Armed with this knowledge, we now can proceed to model training.