Skip to content

API Reference

Main model class

The main model class GNNModel is there to tie together the Graph neural network backbone and a multilayer perceptron classifier model that can be configured for various tasks.

GNNModel

Bases: Module

Torch module for the full GCN model, which consists of a GCN backbone, a classifier, and a pooling layer, augmented with optional graph features network. Args: torch.nn.Module: base class

Source code in src/QuantumGrav/gnn_model.py
class GNNModel(torch.nn.Module):
    """Torch module for the full GCN model, which consists of a GCN backbone, a classifier, and a pooling layer, augmented with optional graph features network.
    Args:
        torch.nn.Module: base class
    """

    def __init__(
        self,
        encoder: list[QGGNN.GNNBlock],
        classifier: QGC.ClassifierBlock | None,
        pooling_layer: torch.nn.Module | None,
        graph_features_net: torch.nn.Module | None = torch.nn.Identity(),
    ):
        """Initialize the GNNModel.

        Args:
            encoder (GCNBackbone): GCN backbone network.
            classifier (ClassifierBlock): Classifier block.
            pooling_layer (torch.nn.Module): Pooling layer.
            graph_features_net (torch.nn.Module, optional): Graph features network. Defaults to torch.nn.Identity.
        """
        super().__init__()
        self.encoder = torch.nn.ModuleList(encoder)
        self.classifier = classifier
        self.graph_features_net = graph_features_net
        self.pooling_layer = pooling_layer

    def _eval_encoder(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        gcn_kwargs: dict[Any, Any] | None = None,
    ) -> torch.Tensor:
        """Evaluate the GCN network on the input data.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            gcn_kwargs (dict[Any, Any], optional): Additional arguments for the GCN. Defaults to None.

        Returns:
            torch.Tensor: Output of the GCN network.
        """
        # Apply each GCN layer to the input features
        features = x
        for gnn_layer in self.encoder:
            features = gnn_layer(
                features, edge_index, **(gcn_kwargs if gcn_kwargs else {})
            )
        return features

    def get_embeddings(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor | None = None,
        gcn_kwargs: dict | None = None,
    ) -> torch.Tensor:
        """Get embeddings from the GCN model.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector for pooling.
            gcn_kwargs (dict, optional): Additional arguments for the GCN. Defaults to None.

        Returns:
            torch.Tensor: Embedding vector for the graph features.
        """
        # apply the GCN backbone to the node features
        embeddings = self._eval_encoder(
            x, edge_index, **(gcn_kwargs if gcn_kwargs else {})
        )

        # pool everything together into a single graph representation
        embeddings = self.pooling_layer(embeddings, batch)
        return embeddings

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor,
        graph_features: torch.Tensor | None = None,
        gcn_kwargs: dict[Any, Any] | None = None,
    ) -> torch.Tensor | Collection[torch.Tensor]:
        """Forward run of the gnn model with optional graph features.
        First execute the graph-neural network backbone, then process the graph features, and finally apply the classifier.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector for pooling.
            graph_features (torch.Tensor, optional): Additional graph features. Defaults to None.
            gcn_kwargs (dict[Any, Any], optional): Additional arguments for the GCN. Defaults to None.

        Returns:
            torch.Tensor | Collection[torch.Tensor]: Class predictions.
        """

        # apply the GCN backbone to the node features
        embeddings = self.get_embeddings(x, edge_index, batch, gcn_kwargs=gcn_kwargs)

        # If we have graph features, we need to process them and concatenate them with the node features
        if graph_features is not None:
            graph_features = self.graph_features_net(graph_features)
            embeddings = torch.cat(
                (embeddings, graph_features), dim=-1
            )  # -1 -> last dim. This concatenates, but we also could sum them

        # Classifier creates raw the logits
        # no softmax or sigmoid is applied here, as we want to keep the logits for loss calculation
        class_predictions = self.classifier(embeddings)

        return class_predictions

    @classmethod
    def from_config(cls, config: dict) -> "GNNModel":
        """Create a GNNModel from a configuration dictionary.

        Args:
            config (dict): Configuration dictionary containing parameters for the model.

        Returns:
            GNNModel: An instance of GNNModel.
        """
        encoder = [QGGNN.GNNBlock.from_config(cfg) for cfg in config["encoder"]]
        classifier = QGC.ClassifierBlock.from_config(config["classifier"])
        pooling_layer = utils.get_registered_pooling_layer(config["pooling_layer"])
        graph_features_net = (
            QGF.GraphFeaturesBlock.from_config(config["graph_features_net"])
            if "graph_features_net" in config
            and config["graph_features_net"] is not None
            else torch.nn.Identity()
        )

        return cls(
            encoder=encoder,
            classifier=classifier,
            pooling_layer=pooling_layer,
            graph_features_net=graph_features_net,
        )

    def save(self, path: str | Path) -> None:
        """Save the model state to file. This saves a dictionary structured like this:
         'encoder': self.encoder,
         'classifier': self.classifier,
         'pooling_layer': self.pooling_layer,
         'graph_features_net': self.graph_features_net

        Args:
            path (str | Path): Path to save the model to
        """
        torch.save(
            {
                "encoder": self.encoder,
                "classifier": self.classifier,
                "pooling_layer": self.pooling_layer,
                "graph_features_net": self.graph_features_net,
            },
            path,
        )

    @classmethod
    def load(
        cls, path: str | Path, device: torch.device = torch.device("cpu")
    ) -> "GNNModel":
        """Load a model from file that has previously been save with the function 'save'.

        Args:
            path (str | Path): path to load the model from.
            device (torch.device): device to put the model to. Defaults to torch.device("cpu")
        Returns:
            GNNModel: model instance initialized with the sub-models loaded from file.
        """
        model_dict = torch.load(path, map_location=device, weights_only=False)

        return cls(
            model_dict["encoder"],
            model_dict["classifier"],
            model_dict["pooling_layer"],
            model_dict["graph_features_net"],
        )

__init__(encoder, classifier, pooling_layer, graph_features_net=torch.nn.Identity())

Initialize the GNNModel.

Parameters:

Name Type Description Default
encoder GCNBackbone

GCN backbone network.

required
classifier ClassifierBlock

Classifier block.

required
pooling_layer Module

Pooling layer.

required
graph_features_net Module

Graph features network. Defaults to torch.nn.Identity.

Identity()
Source code in src/QuantumGrav/gnn_model.py
def __init__(
    self,
    encoder: list[QGGNN.GNNBlock],
    classifier: QGC.ClassifierBlock | None,
    pooling_layer: torch.nn.Module | None,
    graph_features_net: torch.nn.Module | None = torch.nn.Identity(),
):
    """Initialize the GNNModel.

    Args:
        encoder (GCNBackbone): GCN backbone network.
        classifier (ClassifierBlock): Classifier block.
        pooling_layer (torch.nn.Module): Pooling layer.
        graph_features_net (torch.nn.Module, optional): Graph features network. Defaults to torch.nn.Identity.
    """
    super().__init__()
    self.encoder = torch.nn.ModuleList(encoder)
    self.classifier = classifier
    self.graph_features_net = graph_features_net
    self.pooling_layer = pooling_layer

forward(x, edge_index, batch, graph_features=None, gcn_kwargs=None)

Forward run of the gnn model with optional graph features. First execute the graph-neural network backbone, then process the graph features, and finally apply the classifier.

Parameters:

Name Type Description Default
x Tensor

Input node features.

required
edge_index Tensor

Graph connectivity information.

required
batch Tensor

Batch vector for pooling.

required
graph_features Tensor

Additional graph features. Defaults to None.

None
gcn_kwargs dict[Any, Any]

Additional arguments for the GCN. Defaults to None.

None

Returns:

Type Description
Tensor | Collection[Tensor]

torch.Tensor | Collection[torch.Tensor]: Class predictions.

Source code in src/QuantumGrav/gnn_model.py
def forward(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    batch: torch.Tensor,
    graph_features: torch.Tensor | None = None,
    gcn_kwargs: dict[Any, Any] | None = None,
) -> torch.Tensor | Collection[torch.Tensor]:
    """Forward run of the gnn model with optional graph features.
    First execute the graph-neural network backbone, then process the graph features, and finally apply the classifier.

    Args:
        x (torch.Tensor): Input node features.
        edge_index (torch.Tensor): Graph connectivity information.
        batch (torch.Tensor): Batch vector for pooling.
        graph_features (torch.Tensor, optional): Additional graph features. Defaults to None.
        gcn_kwargs (dict[Any, Any], optional): Additional arguments for the GCN. Defaults to None.

    Returns:
        torch.Tensor | Collection[torch.Tensor]: Class predictions.
    """

    # apply the GCN backbone to the node features
    embeddings = self.get_embeddings(x, edge_index, batch, gcn_kwargs=gcn_kwargs)

    # If we have graph features, we need to process them and concatenate them with the node features
    if graph_features is not None:
        graph_features = self.graph_features_net(graph_features)
        embeddings = torch.cat(
            (embeddings, graph_features), dim=-1
        )  # -1 -> last dim. This concatenates, but we also could sum them

    # Classifier creates raw the logits
    # no softmax or sigmoid is applied here, as we want to keep the logits for loss calculation
    class_predictions = self.classifier(embeddings)

    return class_predictions

from_config(config) classmethod

Create a GNNModel from a configuration dictionary.

Parameters:

Name Type Description Default
config dict

Configuration dictionary containing parameters for the model.

required

Returns:

Name Type Description
GNNModel GNNModel

An instance of GNNModel.

Source code in src/QuantumGrav/gnn_model.py
@classmethod
def from_config(cls, config: dict) -> "GNNModel":
    """Create a GNNModel from a configuration dictionary.

    Args:
        config (dict): Configuration dictionary containing parameters for the model.

    Returns:
        GNNModel: An instance of GNNModel.
    """
    encoder = [QGGNN.GNNBlock.from_config(cfg) for cfg in config["encoder"]]
    classifier = QGC.ClassifierBlock.from_config(config["classifier"])
    pooling_layer = utils.get_registered_pooling_layer(config["pooling_layer"])
    graph_features_net = (
        QGF.GraphFeaturesBlock.from_config(config["graph_features_net"])
        if "graph_features_net" in config
        and config["graph_features_net"] is not None
        else torch.nn.Identity()
    )

    return cls(
        encoder=encoder,
        classifier=classifier,
        pooling_layer=pooling_layer,
        graph_features_net=graph_features_net,
    )

get_embeddings(x, edge_index, batch=None, gcn_kwargs=None)

Get embeddings from the GCN model.

Parameters:

Name Type Description Default
x Tensor

Input node features.

required
edge_index Tensor

Graph connectivity information.

required
batch Tensor

Batch vector for pooling.

None
gcn_kwargs dict

Additional arguments for the GCN. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: Embedding vector for the graph features.

Source code in src/QuantumGrav/gnn_model.py
def get_embeddings(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    batch: torch.Tensor | None = None,
    gcn_kwargs: dict | None = None,
) -> torch.Tensor:
    """Get embeddings from the GCN model.

    Args:
        x (torch.Tensor): Input node features.
        edge_index (torch.Tensor): Graph connectivity information.
        batch (torch.Tensor): Batch vector for pooling.
        gcn_kwargs (dict, optional): Additional arguments for the GCN. Defaults to None.

    Returns:
        torch.Tensor: Embedding vector for the graph features.
    """
    # apply the GCN backbone to the node features
    embeddings = self._eval_encoder(
        x, edge_index, **(gcn_kwargs if gcn_kwargs else {})
    )

    # pool everything together into a single graph representation
    embeddings = self.pooling_layer(embeddings, batch)
    return embeddings

load(path, device=torch.device('cpu')) classmethod

Load a model from file that has previously been save with the function 'save'.

Parameters:

Name Type Description Default
path str | Path

path to load the model from.

required
device device

device to put the model to. Defaults to torch.device("cpu")

device('cpu')

Returns: GNNModel: model instance initialized with the sub-models loaded from file.

Source code in src/QuantumGrav/gnn_model.py
@classmethod
def load(
    cls, path: str | Path, device: torch.device = torch.device("cpu")
) -> "GNNModel":
    """Load a model from file that has previously been save with the function 'save'.

    Args:
        path (str | Path): path to load the model from.
        device (torch.device): device to put the model to. Defaults to torch.device("cpu")
    Returns:
        GNNModel: model instance initialized with the sub-models loaded from file.
    """
    model_dict = torch.load(path, map_location=device, weights_only=False)

    return cls(
        model_dict["encoder"],
        model_dict["classifier"],
        model_dict["pooling_layer"],
        model_dict["graph_features_net"],
    )

save(path)

Save the model state to file. This saves a dictionary structured like this: 'encoder': self.encoder, 'classifier': self.classifier, 'pooling_layer': self.pooling_layer, 'graph_features_net': self.graph_features_net

Parameters:

Name Type Description Default
path str | Path

Path to save the model to

required
Source code in src/QuantumGrav/gnn_model.py
def save(self, path: str | Path) -> None:
    """Save the model state to file. This saves a dictionary structured like this:
     'encoder': self.encoder,
     'classifier': self.classifier,
     'pooling_layer': self.pooling_layer,
     'graph_features_net': self.graph_features_net

    Args:
        path (str | Path): Path to save the model to
    """
    torch.save(
        {
            "encoder": self.encoder,
            "classifier": self.classifier,
            "pooling_layer": self.pooling_layer,
            "graph_features_net": self.graph_features_net,
        },
        path,
    )

Multilayer Perceptron submodels

These classes provide a sequence of linear (affine) layers in various configurations that can be used to create classifiers by deriving from it.

Base class for models composed of linear layers

LinearSequential

Bases: Module

This class implements a neural network block consisting of a backbone (a sequence of linear layers with activation functions) and multiple output layers for classification tasks. It supports multi-objective classification by allowing multiple output layers, each corresponding to a different classification task, but can also be used for any other type of sequential processing that involves linear layers.

Source code in src/QuantumGrav/linear_sequential.py
class LinearSequential(torch.nn.Module):
    """This class implements a neural network block consisting of a backbone
    (a sequence of linear layers with activation functions) and multiple
    output layers for classification tasks. It supports multi-objective
    classification by allowing multiple output layers, each corresponding
    to a different classification task, but can also be used for any other type of sequential processing that involves linear layers.
    """

    def __init__(
        self,
        input_dim: int,
        output_dims: list[int],
        hidden_dims: list[int] | None = None,
        activation: type[torch.nn.Module] = torch.nn.ReLU,
        backbone_kwargs: list[dict] | None = None,
        output_kwargs: list[dict] | None = None,
        activation_kwargs: list[dict] | None = None,
    ):
        """Create a LinearSequential object with a backbone and multiple output layers. All layers are of type `Linear` with an activation function in between (the backbone) and a set of linear output layers.

        Args:
            input_dim (int): input dimension of the LinearSequential object
            output_dims (list[int]): list of output dimensions for each output layer, i.e., each classification task
            hidden_dims (list[int]): list of hidden dimensions for the backbone
            activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
            backbone_kwargs (list[dict], optional): additional arguments for the backbone layers. Defaults to None.
            output_kwargs (list[dict], optional): additional arguments for the output layers. Defaults to None.

        Raises:
            ValueError: If hidden_dims contains non-positive integers.
            ValueError: If output_dims is empty or contains non-positive integers.
            ValueError: If any output_dim is not a positive integer.
        """
        super().__init__()

        # validate input parameters
        if hidden_dims is None:
            raise ValueError("hidden_dims must not be None")

        if not all(h > 0 for h in hidden_dims):
            raise ValueError("hidden_dims must be a list of positive integers")

        if len(output_dims) == 0:
            raise ValueError("output_dims must be a non-empty list of integers")

        if not all(o > 0 for o in output_dims):
            raise ValueError("output_dims must be a list of positive integers")

        # manage kwargs for the different parts of the network
        processed_backbone_kwargs = self._handle_kwargs(
            backbone_kwargs, "backbone_kwargs", len(hidden_dims)
        )

        processed_activation_kwargs = self._handle_kwargs(
            activation_kwargs, "activation_kwargs", len(hidden_dims)
        )

        processed_output_kwargs = self._handle_kwargs(
            output_kwargs, "output_kwargs", len(output_dims)
        )

        # build backbone with Sequential
        layers = []
        in_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(
                torch_geometric.nn.dense.Linear(
                    in_dim,
                    hidden_dim,
                    **processed_backbone_kwargs[i],
                )
            )
            layers.append(
                activation(
                    **processed_activation_kwargs[i],
                )
            )
            in_dim = hidden_dim

        if len(layers) > 0:
            self.backbone = torch.nn.Sequential(*layers)
        else:
            self.backbone = torch.nn.Identity()

        # build the final layers - take care of possible multi-objective classification
        output_layers = []

        final_in_dim = (
            hidden_dims[-1] if hidden_dims and len(hidden_dims) > 0 else input_dim
        )

        for i, output_dim in enumerate(output_dims):
            output_layer = torch_geometric.nn.dense.Linear(
                final_in_dim,
                output_dim,
                **(
                    processed_output_kwargs[i]
                    if processed_output_kwargs and processed_output_kwargs[i]
                    else {}
                ),
            )
            output_layers.append(output_layer)
        self.output_layers = torch.nn.ModuleList(output_layers)

    def _handle_kwargs(
        self, kwarglist: list[dict] | None, name: str, needed: int
    ) -> list[dict]:
        """
        handle kwargs for the backbone or activation functions.
        """
        if kwarglist is None:
            kwarglist = [{}] * needed
        elif len(kwarglist) == 1:
            kwarglist = kwarglist * needed
        elif len(kwarglist) != needed:
            raise ValueError(
                f"{name} must be a list of dictionaries with the same length as hidden_dims"
            )
        return kwarglist

    def forward(
        self,
        x: torch.Tensor,
    ) -> list[torch.Tensor]:
        """Forward pass through the LinearSequential object.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            list[torch.Tensor]: List of output tensors from each classifier layer.
        """
        # Sequential handles passing output from one layer to the next
        features = self.backbone(x)  # No need for manual looping or cloning

        # Apply each output layer to the backbone output
        logits = [layer(features) for layer in self.output_layers]

        return logits

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "LinearSequential":
        """Create a LinearSequential from a configuration dictionary.

        Args:
            config (dict[str, Any]): Configuration dictionary containing parameters for the LinearSequential.

        Returns:
            LinearSequential: An instance of LinearSequential initialized with the provided configuration.
        """
        return cls(
            input_dim=config["input_dim"],
            output_dims=config["output_dims"],
            hidden_dims=config["hidden_dims"],
            activation=utils.get_registered_activation(
                config.get("activation", "relu")
            ),
            backbone_kwargs=config.get("backbone_kwargs", None),
            output_kwargs=config.get("output_kwargs", None),
            activation_kwargs=config.get("activation_kwargs", None),
        )

    def save(self, path: str | Path) -> None:
        """Save the model's state to file.

        Args:
            path (str | Path): path to save the model to.
        """

        torch.save(self, path)

    @classmethod
    def load(
        cls, path: str | Path, device: torch.device = torch.device("cpu")
    ) -> "LinearSequential":
        """Load a LinearSequential instance from file

        Args:
            path (str | Path): path to the file to load the model from
            device (torch.device): device to put the model to. Defaults to torch.device("cpu")
        Returns:
            LinearSequential: An instance of LinearSequential initialized from the loaded data.
        """
        model = torch.load(path, map_location=device, weights_only=False)
        return model

__init__(input_dim, output_dims, hidden_dims=None, activation=torch.nn.ReLU, backbone_kwargs=None, output_kwargs=None, activation_kwargs=None)

Create a LinearSequential object with a backbone and multiple output layers. All layers are of type Linear with an activation function in between (the backbone) and a set of linear output layers.

Parameters:

Name Type Description Default
input_dim int

input dimension of the LinearSequential object

required
output_dims list[int]

list of output dimensions for each output layer, i.e., each classification task

required
hidden_dims list[int]

list of hidden dimensions for the backbone

None
activation type[Module]

activation function to use. Defaults to torch.nn.ReLU.

ReLU
backbone_kwargs list[dict]

additional arguments for the backbone layers. Defaults to None.

None
output_kwargs list[dict]

additional arguments for the output layers. Defaults to None.

None

Raises:

Type Description
ValueError

If hidden_dims contains non-positive integers.

ValueError

If output_dims is empty or contains non-positive integers.

ValueError

If any output_dim is not a positive integer.

Source code in src/QuantumGrav/linear_sequential.py
def __init__(
    self,
    input_dim: int,
    output_dims: list[int],
    hidden_dims: list[int] | None = None,
    activation: type[torch.nn.Module] = torch.nn.ReLU,
    backbone_kwargs: list[dict] | None = None,
    output_kwargs: list[dict] | None = None,
    activation_kwargs: list[dict] | None = None,
):
    """Create a LinearSequential object with a backbone and multiple output layers. All layers are of type `Linear` with an activation function in between (the backbone) and a set of linear output layers.

    Args:
        input_dim (int): input dimension of the LinearSequential object
        output_dims (list[int]): list of output dimensions for each output layer, i.e., each classification task
        hidden_dims (list[int]): list of hidden dimensions for the backbone
        activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
        backbone_kwargs (list[dict], optional): additional arguments for the backbone layers. Defaults to None.
        output_kwargs (list[dict], optional): additional arguments for the output layers. Defaults to None.

    Raises:
        ValueError: If hidden_dims contains non-positive integers.
        ValueError: If output_dims is empty or contains non-positive integers.
        ValueError: If any output_dim is not a positive integer.
    """
    super().__init__()

    # validate input parameters
    if hidden_dims is None:
        raise ValueError("hidden_dims must not be None")

    if not all(h > 0 for h in hidden_dims):
        raise ValueError("hidden_dims must be a list of positive integers")

    if len(output_dims) == 0:
        raise ValueError("output_dims must be a non-empty list of integers")

    if not all(o > 0 for o in output_dims):
        raise ValueError("output_dims must be a list of positive integers")

    # manage kwargs for the different parts of the network
    processed_backbone_kwargs = self._handle_kwargs(
        backbone_kwargs, "backbone_kwargs", len(hidden_dims)
    )

    processed_activation_kwargs = self._handle_kwargs(
        activation_kwargs, "activation_kwargs", len(hidden_dims)
    )

    processed_output_kwargs = self._handle_kwargs(
        output_kwargs, "output_kwargs", len(output_dims)
    )

    # build backbone with Sequential
    layers = []
    in_dim = input_dim
    for i, hidden_dim in enumerate(hidden_dims):
        layers.append(
            torch_geometric.nn.dense.Linear(
                in_dim,
                hidden_dim,
                **processed_backbone_kwargs[i],
            )
        )
        layers.append(
            activation(
                **processed_activation_kwargs[i],
            )
        )
        in_dim = hidden_dim

    if len(layers) > 0:
        self.backbone = torch.nn.Sequential(*layers)
    else:
        self.backbone = torch.nn.Identity()

    # build the final layers - take care of possible multi-objective classification
    output_layers = []

    final_in_dim = (
        hidden_dims[-1] if hidden_dims and len(hidden_dims) > 0 else input_dim
    )

    for i, output_dim in enumerate(output_dims):
        output_layer = torch_geometric.nn.dense.Linear(
            final_in_dim,
            output_dim,
            **(
                processed_output_kwargs[i]
                if processed_output_kwargs and processed_output_kwargs[i]
                else {}
            ),
        )
        output_layers.append(output_layer)
    self.output_layers = torch.nn.ModuleList(output_layers)

forward(x)

Forward pass through the LinearSequential object.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
list[Tensor]

list[torch.Tensor]: List of output tensors from each classifier layer.

Source code in src/QuantumGrav/linear_sequential.py
def forward(
    self,
    x: torch.Tensor,
) -> list[torch.Tensor]:
    """Forward pass through the LinearSequential object.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        list[torch.Tensor]: List of output tensors from each classifier layer.
    """
    # Sequential handles passing output from one layer to the next
    features = self.backbone(x)  # No need for manual looping or cloning

    # Apply each output layer to the backbone output
    logits = [layer(features) for layer in self.output_layers]

    return logits

from_config(config) classmethod

Create a LinearSequential from a configuration dictionary.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing parameters for the LinearSequential.

required

Returns:

Name Type Description
LinearSequential LinearSequential

An instance of LinearSequential initialized with the provided configuration.

Source code in src/QuantumGrav/linear_sequential.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "LinearSequential":
    """Create a LinearSequential from a configuration dictionary.

    Args:
        config (dict[str, Any]): Configuration dictionary containing parameters for the LinearSequential.

    Returns:
        LinearSequential: An instance of LinearSequential initialized with the provided configuration.
    """
    return cls(
        input_dim=config["input_dim"],
        output_dims=config["output_dims"],
        hidden_dims=config["hidden_dims"],
        activation=utils.get_registered_activation(
            config.get("activation", "relu")
        ),
        backbone_kwargs=config.get("backbone_kwargs", None),
        output_kwargs=config.get("output_kwargs", None),
        activation_kwargs=config.get("activation_kwargs", None),
    )

load(path, device=torch.device('cpu')) classmethod

Load a LinearSequential instance from file

Parameters:

Name Type Description Default
path str | Path

path to the file to load the model from

required
device device

device to put the model to. Defaults to torch.device("cpu")

device('cpu')

Returns: LinearSequential: An instance of LinearSequential initialized from the loaded data.

Source code in src/QuantumGrav/linear_sequential.py
@classmethod
def load(
    cls, path: str | Path, device: torch.device = torch.device("cpu")
) -> "LinearSequential":
    """Load a LinearSequential instance from file

    Args:
        path (str | Path): path to the file to load the model from
        device (torch.device): device to put the model to. Defaults to torch.device("cpu")
    Returns:
        LinearSequential: An instance of LinearSequential initialized from the loaded data.
    """
    model = torch.load(path, map_location=device, weights_only=False)
    return model

save(path)

Save the model's state to file.

Parameters:

Name Type Description Default
path str | Path

path to save the model to.

required
Source code in src/QuantumGrav/linear_sequential.py
def save(self, path: str | Path) -> None:
    """Save the model's state to file.

    Args:
        path (str | Path): path to save the model to.
    """

    torch.save(self, path)

Classifier model based on linear sequential base class

ClassifierBlock

Bases: LinearSequential

This class implements a neural network block consisting of a backbone (a sequence of linear layers with activation functions) and multiple output layers for classification tasks. It supports multi-objective classification by allowing multiple output layers, each corresponding to a different classification task.

Source code in src/QuantumGrav/classifier_block.py
class ClassifierBlock(QGLS.LinearSequential):
    """This class implements a neural network block consisting of a backbone
    (a sequence of linear layers with activation functions) and multiple
    output layers for classification tasks. It supports multi-objective
    classification by allowing multiple output layers, each corresponding
    to a different classification task.
    """

    def __init__(
        self,
        input_dim: int,
        output_dims: list[int],
        hidden_dims: list[int] = None,
        activation: type[torch.nn.Module] = torch.nn.ReLU,
        backbone_kwargs: list[dict] = None,
        output_kwargs: list[dict] = None,
        activation_kwargs: list[dict] = None,
    ):
        """Instantiate a ClassifierBlock.

        Args:
            input_dim (int): input dimension of the ClassifierBlock
            output_dims (list[int]): output dimensions for each classification task.
            hidden_dims (list[int], optional): list of hidden dimensions for the backbone network. Defaults to None.
            activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
            backbone_kwargs (list[dict], optional): keyword arguments for the backbone network. Defaults to None.
            output_kwargs (list[dict], optional): keyword arguments for the output layers. Defaults to None.
            activation_kwargs (list[dict], optional): keyword arguments for the activation functions. Defaults to None.
        """
        super().__init__(
            input_dim=input_dim,
            output_dims=output_dims,
            hidden_dims=hidden_dims,
            activation=activation,
            backbone_kwargs=backbone_kwargs,
            output_kwargs=output_kwargs,
            activation_kwargs=activation_kwargs,
        )

    @classmethod
    def from_config(cls, config: dict) -> "ClassifierBlock":
        """Create a ClassifierBlock from a configuration dictionary.

        Args:
            config (dict): Configuration dictionary containing parameters for the block.

        Returns:
            ClassifierBlock: An instance of ClassifierBlock.
        """
        return cls(
            input_dim=config["input_dim"],
            output_dims=config["output_dims"],
            hidden_dims=config.get("hidden_dims", []),
            activation=utils.activation_layers[config["activation"]],
            backbone_kwargs=config.get("backbone_kwargs", []),
            output_kwargs=config.get("output_kwargs", []),
            activation_kwargs=config.get("activation_kwargs", {}),
        )

__init__(input_dim, output_dims, hidden_dims=None, activation=torch.nn.ReLU, backbone_kwargs=None, output_kwargs=None, activation_kwargs=None)

Instantiate a ClassifierBlock.

Parameters:

Name Type Description Default
input_dim int

input dimension of the ClassifierBlock

required
output_dims list[int]

output dimensions for each classification task.

required
hidden_dims list[int]

list of hidden dimensions for the backbone network. Defaults to None.

None
activation type[Module]

activation function to use. Defaults to torch.nn.ReLU.

ReLU
backbone_kwargs list[dict]

keyword arguments for the backbone network. Defaults to None.

None
output_kwargs list[dict]

keyword arguments for the output layers. Defaults to None.

None
activation_kwargs list[dict]

keyword arguments for the activation functions. Defaults to None.

None
Source code in src/QuantumGrav/classifier_block.py
def __init__(
    self,
    input_dim: int,
    output_dims: list[int],
    hidden_dims: list[int] = None,
    activation: type[torch.nn.Module] = torch.nn.ReLU,
    backbone_kwargs: list[dict] = None,
    output_kwargs: list[dict] = None,
    activation_kwargs: list[dict] = None,
):
    """Instantiate a ClassifierBlock.

    Args:
        input_dim (int): input dimension of the ClassifierBlock
        output_dims (list[int]): output dimensions for each classification task.
        hidden_dims (list[int], optional): list of hidden dimensions for the backbone network. Defaults to None.
        activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
        backbone_kwargs (list[dict], optional): keyword arguments for the backbone network. Defaults to None.
        output_kwargs (list[dict], optional): keyword arguments for the output layers. Defaults to None.
        activation_kwargs (list[dict], optional): keyword arguments for the activation functions. Defaults to None.
    """
    super().__init__(
        input_dim=input_dim,
        output_dims=output_dims,
        hidden_dims=hidden_dims,
        activation=activation,
        backbone_kwargs=backbone_kwargs,
        output_kwargs=output_kwargs,
        activation_kwargs=activation_kwargs,
    )

from_config(config) classmethod

Create a ClassifierBlock from a configuration dictionary.

Parameters:

Name Type Description Default
config dict

Configuration dictionary containing parameters for the block.

required

Returns:

Name Type Description
ClassifierBlock ClassifierBlock

An instance of ClassifierBlock.

Source code in src/QuantumGrav/classifier_block.py
@classmethod
def from_config(cls, config: dict) -> "ClassifierBlock":
    """Create a ClassifierBlock from a configuration dictionary.

    Args:
        config (dict): Configuration dictionary containing parameters for the block.

    Returns:
        ClassifierBlock: An instance of ClassifierBlock.
    """
    return cls(
        input_dim=config["input_dim"],
        output_dims=config["output_dims"],
        hidden_dims=config.get("hidden_dims", []),
        activation=utils.activation_layers[config["activation"]],
        backbone_kwargs=config.get("backbone_kwargs", []),
        output_kwargs=config.get("output_kwargs", []),
        activation_kwargs=config.get("activation_kwargs", {}),
    )

Graph Neural network submodels

The submodel classes in this section comprise the graph neural network backbone of a QuantumGrav model.

Graph features block

This is a submodel derived from linear sequential that allows us to integrate graph-level features into the model's representation.

GraphFeaturesBlock

Bases: LinearSequential

Graph Features Block for processing global graph features. Similarly to the classifier, this consists of a sequence of linear layers with activation functions.

Source code in src/QuantumGrav/graphfeatures_block.py
class GraphFeaturesBlock(QGLS.LinearSequential):
    """Graph Features Block for processing global graph features. Similarly to the classifier, this consists of a sequence of linear layers with  activation functions."""

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: list[int] | None = None,
        activation: type[torch.nn.Module] = torch.nn.ReLU,
        layer_kwargs: list[dict] | None = None,
        activation_kwargs: dict | None = None,
    ):
        """Create a GraphFeaturesBlock instance. This will create at least one hidden layer and one output layer, with the specified input and output dimensions.

        Args:
            input_dim (int): input dimension of the GraphFeaturesBlock
            output_dim (int): output dimension of the GraphFeaturesBlock
            hidden_dims (list[int], optional): output dimensions of the hidden layers. Defaults to None.
            activation (torch.nn.Module, optional): activation function type, e.g., torch.nn.ReLU. Defaults to torch.nn.ReLU.
            layer_kwargs (list[dict], optional): keyword arguments for the constructors of each layer. Defaults to None.
            activation_kwargs (dict, optional): keyword arguments for the construction of each activation function. Defaults to None.
        """
        super().__init__(
            input_dim=input_dim,
            output_dims=[
                output_dim,
            ],
            hidden_dims=hidden_dims,
            activation=activation,
            backbone_kwargs=layer_kwargs,
            output_kwargs=[
                layer_kwargs[-1] if layer_kwargs and len(layer_kwargs) > 0 else None
            ],
            activation_kwargs=activation_kwargs,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the GraphFeaturesBlock.
        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, input_dim).
        Returns:
            torch.Tensor: Output tensor with shape (batch_size, output_dim).
        """
        res = super().forward(x)
        return res[0] if isinstance(res, list) else res

    @classmethod
    def from_config(cls, config: dict) -> "GraphFeaturesBlock":
        """Create a GraphFeaturesBlock from a configuration dictionary.

        Args:
            config (dict): Configuration dictionary containing parameters for the block.

        Returns:
            GraphFeaturesBlock: An instance of GraphFeaturesBlock.
        """
        return cls(
            input_dim=config["input_dim"],
            output_dim=config["output_dim"],
            hidden_dims=config.get("hidden_dims", []),
            activation=utils.activation_layers[config["activation"]],
            layer_kwargs=config.get("layer_kwargs", []),
            activation_kwargs=config.get("activation_kwargs", {}),
        )

__init__(input_dim, output_dim, hidden_dims=None, activation=torch.nn.ReLU, layer_kwargs=None, activation_kwargs=None)

Create a GraphFeaturesBlock instance. This will create at least one hidden layer and one output layer, with the specified input and output dimensions.

Parameters:

Name Type Description Default
input_dim int

input dimension of the GraphFeaturesBlock

required
output_dim int

output dimension of the GraphFeaturesBlock

required
hidden_dims list[int]

output dimensions of the hidden layers. Defaults to None.

None
activation Module

activation function type, e.g., torch.nn.ReLU. Defaults to torch.nn.ReLU.

ReLU
layer_kwargs list[dict]

keyword arguments for the constructors of each layer. Defaults to None.

None
activation_kwargs dict

keyword arguments for the construction of each activation function. Defaults to None.

None
Source code in src/QuantumGrav/graphfeatures_block.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    hidden_dims: list[int] | None = None,
    activation: type[torch.nn.Module] = torch.nn.ReLU,
    layer_kwargs: list[dict] | None = None,
    activation_kwargs: dict | None = None,
):
    """Create a GraphFeaturesBlock instance. This will create at least one hidden layer and one output layer, with the specified input and output dimensions.

    Args:
        input_dim (int): input dimension of the GraphFeaturesBlock
        output_dim (int): output dimension of the GraphFeaturesBlock
        hidden_dims (list[int], optional): output dimensions of the hidden layers. Defaults to None.
        activation (torch.nn.Module, optional): activation function type, e.g., torch.nn.ReLU. Defaults to torch.nn.ReLU.
        layer_kwargs (list[dict], optional): keyword arguments for the constructors of each layer. Defaults to None.
        activation_kwargs (dict, optional): keyword arguments for the construction of each activation function. Defaults to None.
    """
    super().__init__(
        input_dim=input_dim,
        output_dims=[
            output_dim,
        ],
        hidden_dims=hidden_dims,
        activation=activation,
        backbone_kwargs=layer_kwargs,
        output_kwargs=[
            layer_kwargs[-1] if layer_kwargs and len(layer_kwargs) > 0 else None
        ],
        activation_kwargs=activation_kwargs,
    )

forward(x)

Forward pass through the GraphFeaturesBlock. Args: x (torch.Tensor): Input tensor with shape (batch_size, input_dim). Returns: torch.Tensor: Output tensor with shape (batch_size, output_dim).

Source code in src/QuantumGrav/graphfeatures_block.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the GraphFeaturesBlock.
    Args:
        x (torch.Tensor): Input tensor with shape (batch_size, input_dim).
    Returns:
        torch.Tensor: Output tensor with shape (batch_size, output_dim).
    """
    res = super().forward(x)
    return res[0] if isinstance(res, list) else res

from_config(config) classmethod

Create a GraphFeaturesBlock from a configuration dictionary.

Parameters:

Name Type Description Default
config dict

Configuration dictionary containing parameters for the block.

required

Returns:

Name Type Description
GraphFeaturesBlock GraphFeaturesBlock

An instance of GraphFeaturesBlock.

Source code in src/QuantumGrav/graphfeatures_block.py
@classmethod
def from_config(cls, config: dict) -> "GraphFeaturesBlock":
    """Create a GraphFeaturesBlock from a configuration dictionary.

    Args:
        config (dict): Configuration dictionary containing parameters for the block.

    Returns:
        GraphFeaturesBlock: An instance of GraphFeaturesBlock.
    """
    return cls(
        input_dim=config["input_dim"],
        output_dim=config["output_dim"],
        hidden_dims=config.get("hidden_dims", []),
        activation=utils.activation_layers[config["activation"]],
        layer_kwargs=config.get("layer_kwargs", []),
        activation_kwargs=config.get("activation_kwargs", {}),
    )

Graph model block

This submodel is the main part of the graph neural network backbone, composed of a set of GNN layers from pytorch-geometric with dropout and BatchNorm.

GNNBlock

Bases: Module

Graph Neural Network Block. Consists of a GNN layer, a normalizer, an activation function, and a residual connection. The gnn-layer is applied first, followed by the normalizer and activation function. The result is then projected from the input dimensions to the output dimensions using a linear layer and added to the original input (residual connection). Finally, dropout is applied for regularization.

Source code in src/QuantumGrav/gnn_block.py
class GNNBlock(torch.nn.Module):
    """Graph Neural Network Block. Consists of a GNN layer, a normalizer, an activation function,
    and a residual connection. The gnn-layer is applied first, followed by the normalizer and activation function. The result is then projected from the input dimensions to the output dimensions using a linear layer and added to the original input (residual connection). Finally, dropout is applied for regularization.
    """

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        dropout: float = 0.3,
        gnn_layer_type: torch.nn.Module = tgnn.conv.GCNConv,
        normalizer: torch.nn.Module = torch.nn.Identity,
        activation: torch.nn.Module = torch.nn.ReLU,
        gnn_layer_args: list[Any] | None = None,
        gnn_layer_kwargs: dict[str, Any] | None = None,
        norm_args: list[Any] | None = None,
        norm_kwargs: dict[str, Any] | None = None,
        activation_args: list[Any] | None = None,
        activation_kwargs: dict[str, Any] | None = None,
        projection_args: list[Any] | None = None,
        projection_kwargs: dict[str, Any] | None = None,
    ):
        """Create a GNNBlock instance.

        Args:
            in_dim (int): The dimensions of the input features.
            out_dim (int): The dimensions of the output features.
            dropout (float, optional): The dropout probability. Defaults to 0.3.
            gnn_layer_type (torch.nn.Module, optional): The type of GNN-layer to use. Defaults to tgnn.conv.GCNConv.
            normalizer (torch.nn.Module, optional): The normalizer layer to use. Defaults to torch.nn.Identity.
            activation (torch.nn.Module, optional): The activation function to use. Defaults to torch.nn.ReLU.
            gnn_layer_args (list[Any], optional): Additional arguments for the GNN layer. Defaults to None.
            gnn_layer_kwargs (dict[str, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.
            norm_args (list[Any], optional): Additional arguments for the normalizer layer. Defaults to None.
            norm_kwargs (dict[str, Any], optional): Additional keyword arguments for the normalizer layer. Defaults to None.
            activation_args (list[Any], optional): Additional arguments for the activation function. Defaults to None.
            activation_kwargs (dict[str, Any], optional): Additional keyword arguments for the activation function. Defaults to None.
            projection_args (list[Any], optional): Additional arguments for the projection layer. Defaults to None.
            projection_kwargs (dict[str, Any], optional): Additional keyword arguments for the projection layer. Defaults to None.

        """
        super().__init__()

        # save parameters
        self.dropout_p = dropout
        self.in_dim = in_dim
        self.out_dim = out_dim

        # initialize layers
        self.dropout = torch.nn.Dropout(p=dropout, inplace=False)

        self.normalizer = normalizer(
            *(norm_args if norm_args is not None else []),
            **(norm_kwargs if norm_kwargs is not None else {}),
        )

        self.activation = activation(
            *(activation_args if activation_args is not None else []),
            **(activation_kwargs if activation_kwargs is not None else {}),
        )

        self.conv = gnn_layer_type(
            in_dim,
            out_dim,
            *(gnn_layer_args if gnn_layer_args is not None else []),
            **(gnn_layer_kwargs if gnn_layer_kwargs is not None else {}),
        )

        if in_dim != out_dim:
            self.projection = torch.nn.Linear(
                *(projection_args if projection_args is not None else []),
                **(projection_kwargs if projection_kwargs is not None else {}),
            )
        else:
            self.projection = torch.nn.Identity()

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, **kwargs
    ) -> torch.Tensor:
        """Forward pass for the GNNBlock.
        First apply the graph convolution layer, then normalize and apply the activation function.
        Finally, apply a residual connection and dropout.
        Args:
            x (torch.Tensor): The input node features.
            edge_index (torch.Tensor): The graph connectivity information.
            edge_weight (torch.Tensor, optional): The edge weights. Defaults to None.
            kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

        Returns:
            torch.Tensor: The output node features.
        """

        # convolution, then normalize and apply nonlinearity
        x_res = self.conv(x, edge_index, **kwargs)
        x_res = self.normalizer(x_res)
        x_res = self.activation(x_res)

        # Residual connection
        x_res = x_res + self.projection(x)

        # Apply dropout as regularization
        x_res = self.dropout(x_res)

        return x_res

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "GNNBlock":
        """Create a GNNBlock from a configuration dictionary.
        When the config does not have 'dropout', it defaults to 0.3.

        Args:
            config (dict[str, Any]): Configuration dictionary containing the parameters for the GNNBlock.

        Returns:
            GNNBlock: An instance of GNNBlock initialized with the provided configuration.
        """
        return cls(
            in_dim=config["in_dim"],
            out_dim=config["out_dim"],
            dropout=config.get("dropout", 0.3),
            gnn_layer_type=utils.gnn_layers[config["gnn_layer_type"]],
            normalizer=utils.normalizer_layers[config["normalizer"]],
            activation=utils.activation_layers[config["activation"]],
            gnn_layer_args=config.get("gnn_layer_args", []),
            gnn_layer_kwargs=config.get("gnn_layer_kwargs", {}),
            norm_args=config.get("norm_args", []),
            norm_kwargs=config.get("norm_kwargs", {}),
            activation_args=config.get("activation_args", []),
            activation_kwargs=config.get("activation_kwargs", {}),
            projection_args=config.get("projection_args", []),
            projection_kwargs=config.get("projection_kwargs", {}),
        )

    def save(self, path: str | Path) -> None:
        """Save the model's state to file.

        Args:
            path (str | Path): path to save the model to.
        """

        torch.save(self, path)

    @classmethod
    def load(
        cls, path: str | Path, device: torch.device = torch.device("cpu")
    ) -> "GNNBlock":
        """Load a mode instance from file

        Args:
            path (str | Path): Path to the file to load.
            device (torch.device): device to put the model to. Defaults to torch.device("cpu")
        Returns:
            GNNBlock: A GNNBlock instance initialized from the data loaded from the file.
        """

        model = torch.load(path, map_location=device, weights_only=False)
        return model

__init__(in_dim, out_dim, dropout=0.3, gnn_layer_type=tgnn.conv.GCNConv, normalizer=torch.nn.Identity, activation=torch.nn.ReLU, gnn_layer_args=None, gnn_layer_kwargs=None, norm_args=None, norm_kwargs=None, activation_args=None, activation_kwargs=None, projection_args=None, projection_kwargs=None)

Create a GNNBlock instance.

Parameters:

Name Type Description Default
in_dim int

The dimensions of the input features.

required
out_dim int

The dimensions of the output features.

required
dropout float

The dropout probability. Defaults to 0.3.

0.3
gnn_layer_type Module

The type of GNN-layer to use. Defaults to tgnn.conv.GCNConv.

GCNConv
normalizer Module

The normalizer layer to use. Defaults to torch.nn.Identity.

Identity
activation Module

The activation function to use. Defaults to torch.nn.ReLU.

ReLU
gnn_layer_args list[Any]

Additional arguments for the GNN layer. Defaults to None.

None
gnn_layer_kwargs dict[str, Any]

Additional keyword arguments for the GNN layer. Defaults to None.

None
norm_args list[Any]

Additional arguments for the normalizer layer. Defaults to None.

None
norm_kwargs dict[str, Any]

Additional keyword arguments for the normalizer layer. Defaults to None.

None
activation_args list[Any]

Additional arguments for the activation function. Defaults to None.

None
activation_kwargs dict[str, Any]

Additional keyword arguments for the activation function. Defaults to None.

None
projection_args list[Any]

Additional arguments for the projection layer. Defaults to None.

None
projection_kwargs dict[str, Any]

Additional keyword arguments for the projection layer. Defaults to None.

None
Source code in src/QuantumGrav/gnn_block.py
def __init__(
    self,
    in_dim: int,
    out_dim: int,
    dropout: float = 0.3,
    gnn_layer_type: torch.nn.Module = tgnn.conv.GCNConv,
    normalizer: torch.nn.Module = torch.nn.Identity,
    activation: torch.nn.Module = torch.nn.ReLU,
    gnn_layer_args: list[Any] | None = None,
    gnn_layer_kwargs: dict[str, Any] | None = None,
    norm_args: list[Any] | None = None,
    norm_kwargs: dict[str, Any] | None = None,
    activation_args: list[Any] | None = None,
    activation_kwargs: dict[str, Any] | None = None,
    projection_args: list[Any] | None = None,
    projection_kwargs: dict[str, Any] | None = None,
):
    """Create a GNNBlock instance.

    Args:
        in_dim (int): The dimensions of the input features.
        out_dim (int): The dimensions of the output features.
        dropout (float, optional): The dropout probability. Defaults to 0.3.
        gnn_layer_type (torch.nn.Module, optional): The type of GNN-layer to use. Defaults to tgnn.conv.GCNConv.
        normalizer (torch.nn.Module, optional): The normalizer layer to use. Defaults to torch.nn.Identity.
        activation (torch.nn.Module, optional): The activation function to use. Defaults to torch.nn.ReLU.
        gnn_layer_args (list[Any], optional): Additional arguments for the GNN layer. Defaults to None.
        gnn_layer_kwargs (dict[str, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.
        norm_args (list[Any], optional): Additional arguments for the normalizer layer. Defaults to None.
        norm_kwargs (dict[str, Any], optional): Additional keyword arguments for the normalizer layer. Defaults to None.
        activation_args (list[Any], optional): Additional arguments for the activation function. Defaults to None.
        activation_kwargs (dict[str, Any], optional): Additional keyword arguments for the activation function. Defaults to None.
        projection_args (list[Any], optional): Additional arguments for the projection layer. Defaults to None.
        projection_kwargs (dict[str, Any], optional): Additional keyword arguments for the projection layer. Defaults to None.

    """
    super().__init__()

    # save parameters
    self.dropout_p = dropout
    self.in_dim = in_dim
    self.out_dim = out_dim

    # initialize layers
    self.dropout = torch.nn.Dropout(p=dropout, inplace=False)

    self.normalizer = normalizer(
        *(norm_args if norm_args is not None else []),
        **(norm_kwargs if norm_kwargs is not None else {}),
    )

    self.activation = activation(
        *(activation_args if activation_args is not None else []),
        **(activation_kwargs if activation_kwargs is not None else {}),
    )

    self.conv = gnn_layer_type(
        in_dim,
        out_dim,
        *(gnn_layer_args if gnn_layer_args is not None else []),
        **(gnn_layer_kwargs if gnn_layer_kwargs is not None else {}),
    )

    if in_dim != out_dim:
        self.projection = torch.nn.Linear(
            *(projection_args if projection_args is not None else []),
            **(projection_kwargs if projection_kwargs is not None else {}),
        )
    else:
        self.projection = torch.nn.Identity()

forward(x, edge_index, **kwargs)

Forward pass for the GNNBlock. First apply the graph convolution layer, then normalize and apply the activation function. Finally, apply a residual connection and dropout. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor): The graph connectivity information. edge_weight (torch.Tensor, optional): The edge weights. Defaults to None. kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

Returns:

Type Description
Tensor

torch.Tensor: The output node features.

Source code in src/QuantumGrav/gnn_block.py
def forward(
    self, x: torch.Tensor, edge_index: torch.Tensor, **kwargs
) -> torch.Tensor:
    """Forward pass for the GNNBlock.
    First apply the graph convolution layer, then normalize and apply the activation function.
    Finally, apply a residual connection and dropout.
    Args:
        x (torch.Tensor): The input node features.
        edge_index (torch.Tensor): The graph connectivity information.
        edge_weight (torch.Tensor, optional): The edge weights. Defaults to None.
        kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

    Returns:
        torch.Tensor: The output node features.
    """

    # convolution, then normalize and apply nonlinearity
    x_res = self.conv(x, edge_index, **kwargs)
    x_res = self.normalizer(x_res)
    x_res = self.activation(x_res)

    # Residual connection
    x_res = x_res + self.projection(x)

    # Apply dropout as regularization
    x_res = self.dropout(x_res)

    return x_res

from_config(config) classmethod

Create a GNNBlock from a configuration dictionary. When the config does not have 'dropout', it defaults to 0.3.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing the parameters for the GNNBlock.

required

Returns:

Name Type Description
GNNBlock GNNBlock

An instance of GNNBlock initialized with the provided configuration.

Source code in src/QuantumGrav/gnn_block.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GNNBlock":
    """Create a GNNBlock from a configuration dictionary.
    When the config does not have 'dropout', it defaults to 0.3.

    Args:
        config (dict[str, Any]): Configuration dictionary containing the parameters for the GNNBlock.

    Returns:
        GNNBlock: An instance of GNNBlock initialized with the provided configuration.
    """
    return cls(
        in_dim=config["in_dim"],
        out_dim=config["out_dim"],
        dropout=config.get("dropout", 0.3),
        gnn_layer_type=utils.gnn_layers[config["gnn_layer_type"]],
        normalizer=utils.normalizer_layers[config["normalizer"]],
        activation=utils.activation_layers[config["activation"]],
        gnn_layer_args=config.get("gnn_layer_args", []),
        gnn_layer_kwargs=config.get("gnn_layer_kwargs", {}),
        norm_args=config.get("norm_args", []),
        norm_kwargs=config.get("norm_kwargs", {}),
        activation_args=config.get("activation_args", []),
        activation_kwargs=config.get("activation_kwargs", {}),
        projection_args=config.get("projection_args", []),
        projection_kwargs=config.get("projection_kwargs", {}),
    )

load(path, device=torch.device('cpu')) classmethod

Load a mode instance from file

Parameters:

Name Type Description Default
path str | Path

Path to the file to load.

required
device device

device to put the model to. Defaults to torch.device("cpu")

device('cpu')

Returns: GNNBlock: A GNNBlock instance initialized from the data loaded from the file.

Source code in src/QuantumGrav/gnn_block.py
@classmethod
def load(
    cls, path: str | Path, device: torch.device = torch.device("cpu")
) -> "GNNBlock":
    """Load a mode instance from file

    Args:
        path (str | Path): Path to the file to load.
        device (torch.device): device to put the model to. Defaults to torch.device("cpu")
    Returns:
        GNNBlock: A GNNBlock instance initialized from the data loaded from the file.
    """

    model = torch.load(path, map_location=device, weights_only=False)
    return model

save(path)

Save the model's state to file.

Parameters:

Name Type Description Default
path str | Path

path to save the model to.

required
Source code in src/QuantumGrav/gnn_block.py
def save(self, path: str | Path) -> None:
    """Save the model's state to file.

    Args:
        path (str | Path): path to save the model to.
    """

    torch.save(self, path)

Model evaluation

This module provides base classes that take the output of applying the model to a validation or training dataset, and derive useful quantities to evaluate the model quality. These do not do anything useful by default. Rather, you must derive your own class from them that implemements your desired evaluation, e.g., using an F1 score.

DefaultEarlyStopping

Early stopping based on a validation metric.

Source code in src/QuantumGrav/evaluate.py
class DefaultEarlyStopping:
    """Early stopping based on a validation metric."""

    def __init__(
        self,
        patience: int,
        delta: float = 1e-4,
        window=7,
    ):
        """Early stopping initialization.

        Args:
            patience (int): Number of epochs with no improvement after which training will be stopped.
            delta (float, optional): Minimum change to consider an improvement. Defaults to 1e-4.
            window (int, optional): Size of the moving window for smoothing. Defaults to 7.
        """
        self.patience = patience
        self.current_patience = patience
        self.delta = delta
        self.best_score = np.inf
        self.window = window
        self.found_better = False
        self.logger = logging.getLogger(__name__)

    def __call__(self, data: Iterable | pd.DataFrame | pd.Series) -> bool:
        """Check if early stopping criteria are met.

        Args:
            data: Iterable of validation metrics, e.g., list of scalars, list of tuples, Dataframe, numpy array...

        Returns:
            bool: True if training should be stopped, False otherwise.
        """
        window = min(self.window, len(data))
        smoothed = pd.Series(data).rolling(window=window, min_periods=1).mean()
        if smoothed.iloc[-1] < self.best_score - self.delta:
            self.logger.info(
                f"Early stopping patience reset: {self.current_patience} -> {self.patience}, early stopping best score updated: {self.best_score} -> {smoothed.iloc[-1]}"
            )
            self.best_score = smoothed.iloc[-1]
            self.current_patience = self.patience
            self.found_better = True
        else:
            self.logger.info(
                f"Early stopping patience decreased: {self.current_patience} -> {self.current_patience - 1}"
            )
            self.current_patience -= 1
            self.found_better = False

        return self.current_patience <= 0

__call__(data)

Check if early stopping criteria are met.

Parameters:

Name Type Description Default
data Iterable | DataFrame | Series

Iterable of validation metrics, e.g., list of scalars, list of tuples, Dataframe, numpy array...

required

Returns:

Name Type Description
bool bool

True if training should be stopped, False otherwise.

Source code in src/QuantumGrav/evaluate.py
def __call__(self, data: Iterable | pd.DataFrame | pd.Series) -> bool:
    """Check if early stopping criteria are met.

    Args:
        data: Iterable of validation metrics, e.g., list of scalars, list of tuples, Dataframe, numpy array...

    Returns:
        bool: True if training should be stopped, False otherwise.
    """
    window = min(self.window, len(data))
    smoothed = pd.Series(data).rolling(window=window, min_periods=1).mean()
    if smoothed.iloc[-1] < self.best_score - self.delta:
        self.logger.info(
            f"Early stopping patience reset: {self.current_patience} -> {self.patience}, early stopping best score updated: {self.best_score} -> {smoothed.iloc[-1]}"
        )
        self.best_score = smoothed.iloc[-1]
        self.current_patience = self.patience
        self.found_better = True
    else:
        self.logger.info(
            f"Early stopping patience decreased: {self.current_patience} -> {self.current_patience - 1}"
        )
        self.current_patience -= 1
        self.found_better = False

    return self.current_patience <= 0

__init__(patience, delta=0.0001, window=7)

Early stopping initialization.

Parameters:

Name Type Description Default
patience int

Number of epochs with no improvement after which training will be stopped.

required
delta float

Minimum change to consider an improvement. Defaults to 1e-4.

0.0001
window int

Size of the moving window for smoothing. Defaults to 7.

7
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self,
    patience: int,
    delta: float = 1e-4,
    window=7,
):
    """Early stopping initialization.

    Args:
        patience (int): Number of epochs with no improvement after which training will be stopped.
        delta (float, optional): Minimum change to consider an improvement. Defaults to 1e-4.
        window (int, optional): Size of the moving window for smoothing. Defaults to 7.
    """
    self.patience = patience
    self.current_patience = patience
    self.delta = delta
    self.best_score = np.inf
    self.window = window
    self.found_better = False
    self.logger = logging.getLogger(__name__)

DefaultEvaluator

Source code in src/QuantumGrav/evaluate.py
class DefaultEvaluator:
    def __init__(
        self, device, criterion: Callable, apply_model: Callable | None = None
    ):
        """Default evaluator for model evaluation.

        Args:
            device (_type_): The device to run the evaluation on.
            criterion (Callable): The loss function to use for evaluation.
            apply_model (Callable): A function to apply the model to the data.
        """
        self.criterion = criterion
        self.apply_model = apply_model
        self.device = device
        self.data = []
        self.logger = logging.getLogger(__name__)

    def evaluate(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ) -> Any:
        """Evaluate the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to evaluate.
            data_loader (torch_geometric.loader.DataLoader): Data loader for evaluation.

        Returns:
             list[Any]: A list of evaluation results.
        """
        model.eval()
        current_data = []

        with torch.no_grad():
            for i, batch in enumerate(data_loader):
                data = batch.to(self.device)
                if self.apply_model:
                    outputs = self.apply_model(model, data)
                else:
                    outputs = model(data.x, data.edge_index, data.batch)
                loss = self.criterion(outputs, data)
                current_data.append(loss)

        return current_data

    def report(self, data: list | pd.Series | torch.Tensor | np.ndarray) -> None:
        """Report the evaluation results to stdout"""

        if isinstance(data, torch.Tensor):
            data = data.cpu().numpy()

        if isinstance(data, list):
            for i, d in enumerate(data):
                if isinstance(d, torch.Tensor):
                    data[i] = d.cpu().numpy()

        avg = np.mean(data)
        sigma = np.std(data)
        self.logger.info(f"Average loss: {avg}, Standard deviation: {sigma}")
        self.data.append((avg, sigma))

__init__(device, criterion, apply_model=None)

Default evaluator for model evaluation.

Parameters:

Name Type Description Default
device _type_

The device to run the evaluation on.

required
criterion Callable

The loss function to use for evaluation.

required
apply_model Callable

A function to apply the model to the data.

None
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self, device, criterion: Callable, apply_model: Callable | None = None
):
    """Default evaluator for model evaluation.

    Args:
        device (_type_): The device to run the evaluation on.
        criterion (Callable): The loss function to use for evaluation.
        apply_model (Callable): A function to apply the model to the data.
    """
    self.criterion = criterion
    self.apply_model = apply_model
    self.device = device
    self.data = []
    self.logger = logging.getLogger(__name__)

evaluate(model, data_loader)

Evaluate the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to evaluate.

required
data_loader DataLoader

Data loader for evaluation.

required

Returns:

Type Description
Any

list[Any]: A list of evaluation results.

Source code in src/QuantumGrav/evaluate.py
def evaluate(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
) -> Any:
    """Evaluate the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to evaluate.
        data_loader (torch_geometric.loader.DataLoader): Data loader for evaluation.

    Returns:
         list[Any]: A list of evaluation results.
    """
    model.eval()
    current_data = []

    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            data = batch.to(self.device)
            if self.apply_model:
                outputs = self.apply_model(model, data)
            else:
                outputs = model(data.x, data.edge_index, data.batch)
            loss = self.criterion(outputs, data)
            current_data.append(loss)

    return current_data

report(data)

Report the evaluation results to stdout

Source code in src/QuantumGrav/evaluate.py
def report(self, data: list | pd.Series | torch.Tensor | np.ndarray) -> None:
    """Report the evaluation results to stdout"""

    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()

    if isinstance(data, list):
        for i, d in enumerate(data):
            if isinstance(d, torch.Tensor):
                data[i] = d.cpu().numpy()

    avg = np.mean(data)
    sigma = np.std(data)
    self.logger.info(f"Average loss: {avg}, Standard deviation: {sigma}")
    self.data.append((avg, sigma))

DefaultTester

Bases: DefaultEvaluator

Source code in src/QuantumGrav/evaluate.py
class DefaultTester(DefaultEvaluator):
    def __init__(
        self, device, criterion: Callable, apply_model: Callable | None = None
    ):
        """Default tester for model testing.

        Args:
            device (_type_): The device to run the testing on.
            criterion (Callable): The loss function to use for testing.
            apply_model (Callable): A function to apply the model to the data.
        """
        super().__init__(device, criterion, apply_model)

    def test(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ):
        """Test the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to test.
            data_loader (torch_geometric.loader.DataLoader): Data loader for testing.

        Returns:
            list[Any]: A list of testing results.
        """
        return self.evaluate(model, data_loader)

__init__(device, criterion, apply_model=None)

Default tester for model testing.

Parameters:

Name Type Description Default
device _type_

The device to run the testing on.

required
criterion Callable

The loss function to use for testing.

required
apply_model Callable

A function to apply the model to the data.

None
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self, device, criterion: Callable, apply_model: Callable | None = None
):
    """Default tester for model testing.

    Args:
        device (_type_): The device to run the testing on.
        criterion (Callable): The loss function to use for testing.
        apply_model (Callable): A function to apply the model to the data.
    """
    super().__init__(device, criterion, apply_model)

test(model, data_loader)

Test the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to test.

required
data_loader DataLoader

Data loader for testing.

required

Returns:

Type Description

list[Any]: A list of testing results.

Source code in src/QuantumGrav/evaluate.py
def test(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
):
    """Test the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to test.
        data_loader (torch_geometric.loader.DataLoader): Data loader for testing.

    Returns:
        list[Any]: A list of testing results.
    """
    return self.evaluate(model, data_loader)

DefaultValidator

Bases: DefaultEvaluator

Source code in src/QuantumGrav/evaluate.py
class DefaultValidator(DefaultEvaluator):
    def __init__(
        self, device, criterion: Callable, apply_model: Callable | None = None
    ):
        super().__init__(device, criterion, apply_model)

    def validate(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ):
        """Validate the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to validate.
            data_loader (torch_geometric.loader.DataLoader): Data loader for validation.
        Returns:
            list[Any]: A list of validation results.
        """
        return self.evaluate(model, data_loader)

validate(model, data_loader)

Validate the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to validate.

required
data_loader DataLoader

Data loader for validation.

required

Returns: list[Any]: A list of validation results.

Source code in src/QuantumGrav/evaluate.py
def validate(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
):
    """Validate the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to validate.
        data_loader (torch_geometric.loader.DataLoader): Data loader for validation.
    Returns:
        list[Any]: A list of validation results.
    """
    return self.evaluate(model, data_loader)

Datasets

The package supports three kinds of datasets with a common baseclass QGDatasetBase. For the basics of how those work, check out the pytorch-geometric documentation of dataset

These are: - QGDataset: A dataset that relies on an on-disk storage of the processed data. It lazily loads csets from disk when needed. - QGDatasetInMemory: A dataset that holds the entire processed dataset in memory at once. - QGDatasetOnthefly: This dataset does not hold anything on disk or in memory, but creates the data on demand from some supplied Julia code.

Dataset base class

QGDatasetBase

Mixin class that provides common functionality for the dataset classes. Works only for file-based datasets. Provides methods for processing data.

Source code in src/QuantumGrav/dataset_base.py
class QGDatasetBase:
    """Mixin class that provides common functionality for the dataset classes. Works only for file-based datasets. Provides methods for processing data."""

    def __init__(
        self,
        input: list[str | Path],
        output: str | Path,
        mode: str = "hdf5",
        reader: Callable[
            [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool],
            list[Data],
        ]
        | None = None,
        float_type: torch.dtype = torch.float32,
        int_type: torch.dtype = torch.int64,
        validate_data: bool = True,
        n_processes: int = 1,
        chunksize: int = 1000,
        **kwargs,
    ):
        """Initialize a DatasetMixin instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets. It provides a common interface for both in-memory and on-disk datasets. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

        Args:
            input (list[str  |  Path] : The list of input files for the dataset, or a callable that generates a set of input files.
            output (str | Path): The output directory where processed data will be stored.
            mode (str): File storage mode. 'zarr' or 'hdf5'
            reader (Callable[[h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to load data from a file. Defaults to None.
            float_type (torch.dtype, optional): The data type to use for floating point values. Defaults to torch.float32.
            int_type (torch.dtype, optional): The data type to use for integer values. Defaults to torch.int64.
            validate_data (bool, optional): Whether to validate the data after loading. Defaults to True.
            n_processes (int, optional): The number of processes to use for parallel processing of read data. Defaults to 1.
            chunksize (int, optional): The size of the chunks to process in parallel. Defaults to 1000.

        Raises:
            ValueError: If one of the input data files is not a valid HDF5 file
            ValueError: If the metadata retrieval function is invalid.
            FileNotFoundError: If an input file does not exist.
        """
        if reader is None:
            raise ValueError("A reader function must be provided to load the data.")

        if mode not in ["hdf5", "zarr"]:
            raise ValueError("mode must be 'hdf5' or 'zarr'")

        self.mode = mode
        self.input = input
        for file in self.input:
            if Path(file).exists() is False:
                raise FileNotFoundError(f"Input file {file} does not exist.")

        self.output = output
        self.data_reader = reader
        self.metadata = {}
        self.float_type = float_type
        self.int_type = int_type
        self.validate_data = validate_data
        self.n_processes = n_processes
        self.chunksize = chunksize

        # get the number of samples in the dataset
        self._num_samples = 0

        for filepath in self.input:
            if not Path(filepath).exists():
                raise FileNotFoundError(f"Input file {filepath} does not exist.")
            self._num_samples += self._get_num_samples_per_file(filepath)

        # ensure the input is a list of paths
        if Path(self.processed_dir).exists():
            with open(Path(self.processed_dir) / "metadata.yaml", "r") as f:
                self.metadata = yaml.load(f, Loader=yaml.FullLoader)
        else:
            Path(self.processed_dir).mkdir(parents=True, exist_ok=True)
            self.metadata = {
                "num_samples": int(self._num_samples),
                "input": [str(Path(f).resolve().absolute()) for f in self.input],
                "output": str(Path(self.output).resolve().absolute()),
                "float_type": str(self.float_type),
                "int_type": str(self.int_type),
                "validate_data": self.validate_data,
                "n_processes": self.n_processes,
                "chunksize": self.chunksize,
            }

            with open(Path(self.processed_dir) / "metadata.yaml", "w") as f:
                yaml.dump(self.metadata, f)

    def _get_num_samples_per_file(self, filepath: str | Path) -> int:
        """Get the number of samples in a given file.

        Args:
            filepath (str | Path): The path to the file.

        Raises:
            ValueError: If the file is not a valid HDF5 or Zarr file.

        Returns:
            int: The number of samples in the file.
        """

        # try to find the sample number from a dedicated dataset
        def try_find_numsamples(f):
            s = None
            for name in ["num_causal_sets", "num_samples"]:
                if name in f:
                    s = f[name]
                    break
            return s

        # ... if that fails, we try to read it from any scalar dataset.
        # ... if we can´t because they are of unequal sizes, we return None
        # ... to indicate an unresolvable state
        def fallback(f) -> int | None:
            # find scalar datasets and use their sizes to determine size
            shapes = [f[k].shape[0] for k in f.keys() if len(f[k].shape) == 1]
            max_shape = max(shapes)
            min_shape = min(shapes)
            if max_shape != min_shape:
                return None
            else:
                return max_shape

        # same logic for Zarr and HDF5
        if self.mode == "hdf5":
            with h5py.File(filepath, "r") as f:
                try:
                    # note that fallback returns an int directly,
                    # while for try_find_numsamples we need to index into the result
                    s = try_find_numsamples(f)
                    if s is not None:
                        return s[()]
                    else:
                        s = fallback(f)
                        if s is not None:
                            return s
                        else:
                            raise RuntimeError("Unable to determine number of samples.")
                except Exception:
                    raise
        elif self.mode == "zarr":
            try:
                group = zarr.open_group(
                    zarr.storage.LocalStore(filepath, read_only=True),
                    path="",
                    mode="r",
                )
                # note that fallback returns an int directly,
                # while for try_find_numsamples we need to index into the result
                s = try_find_numsamples(group)
                if s is not None:
                    return s[0]
                else:
                    s = fallback(group)
                    if s is not None:
                        return s
                    else:
                        raise RuntimeError("Unable to determine number of samples.")
            except Exception:
                # we need an extra fallback for zarr b/c Julia Zarr and python Zarr
                # can differ in layout - Julia Zarr does not have to have a group
                try:
                    store = zarr.storage.LocalStore(filepath, read_only=True)
                    arr = zarr.open_array(store, path="adjacency_matrix")
                    s = max(arr.shape)
                    return s
                except Exception:
                    raise
        else:
            raise ValueError("mode must be 'hdf5' or 'zarr'")

    @property
    def processed_dir(self) -> str | None:
        """Get the path to the processed directory.

        Returns:
            str: The path to the processed directory, or None if it doesn't exist.
        """
        processed_path = Path(self.output).resolve().absolute() / "processed"
        return str(processed_path)

    @property
    def raw_file_names(self) -> list[str]:
        """Get the raw file paths from the input list.

        Returns:
            list[str]: A list of raw file paths.
        """
        if self.mode == "zarr":
            suf = ".zarr"
        else:
            suf = ".h5"

        return [str(Path(f).name) for f in self.input if Path(f).suffix == suf]

    @property
    def processed_file_names(self) -> list[str]:
        """Get a list of processed files in the processed directory.

        Returns:
            list[str]: A list of processed file paths, excluding JSON files.
        """

        if not Path(self.processed_dir).exists():
            return []
        return [
            str(f.name)
            for f in Path(self.processed_dir).iterdir()
            if f.is_file() and f.suffix == ".pt" and "data" in f.name
        ]

    def process_chunk_hdf5(
        self,
        raw_file: h5py.File,
        start: int,
        pre_transform: Callable[[Data | Collection], Data] | None = None,
        pre_filter: Callable[[Data | Collection], bool] | None = None,
    ) -> list[Data]:
        """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

        Args:
            raw_file (h5py.File): The raw HDF5 file to read from.
            start (int): The starting index of the chunk.
            pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

        Returns:
            list[Data]: The processed data or None if the chunk is empty.
        """

        # we can't rely on being able to read from the raw_files in parallel, so we need to read the data sequentially first
        data = [
            self.data_reader(
                raw_file,
                i,
                self.float_type,
                self.int_type,
                self.validate_data,
            )
            for i in range(
                start, min(start + self.chunksize, raw_file["num_causal_sets"][()])
            )
        ]

        def process_item(item):
            if pre_filter is not None and not pre_filter(item):
                return None
            if pre_transform is not None:
                return pre_transform(item)
            return item

        results = []
        if self.n_processes > 1:
            results = Parallel(n_jobs=self.n_processes)(
                delayed(process_item)(datapoint) for datapoint in data
            )
        else:
            results = [process_item(datapoint) for datapoint in data]
        return [res for res in results if res is not None]

    def process_chunk_zarr(
        self,
        store: zarr.storage.LocalStore,
        start: int,
        pre_transform: Callable[[Data | Collection], Data] | None = None,
        pre_filter: Callable[[Data | Collection], bool] | None = None,
    ) -> list[Data]:
        """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

        Args:
            store (zarr.storage.LocalStore): local zarr storage
            start (int): start index
            pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

        Returns:
            list[Data]: The processed data or None if the chunk is empty.
        """
        N = self._get_num_samples_per_file(store.root)
        rootgroup = zarr.open_group(store.root)

        def process_item(i: int):
            item = self.data_reader(
                rootgroup,
                i,
                self.float_type,
                self.int_type,
                self.validate_data,
            )
            if pre_filter is not None and not pre_filter(item):
                return None
            if pre_transform is not None:
                return pre_transform(item)
            return item

        if self.n_processes > 1:
            results = Parallel(n_jobs=self.n_processes)(
                delayed(process_item)(i)
                for i in range(start, min(start + self.chunksize, N))
            )
        else:
            results = [
                process_item(i) for i in range(start, min(start + self.chunksize, N))
            ]

        return [res for res in results if res is not None]

    def process_chunk(
        self,
        store: zarr.storage.LocalStore | h5py.File,
        start: int,
        pre_transform: Callable[[Data], Data] | None = None,
        pre_filter: Callable[[Data], bool] | None = None,
    ):
        if self.mode == "hdf5":
            return self.process_chunk_hdf5(store, start, pre_transform, pre_filter)
        else:
            return self.process_chunk_zarr(store, start, pre_transform, pre_filter)

processed_dir property

Get the path to the processed directory.

Returns:

Name Type Description
str str | None

The path to the processed directory, or None if it doesn't exist.

processed_file_names property

Get a list of processed files in the processed directory.

Returns:

Type Description
list[str]

list[str]: A list of processed file paths, excluding JSON files.

raw_file_names property

Get the raw file paths from the input list.

Returns:

Type Description
list[str]

list[str]: A list of raw file paths.

__init__(input, output, mode='hdf5', reader=None, float_type=torch.float32, int_type=torch.int64, validate_data=True, n_processes=1, chunksize=1000, **kwargs)

Initialize a DatasetMixin instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets. It provides a common interface for both in-memory and on-disk datasets. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

Parameters:

Name Type Description Default
input (list[str | Path]

The list of input files for the dataset, or a callable that generates a set of input files.

required
output str | Path

The output directory where processed data will be stored.

required
mode str

File storage mode. 'zarr' or 'hdf5'

'hdf5'
reader Callable[[File | Group, dtype, dtype, bool], list[Data]] | None

A function to load data from a file. Defaults to None.

None
float_type dtype

The data type to use for floating point values. Defaults to torch.float32.

float32
int_type dtype

The data type to use for integer values. Defaults to torch.int64.

int64
validate_data bool

Whether to validate the data after loading. Defaults to True.

True
n_processes int

The number of processes to use for parallel processing of read data. Defaults to 1.

1
chunksize int

The size of the chunks to process in parallel. Defaults to 1000.

1000

Raises:

Type Description
ValueError

If one of the input data files is not a valid HDF5 file

ValueError

If the metadata retrieval function is invalid.

FileNotFoundError

If an input file does not exist.

Source code in src/QuantumGrav/dataset_base.py
def __init__(
    self,
    input: list[str | Path],
    output: str | Path,
    mode: str = "hdf5",
    reader: Callable[
        [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool],
        list[Data],
    ]
    | None = None,
    float_type: torch.dtype = torch.float32,
    int_type: torch.dtype = torch.int64,
    validate_data: bool = True,
    n_processes: int = 1,
    chunksize: int = 1000,
    **kwargs,
):
    """Initialize a DatasetMixin instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets. It provides a common interface for both in-memory and on-disk datasets. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

    Args:
        input (list[str  |  Path] : The list of input files for the dataset, or a callable that generates a set of input files.
        output (str | Path): The output directory where processed data will be stored.
        mode (str): File storage mode. 'zarr' or 'hdf5'
        reader (Callable[[h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to load data from a file. Defaults to None.
        float_type (torch.dtype, optional): The data type to use for floating point values. Defaults to torch.float32.
        int_type (torch.dtype, optional): The data type to use for integer values. Defaults to torch.int64.
        validate_data (bool, optional): Whether to validate the data after loading. Defaults to True.
        n_processes (int, optional): The number of processes to use for parallel processing of read data. Defaults to 1.
        chunksize (int, optional): The size of the chunks to process in parallel. Defaults to 1000.

    Raises:
        ValueError: If one of the input data files is not a valid HDF5 file
        ValueError: If the metadata retrieval function is invalid.
        FileNotFoundError: If an input file does not exist.
    """
    if reader is None:
        raise ValueError("A reader function must be provided to load the data.")

    if mode not in ["hdf5", "zarr"]:
        raise ValueError("mode must be 'hdf5' or 'zarr'")

    self.mode = mode
    self.input = input
    for file in self.input:
        if Path(file).exists() is False:
            raise FileNotFoundError(f"Input file {file} does not exist.")

    self.output = output
    self.data_reader = reader
    self.metadata = {}
    self.float_type = float_type
    self.int_type = int_type
    self.validate_data = validate_data
    self.n_processes = n_processes
    self.chunksize = chunksize

    # get the number of samples in the dataset
    self._num_samples = 0

    for filepath in self.input:
        if not Path(filepath).exists():
            raise FileNotFoundError(f"Input file {filepath} does not exist.")
        self._num_samples += self._get_num_samples_per_file(filepath)

    # ensure the input is a list of paths
    if Path(self.processed_dir).exists():
        with open(Path(self.processed_dir) / "metadata.yaml", "r") as f:
            self.metadata = yaml.load(f, Loader=yaml.FullLoader)
    else:
        Path(self.processed_dir).mkdir(parents=True, exist_ok=True)
        self.metadata = {
            "num_samples": int(self._num_samples),
            "input": [str(Path(f).resolve().absolute()) for f in self.input],
            "output": str(Path(self.output).resolve().absolute()),
            "float_type": str(self.float_type),
            "int_type": str(self.int_type),
            "validate_data": self.validate_data,
            "n_processes": self.n_processes,
            "chunksize": self.chunksize,
        }

        with open(Path(self.processed_dir) / "metadata.yaml", "w") as f:
            yaml.dump(self.metadata, f)

process_chunk_hdf5(raw_file, start, pre_transform=None, pre_filter=None)

Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

Parameters:

Name Type Description Default
raw_file File

The raw HDF5 file to read from.

required
start int

The starting index of the chunk.

required
pre_transform Callable[[Data], Data] | None

Transformation that adds additional features to the data. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

A function that filters the data. Defaults to None.

None

Returns:

Type Description
list[Data]

list[Data]: The processed data or None if the chunk is empty.

Source code in src/QuantumGrav/dataset_base.py
def process_chunk_hdf5(
    self,
    raw_file: h5py.File,
    start: int,
    pre_transform: Callable[[Data | Collection], Data] | None = None,
    pre_filter: Callable[[Data | Collection], bool] | None = None,
) -> list[Data]:
    """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

    Args:
        raw_file (h5py.File): The raw HDF5 file to read from.
        start (int): The starting index of the chunk.
        pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

    Returns:
        list[Data]: The processed data or None if the chunk is empty.
    """

    # we can't rely on being able to read from the raw_files in parallel, so we need to read the data sequentially first
    data = [
        self.data_reader(
            raw_file,
            i,
            self.float_type,
            self.int_type,
            self.validate_data,
        )
        for i in range(
            start, min(start + self.chunksize, raw_file["num_causal_sets"][()])
        )
    ]

    def process_item(item):
        if pre_filter is not None and not pre_filter(item):
            return None
        if pre_transform is not None:
            return pre_transform(item)
        return item

    results = []
    if self.n_processes > 1:
        results = Parallel(n_jobs=self.n_processes)(
            delayed(process_item)(datapoint) for datapoint in data
        )
    else:
        results = [process_item(datapoint) for datapoint in data]
    return [res for res in results if res is not None]

process_chunk_zarr(store, start, pre_transform=None, pre_filter=None)

Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

Parameters:

Name Type Description Default
store LocalStore

local zarr storage

required
start int

start index

required
pre_transform Callable[[Data], Data] | None

Transformation that adds additional features to the data. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

A function that filters the data. Defaults to None.

None

Returns:

Type Description
list[Data]

list[Data]: The processed data or None if the chunk is empty.

Source code in src/QuantumGrav/dataset_base.py
def process_chunk_zarr(
    self,
    store: zarr.storage.LocalStore,
    start: int,
    pre_transform: Callable[[Data | Collection], Data] | None = None,
    pre_filter: Callable[[Data | Collection], bool] | None = None,
) -> list[Data]:
    """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

    Args:
        store (zarr.storage.LocalStore): local zarr storage
        start (int): start index
        pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

    Returns:
        list[Data]: The processed data or None if the chunk is empty.
    """
    N = self._get_num_samples_per_file(store.root)
    rootgroup = zarr.open_group(store.root)

    def process_item(i: int):
        item = self.data_reader(
            rootgroup,
            i,
            self.float_type,
            self.int_type,
            self.validate_data,
        )
        if pre_filter is not None and not pre_filter(item):
            return None
        if pre_transform is not None:
            return pre_transform(item)
        return item

    if self.n_processes > 1:
        results = Parallel(n_jobs=self.n_processes)(
            delayed(process_item)(i)
            for i in range(start, min(start + self.chunksize, N))
        )
    else:
        results = [
            process_item(i) for i in range(start, min(start + self.chunksize, N))
        ]

    return [res for res in results if res is not None]

Dataset holding everything in memory

QGDatasetInMemory

Bases: QGDatasetBase, InMemoryDataset

A dataset class for QuantumGrav data that can be loaded into memory.

Source code in src/QuantumGrav/dataset_inmemory.py
class QGDatasetInMemory(QGDatasetBase, InMemoryDataset):
    """A dataset class for QuantumGrav data that can be loaded into memory."""

    def __init__(
        self,
        input: list[str | Path],
        output: str | Path,
        mode: str = "hdf5",
        reader: Callable[
            [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]
        ]
        | None = None,
        float_type: torch.dtype = torch.float32,
        int_type: torch.dtype = torch.int64,
        validate_data: bool = True,
        chunksize: int = 1000,
        n_processes: int = 1,
        # dataset properties
        transform: Callable[[Data | Collection], Data] | None = None,
        pre_transform: Callable[[Data | Collection], Data] | None = None,
        pre_filter: Callable[[Data | Collection], bool] | None = None,
    ):
        """Initialize a QGDatasetInMemory instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that can be loaded into memory completely.

        Args:
            input (list[str  |  Path]): A list of file paths (as strings or Path objects) to the input data files.
            output (str | Path): A file path (as a string or Path object) to the output data file.
            mode (str): File storage mode. 'zarr' or 'hdf5'
            reader (Callable[[h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to read the data from the input files. Defaults to None.
            float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
            int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
            validate_data (bool, optional): Whether to validate the data. Defaults to True.
            chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
            n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
            transform (Callable[[Data], Data] | None, optional): Function to transform the data each time the data is loaded. Defaults to None.
            pre_transform (Callable[[Data], Data] | None, optional): Function to transform the read data once and store the results on the disk. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data once and store the results on the disk. Defaults to None.
        """
        QGDatasetBase.__init__(
            self,
            input,
            output,
            mode,
            reader=reader,
            float_type=float_type,
            int_type=int_type,
            validate_data=validate_data,
            chunksize=chunksize,
            n_processes=n_processes,
        )

        InMemoryDataset.__init__(
            self,
            root=output,
            transform=transform,
            pre_transform=pre_transform,
            pre_filter=pre_filter,
        )

        self.load(str(Path(self.processed_dir) / "data.pt"))

    def process(self) -> None:
        """Process the dataset from the read rawdata into its final form."""

        data_list = []

        for file in self.input:
            if self.mode == "hdf5":
                raw_file = h5py.File(str(Path(file).resolve().absolute()), "r")
                num_chunks = raw_file["num_causal_sets"][()] // self.chunksize

            else:
                raw_file = zarr.storage.LocalStore(
                    str(Path(file).resolve().absolute()), read_only=True
                )
                N = self._get_num_samples_per_file(Path(file).resolve().absolute())
                num_chunks = N // self.chunksize

            # read the data in chunks and process it parallelized or
            # sequentially based on the parallel_processing flag

            for i in range(0, num_chunks * self.chunksize, self.chunksize):
                data = self.process_chunk(
                    raw_file,
                    i,
                    pre_transform=self.pre_transform,
                    pre_filter=self.pre_filter,
                )

                data_list.extend(data)

            # final chunk processing
            data = self.process_chunk(
                raw_file,
                num_chunks * self.chunksize,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            data_list.extend(data)

            raw_file.close()

        InMemoryDataset.save(data_list, Path(self.processed_dir) / "data.pt")

__init__(input, output, mode='hdf5', reader=None, float_type=torch.float32, int_type=torch.int64, validate_data=True, chunksize=1000, n_processes=1, transform=None, pre_transform=None, pre_filter=None)

Initialize a QGDatasetInMemory instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that can be loaded into memory completely.

Parameters:

Name Type Description Default
input list[str | Path]

A list of file paths (as strings or Path objects) to the input data files.

required
output str | Path

A file path (as a string or Path object) to the output data file.

required
mode str

File storage mode. 'zarr' or 'hdf5'

'hdf5'
reader Callable[[File | Group, dtype, dtype, bool], list[Data]] | None

A function to read the data from the input files. Defaults to None.

None
float_type dtype

Data type for float tensors. Defaults to torch.float32.

float32
int_type dtype

Data type for int tensors. Defaults to torch.int64.

int64
validate_data bool

Whether to validate the data. Defaults to True.

True
chunksize int

Size of data chunks to process at once. Defaults to 1000.

1000
n_processes int

Number of processes to use for data loading. Defaults to 1.

1
transform Callable[[Data], Data] | None

Function to transform the data each time the data is loaded. Defaults to None.

None
pre_transform Callable[[Data], Data] | None

Function to transform the read data once and store the results on the disk. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

Function to pre-filter the data once and store the results on the disk. Defaults to None.

None
Source code in src/QuantumGrav/dataset_inmemory.py
def __init__(
    self,
    input: list[str | Path],
    output: str | Path,
    mode: str = "hdf5",
    reader: Callable[
        [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]
    ]
    | None = None,
    float_type: torch.dtype = torch.float32,
    int_type: torch.dtype = torch.int64,
    validate_data: bool = True,
    chunksize: int = 1000,
    n_processes: int = 1,
    # dataset properties
    transform: Callable[[Data | Collection], Data] | None = None,
    pre_transform: Callable[[Data | Collection], Data] | None = None,
    pre_filter: Callable[[Data | Collection], bool] | None = None,
):
    """Initialize a QGDatasetInMemory instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that can be loaded into memory completely.

    Args:
        input (list[str  |  Path]): A list of file paths (as strings or Path objects) to the input data files.
        output (str | Path): A file path (as a string or Path object) to the output data file.
        mode (str): File storage mode. 'zarr' or 'hdf5'
        reader (Callable[[h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to read the data from the input files. Defaults to None.
        float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
        int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
        validate_data (bool, optional): Whether to validate the data. Defaults to True.
        chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
        n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
        transform (Callable[[Data], Data] | None, optional): Function to transform the data each time the data is loaded. Defaults to None.
        pre_transform (Callable[[Data], Data] | None, optional): Function to transform the read data once and store the results on the disk. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data once and store the results on the disk. Defaults to None.
    """
    QGDatasetBase.__init__(
        self,
        input,
        output,
        mode,
        reader=reader,
        float_type=float_type,
        int_type=int_type,
        validate_data=validate_data,
        chunksize=chunksize,
        n_processes=n_processes,
    )

    InMemoryDataset.__init__(
        self,
        root=output,
        transform=transform,
        pre_transform=pre_transform,
        pre_filter=pre_filter,
    )

    self.load(str(Path(self.processed_dir) / "data.pt"))

process()

Process the dataset from the read rawdata into its final form.

Source code in src/QuantumGrav/dataset_inmemory.py
def process(self) -> None:
    """Process the dataset from the read rawdata into its final form."""

    data_list = []

    for file in self.input:
        if self.mode == "hdf5":
            raw_file = h5py.File(str(Path(file).resolve().absolute()), "r")
            num_chunks = raw_file["num_causal_sets"][()] // self.chunksize

        else:
            raw_file = zarr.storage.LocalStore(
                str(Path(file).resolve().absolute()), read_only=True
            )
            N = self._get_num_samples_per_file(Path(file).resolve().absolute())
            num_chunks = N // self.chunksize

        # read the data in chunks and process it parallelized or
        # sequentially based on the parallel_processing flag

        for i in range(0, num_chunks * self.chunksize, self.chunksize):
            data = self.process_chunk(
                raw_file,
                i,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            data_list.extend(data)

        # final chunk processing
        data = self.process_chunk(
            raw_file,
            num_chunks * self.chunksize,
            pre_transform=self.pre_transform,
            pre_filter=self.pre_filter,
        )

        data_list.extend(data)

        raw_file.close()

    InMemoryDataset.save(data_list, Path(self.processed_dir) / "data.pt")

Dataset creating csets on the fly

QGDatasetOnthefly

Bases: Dataset

A dataset that generates data on the fly using a Julia function.

Parameters:

Name Type Description Default
Dataset Dataset

The base dataset class.

required
Source code in src/QuantumGrav/dataset_onthefly.py
class QGDatasetOnthefly(Dataset):
    """A dataset that generates data on the fly using a Julia function.

    Args:
        Dataset (Dataset): The base dataset class.
    """

    def __init__(
        self,
        config: dict[str, Any],
        jl_code_path: str | Path | None = None,
        jl_constructor_name: str | None = None,
        jl_base_module_path: str | Path | None = None,
        jl_dependencies: list[str] | None = None,
        transform: Callable[[dict[Any, Any]], Data] | None = None,
        converter: Callable[[Any], Any] | None = None,
    ):
        """Initialize the dataset. This will initialize the Julia worker and set up the dataset. The julia worker must be callable with a single argument which is the number of samples to generate. The worker will return a list of raw data dictionaries which will be transformed into PyTorch Geometric Data objects using the provided transform function.

        Args:
            config (dict[str, Any]): Configuration dictionary.
            jl_code_path (str | Path | None, optional): Path to the Julia code. Defaults to None.
            jl_constructor_name (str | None, optional): Name of the Julia constructor. Defaults to None.
            jl_base_module_path (str | Path | None, optional): Path to the base Julia module. Defaults to None.
            jl_dependencies (list[str] | None, optional): List of Julia dependencies. Defaults to None.
            transform (Callable[[dict[Any, Any]], Data] | None, optional): Function to transform raw data into PyTorch Geometric Data objects. Defaults to None.
            converter (Callable[[Any], Any] | None, optional): Function to convert Julia objects into standard Python objects. Defaults to None.

        Raises:
            ValueError: If the transform function is not provided.
            ValueError: If the converter function is not provided.
            RuntimeError: If there is an error initializing the Julia process.
        """
        if transform is None:
            raise ValueError(
                "Transform function must be provided to turn raw data dictionaries into PyTorch Geometric Data objects."
            )
        self.transform = transform

        if converter is None:
            raise ValueError(
                "Converter function must be provided to convert Julia objects into standard Python objects."
            )
        else:
            self.converter = converter

        self.config = config
        self.databatch: list[Data] = []  # hold a batch of generated data

        try:
            self.worker = jl_worker.JuliaWorker(
                config,
                jl_code_path,
                jl_constructor_name,
                jl_base_module_path,
                jl_dependencies,
            )
        except jcall.JuliaError as e:
            raise RuntimeError(f"Error initializing Julia worker: {e}") from e
        except (OSError, FileNotFoundError) as e:
            raise RuntimeError(f"Path to file or directory not found: {e}") from e
        except (KeyError, TypeError) as e:
            raise RuntimeError(f"Invalid configuration for Julia worker: {e}") from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while initializing Julia worker: {e}"
            ) from e

        super().__init__(None, transform=transform, pre_transform=None, pre_filter=None)

    def len(self) -> int:
        """Return the length of the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return sys.maxsize

    def make_batch(self, size: int) -> list[Data]:
        """Create a batch of data.

        Args:
            size (int): The number of samples in the batch.

        Returns:
            list[Data]: The list of Data objects in the batch.
        """
        return [self.get(i) for i in range(size)]

    def get(self, _: int) -> Data:
        """Get a data point from the dataset. This relies on the Julia worker accepting a single integer argument to its call operator, which is the number of samples to generate. The index is ignored since the dataset generates data on the fly.

        Args:
            _ (int): The index of the data point to retrieve. This is ignored since the dataset generates data on the fly.

        Raises:
            RuntimeError: If the worker process is not initialized.
            raw_data: If there is an error retrieving raw data from the worker.
            RuntimeError: If there is an error transforming the data.

        Returns:
            Data: The transformed data point.
        """
        # this breaks the contract of the 'get'method that the base class provides, but
        # it nevertheless is useful to generate training data
        if self.worker is None:
            raise RuntimeError("Worker attribute is not initialized.")

        if len(self.databatch) == 0:
            # Call the Julia function to get the data
            try:
                raw_data = [
                    self.converter(x) for x in self.worker(self.config["batch_size"])
                ]
            except jcall.JuliaError as e:
                raise RuntimeError(f"Julia worker failed to generate data: {e}") from e
            except (KeyError, IndexError) as e:
                raise RuntimeError(f"Invalid configuration or data access: {e}") from e

            if isinstance(raw_data, Exception):
                raise RuntimeError(
                    "Unexpected error in data generation or conversion"
                ) from raw_data
            try:
                # parallel processing in Julia is handled on the Julia side
                # use primitve indexing here to avoid issues with julia arrays
                if self.config["n_processes"] > 1:
                    self.databatch = Parallel(
                        n_jobs=self.config["n_processes"],
                    )(
                        delayed(self.transform)(raw_data[i])
                        for i in range(len(raw_data))
                    )
                else:
                    self.databatch = [
                        self.transform(raw_data[i]) for i in range(len(raw_data))
                    ]
            except (KeyError, TypeError, ValueError) as e:
                raise RuntimeError(f"Data transformation failed: {e}") from e
            except Exception as e:
                raise RuntimeError(f"Unexpected error transforming data: {e}") from e

        datapoint = self.databatch.pop()

        return datapoint

__init__(config, jl_code_path=None, jl_constructor_name=None, jl_base_module_path=None, jl_dependencies=None, transform=None, converter=None)

Initialize the dataset. This will initialize the Julia worker and set up the dataset. The julia worker must be callable with a single argument which is the number of samples to generate. The worker will return a list of raw data dictionaries which will be transformed into PyTorch Geometric Data objects using the provided transform function.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary.

required
jl_code_path str | Path | None

Path to the Julia code. Defaults to None.

None
jl_constructor_name str | None

Name of the Julia constructor. Defaults to None.

None
jl_base_module_path str | Path | None

Path to the base Julia module. Defaults to None.

None
jl_dependencies list[str] | None

List of Julia dependencies. Defaults to None.

None
transform Callable[[dict[Any, Any]], Data] | None

Function to transform raw data into PyTorch Geometric Data objects. Defaults to None.

None
converter Callable[[Any], Any] | None

Function to convert Julia objects into standard Python objects. Defaults to None.

None

Raises:

Type Description
ValueError

If the transform function is not provided.

ValueError

If the converter function is not provided.

RuntimeError

If there is an error initializing the Julia process.

Source code in src/QuantumGrav/dataset_onthefly.py
def __init__(
    self,
    config: dict[str, Any],
    jl_code_path: str | Path | None = None,
    jl_constructor_name: str | None = None,
    jl_base_module_path: str | Path | None = None,
    jl_dependencies: list[str] | None = None,
    transform: Callable[[dict[Any, Any]], Data] | None = None,
    converter: Callable[[Any], Any] | None = None,
):
    """Initialize the dataset. This will initialize the Julia worker and set up the dataset. The julia worker must be callable with a single argument which is the number of samples to generate. The worker will return a list of raw data dictionaries which will be transformed into PyTorch Geometric Data objects using the provided transform function.

    Args:
        config (dict[str, Any]): Configuration dictionary.
        jl_code_path (str | Path | None, optional): Path to the Julia code. Defaults to None.
        jl_constructor_name (str | None, optional): Name of the Julia constructor. Defaults to None.
        jl_base_module_path (str | Path | None, optional): Path to the base Julia module. Defaults to None.
        jl_dependencies (list[str] | None, optional): List of Julia dependencies. Defaults to None.
        transform (Callable[[dict[Any, Any]], Data] | None, optional): Function to transform raw data into PyTorch Geometric Data objects. Defaults to None.
        converter (Callable[[Any], Any] | None, optional): Function to convert Julia objects into standard Python objects. Defaults to None.

    Raises:
        ValueError: If the transform function is not provided.
        ValueError: If the converter function is not provided.
        RuntimeError: If there is an error initializing the Julia process.
    """
    if transform is None:
        raise ValueError(
            "Transform function must be provided to turn raw data dictionaries into PyTorch Geometric Data objects."
        )
    self.transform = transform

    if converter is None:
        raise ValueError(
            "Converter function must be provided to convert Julia objects into standard Python objects."
        )
    else:
        self.converter = converter

    self.config = config
    self.databatch: list[Data] = []  # hold a batch of generated data

    try:
        self.worker = jl_worker.JuliaWorker(
            config,
            jl_code_path,
            jl_constructor_name,
            jl_base_module_path,
            jl_dependencies,
        )
    except jcall.JuliaError as e:
        raise RuntimeError(f"Error initializing Julia worker: {e}") from e
    except (OSError, FileNotFoundError) as e:
        raise RuntimeError(f"Path to file or directory not found: {e}") from e
    except (KeyError, TypeError) as e:
        raise RuntimeError(f"Invalid configuration for Julia worker: {e}") from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while initializing Julia worker: {e}"
        ) from e

    super().__init__(None, transform=transform, pre_transform=None, pre_filter=None)

get(_)

Get a data point from the dataset. This relies on the Julia worker accepting a single integer argument to its call operator, which is the number of samples to generate. The index is ignored since the dataset generates data on the fly.

Parameters:

Name Type Description Default
_ int

The index of the data point to retrieve. This is ignored since the dataset generates data on the fly.

required

Raises:

Type Description
RuntimeError

If the worker process is not initialized.

raw_data

If there is an error retrieving raw data from the worker.

RuntimeError

If there is an error transforming the data.

Returns:

Name Type Description
Data Data

The transformed data point.

Source code in src/QuantumGrav/dataset_onthefly.py
def get(self, _: int) -> Data:
    """Get a data point from the dataset. This relies on the Julia worker accepting a single integer argument to its call operator, which is the number of samples to generate. The index is ignored since the dataset generates data on the fly.

    Args:
        _ (int): The index of the data point to retrieve. This is ignored since the dataset generates data on the fly.

    Raises:
        RuntimeError: If the worker process is not initialized.
        raw_data: If there is an error retrieving raw data from the worker.
        RuntimeError: If there is an error transforming the data.

    Returns:
        Data: The transformed data point.
    """
    # this breaks the contract of the 'get'method that the base class provides, but
    # it nevertheless is useful to generate training data
    if self.worker is None:
        raise RuntimeError("Worker attribute is not initialized.")

    if len(self.databatch) == 0:
        # Call the Julia function to get the data
        try:
            raw_data = [
                self.converter(x) for x in self.worker(self.config["batch_size"])
            ]
        except jcall.JuliaError as e:
            raise RuntimeError(f"Julia worker failed to generate data: {e}") from e
        except (KeyError, IndexError) as e:
            raise RuntimeError(f"Invalid configuration or data access: {e}") from e

        if isinstance(raw_data, Exception):
            raise RuntimeError(
                "Unexpected error in data generation or conversion"
            ) from raw_data
        try:
            # parallel processing in Julia is handled on the Julia side
            # use primitve indexing here to avoid issues with julia arrays
            if self.config["n_processes"] > 1:
                self.databatch = Parallel(
                    n_jobs=self.config["n_processes"],
                )(
                    delayed(self.transform)(raw_data[i])
                    for i in range(len(raw_data))
                )
            else:
                self.databatch = [
                    self.transform(raw_data[i]) for i in range(len(raw_data))
                ]
        except (KeyError, TypeError, ValueError) as e:
            raise RuntimeError(f"Data transformation failed: {e}") from e
        except Exception as e:
            raise RuntimeError(f"Unexpected error transforming data: {e}") from e

    datapoint = self.databatch.pop()

    return datapoint

len()

Return the length of the dataset.

Returns:

Name Type Description
int int

The number of samples in the dataset.

Source code in src/QuantumGrav/dataset_onthefly.py
def len(self) -> int:
    """Return the length of the dataset.

    Returns:
        int: The number of samples in the dataset.
    """
    return sys.maxsize

make_batch(size)

Create a batch of data.

Parameters:

Name Type Description Default
size int

The number of samples in the batch.

required

Returns:

Type Description
list[Data]

list[Data]: The list of Data objects in the batch.

Source code in src/QuantumGrav/dataset_onthefly.py
def make_batch(self, size: int) -> list[Data]:
    """Create a batch of data.

    Args:
        size (int): The number of samples in the batch.

    Returns:
        list[Data]: The list of Data objects in the batch.
    """
    return [self.get(i) for i in range(size)]

Dataset loading data from disk

QGDataset

Bases: QGDatasetBase, Dataset

A dataset class for QuantumGrav data that is designed to handle large datasets stored on disk. This class provides methods for loading, processing, and writing data that are common to both in-memory and on-disk datasets.

Source code in src/QuantumGrav/dataset_ondisk.py
class QGDataset(QGDatasetBase, Dataset):
    """A dataset class for QuantumGrav data that is designed to handle large datasets stored on disk. This class provides methods for loading, processing, and writing data that are common to both in-memory and on-disk datasets."""

    def __init__(
        self,
        input: list[str | Path],
        output: str | Path,
        mode: str = "hdf5",
        reader: Callable[
            [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]
        ]
        | None = None,
        float_type: torch.dtype = torch.float32,
        int_type: torch.dtype = torch.int64,
        validate_data: bool = True,
        chunksize: int = 1000,
        n_processes: int = 1,
        # dataset properties
        transform: Callable[[Data | Collection], Data] | None = None,
        pre_transform: Callable[[Data | Collection], Data] | None = None,
        pre_filter: Callable[[Data | Collection], bool] | None = None,
    ):
        """Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk.

        Args:
            input (list[str  |  Path] | Callable[[Any], dict]): List of input hdf5 file paths.
            output (str | Path): Output directory where processed data will be stored.
            mode (str): File storage mode. 'zarr' or 'hdf5'
            reader (Callable[[h5py.File | zarr.Group, int], list[Data]] | None, optional): Function to read data from the hdf5 file. Defaults to None.
            float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
            int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
            validate_data (bool, optional): Whether to validate the data. Defaults to True.
            chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
            n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
            transform (Callable[[Data], Data] | None, optional): Function to transform the data. Defaults to None.
            pre_transform (Callable[[Data], Data] | None, optional): Function to pre-transform the data. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data. Defaults to None.
        """

        QGDatasetBase.__init__(
            self,
            input,
            output,
            mode=mode,
            reader=reader,
            float_type=float_type,
            int_type=int_type,
            validate_data=validate_data,
            chunksize=chunksize,
            n_processes=n_processes,
        )

        Dataset.__init__(
            self,
            root=output,
            transform=transform,
            pre_transform=pre_transform,
            pre_filter=pre_filter,
        )

    def write_data(self, data: list[Data], idx: int) -> int:
        """Write the processed data to disk using `torch.save`. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

        Args:
            data (list[Data]): The list of Data objects to write to disk.
            idx (int): The index to use for naming the files.
        """
        if not Path(self.processed_dir).exists():
            Path(self.processed_dir).mkdir(parents=True, exist_ok=True)

        for d in data:
            if d is not None:
                file_path = Path(self.processed_dir) / f"data_{idx}.pt"
                torch.save(d, file_path)
                idx += 1
        return idx

    def process(self) -> None:
        """Process the dataset from the read rawdata into its final form."""
        # process data files
        k = 0  # index to create the filenames for the processed data
        for file in self.input:
            if self.mode == "hdf5":
                raw_file = h5py.File(str(Path(file).resolve().absolute()), "r")
                num_chunks = raw_file["num_causal_sets"][()] // self.chunksize

            else:
                N = self._get_num_samples_per_file(Path(file).resolve().absolute())
                num_chunks = N // self.chunksize
                raw_file = zarr.storage.LocalStore(
                    str(Path(file).resolve().absolute()), read_only=True
                )

            for i in range(0, num_chunks * self.chunksize, self.chunksize):
                data = self.process_chunk(
                    raw_file,
                    i,
                    pre_transform=self.pre_transform,
                    pre_filter=self.pre_filter,
                )

                k = self.write_data(data, k)

            # final chunk processing
            data = self.process_chunk(
                raw_file,
                num_chunks * self.chunksize,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            k = self.write_data(data, k)

            raw_file.close()

    def get(self, idx: int) -> Data:
        """Get a single data sample by index."""
        if self._num_samples is None:
            raise ValueError("Dataset has not been processed yet.")

        if idx < 0 or idx >= self._num_samples:
            raise IndexError("Index out of bounds.")
        # Load the data from the processed files
        datapoint = torch.load(
            Path(self.processed_dir) / f"data_{idx}.pt", weights_only=False
        )
        if self.transform is not None:
            datapoint = self.transform(datapoint)
        return datapoint

    def __getitem__(self, idx: int | Collection[int]) -> Data | Collection[Data]:
        if isinstance(idx, int):
            return self.get(idx)
        else:
            return [self.get(i) for i in idx]

    def len(self) -> int:
        """Get the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.processed_file_names)

__init__(input, output, mode='hdf5', reader=None, float_type=torch.float32, int_type=torch.int64, validate_data=True, chunksize=1000, n_processes=1, transform=None, pre_transform=None, pre_filter=None)

Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk.

Parameters:

Name Type Description Default
input list[str | Path] | Callable[[Any], dict]

List of input hdf5 file paths.

required
output str | Path

Output directory where processed data will be stored.

required
mode str

File storage mode. 'zarr' or 'hdf5'

'hdf5'
reader Callable[[File | Group, int], list[Data]] | None

Function to read data from the hdf5 file. Defaults to None.

None
float_type dtype

Data type for float tensors. Defaults to torch.float32.

float32
int_type dtype

Data type for int tensors. Defaults to torch.int64.

int64
validate_data bool

Whether to validate the data. Defaults to True.

True
chunksize int

Size of data chunks to process at once. Defaults to 1000.

1000
n_processes int

Number of processes to use for data loading. Defaults to 1.

1
transform Callable[[Data], Data] | None

Function to transform the data. Defaults to None.

None
pre_transform Callable[[Data], Data] | None

Function to pre-transform the data. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

Function to pre-filter the data. Defaults to None.

None
Source code in src/QuantumGrav/dataset_ondisk.py
def __init__(
    self,
    input: list[str | Path],
    output: str | Path,
    mode: str = "hdf5",
    reader: Callable[
        [h5py.File | zarr.Group, torch.dtype, torch.dtype, bool], list[Data]
    ]
    | None = None,
    float_type: torch.dtype = torch.float32,
    int_type: torch.dtype = torch.int64,
    validate_data: bool = True,
    chunksize: int = 1000,
    n_processes: int = 1,
    # dataset properties
    transform: Callable[[Data | Collection], Data] | None = None,
    pre_transform: Callable[[Data | Collection], Data] | None = None,
    pre_filter: Callable[[Data | Collection], bool] | None = None,
):
    """Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk.

    Args:
        input (list[str  |  Path] | Callable[[Any], dict]): List of input hdf5 file paths.
        output (str | Path): Output directory where processed data will be stored.
        mode (str): File storage mode. 'zarr' or 'hdf5'
        reader (Callable[[h5py.File | zarr.Group, int], list[Data]] | None, optional): Function to read data from the hdf5 file. Defaults to None.
        float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
        int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
        validate_data (bool, optional): Whether to validate the data. Defaults to True.
        chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
        n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
        transform (Callable[[Data], Data] | None, optional): Function to transform the data. Defaults to None.
        pre_transform (Callable[[Data], Data] | None, optional): Function to pre-transform the data. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data. Defaults to None.
    """

    QGDatasetBase.__init__(
        self,
        input,
        output,
        mode=mode,
        reader=reader,
        float_type=float_type,
        int_type=int_type,
        validate_data=validate_data,
        chunksize=chunksize,
        n_processes=n_processes,
    )

    Dataset.__init__(
        self,
        root=output,
        transform=transform,
        pre_transform=pre_transform,
        pre_filter=pre_filter,
    )

get(idx)

Get a single data sample by index.

Source code in src/QuantumGrav/dataset_ondisk.py
def get(self, idx: int) -> Data:
    """Get a single data sample by index."""
    if self._num_samples is None:
        raise ValueError("Dataset has not been processed yet.")

    if idx < 0 or idx >= self._num_samples:
        raise IndexError("Index out of bounds.")
    # Load the data from the processed files
    datapoint = torch.load(
        Path(self.processed_dir) / f"data_{idx}.pt", weights_only=False
    )
    if self.transform is not None:
        datapoint = self.transform(datapoint)
    return datapoint

len()

Get the number of samples in the dataset.

Returns:

Name Type Description
int int

The number of samples in the dataset.

Source code in src/QuantumGrav/dataset_ondisk.py
def len(self) -> int:
    """Get the number of samples in the dataset.

    Returns:
        int: The number of samples in the dataset.
    """
    return len(self.processed_file_names)

process()

Process the dataset from the read rawdata into its final form.

Source code in src/QuantumGrav/dataset_ondisk.py
def process(self) -> None:
    """Process the dataset from the read rawdata into its final form."""
    # process data files
    k = 0  # index to create the filenames for the processed data
    for file in self.input:
        if self.mode == "hdf5":
            raw_file = h5py.File(str(Path(file).resolve().absolute()), "r")
            num_chunks = raw_file["num_causal_sets"][()] // self.chunksize

        else:
            N = self._get_num_samples_per_file(Path(file).resolve().absolute())
            num_chunks = N // self.chunksize
            raw_file = zarr.storage.LocalStore(
                str(Path(file).resolve().absolute()), read_only=True
            )

        for i in range(0, num_chunks * self.chunksize, self.chunksize):
            data = self.process_chunk(
                raw_file,
                i,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            k = self.write_data(data, k)

        # final chunk processing
        data = self.process_chunk(
            raw_file,
            num_chunks * self.chunksize,
            pre_transform=self.pre_transform,
            pre_filter=self.pre_filter,
        )

        k = self.write_data(data, k)

        raw_file.close()

write_data(data, idx)

Write the processed data to disk using torch.save. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

Parameters:

Name Type Description Default
data list[Data]

The list of Data objects to write to disk.

required
idx int

The index to use for naming the files.

required
Source code in src/QuantumGrav/dataset_ondisk.py
def write_data(self, data: list[Data], idx: int) -> int:
    """Write the processed data to disk using `torch.save`. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

    Args:
        data (list[Data]): The list of Data objects to write to disk.
        idx (int): The index to use for naming the files.
    """
    if not Path(self.processed_dir).exists():
        Path(self.processed_dir).mkdir(parents=True, exist_ok=True)

    for d in data:
        if d is not None:
            file_path = Path(self.processed_dir) / f"data_{idx}.pt"
            torch.save(d, file_path)
            idx += 1
    return idx

Julia-Python integration

This class provides a bridge to some user-supplied Julia code and converts its output into something Python can work with.

JuliaWorker

This class runs a given Julia callable object from a given Julia code file. It additionally imports the QuantumGrav julia module and installs given dependencies if provided. After creation, the wrapped julia callable can be called via the call method of this calls. Warning: This class requires the juliacall package to be installed in the Python environment. Warning: This class is in early development and may change in the future, be slow, or otherwise not ready for high performance production use.

Source code in src/QuantumGrav/julia_worker.py
class JuliaWorker:
    """This class runs a given Julia callable object from a given Julia code file. It additionally imports the QuantumGrav julia module and installs given dependencies if provided. After creation, the wrapped julia callable can be called via the __call__ method of this calls.
    **Warning**: This class requires the juliacall package to be installed in the Python environment.
    **Warning**: This class is in early development and may change in the future, be slow, or otherwise not ready for high performance production use.
    """

    jl_constructor_name = None

    def __init__(
        self,
        jl_kwargs: dict[str, Any] | None = None,
        jl_code_path: str | Path | None = None,
        jl_constructor_name: str | None = None,
        jl_base_module_path: str | Path | None = None,
        jl_dependencies: list[str] | None = None,
    ):
        """Initializes the JuliaWorker with the given parameters.

        Args:
            jl_kwargs (dict[str, Any] | None, optional): Keyword arguments to pass to the Julia callable object constructor. Defaults to None.
            jl_code_path (str | Path | None, optional): Path to the Julia code file in which the callable object is defined. Defaults to None.
            jl_constructor_name (str | None, optional): Name of the Julia constructor function. Defaults to None.
            jl_base_module_path (str | Path | None, optional): Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default `using QuantumGrav` import. Defaults to None.
            jl_dependencies (list[str] | None, optional): List of Julia package dependencies. Defaults to None. Will be installed via `Pkg.add` if provided upon first call.

        Raises:
            ValueError: If the Julia function name is not provided.
            ValueError: If the Julia code path is not provided.
            FileNotFoundError: If the Julia code path does not exist.
            NotImplementedError: If the base module path is not provided.
            RuntimeError: If there is an error loading the base module.
            RuntimeError: If there is an error loading Julia dependencies.
            RuntimeError: If there is an error loading the Julia code.
        """

        # we test for a bunch of needed args first
        if jl_constructor_name is None:
            raise ValueError("Julia function name must be provided.")

        if jl_code_path is None:
            raise ValueError("Julia code path must be provided.")

        jl_code_path = Path(jl_code_path).resolve().absolute()
        if not jl_code_path.exists():
            raise FileNotFoundError(f"Julia code path {jl_code_path} does not exist.")

        self.jl_constructor_name = jl_constructor_name
        self.jl_module_name = "QuantumGravPy2Jl"  # the module name is hardcoded here

        # try to initialize the new Julia module, then later do every julia call through this module
        try:
            self.jl_module = jcall.newmodule(self.jl_module_name)

        except jcall.JuliaError as e:
            raise RuntimeError(f"Error creating new julia module: {e}") from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while creating Julia module {self.jl_module_name}: {e}"
            ) from e

        # add base module for dependencies if exists
        if jl_base_module_path is not None:
            jl_base_module_path = Path(jl_base_module_path).resolve().absolute()
            try:
                self.jl_module.seval(
                    f'using Pkg; Pkg.develop(path="{str(jl_base_module_path)}")'
                )  # only for now -> get from package index later
            except jcall.JuliaError as e:
                raise RuntimeError(
                    f"Error loading base module {str(jl_base_module_path)}: {e}"
                ) from e
            except Exception as e:
                raise RuntimeError(
                    f"Unexpected exception while initializing julia base module: {e}"
                ) from e

        try:
            # add dependencies if provided
            if jl_dependencies is not None:
                for dep in jl_dependencies:
                    self.jl_module.seval(f'using Pkg; Pkg.add("{dep}")')
        except jcall.JuliaError as e:
            raise RuntimeError(f"Error processing Julia dependencies: {e}") from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while processing Julia dependencies: {e}"
            ) from e

        try:
            # load the julia data generation julia code
            self.jl_module.seval(f'push!(LOAD_PATH, "{jl_code_path}")')
            self.jl_module.seval("using QuantumGrav")
            self.jl_module.seval(f'include("{jl_code_path}")')
            constructor_name = getattr(self.jl_module, jl_constructor_name)
            self.jl_generator = constructor_name(jl_kwargs)
        except jcall.JuliaError as e:
            raise RuntimeError(
                f"Error evaluating Julia code to activate base module: {e}"
            ) from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while loading Julia base module: {e}"
            ) from e

    def __call__(self, *args, **kwargs) -> Any:
        """Calls the wrapped Julia generator with the given arguments.

        Raises:
            RuntimeError: If the Julia module is not initialized.
        Args:
            *args: Positional arguments to pass to the Julia generator.
            **kwargs: Keyword arguments to pass to the Julia generator.
        Returns:
            Any: The raw data generated by the Julia generator.
        """
        if self.jl_module is None:
            raise RuntimeError("Julia module is not initialized.")
        raw_data = self.jl_generator(*args, **kwargs)
        return raw_data

__call__(*args, **kwargs)

Calls the wrapped Julia generator with the given arguments.

Raises:

Type Description
RuntimeError

If the Julia module is not initialized.

Args: args: Positional arguments to pass to the Julia generator. *kwargs: Keyword arguments to pass to the Julia generator. Returns: Any: The raw data generated by the Julia generator.

Source code in src/QuantumGrav/julia_worker.py
def __call__(self, *args, **kwargs) -> Any:
    """Calls the wrapped Julia generator with the given arguments.

    Raises:
        RuntimeError: If the Julia module is not initialized.
    Args:
        *args: Positional arguments to pass to the Julia generator.
        **kwargs: Keyword arguments to pass to the Julia generator.
    Returns:
        Any: The raw data generated by the Julia generator.
    """
    if self.jl_module is None:
        raise RuntimeError("Julia module is not initialized.")
    raw_data = self.jl_generator(*args, **kwargs)
    return raw_data

__init__(jl_kwargs=None, jl_code_path=None, jl_constructor_name=None, jl_base_module_path=None, jl_dependencies=None)

Initializes the JuliaWorker with the given parameters.

Parameters:

Name Type Description Default
jl_kwargs dict[str, Any] | None

Keyword arguments to pass to the Julia callable object constructor. Defaults to None.

None
jl_code_path str | Path | None

Path to the Julia code file in which the callable object is defined. Defaults to None.

None
jl_constructor_name str | None

Name of the Julia constructor function. Defaults to None.

None
jl_base_module_path str | Path | None

Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default using QuantumGrav import. Defaults to None.

None
jl_dependencies list[str] | None

List of Julia package dependencies. Defaults to None. Will be installed via Pkg.add if provided upon first call.

None

Raises:

Type Description
ValueError

If the Julia function name is not provided.

ValueError

If the Julia code path is not provided.

FileNotFoundError

If the Julia code path does not exist.

NotImplementedError

If the base module path is not provided.

RuntimeError

If there is an error loading the base module.

RuntimeError

If there is an error loading Julia dependencies.

RuntimeError

If there is an error loading the Julia code.

Source code in src/QuantumGrav/julia_worker.py
def __init__(
    self,
    jl_kwargs: dict[str, Any] | None = None,
    jl_code_path: str | Path | None = None,
    jl_constructor_name: str | None = None,
    jl_base_module_path: str | Path | None = None,
    jl_dependencies: list[str] | None = None,
):
    """Initializes the JuliaWorker with the given parameters.

    Args:
        jl_kwargs (dict[str, Any] | None, optional): Keyword arguments to pass to the Julia callable object constructor. Defaults to None.
        jl_code_path (str | Path | None, optional): Path to the Julia code file in which the callable object is defined. Defaults to None.
        jl_constructor_name (str | None, optional): Name of the Julia constructor function. Defaults to None.
        jl_base_module_path (str | Path | None, optional): Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default `using QuantumGrav` import. Defaults to None.
        jl_dependencies (list[str] | None, optional): List of Julia package dependencies. Defaults to None. Will be installed via `Pkg.add` if provided upon first call.

    Raises:
        ValueError: If the Julia function name is not provided.
        ValueError: If the Julia code path is not provided.
        FileNotFoundError: If the Julia code path does not exist.
        NotImplementedError: If the base module path is not provided.
        RuntimeError: If there is an error loading the base module.
        RuntimeError: If there is an error loading Julia dependencies.
        RuntimeError: If there is an error loading the Julia code.
    """

    # we test for a bunch of needed args first
    if jl_constructor_name is None:
        raise ValueError("Julia function name must be provided.")

    if jl_code_path is None:
        raise ValueError("Julia code path must be provided.")

    jl_code_path = Path(jl_code_path).resolve().absolute()
    if not jl_code_path.exists():
        raise FileNotFoundError(f"Julia code path {jl_code_path} does not exist.")

    self.jl_constructor_name = jl_constructor_name
    self.jl_module_name = "QuantumGravPy2Jl"  # the module name is hardcoded here

    # try to initialize the new Julia module, then later do every julia call through this module
    try:
        self.jl_module = jcall.newmodule(self.jl_module_name)

    except jcall.JuliaError as e:
        raise RuntimeError(f"Error creating new julia module: {e}") from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while creating Julia module {self.jl_module_name}: {e}"
        ) from e

    # add base module for dependencies if exists
    if jl_base_module_path is not None:
        jl_base_module_path = Path(jl_base_module_path).resolve().absolute()
        try:
            self.jl_module.seval(
                f'using Pkg; Pkg.develop(path="{str(jl_base_module_path)}")'
            )  # only for now -> get from package index later
        except jcall.JuliaError as e:
            raise RuntimeError(
                f"Error loading base module {str(jl_base_module_path)}: {e}"
            ) from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while initializing julia base module: {e}"
            ) from e

    try:
        # add dependencies if provided
        if jl_dependencies is not None:
            for dep in jl_dependencies:
                self.jl_module.seval(f'using Pkg; Pkg.add("{dep}")')
    except jcall.JuliaError as e:
        raise RuntimeError(f"Error processing Julia dependencies: {e}") from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while processing Julia dependencies: {e}"
        ) from e

    try:
        # load the julia data generation julia code
        self.jl_module.seval(f'push!(LOAD_PATH, "{jl_code_path}")')
        self.jl_module.seval("using QuantumGrav")
        self.jl_module.seval(f'include("{jl_code_path}")')
        constructor_name = getattr(self.jl_module, jl_constructor_name)
        self.jl_generator = constructor_name(jl_kwargs)
    except jcall.JuliaError as e:
        raise RuntimeError(
            f"Error evaluating Julia code to activate base module: {e}"
        ) from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while loading Julia base module: {e}"
        ) from e

Model training

This consists of two classes, one which provides the basic training functionality - Trainer, and a class derived from this, TrainerDDP, which provides functionality for distributed data parallel training.

Trainer

This class provides wrapper functions for setting up a model and for training and evaluating it. The basic concept is that everything is defined in a yaml file and handed to this class together with evaluator classes. After construction, the train and test functions will take care of the training and testing of the model.

Trainer

Trainer class for training and evaluating GNN models.

Source code in src/QuantumGrav/train.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class Trainer:
    """Trainer class for training and evaluating GNN models."""

    def __init__(
        self,
        config: dict[str, Any],
        # training and evaluation functions
        criterion: Callable[[Any, Data], torch.Tensor],
        apply_model: Callable | None = None,
        # training evaluation and reporting
        early_stopping: Callable[[Collection[Any] | torch.Tensor], bool] | None = None,
        validator: DefaultValidator | None = None,
        tester: DefaultTester | None = None,
    ):
        """Initialize the trainer.

        Args:
            config (dict[str, Any]): The configuration dictionary.
            criterion (Callable): The loss function to use.
            apply_model (Callable | None, optional): A function to apply the model. Defaults to None.
            early_stopping (Callable[[Collection[Any]], bool] | None, optional): A function for early stopping. Defaults to None.
            validator (DefaultValidator | None, optional): A validator for model evaluation. Defaults to None.
            tester (DefaultTester | None, optional): A tester for model evaluation. Defaults to None.

        Raises:
            ValueError: If the configuration is invalid.
        """
        if (
            all(x in config for x in ["training", "model", "validation", "testing"])
            is False
        ):
            raise ValueError(
                "Configuration must contain 'training', 'model', 'validation' and 'testing' sections."
            )

        self.config = config
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(config.get("log_level", logging.INFO))
        self.logger.info("Initializing Trainer instance")

        # functions for executing training and evaluation
        self.criterion = criterion
        self.apply_model = apply_model
        self.early_stopping = early_stopping
        self.seed = config["training"]["seed"]
        self.device = torch.device(config["training"]["device"])

        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.seed)

        # parameters for finding out which model is best
        self.best_score = None
        self.best_epoch = 0
        self.epoch = 0

        # date and time of run:
        run_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.data_path = (
            Path(self.config["training"]["path"])
            / f"{config['model'].get('name', 'run')}_{run_date}"
        )

        if not self.data_path.exists():
            self.data_path.mkdir(parents=True)
        self.logger.info(f"Data path set to: {self.data_path}")

        self.checkpoint_path = self.data_path / "model_checkpoints"
        self.checkpoint_at = config["training"].get("checkpoint_at", None)
        self.latest_checkpoint = None
        # training and evaluation functions
        self.validator = validator
        self.tester = tester
        self.model = None
        self.optimizer = None

        with open(self.data_path / "config.yaml", "w") as f:
            yaml.dump(self.config, f)

        self.logger.info("Trainer initialized")
        self.logger.debug(f"Configuration: {self.config}")

    def initialize_model(self) -> Any:
        """Initialize the model for training.

        Returns:
            Any: The initialized model.
        """
        if self.model is not None:
            return self.model
        # try:
        model = gnn_model.GNNModel.from_config(self.config["model"])
        model = model.to(self.device)
        self.model = model
        self.logger.info("Model initialized to device: {}".format(self.device))
        return self.model

    def initialize_optimizer(self) -> torch.optim.Optimizer | None:
        """Initialize the optimizer for training.

        Raises:
            RuntimeError: If the model is not initialized.

        Returns:
            torch.optim.Optimizer: The initialized optimizer.
        """

        if self.model is None:
            raise RuntimeError(
                "Model must be initialized before initializing optimizer."
            )

        if self.optimizer is not None:
            return self.optimizer

        try:
            lr = self.config["training"].get("learning_rate", 0.001)
            weight_decay = self.config["training"].get("weight_decay", 0.0001)
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=lr,
                weight_decay=weight_decay,
            )
            self.optimizer = optimizer
            self.logger.info(
                f"Optimizer initialized with learning rate: {lr} and weight decay: {weight_decay}"
            )
        except Exception as e:
            self.logger.error(f"Error initializing optimizer: {e}")
        return self.optimizer

    def prepare_dataloaders(
        self, dataset: Dataset, split: list[float] = [0.8, 0.1, 0.1]
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Prepare the data loaders for training, validation, and testing.

        Args:
            dataset (Dataset): The dataset to prepare.
            split (list[float], optional): The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].

        Returns:
            Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.
        """
        train_size = int(len(dataset) * split[0])
        val_size = int(len(dataset) * split[1])
        test_size = len(dataset) - train_size - val_size

        if not np.isclose(np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False):
            raise ValueError(f"Split ratios must sum to 1.0. Provided split: {split}")

        self.train_dataset, self.val_dataset, self.test_dataset = (
            torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
        )

        train_loader = DataLoader(
            self.train_dataset,  # type: ignore
            batch_size=self.config["training"]["batch_size"],
            num_workers=self.config["training"].get("num_workers", 0),
            pin_memory=self.config["training"].get("pin_memory", True),
            drop_last=self.config["training"].get("drop_last", False),
            prefetch_factor=self.config["training"].get("prefetch_factor", None),
            shuffle=self.config["training"].get("shuffle", True),
        )

        val_loader = DataLoader(
            self.val_dataset,  # type: ignore
            batch_size=self.config["validation"]["batch_size"],
            num_workers=self.config["validation"].get("num_workers", 0),
            pin_memory=self.config["validation"].get("pin_memory", True),
            drop_last=self.config["validation"].get("drop_last", False),
            prefetch_factor=self.config["validation"].get("prefetch_factor", None),
            shuffle=self.config["validation"].get("shuffle", True),
        )

        test_loader = DataLoader(
            self.test_dataset,  # type: ignore
            batch_size=self.config["testing"]["batch_size"],
            num_workers=self.config["testing"].get("num_workers", 0),
            pin_memory=self.config["testing"].get("pin_memory", True),
            drop_last=self.config["testing"].get("drop_last", False),
            prefetch_factor=self.config["testing"].get("prefetch_factor", None),
            shuffle=self.config["testing"].get("shuffle", True),
        )
        self.logger.info(
            f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
        )
        return train_loader, val_loader, test_loader

    # training helper functions
    def _evaluate_batch(
        self,
        model: torch.nn.Module,
        data: Data,
    ) -> torch.Tensor | Collection[torch.Tensor]:
        """Evaluate a single batch of data using the model.

        Args:
            model (torch.nn.Module): The model to evaluate.
            data (Data): The input data for the model.

        Returns:
            torch.Tensor | Collection[torch.Tensor]: The output of the model.
        """
        self.logger.debug(f"  Evaluating batch on device: {self.device}")
        if self.apply_model:
            outputs = self.apply_model(model, data)
        else:
            outputs = model(data.x, data.edge_index, data.batch)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return outputs

    def _run_train_epoch(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_loader: DataLoader,
    ) -> torch.Tensor:
        """Run a single training epoch.

        Args:
            model (torch.nn.Module): The model to train.
            optimizer (torch.optim.Optimizer): The optimizer for the model.
            train_loader (DataLoader): The data loader for the training set.
        Raises:
            RuntimeError: If the model is not initialized.
            RuntimeError: If the optimizer is not initialized.

        Returns:
            torch.Tensor: The training loss for each batch stored in a torch.Tensor
        """

        if model is None:
            raise RuntimeError("Model must be initialized before training.")

        if optimizer is None:
            raise RuntimeError("Optimizer must be initialized before training.")

        #
        output_size = len(self.config["model"]["classifier"]["output_dims"])

        losses = torch.zeros(
            len(train_loader), output_size, dtype=torch.float32, device=self.device
        )
        self.logger.info(f"  Starting training epoch {self.epoch}")
        # training run
        for i, batch in enumerate(
            tqdm.tqdm(train_loader, desc=f"Training Epoch {self.epoch}")
        ):
            self.logger.debug(f"    Moving batch {i} to device: {self.device}")
            optimizer.zero_grad()

            data = batch.to(self.device)
            outputs = self._evaluate_batch(model, data)

            self.logger.debug("    Computing loss")
            loss = self.criterion(outputs, data)

            self.logger.debug(f"    Backpropagating loss: {loss.item()}")
            loss.backward()

            optimizer.step()

            losses[i, :] = loss

        return losses

    def _check_model_status(self, eval_data: list[Any] | torch.Tensor) -> bool:
        """Check the status of the model during training.

        Args:
            eval_data (list[Any]): The evaluation data from the training epoch.

        Returns:
            bool: Whether the training should stop early.
        """
        if (
            self.checkpoint_at is not None
            and self.epoch % self.checkpoint_at == 0
            and self.epoch > 0
        ):
            self.save_checkpoint()

        if self.early_stopping is not None:
            if self.early_stopping(eval_data):
                self.logger.debug(f"Early stopping at epoch {self.epoch}.")
                self.save_checkpoint(name_addition="early_stopping")
                return True

            if self.early_stopping.found_better:
                self.logger.debug(f"Found better model at epoch {self.epoch}.")
                self.save_checkpoint(name_addition="current_best")
                # not returning true because this is not the end of training

        return False

    def run_training(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        trial: optuna.trial.Trial | None = None,
    ) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
        """Run the training process.

        Args:
            train_loader (DataLoader): The data loader for the training set.
            val_loader (DataLoader): The data loader for the validation set.
            trial (optuna.trial.Trial | None, optional): An Optuna trial
                for hyperparameter tuning. Defaults to None.

        Returns:
            Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
        """
        self.logger.info("Starting training process.")
        # training loop
        num_epochs = self.config["training"]["num_epochs"]

        self.model = self.initialize_model()

        optimizer = self.initialize_optimizer()

        if optimizer is None:
            raise AttributeError(
                "Error, optimizer must be successfully initialized before running training"
            )

        total_training_data = torch.zeros(num_epochs, 2, dtype=torch.float32)

        for epoch in range(0, num_epochs):
            self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")

            self.model.train()

            epoch_data = self._run_train_epoch(self.model, optimizer, train_loader)

            # collect mean and std for each epoch
            total_training_data[epoch, :] = torch.Tensor(
                [epoch_data.mean(dim=0).item(), epoch_data.std(dim=0).item()]
            )

            self.logger.info(
                f"  Completed epoch {self.epoch}. training loss: {total_training_data[self.epoch, 0]} +/- {total_training_data[self.epoch, 1]}."
            )

            # evaluation run on validation set
            if self.validator is not None:
                validation_result = self.validator.validate(self.model, val_loader)
                self.validator.report(validation_result)

                # integrate Optuna here for hyperparameter tuning
                if trial is not None:
                    avg_sigma_loss = self.validator.data[self.epoch]
                    avg_loss = avg_sigma_loss[0]
                    trial.report(avg_loss, self.epoch)

                    # Handle pruning based on the intermediate value.
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

            should_stop = self._check_model_status(
                self.validator.data if self.validator else total_training_data,
            )
            if should_stop:
                self.logger.info("Stopping training early.")
                break
            self.epoch += 1

        self.logger.info("Training process completed.")
        self.logger.info("Saving model")

        outpath = self.data_path / f"final_model_epoch={self.epoch}.pt"
        self.model.save(outpath)

        return total_training_data, self.validator.data if self.validator else []

    def run_test(
        self,
        test_loader: DataLoader,
    ) -> Collection[Any]:
        """Run testing phase.

        Args:
            test_loader (DataLoader): The data loader for the test set.

        Raises:
            RuntimeError: If the model is not initialized.
            RuntimeError: If the test data is not available.

        Returns:
            Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.
        """
        self.logger.info("Starting testing process.")
        if self.model is None:
            raise RuntimeError("Model must be initialized before testing.")
        self.model.eval()
        if self.tester is None:
            raise RuntimeError("Tester must be initialized before testing.")
        test_result = self.tester.test(self.model, test_loader)
        self.tester.report(test_result)
        self.logger.info("Testing process completed.")
        return self.tester.data

    def save_checkpoint(self, name_addition: str = ""):
        """Save model checkpoint.

        Raises:
            ValueError: If the model is not initialized.
            ValueError: If the model configuration does not contain 'name'.
            ValueError: If the training configuration does not contain 'checkpoint_path'.
        """
        if self.model is None:
            raise ValueError("Model must be initialized before saving checkpoint.")

        self.logger.info(
            f"Saving checkpoint for model {self.config['model'].get('name', ' model')} at epoch {self.epoch} to {self.checkpoint_path}"
        )
        outpath = (
            self.checkpoint_path
            / f"{self.config['model'].get('name', 'model')}_epoch_{self.epoch}_{name_addition}.pt"
        )

        if outpath.exists() is False:
            outpath.parent.mkdir(parents=True, exist_ok=True)
            self.logger.debug(f"Created directory {outpath.parent} for checkpoint.")

        self.latest_checkpoint = outpath
        self.model.save(outpath)

    def load_checkpoint(self, epoch: int, name_addition: str = "") -> None:
        """Load model checkpoint to the device given

        Args:
            epoch (int): The epoch number to load.

        Raises:
            RuntimeError: If the model is not initialized.
        """

        if self.model is None:
            raise RuntimeError("Model must be initialized before loading checkpoint.")

        loadpath = (
            Path(self.checkpoint_path)
            / f"{self.config['model'].get('name', 'model')}_epoch_{epoch}_{name_addition}.pt"
        )

        if not loadpath.exists():
            raise FileNotFoundError(f"Checkpoint file {loadpath} does not exist.")

        self.model = gnn_model.GNNModel.load(loadpath)

__init__(config, criterion, apply_model=None, early_stopping=None, validator=None, tester=None)

Initialize the trainer.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary.

required
criterion Callable

The loss function to use.

required
apply_model Callable | None

A function to apply the model. Defaults to None.

None
early_stopping Callable[[Collection[Any]], bool] | None

A function for early stopping. Defaults to None.

None
validator DefaultValidator | None

A validator for model evaluation. Defaults to None.

None
tester DefaultTester | None

A tester for model evaluation. Defaults to None.

None

Raises:

Type Description
ValueError

If the configuration is invalid.

Source code in src/QuantumGrav/train.py
def __init__(
    self,
    config: dict[str, Any],
    # training and evaluation functions
    criterion: Callable[[Any, Data], torch.Tensor],
    apply_model: Callable | None = None,
    # training evaluation and reporting
    early_stopping: Callable[[Collection[Any] | torch.Tensor], bool] | None = None,
    validator: DefaultValidator | None = None,
    tester: DefaultTester | None = None,
):
    """Initialize the trainer.

    Args:
        config (dict[str, Any]): The configuration dictionary.
        criterion (Callable): The loss function to use.
        apply_model (Callable | None, optional): A function to apply the model. Defaults to None.
        early_stopping (Callable[[Collection[Any]], bool] | None, optional): A function for early stopping. Defaults to None.
        validator (DefaultValidator | None, optional): A validator for model evaluation. Defaults to None.
        tester (DefaultTester | None, optional): A tester for model evaluation. Defaults to None.

    Raises:
        ValueError: If the configuration is invalid.
    """
    if (
        all(x in config for x in ["training", "model", "validation", "testing"])
        is False
    ):
        raise ValueError(
            "Configuration must contain 'training', 'model', 'validation' and 'testing' sections."
        )

    self.config = config
    self.logger = logging.getLogger(__name__)
    self.logger.setLevel(config.get("log_level", logging.INFO))
    self.logger.info("Initializing Trainer instance")

    # functions for executing training and evaluation
    self.criterion = criterion
    self.apply_model = apply_model
    self.early_stopping = early_stopping
    self.seed = config["training"]["seed"]
    self.device = torch.device(config["training"]["device"])

    torch.manual_seed(self.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(self.seed)

    # parameters for finding out which model is best
    self.best_score = None
    self.best_epoch = 0
    self.epoch = 0

    # date and time of run:
    run_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    self.data_path = (
        Path(self.config["training"]["path"])
        / f"{config['model'].get('name', 'run')}_{run_date}"
    )

    if not self.data_path.exists():
        self.data_path.mkdir(parents=True)
    self.logger.info(f"Data path set to: {self.data_path}")

    self.checkpoint_path = self.data_path / "model_checkpoints"
    self.checkpoint_at = config["training"].get("checkpoint_at", None)
    self.latest_checkpoint = None
    # training and evaluation functions
    self.validator = validator
    self.tester = tester
    self.model = None
    self.optimizer = None

    with open(self.data_path / "config.yaml", "w") as f:
        yaml.dump(self.config, f)

    self.logger.info("Trainer initialized")
    self.logger.debug(f"Configuration: {self.config}")

initialize_model()

Initialize the model for training.

Returns:

Name Type Description
Any Any

The initialized model.

Source code in src/QuantumGrav/train.py
def initialize_model(self) -> Any:
    """Initialize the model for training.

    Returns:
        Any: The initialized model.
    """
    if self.model is not None:
        return self.model
    # try:
    model = gnn_model.GNNModel.from_config(self.config["model"])
    model = model.to(self.device)
    self.model = model
    self.logger.info("Model initialized to device: {}".format(self.device))
    return self.model

initialize_optimizer()

Initialize the optimizer for training.

Raises:

Type Description
RuntimeError

If the model is not initialized.

Returns:

Type Description
Optimizer | None

torch.optim.Optimizer: The initialized optimizer.

Source code in src/QuantumGrav/train.py
def initialize_optimizer(self) -> torch.optim.Optimizer | None:
    """Initialize the optimizer for training.

    Raises:
        RuntimeError: If the model is not initialized.

    Returns:
        torch.optim.Optimizer: The initialized optimizer.
    """

    if self.model is None:
        raise RuntimeError(
            "Model must be initialized before initializing optimizer."
        )

    if self.optimizer is not None:
        return self.optimizer

    try:
        lr = self.config["training"].get("learning_rate", 0.001)
        weight_decay = self.config["training"].get("weight_decay", 0.0001)
        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=lr,
            weight_decay=weight_decay,
        )
        self.optimizer = optimizer
        self.logger.info(
            f"Optimizer initialized with learning rate: {lr} and weight decay: {weight_decay}"
        )
    except Exception as e:
        self.logger.error(f"Error initializing optimizer: {e}")
    return self.optimizer

load_checkpoint(epoch, name_addition='')

Load model checkpoint to the device given

Parameters:

Name Type Description Default
epoch int

The epoch number to load.

required

Raises:

Type Description
RuntimeError

If the model is not initialized.

Source code in src/QuantumGrav/train.py
def load_checkpoint(self, epoch: int, name_addition: str = "") -> None:
    """Load model checkpoint to the device given

    Args:
        epoch (int): The epoch number to load.

    Raises:
        RuntimeError: If the model is not initialized.
    """

    if self.model is None:
        raise RuntimeError("Model must be initialized before loading checkpoint.")

    loadpath = (
        Path(self.checkpoint_path)
        / f"{self.config['model'].get('name', 'model')}_epoch_{epoch}_{name_addition}.pt"
    )

    if not loadpath.exists():
        raise FileNotFoundError(f"Checkpoint file {loadpath} does not exist.")

    self.model = gnn_model.GNNModel.load(loadpath)

prepare_dataloaders(dataset, split=[0.8, 0.1, 0.1])

Prepare the data loaders for training, validation, and testing.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to prepare.

required
split list[float]

The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].

[0.8, 0.1, 0.1]

Returns:

Type Description
Tuple[DataLoader, DataLoader, DataLoader]

Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.

Source code in src/QuantumGrav/train.py
def prepare_dataloaders(
    self, dataset: Dataset, split: list[float] = [0.8, 0.1, 0.1]
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Prepare the data loaders for training, validation, and testing.

    Args:
        dataset (Dataset): The dataset to prepare.
        split (list[float], optional): The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].

    Returns:
        Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.
    """
    train_size = int(len(dataset) * split[0])
    val_size = int(len(dataset) * split[1])
    test_size = len(dataset) - train_size - val_size

    if not np.isclose(np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False):
        raise ValueError(f"Split ratios must sum to 1.0. Provided split: {split}")

    self.train_dataset, self.val_dataset, self.test_dataset = (
        torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    )

    train_loader = DataLoader(
        self.train_dataset,  # type: ignore
        batch_size=self.config["training"]["batch_size"],
        num_workers=self.config["training"].get("num_workers", 0),
        pin_memory=self.config["training"].get("pin_memory", True),
        drop_last=self.config["training"].get("drop_last", False),
        prefetch_factor=self.config["training"].get("prefetch_factor", None),
        shuffle=self.config["training"].get("shuffle", True),
    )

    val_loader = DataLoader(
        self.val_dataset,  # type: ignore
        batch_size=self.config["validation"]["batch_size"],
        num_workers=self.config["validation"].get("num_workers", 0),
        pin_memory=self.config["validation"].get("pin_memory", True),
        drop_last=self.config["validation"].get("drop_last", False),
        prefetch_factor=self.config["validation"].get("prefetch_factor", None),
        shuffle=self.config["validation"].get("shuffle", True),
    )

    test_loader = DataLoader(
        self.test_dataset,  # type: ignore
        batch_size=self.config["testing"]["batch_size"],
        num_workers=self.config["testing"].get("num_workers", 0),
        pin_memory=self.config["testing"].get("pin_memory", True),
        drop_last=self.config["testing"].get("drop_last", False),
        prefetch_factor=self.config["testing"].get("prefetch_factor", None),
        shuffle=self.config["testing"].get("shuffle", True),
    )
    self.logger.info(
        f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
    )
    return train_loader, val_loader, test_loader

run_test(test_loader)

Run testing phase.

Parameters:

Name Type Description Default
test_loader DataLoader

The data loader for the test set.

required

Raises:

Type Description
RuntimeError

If the model is not initialized.

RuntimeError

If the test data is not available.

Returns:

Type Description
Collection[Any]

Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.

Source code in src/QuantumGrav/train.py
def run_test(
    self,
    test_loader: DataLoader,
) -> Collection[Any]:
    """Run testing phase.

    Args:
        test_loader (DataLoader): The data loader for the test set.

    Raises:
        RuntimeError: If the model is not initialized.
        RuntimeError: If the test data is not available.

    Returns:
        Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.
    """
    self.logger.info("Starting testing process.")
    if self.model is None:
        raise RuntimeError("Model must be initialized before testing.")
    self.model.eval()
    if self.tester is None:
        raise RuntimeError("Tester must be initialized before testing.")
    test_result = self.tester.test(self.model, test_loader)
    self.tester.report(test_result)
    self.logger.info("Testing process completed.")
    return self.tester.data

run_training(train_loader, val_loader, trial=None)

Run the training process.

Parameters:

Name Type Description Default
train_loader DataLoader

The data loader for the training set.

required
val_loader DataLoader

The data loader for the validation set.

required
trial Trial | None

An Optuna trial for hyperparameter tuning. Defaults to None.

None

Returns:

Type Description
Tuple[Tensor | Collection[Any], Tensor | Collection[Any]]

Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.

Source code in src/QuantumGrav/train.py
def run_training(
    self,
    train_loader: DataLoader,
    val_loader: DataLoader,
    trial: optuna.trial.Trial | None = None,
) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
    """Run the training process.

    Args:
        train_loader (DataLoader): The data loader for the training set.
        val_loader (DataLoader): The data loader for the validation set.
        trial (optuna.trial.Trial | None, optional): An Optuna trial
            for hyperparameter tuning. Defaults to None.

    Returns:
        Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
    """
    self.logger.info("Starting training process.")
    # training loop
    num_epochs = self.config["training"]["num_epochs"]

    self.model = self.initialize_model()

    optimizer = self.initialize_optimizer()

    if optimizer is None:
        raise AttributeError(
            "Error, optimizer must be successfully initialized before running training"
        )

    total_training_data = torch.zeros(num_epochs, 2, dtype=torch.float32)

    for epoch in range(0, num_epochs):
        self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")

        self.model.train()

        epoch_data = self._run_train_epoch(self.model, optimizer, train_loader)

        # collect mean and std for each epoch
        total_training_data[epoch, :] = torch.Tensor(
            [epoch_data.mean(dim=0).item(), epoch_data.std(dim=0).item()]
        )

        self.logger.info(
            f"  Completed epoch {self.epoch}. training loss: {total_training_data[self.epoch, 0]} +/- {total_training_data[self.epoch, 1]}."
        )

        # evaluation run on validation set
        if self.validator is not None:
            validation_result = self.validator.validate(self.model, val_loader)
            self.validator.report(validation_result)

            # integrate Optuna here for hyperparameter tuning
            if trial is not None:
                avg_sigma_loss = self.validator.data[self.epoch]
                avg_loss = avg_sigma_loss[0]
                trial.report(avg_loss, self.epoch)

                # Handle pruning based on the intermediate value.
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()

        should_stop = self._check_model_status(
            self.validator.data if self.validator else total_training_data,
        )
        if should_stop:
            self.logger.info("Stopping training early.")
            break
        self.epoch += 1

    self.logger.info("Training process completed.")
    self.logger.info("Saving model")

    outpath = self.data_path / f"final_model_epoch={self.epoch}.pt"
    self.model.save(outpath)

    return total_training_data, self.validator.data if self.validator else []

save_checkpoint(name_addition='')

Save model checkpoint.

Raises:

Type Description
ValueError

If the model is not initialized.

ValueError

If the model configuration does not contain 'name'.

ValueError

If the training configuration does not contain 'checkpoint_path'.

Source code in src/QuantumGrav/train.py
def save_checkpoint(self, name_addition: str = ""):
    """Save model checkpoint.

    Raises:
        ValueError: If the model is not initialized.
        ValueError: If the model configuration does not contain 'name'.
        ValueError: If the training configuration does not contain 'checkpoint_path'.
    """
    if self.model is None:
        raise ValueError("Model must be initialized before saving checkpoint.")

    self.logger.info(
        f"Saving checkpoint for model {self.config['model'].get('name', ' model')} at epoch {self.epoch} to {self.checkpoint_path}"
    )
    outpath = (
        self.checkpoint_path
        / f"{self.config['model'].get('name', 'model')}_epoch_{self.epoch}_{name_addition}.pt"
    )

    if outpath.exists() is False:
        outpath.parent.mkdir(parents=True, exist_ok=True)
        self.logger.debug(f"Created directory {outpath.parent} for checkpoint.")

    self.latest_checkpoint = outpath
    self.model.save(outpath)

Distributed data parallel Trainer class

This is based on this part of the pytorch documentation and is untested at the time of writing.

TrainerDDP

Bases: Trainer

Source code in src/QuantumGrav/train_ddp.py
class TrainerDDP(train.Trainer):
    def __init__(
        self,
        rank: int,
        config: dict[str, Any],
        # training and evaluation functions
        criterion: Callable,
        apply_model: Callable | None = None,
        # training evaluation and reporting
        early_stopping: Callable[[Collection[Any] | torch.Tensor], bool] | None = None,
        validator: DefaultValidator | None = None,
        tester: DefaultTester | None = None,
    ):
        """Initialize the distributed data parallel (DDP) trainer.

        Args:
            rank (int): The rank of the current process.
            config (dict[str, Any]): The configuration dictionary.
            criterion (Callable): The loss function.
            apply_model (Callable | None, optional): The function to apply the model. Defaults to None.
            early_stopping (Callable[[list[dict[str, Any]]], bool] | None, optional): The early stopping function. Defaults to None.
            validator (DefaultValidator | None, optional): The validator for model evaluation. Defaults to None.
            tester (DefaultTester | None, optional): The tester for model testing. Defaults to None.

        Raises:
            ValueError: If the configuration is invalid.
        """
        if "parallel" not in config:
            raise ValueError("Configuration must contain 'parallel' section for DDP.")

        super().__init__(
            config,
            criterion,
            apply_model,
            early_stopping,
            validator,
            tester,
        )
        # initialize the systems differently on each process/rank
        torch.manual_seed(self.seed + rank)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.seed + rank)

        if torch.cuda.is_available() and config["training"]["device"] != "cpu":
            torch.cuda.set_device(rank)
            self.device = torch.device(f"cuda:{rank}")
        else:
            self.device = torch.device("cpu")

        self.rank = rank
        self.world_size = config["parallel"]["world_size"]
        self.logger.info("Initialized DDP trainer")

    def initialize_model(self) -> DDP:
        """Initialize the model for training.

        Returns:
            DDP: The initialized model.
        """
        model = gnn_model.GNNModel.from_config(self.config["model"])

        if self.device.type == "cpu" or (
            isinstance(self.device, torch.device) and self.device.type == "cpu"
        ):
            d_id = None
            o_id = None
        else:
            d_id = [
                self.device,
            ]
            o_id = self.config["parallel"].get("output_device", None)
        model = DDP(
            model,
            device_ids=d_id,
            output_device=o_id,
            find_unused_parameters=self.config["parallel"].get(
                "find_unused_parameters", False
            ),
        )
        self.model = model.to(self.device, non_blocking=True)
        self.logger.info(f"Model initialized on device: {self.device}")
        return self.model

    def prepare_dataloaders(
        self, dataset: Dataset, split: list[float] = [0.8, 0.1, 0.1]
    ) -> Tuple[
        DataLoader,
        DataLoader,
        DataLoader,
    ]:
        """Prepare the data loaders for training, validation, and testing.

        Args:
            dataset (Dataset): The dataset to split.
            split (list[float], optional): The proportions for train/val/test split. Defaults to [0.8, 0.1, 0.1].

        Returns:
            Tuple[ DataLoader, DataLoader, DataLoader, ]: The data loaders for training, validation, and testing.
        """
        train_size = int(len(dataset) * split[0])
        val_size = int(len(dataset) * split[1])
        test_size = len(dataset) - train_size - val_size
        self.logger.info(
            f"Preparing data loaders with split: {split}, train size: {train_size}, val size: {val_size}, test size: {test_size}"
        )
        if (
            np.isclose(np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False)
            is False
        ):
            raise ValueError(
                "Split ratios must sum to 1.0. Provided split: {}".format(split)
            )

        self.train_dataset, self.val_dataset, self.test_dataset = (
            torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
        )

        # samplers are needed to distribute the data across processes in such a way that each process gets a unique subset of the data
        self.train_sampler = torch.utils.data.DistributedSampler(
            self.train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True,
        )

        self.val_sampler = torch.utils.data.DistributedSampler(
            self.val_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )

        self.test_sampler = torch.utils.data.DistributedSampler(
            self.test_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )

        # make the data loaders
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config["training"]["batch_size"],
            sampler=self.train_sampler,
            num_workers=self.config["training"].get("num_workers", 0),
            pin_memory=self.config["training"].get("pin_memory", True),
            drop_last=self.config["training"].get("drop_last", False),
            prefetch_factor=self.config["training"].get("prefetch_factor", None),
        )

        val_loader = DataLoader(
            self.val_dataset,
            sampler=self.val_sampler,
            batch_size=self.config["validation"]["batch_size"],
            num_workers=self.config["validation"].get("num_workers", 0),
            pin_memory=self.config["validation"].get("pin_memory", True),
            drop_last=self.config["validation"].get("drop_last", False),
            prefetch_factor=self.config["validation"].get("prefetch_factor", None),
        )

        test_loader = DataLoader(
            self.test_dataset,
            sampler=self.test_sampler,
            batch_size=self.config["testing"]["batch_size"],
            num_workers=self.config["testing"].get("num_workers", 0),
            pin_memory=self.config["testing"].get("pin_memory", True),
            drop_last=self.config["testing"].get("drop_last", False),
            prefetch_factor=self.config["testing"].get("prefetch_factor", None),
        )
        self.logger.info(
            f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
        )
        return train_loader, val_loader, test_loader

    def _check_model_status(self, eval_data: list[Any] | torch.Tensor) -> bool:
        """Check the status of the model during evaluation.

        Args:
            eval_data (list[Any] | torch.Tensor): The evaluation data to check.

        Returns:
            bool: Whether the model training should stop.
        """
        should_stop = False
        if self.rank == 0:
            should_stop = super()._check_model_status(eval_data)
        return should_stop

    def save_checkpoint(self, name_addition: str = ""):
        """Save model checkpoint.

        Raises:
            ValueError: If the model is not initialized.
            ValueError: If the model configuration does not contain 'name'.
            ValueError: If the training configuration does not contain 'checkpoint_path'.
        """
        if self.rank == 0:
            if self.model is None:
                raise ValueError("Model must be initialized before saving checkpoint.")

            if "name" not in self.config["model"]:
                raise ValueError(
                    "Model configuration must contain 'name' to save checkpoint."
                )

            self.logger.info(
                f"Saving checkpoint for model {self.config['model']['name']} at epoch {self.epoch} to {self.checkpoint_path}"
            )
            outpath = (
                self.checkpoint_path
                / f"{self.config['model']['name']}_epoch_{self.epoch}_{name_addition}.pt"
            )

            if outpath.exists() is False:
                outpath.parent.mkdir(parents=True, exist_ok=True)
                self.logger.info(f"Created directory {outpath.parent} for checkpoint.")

            self.latest_checkpoint = outpath
            torch.save(self.model, outpath)

    def run_training(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        trial: optuna.trial.Trial | None = None,
    ) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
        """
        Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

        Args:
            train_loader (DataLoader): The training data loader.
            val_loader (DataLoader): The validation data loader.
            trial (optuna.trial.Trial | None, optional): An Optuna trial for hyperparameter optimization.
                Defaults to None.

        Returns:
            Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
        """

        self.model = self.initialize_model()
        self.optimizer = self.initialize_optimizer()

        num_epochs = self.config["training"]["num_epochs"]
        self.logger.info("Starting training process.")

        total_training_data = []
        all_training_data: list[Any] = [None for _ in range(self.world_size)]
        all_validation_data: list[Any] = [None for _ in range(self.world_size)]
        for _ in range(0, num_epochs):
            self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
            self.model.train()
            train_loader.sampler.set_epoch(self.epoch)
            epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)
            total_training_data.append(epoch_data)

            # evaluation run on validation set
            self.model.eval()
            if self.validator is not None:
                validation_result = self.validator.validate(self.model, val_loader)
                if self.rank == 0:
                    self.validator.report(validation_result)

                    # integrate Optuna here for hyperparameter tuning
                    if trial is not None:
                        avg_sigma_loss = self.validator.data[self.epoch]
                        avg_loss = avg_sigma_loss[0]
                        trial.report(avg_loss, self.epoch)

                        # Handle pruning based on the intermediate value.
                        if trial.should_prune():
                            raise optuna.exceptions.TrialPruned()

            dist.barrier()  # Ensure all processes have completed the epoch before checking status
            should_stop = self._check_model_status(
                self.validator.data if self.validator else total_training_data,
            )

            object_list = [should_stop]

            should_stop = dist.broadcast_object_list(
                object_list, src=0, device=self.device
            )
            should_stop = object_list[0]

            if should_stop:
                break

            self.epoch += 1

        dist.barrier()
        dist.all_gather_object(all_training_data, total_training_data)
        dist.all_gather_object(
            all_validation_data, self.validator.data if self.validator else []
        )
        self.logger.info("Training process completed.")

        return all_training_data, all_validation_data

__init__(rank, config, criterion, apply_model=None, early_stopping=None, validator=None, tester=None)

Initialize the distributed data parallel (DDP) trainer.

Parameters:

Name Type Description Default
rank int

The rank of the current process.

required
config dict[str, Any]

The configuration dictionary.

required
criterion Callable

The loss function.

required
apply_model Callable | None

The function to apply the model. Defaults to None.

None
early_stopping Callable[[list[dict[str, Any]]], bool] | None

The early stopping function. Defaults to None.

None
validator DefaultValidator | None

The validator for model evaluation. Defaults to None.

None
tester DefaultTester | None

The tester for model testing. Defaults to None.

None

Raises:

Type Description
ValueError

If the configuration is invalid.

Source code in src/QuantumGrav/train_ddp.py
def __init__(
    self,
    rank: int,
    config: dict[str, Any],
    # training and evaluation functions
    criterion: Callable,
    apply_model: Callable | None = None,
    # training evaluation and reporting
    early_stopping: Callable[[Collection[Any] | torch.Tensor], bool] | None = None,
    validator: DefaultValidator | None = None,
    tester: DefaultTester | None = None,
):
    """Initialize the distributed data parallel (DDP) trainer.

    Args:
        rank (int): The rank of the current process.
        config (dict[str, Any]): The configuration dictionary.
        criterion (Callable): The loss function.
        apply_model (Callable | None, optional): The function to apply the model. Defaults to None.
        early_stopping (Callable[[list[dict[str, Any]]], bool] | None, optional): The early stopping function. Defaults to None.
        validator (DefaultValidator | None, optional): The validator for model evaluation. Defaults to None.
        tester (DefaultTester | None, optional): The tester for model testing. Defaults to None.

    Raises:
        ValueError: If the configuration is invalid.
    """
    if "parallel" not in config:
        raise ValueError("Configuration must contain 'parallel' section for DDP.")

    super().__init__(
        config,
        criterion,
        apply_model,
        early_stopping,
        validator,
        tester,
    )
    # initialize the systems differently on each process/rank
    torch.manual_seed(self.seed + rank)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(self.seed + rank)

    if torch.cuda.is_available() and config["training"]["device"] != "cpu":
        torch.cuda.set_device(rank)
        self.device = torch.device(f"cuda:{rank}")
    else:
        self.device = torch.device("cpu")

    self.rank = rank
    self.world_size = config["parallel"]["world_size"]
    self.logger.info("Initialized DDP trainer")

initialize_model()

Initialize the model for training.

Returns:

Name Type Description
DDP DistributedDataParallel

The initialized model.

Source code in src/QuantumGrav/train_ddp.py
def initialize_model(self) -> DDP:
    """Initialize the model for training.

    Returns:
        DDP: The initialized model.
    """
    model = gnn_model.GNNModel.from_config(self.config["model"])

    if self.device.type == "cpu" or (
        isinstance(self.device, torch.device) and self.device.type == "cpu"
    ):
        d_id = None
        o_id = None
    else:
        d_id = [
            self.device,
        ]
        o_id = self.config["parallel"].get("output_device", None)
    model = DDP(
        model,
        device_ids=d_id,
        output_device=o_id,
        find_unused_parameters=self.config["parallel"].get(
            "find_unused_parameters", False
        ),
    )
    self.model = model.to(self.device, non_blocking=True)
    self.logger.info(f"Model initialized on device: {self.device}")
    return self.model

prepare_dataloaders(dataset, split=[0.8, 0.1, 0.1])

Prepare the data loaders for training, validation, and testing.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to split.

required
split list[float]

The proportions for train/val/test split. Defaults to [0.8, 0.1, 0.1].

[0.8, 0.1, 0.1]

Returns:

Type Description
Tuple[DataLoader, DataLoader, DataLoader]

Tuple[ DataLoader, DataLoader, DataLoader, ]: The data loaders for training, validation, and testing.

Source code in src/QuantumGrav/train_ddp.py
def prepare_dataloaders(
    self, dataset: Dataset, split: list[float] = [0.8, 0.1, 0.1]
) -> Tuple[
    DataLoader,
    DataLoader,
    DataLoader,
]:
    """Prepare the data loaders for training, validation, and testing.

    Args:
        dataset (Dataset): The dataset to split.
        split (list[float], optional): The proportions for train/val/test split. Defaults to [0.8, 0.1, 0.1].

    Returns:
        Tuple[ DataLoader, DataLoader, DataLoader, ]: The data loaders for training, validation, and testing.
    """
    train_size = int(len(dataset) * split[0])
    val_size = int(len(dataset) * split[1])
    test_size = len(dataset) - train_size - val_size
    self.logger.info(
        f"Preparing data loaders with split: {split}, train size: {train_size}, val size: {val_size}, test size: {test_size}"
    )
    if (
        np.isclose(np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False)
        is False
    ):
        raise ValueError(
            "Split ratios must sum to 1.0. Provided split: {}".format(split)
        )

    self.train_dataset, self.val_dataset, self.test_dataset = (
        torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    )

    # samplers are needed to distribute the data across processes in such a way that each process gets a unique subset of the data
    self.train_sampler = torch.utils.data.DistributedSampler(
        self.train_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=True,
    )

    self.val_sampler = torch.utils.data.DistributedSampler(
        self.val_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=False,
    )

    self.test_sampler = torch.utils.data.DistributedSampler(
        self.test_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=False,
    )

    # make the data loaders
    train_loader = DataLoader(
        self.train_dataset,
        batch_size=self.config["training"]["batch_size"],
        sampler=self.train_sampler,
        num_workers=self.config["training"].get("num_workers", 0),
        pin_memory=self.config["training"].get("pin_memory", True),
        drop_last=self.config["training"].get("drop_last", False),
        prefetch_factor=self.config["training"].get("prefetch_factor", None),
    )

    val_loader = DataLoader(
        self.val_dataset,
        sampler=self.val_sampler,
        batch_size=self.config["validation"]["batch_size"],
        num_workers=self.config["validation"].get("num_workers", 0),
        pin_memory=self.config["validation"].get("pin_memory", True),
        drop_last=self.config["validation"].get("drop_last", False),
        prefetch_factor=self.config["validation"].get("prefetch_factor", None),
    )

    test_loader = DataLoader(
        self.test_dataset,
        sampler=self.test_sampler,
        batch_size=self.config["testing"]["batch_size"],
        num_workers=self.config["testing"].get("num_workers", 0),
        pin_memory=self.config["testing"].get("pin_memory", True),
        drop_last=self.config["testing"].get("drop_last", False),
        prefetch_factor=self.config["testing"].get("prefetch_factor", None),
    )
    self.logger.info(
        f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
    )
    return train_loader, val_loader, test_loader

run_training(train_loader, val_loader, trial=None)

Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

Parameters:

Name Type Description Default
train_loader DataLoader

The training data loader.

required
val_loader DataLoader

The validation data loader.

required
trial Trial | None

An Optuna trial for hyperparameter optimization. Defaults to None.

None

Returns:

Type Description
Tuple[Tensor | Collection[Any], Tensor | Collection[Any]]

Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.

Source code in src/QuantumGrav/train_ddp.py
def run_training(
    self,
    train_loader: DataLoader,
    val_loader: DataLoader,
    trial: optuna.trial.Trial | None = None,
) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
    """
    Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

    Args:
        train_loader (DataLoader): The training data loader.
        val_loader (DataLoader): The validation data loader.
        trial (optuna.trial.Trial | None, optional): An Optuna trial for hyperparameter optimization.
            Defaults to None.

    Returns:
        Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
    """

    self.model = self.initialize_model()
    self.optimizer = self.initialize_optimizer()

    num_epochs = self.config["training"]["num_epochs"]
    self.logger.info("Starting training process.")

    total_training_data = []
    all_training_data: list[Any] = [None for _ in range(self.world_size)]
    all_validation_data: list[Any] = [None for _ in range(self.world_size)]
    for _ in range(0, num_epochs):
        self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
        self.model.train()
        train_loader.sampler.set_epoch(self.epoch)
        epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)
        total_training_data.append(epoch_data)

        # evaluation run on validation set
        self.model.eval()
        if self.validator is not None:
            validation_result = self.validator.validate(self.model, val_loader)
            if self.rank == 0:
                self.validator.report(validation_result)

                # integrate Optuna here for hyperparameter tuning
                if trial is not None:
                    avg_sigma_loss = self.validator.data[self.epoch]
                    avg_loss = avg_sigma_loss[0]
                    trial.report(avg_loss, self.epoch)

                    # Handle pruning based on the intermediate value.
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

        dist.barrier()  # Ensure all processes have completed the epoch before checking status
        should_stop = self._check_model_status(
            self.validator.data if self.validator else total_training_data,
        )

        object_list = [should_stop]

        should_stop = dist.broadcast_object_list(
            object_list, src=0, device=self.device
        )
        should_stop = object_list[0]

        if should_stop:
            break

        self.epoch += 1

    dist.barrier()
    dist.all_gather_object(all_training_data, total_training_data)
    dist.all_gather_object(
        all_validation_data, self.validator.data if self.validator else []
    )
    self.logger.info("Training process completed.")

    return all_training_data, all_validation_data

save_checkpoint(name_addition='')

Save model checkpoint.

Raises:

Type Description
ValueError

If the model is not initialized.

ValueError

If the model configuration does not contain 'name'.

ValueError

If the training configuration does not contain 'checkpoint_path'.

Source code in src/QuantumGrav/train_ddp.py
def save_checkpoint(self, name_addition: str = ""):
    """Save model checkpoint.

    Raises:
        ValueError: If the model is not initialized.
        ValueError: If the model configuration does not contain 'name'.
        ValueError: If the training configuration does not contain 'checkpoint_path'.
    """
    if self.rank == 0:
        if self.model is None:
            raise ValueError("Model must be initialized before saving checkpoint.")

        if "name" not in self.config["model"]:
            raise ValueError(
                "Model configuration must contain 'name' to save checkpoint."
            )

        self.logger.info(
            f"Saving checkpoint for model {self.config['model']['name']} at epoch {self.epoch} to {self.checkpoint_path}"
        )
        outpath = (
            self.checkpoint_path
            / f"{self.config['model']['name']}_epoch_{self.epoch}_{name_addition}.pt"
        )

        if outpath.exists() is False:
            outpath.parent.mkdir(parents=True, exist_ok=True)
            self.logger.info(f"Created directory {outpath.parent} for checkpoint.")

        self.latest_checkpoint = outpath
        torch.save(self.model, outpath)

cleanup_ddp()

Clean up the distributed process group.

Source code in src/QuantumGrav/train_ddp.py
def cleanup_ddp() -> None:
    """Clean up the distributed process group."""
    if dist.is_initialized():
        dist.destroy_process_group()
        os.environ.pop("MASTER_ADDR", None)
        os.environ.pop("MASTER_PORT", None)

initialize_ddp(rank, worldsize, master_addr='localhost', master_port='12345', backend='nccl')

Initialize the distributed process group. This assumes one process per GPU.

Parameters:

Name Type Description Default
rank int

The rank of the current process.

required
worldsize int

The total number of processes.

required
master_addr str

The address of the master process. Defaults to "localhost". This needs to be the ip of the master node if you are running on a cluster.

'localhost'
master_port str

The port of the master process. Defaults to "12345". Choose a high port if you are running multiple jobs on the same machine to avoid conflicts. If running on a cluster, this should be the port that the master node is listening on.

'12345'
backend str

The backend to use for distributed training. Defaults to "nccl".

'nccl'

Raises:

Type Description
RuntimeError

If the environment variables MASTER_ADDR and MASTER_PORT are already set.

Source code in src/QuantumGrav/train_ddp.py
def initialize_ddp(
    rank: int,
    worldsize: int,
    master_addr: str = "localhost",
    master_port: str = "12345",
    backend: str = "nccl",
) -> None:
    """Initialize the distributed process group. This assumes one process per GPU.

    Args:
        rank (int): The rank of the current process.
        worldsize (int): The total number of processes.
        master_addr (str, optional): The address of the master process. Defaults to "localhost". This needs to be the ip of the master node if you are running on a cluster.
        master_port (str, optional): The port of the master process. Defaults to "12345". Choose a high port if you are running multiple jobs on the same machine to avoid conflicts. If running on a cluster, this should be the port that the master node is listening on.
        backend (str, optional): The backend to use for distributed training. Defaults to "nccl".

    Raises:
        RuntimeError: If the environment variables MASTER_ADDR and MASTER_PORT are already set.
    """
    if dist.is_initialized():
        raise RuntimeError("The distributed process group is already initialized.")
    else:
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = master_port
        dist.init_process_group(backend=backend, rank=rank, world_size=worldsize)

Utilities

General utilities that are used throughout this package.

get_registered_activation(name)

Get a registered activation layer by name.

Parameters:

Name Type Description Default
name str

The name of the activation layer.

required

Returns:

Type Description
type[Module] | None

type[torch.nn.Module] | None: The registered activation layer named name, or None if not found.

Source code in src/QuantumGrav/utils.py
def get_registered_activation(name: str) -> type[torch.nn.Module] | None:
    """Get a registered activation layer by name.

    Args:
        name (str): The name of the activation layer.

    Returns:
        type[torch.nn.Module] | None: The registered activation layer named `name`, or None if not found.
    """
    return activation_layers[name] if name in activation_layers else None

get_registered_gnn_layer(name)

Get a registered GNN layer by name. Args: name (str): The name of the GNN layer.

Returns:

Type Description
type[Module] | None

type[torch.nn.Module] | None: The registered GNN layer named name, or None if not found.

Source code in src/QuantumGrav/utils.py
def get_registered_gnn_layer(name: str) -> type[torch.nn.Module] | None:
    """Get a registered GNN layer by name.
    Args:
        name (str): The name of the GNN layer.

    Returns:
        type[torch.nn.Module] | None: The registered GNN layer named `name`, or None if not found.
    """
    return gnn_layers[name] if name in gnn_layers else None

get_registered_normalizer(name)

Get a registered normalizer layer by name.

Parameters:

Name Type Description Default
name str

The name of the normalizer layer.

required

Returns:

Type Description
type[Module] | None

type[torch.nn.Module]| None: The registered normalizer layer named name, or None if not found.

Source code in src/QuantumGrav/utils.py
def get_registered_normalizer(name: str) -> type[torch.nn.Module] | None:
    """Get a registered normalizer layer by name.

    Args:
        name (str): The name of the normalizer layer.

    Returns:
        type[torch.nn.Module]| None: The registered normalizer layer named `name`, or None if not found.
    """
    return normalizer_layers[name] if name in normalizer_layers else None

get_registered_pooling_layer(name)

Get a registered pooling layer by name.

Parameters:

Name Type Description Default
name str

The name of the pooling layer.

required

Returns:

Type Description
Module | None

torch.nn.Module | None: The registered pooling layer named name, or None if not found.

Source code in src/QuantumGrav/utils.py
def get_registered_pooling_layer(name: str) -> torch.nn.Module | None:
    """Get a registered pooling layer by name.

    Args:
        name (str): The name of the pooling layer.

    Returns:
        torch.nn.Module | None: The registered pooling layer named `name`, or None if not found.
    """
    return pooling_layers[name] if name in pooling_layers else None

list_registered_activations()

List all registered activation layers.

Source code in src/QuantumGrav/utils.py
def list_registered_activations() -> list[str]:
    """List all registered activation layers."""
    return list(activation_layers.keys())

list_registered_gnn_layers()

List all registered GNN layers.

Source code in src/QuantumGrav/utils.py
def list_registered_gnn_layers() -> list[str]:
    """List all registered GNN layers."""
    return list(gnn_layers.keys())

list_registered_normalizers()

List all registered normalizer layers.

Source code in src/QuantumGrav/utils.py
def list_registered_normalizers() -> list[str]:
    """List all registered normalizer layers."""
    return list(normalizer_layers.keys())

list_registered_pooling_layers()

List all registered pooling layers.

Source code in src/QuantumGrav/utils.py
def list_registered_pooling_layers() -> list[str]:
    """List all registered pooling layers."""
    return list(pooling_layers.keys())

register_activation(activation_name, activation_layer)

Register an activation layer with the module

Parameters:

Name Type Description Default
activation_name str

The name of the activation layer.

required
activation_layer type[Module]

The activation layer to register.

required

Raises:

Type Description
ValueError

If the activation layer is already registered.

Source code in src/QuantumGrav/utils.py
def register_activation(
    activation_name: str, activation_layer: type[torch.nn.Module]
) -> None:
    """Register an activation layer with the module

    Args:
        activation_name (str): The name of the activation layer.
        activation_layer (type[torch.nn.Module]): The activation layer to register.

    Raises:
        ValueError: If the activation layer is already registered.
    """
    if activation_name in activation_layers:
        raise ValueError(f"Activation '{activation_name}' is already registered.")
    activation_layers[activation_name] = activation_layer

register_gnn_layer(gnn_layer_name, gnn_layer)

Register a GNN layer with the module

Parameters:

Name Type Description Default
gnn_layer_name str

The name of the GNN layer.

required
gnn_layer type[Module]

The GNN layer to register.

required
Source code in src/QuantumGrav/utils.py
def register_gnn_layer(gnn_layer_name: str, gnn_layer: type[torch.nn.Module]) -> None:
    """Register a GNN layer with the module

    Args:
        gnn_layer_name (str): The name of the GNN layer.
        gnn_layer (type[torch.nn.Module]): The GNN layer to register.
    """
    if gnn_layer_name in gnn_layers:
        raise ValueError(f"GNN layer '{gnn_layer_name}' is already registered.")
    gnn_layers[gnn_layer_name] = gnn_layer

register_normalizer(normalizer_name, normalizer_layer)

Register a normalizer layer with the module

Parameters:

Name Type Description Default
normalizer_name str

The name of the normalizer.

required
normalizer_layer type[Module]

The normalizer layer to register.

required

Raises:

Type Description
ValueError

If the normalizer layer is already registered.

Source code in src/QuantumGrav/utils.py
def register_normalizer(
    normalizer_name: str, normalizer_layer: type[torch.nn.Module]
) -> None:
    """Register a normalizer layer with the module

    Args:
        normalizer_name (str): The name of the normalizer.
        normalizer_layer (type[torch.nn.Module]): The normalizer layer to register.

    Raises:
        ValueError: If the normalizer layer is already registered.
    """
    if normalizer_name in normalizer_layers:
        raise ValueError(f"Normalizer '{normalizer_name}' is already registered.")
    normalizer_layers[normalizer_name] = normalizer_layer

register_pooling_layer(pooling_layer_name, pooling_layer)

Register a pooling layer with the module

Parameters:

Name Type Description Default
pooling_layer_name str

The name of the pooling layer.

required
pooling_layer Module

The pooling layer to register.

required
Source code in src/QuantumGrav/utils.py
def register_pooling_layer(
    pooling_layer_name: str, pooling_layer: torch.nn.Module
) -> None:
    """Register a pooling layer with the module

    Args:
        pooling_layer_name (str): The name of the pooling layer.
        pooling_layer (torch.nn.Module): The pooling layer to register.
    """
    if pooling_layer_name in pooling_layers:
        raise ValueError(f"Pooling layer '{pooling_layer_name}' is already registered.")
    pooling_layers[pooling_layer_name] = pooling_layer