Skip to content

API Reference

Main model class

The main model class GNNModel is there to tie together the Graph neural network backbone and a multilayer perceptron classifier model that can be configured for various tasks.

GNNModel

Bases: Module, Configurable

Complete GNN model architecture with encoder, pooling, and downstream tasks.

This model combines: - An encoder network (typically GNN layers) to process graph structure - Optional pooling layers to aggregate node features into graph-level representations - Multiple downstream task heads for classification, regression, etc. - Optional graph features network for processing additional graph-level features

The model supports multi-task learning with selective task activation.

Source code in src/QuantumGrav/gnn_model.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
class GNNModel(torch.nn.Module, base.Configurable):
    """Complete GNN model architecture with encoder, pooling, and downstream tasks.

    This model combines:
    - An encoder network (typically GNN layers) to process graph structure
    - Optional pooling layers to aggregate node features into graph-level representations
    - Multiple downstream task heads for classification, regression, etc.
    - Optional graph features network for processing additional graph-level features

    The model supports multi-task learning with selective task activation.
    """

    schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "title": "GNNModel Configuration",
        "type": "object",
        "properties": {
            "encoder_type": {
                "description": "Type of the encoder network (Python class or string)",
            },
            "encoder_args": {
                "type": "array",
                "description": "Positional arguments for encoder initialization",
                "items": {},
            },
            "encoder_kwargs": {
                "type": "object",
                "description": "Keyword arguments for encoder initialization",
            },
            "downstream_tasks": {
                "type": "array",
                "description": "List of downstream tasks, each as [type, args, kwargs]",
                "items": {
                    "type": "array",
                    "minItems": 1,
                    "maxItems": 3,
                    "items": [
                        {"description": "Task type (class or string)"},
                        {"type": "array", "description": "Task args", "items": {}},
                        {"type": "object", "description": "Task kwargs"},
                    ],
                },
                "minItems": 1,
            },
            "pooling_layers": {
                "type": "array",
                "description": "List of pooling layers, each as [type, args, kwargs]",
                "items": {
                    "type": "array",
                    "minItems": 1,
                    "maxItems": 3,
                    "items": [
                        {"description": "Pooling type (class, callable, or string)"},
                        {"type": "array", "description": "Pooling args", "items": {}},
                        {"type": "object", "description": "Pooling kwargs"},
                    ],
                },
            },
            "aggregate_pooling_type": {
                "description": "Type for aggregating pooling results (class, callable, or string)",
            },
            "aggregate_pooling_args": {
                "type": "array",
                "description": "Arguments for aggregate pooling",
                "items": {},
            },
            "aggregate_pooling_kwargs": {
                "type": "object",
                "description": "Keyword arguments for aggregate pooling",
            },
            "latent_model_type": {
                "description": "Type of a general latent space model. Alternative to pooling_layers + aggregate_pooling"
            },
            "latent_model_args": {
                "type": "array",
                "description": "Arguments for latent model",
                "items": {},
            },
            "latent_model_kwargs": {
                "type": "object",
                "description": "Keyword arguments for latent model",
            },
            "graph_features_net_type": {
                "description": "Type of graph features network (class or string)",
            },
            "graph_features_net_args": {
                "type": "array",
                "description": "Arguments for graph features network",
                "items": {},
            },
            "graph_features_net_kwargs": {
                "type": "object",
                "description": "Keyword arguments for graph features network",
            },
            "aggregate_graph_features_type": {
                "description": "Type for aggregating graph features (class, callable, or string)",
            },
            "aggregate_graph_features_args": {
                "type": "array",
                "description": "Arguments for aggregate graph features",
                "items": {},
            },
            "aggregate_graph_features_kwargs": {
                "type": "object",
                "description": "Keyword arguments for aggregate graph features",
            },
            "active_tasks": {
                "type": "object",
                "description": "Dictionary mapping task indices/names to boolean active status",
            },
        },
        "required": [
            "encoder_type",
            "downstream_tasks",
        ],
        "additionalProperties": False,
    }

    def __init__(
        self,
        encoder_type: type | torch.nn.Module,
        downstream_tasks: Sequence[
            Tuple[type | torch.nn.Module, Sequence[Any] | None, Dict[str, Any] | None]
        ],
        encoder_args: Sequence[Any] | None = None,
        encoder_kwargs: Dict[str, Any] | None = None,
        pooling_layers: Sequence[
            Tuple[type | torch.nn.Module, Sequence[Any] | None, Dict[str, Any] | None]
        ]
        | None = None,
        aggregate_pooling_type: type | torch.nn.Module | Callable | None = None,
        aggregate_pooling_args: Sequence[Any] | None = None,
        aggregate_pooling_kwargs: Dict[str, Any] | None = None,
        latent_model_type: type | torch.nn.Module | None = None,
        latent_model_args: Sequence[Any] | None = None,
        latent_model_kwargs: Dict[str, Any] | None = None,
        graph_features_net_type: type | torch.nn.Module | None = None,
        graph_features_net_args: Sequence[Any] | None = None,
        graph_features_net_kwargs: Dict[str, Any] | None = None,
        aggregate_graph_features_type: type | torch.nn.Module | Callable | None = None,
        aggregate_graph_features_args: Sequence[Any] | None = None,
        aggregate_graph_features_kwargs: Dict[str, Any] | None = None,
        active_tasks: Dict[int, bool] | None = None,
    ):
        """Initialize GNNModel with encoder, pooling, and downstream task components.

        Args:
            encoder_type (type): Class type or torch Module instance for the encoder network (e.g., GNN backbone).
            encoder_args (Sequence[Any]): Positional arguments to pass to encoder_type constructor.
            encoder_kwargs (Dict[str, Any]): Keyword arguments to pass to encoder_type constructor.
            downstream_tasks (Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]]): List of downstream tasks,
                where each task is specified as [task_type, task_args, task_kwargs].
            pooling_layers (Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]] | None, optional): List of pooling layers,
                where each layer is specified as [pooling_type, pooling_args, pooling_kwargs]. Defaults to None.
            aggregate_pooling_type (type | Callable | None, optional): Type, Module instance or function for aggregating multiple pooling outputs.
                Required if pooling_layers is provided. Defaults to None.
            aggregate_pooling_args (Sequence[Any] | None, optional): Positional arguments for aggregate_pooling_type. Defaults to None.
            aggregate_pooling_kwargs (Dict[str, Any] | None, optional): Keyword arguments for aggregate_pooling_type. Defaults to None.
            latent_model_type (type | torch.nn.Module | None): Latent model type. Either this or pooling_layers can be used, not both.
            latent_model_args (Sequence[Any] | None, optional): Latent model args.
            latent_model_kwargs (Dict[str, Any] | None, optional): Latent model kwargs.
            graph_features_net_type (type | None, optional): Network type for processing additional graph-level features. Defaults to None.
            graph_features_net_args (Sequence[Any] | None, optional): Positional arguments for graph_features_net_type. Defaults to None.
            graph_features_net_kwargs (Dict[str, Any] | None, optional): Keyword arguments for graph_features_net_type. Defaults to None.
            aggregate_graph_features_type (type | Callable | None, optional): Type, Module instance or function for combining embeddings with graph features.
                Required if graph_features_net_type is provided. Defaults to None.
            aggregate_graph_features_args (Sequence[Any] | None, optional): Positional arguments for aggregate_graph_features_type. Defaults to None.
            aggregate_graph_features_kwargs (Dict[str, Any] | None, optional): Keyword arguments for aggregate_graph_features_type. Defaults to None.
            active_tasks (Dict[int, bool] | None, optional): Dictionary mapping task indices to active status.
                If None, all tasks are active by default. Defaults to None.

        Raises:
            ValueError: If downstream_tasks is empty (at least one task required).
            ValueError: If pooling_layers provided without aggregate_pooling_type or vice versa.
            ValueError: If pooling_layers is empty when provided.
            ValueError: If graph_features_net_type provided without aggregate_graph_features_type or vice versa.
            ValueError: If pooling_layers and latent_type are given at the same time.
        """

        # check consistency
        graph_processors = [graph_features_net_type, aggregate_graph_features_type]
        if any([g is not None for g in graph_processors]) and not all(
            g is not None for g in graph_processors
        ):
            raise ValueError(
                "If graph features are to be used, both a graph features network and an aggregation method must be provided."
            )

        pooling_funcs = [aggregate_pooling_type, pooling_layers]
        self.with_pooling = False
        if any(p is not None for p in pooling_funcs):
            if not all(p is not None for p in pooling_funcs):
                raise ValueError(
                    "If pooling layers are to be used, both an aggregate pooling method and pooling layers must be provided."
                )
            else:
                self.with_pooling = True

        self.with_latent = False
        if latent_model_type is not None:
            self.with_latent = True

        if self.with_latent and self.with_pooling:
            raise ValueError(
                "Either latent_model_type or pooling_layers and aggregate_pooling can be given, not both"
            )

        # set up downstream tasks. These are independent of each other, but there must be one at least
        if len(downstream_tasks) == 0:
            raise ValueError("At least one downstream task must be provided.")

        super().__init__()

        self.encoder_args = encoder_args
        self.encoder_kwargs = encoder_kwargs
        self.downstream_task_specs = downstream_tasks
        self.pooling_layer_specs = pooling_layers

        self.aggregate_pooling_type = aggregate_pooling_type
        self.aggregate_pooling_args = aggregate_pooling_args
        self.aggregate_pooling_kwargs = aggregate_pooling_kwargs

        self.graph_features_net_args = graph_features_net_args
        self.graph_features_net_kwargs = graph_features_net_kwargs

        self.aggregate_graph_features_type = aggregate_graph_features_type
        self.aggregate_graph_features_args = aggregate_graph_features_args
        self.aggregate_graph_features_kwargs = aggregate_graph_features_kwargs

        # set up encoder type and downstream tasks
        self.encoder = instantiate_type(encoder_type, encoder_args, encoder_kwargs)

        downstream_task_modules = []
        for task_type, args, kwargs in downstream_tasks:
            downstream_task_modules.append(instantiate_type(task_type, args, kwargs))

        self.downstream_tasks = torch.nn.ModuleList(downstream_task_modules)

        if self.with_pooling:
            if pooling_layers is not None:
                if len(pooling_layers) == 0:
                    raise ValueError("At least one pooling layer must be provided.")

                self.pooling_layers = torch.nn.ModuleList(
                    [
                        instantiate_type(pl_type, args, kwargs)
                        for pl_type, args, kwargs in pooling_layers
                    ]
                )
            else:
                self.pooling_layers = None

            # aggregate pooling layer
            if aggregate_pooling_type is not None:
                self.aggregate_pooling = instantiate_type(
                    aggregate_pooling_type,
                    aggregate_pooling_args,
                    aggregate_pooling_kwargs,
                )
            else:
                self.aggregate_pooling = torch.nn.Identity()  # non-op

        if self.with_latent:
            # alternatively to pooling_layer, set up latent_model
            self.latent_model = instantiate_type(
                latent_model_type, latent_model_args, latent_model_kwargs
            )

        # set up graph features processing if provided
        if graph_features_net_type is not None:
            self.graph_features_net = instantiate_type(
                graph_features_net_type,
                graph_features_net_args,
                graph_features_net_kwargs,
            )
        else:
            self.graph_features_net = torch.nn.Identity()  # non-op

        if aggregate_graph_features_type is not None:
            self.aggregate_graph_features = instantiate_type(
                aggregate_graph_features_type,
                aggregate_graph_features_args,
                aggregate_graph_features_kwargs,
            )
        else:
            self.aggregate_graph_features = torch.nn.Identity()  # non-op

        # active tasks
        if active_tasks:
            if len(active_tasks) != len(self.downstream_tasks) or set(
                active_tasks.keys()
            ) != set(range(len(self.downstream_tasks))):
                raise ValueError(
                    "active_tasks keys must match the indices of downstream tasks."
                )
            self.active_tasks: Dict[int, bool] = active_tasks
        else:
            self.active_tasks = {i: True for i in range(0, len(self.downstream_tasks))}

    def set_task_active(self, key: int) -> None:
        """Set a downstream task as active.

        Args:
            key (int): key (name) of the downstream task to activate.
        """
        if key not in self.active_tasks:
            raise KeyError(f"Task {key} not found in active tasks.")
        self.active_tasks[key] = True

    def set_task_inactive(self, key: int) -> None:
        """Set a downstream task as inactive.

        Args:
            key (int): key (name) of the downstream task to deactivate.
        """

        if key not in self.active_tasks:
            raise KeyError(f"Task {key} not found in active tasks.")
        self.active_tasks[key] = False

    def get_embeddings(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor | None = None,
        gcn_kwargs: Dict[str, Any] | None = None,
        latent_args: Sequence[Any] | None = None,
        latent_kwargs: Dict[str, Any] | None = None,
    ) -> torch.Tensor:
        """Get embeddings from the GCN model.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector for pooling.
            gcn_kwargs (dict, optional): Additional arguments for the GCN. Defaults to None.

        Returns:
            torch.Tensor: Embedding vector for the graph features.
        """
        # apply the GCN backbone to the node features
        embeddings = self.encoder(x, edge_index, **(gcn_kwargs if gcn_kwargs else {}))

        # pool everything together into a single graph representation
        if self.with_pooling:
            if not self.pooling_layers or self.pooling_layers == [None]:
                # No pooling layers provided; pass embeddings directly
                pooled_embeddings = [
                    embeddings,
                ]
            else:
                pooled_embeddings = [
                    pooling_op(embeddings, batch) if pooling_op else embeddings
                    for pooling_op in self.pooling_layers
                ]

            return self.aggregate_pooling(pooled_embeddings)

        elif self.with_latent:
            return self.latent_model(
                embeddings,
                *(latent_args if latent_args is not None else []),
                **(latent_kwargs if latent_kwargs is not None else {}),
            )

    def compute_downstream_tasks(
        self,
        x: torch.Tensor,
        args: Sequence[Tuple | Sequence] | None = None,
        kwargs: Sequence[Dict[str, Any]] | None = None,
    ) -> Dict[int, torch.Tensor | Collection[torch.Tensor]]:
        """Compute the outputs of the downstream tasks. Only the active tasks will be computed.

        Args:
            x (torch.Tensor): Input embeddings tensor
            args (Sequence[Tuple | Sequence] | None, optional): Arguments for downstream tasks. Defaults to None.
            kwargs (Sequence[Dict[str, Any]] | None, optional): Keyword arguments for downstream tasks. Defaults to None.

        Returns:
            Dict[int, torch.Tensor | Collection[torch.Tensor]]: Outputs of the downstream tasks.
        """
        d_args = [[] for _ in self.downstream_tasks] if args is None else args
        d_kwargs: Sequence[Dict[str, Any]] = (
            [{} for _ in self.downstream_tasks] if kwargs is None else kwargs
        )
        return {
            i: self.downstream_tasks[i](x, *d_args[i], **d_kwargs[i])
            for i in range(0, len(self.downstream_tasks))
            if self.active_tasks[i]
        }

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        batch: torch.Tensor,
        graph_features: torch.Tensor | None = None,
        latent_args: Sequence[Any] | None = None,
        latent_kwargs: Dict[str, Any] | None = None,
        downstream_task_args: Sequence[Tuple | Sequence[Any]] | None = None,
        downstream_task_kwargs: Sequence[Dict[str, Any]] | None = None,
        embedding_kwargs: Dict[Any, Any] | None = None,
    ) -> Dict[int, torch.Tensor]:
        """Forward run of the gnn model with optional graph features.
        First execute the graph-neural network backbone, then process the graph features, and finally apply the downstream tasks.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector for pooling.
            graph_features (torch.Tensor | None, optional): Additional graph features. Defaults to None.
            downstream_task_args ( Sequence[Tuple | Sequence[Any]] | None, optional): Arguments for downstream tasks. Defaults to None.
            downstream_task_kwargs (Sequence[Dict[str, Any]] | None, optional): Keyword arguments for downstream tasks. Defaults to None.
            embedding_kwargs (dict[Any, Any] | None, optional): Additional arguments for the GCN. Defaults to None.

        Returns:
            Dict[int, torch.Tensor]: Raw output of downstream tasks.
        """
        # apply the GCN backbone to the node features
        embeddings = self.get_embeddings(
            x,
            edge_index,
            batch,
            gcn_kwargs=embedding_kwargs,
            latent_args=latent_args,
            latent_kwargs=latent_kwargs,
        )
        # If we have graph features, we need to process them and concatenate them with the node features
        if graph_features is not None:
            graph_features = self.graph_features_net(graph_features)
            embeddings = self.aggregate_graph_features(embeddings, graph_features)

        # downstream tasks are given out as is, no softmax or other assumptions
        return self.compute_downstream_tasks(
            embeddings,
            args=downstream_task_args,
            kwargs=downstream_task_kwargs,
        )

    @classmethod
    def from_config(cls, config: dict) -> "GNNModel":
        """Create a GNNModel instance from a configuration dictionary.

        Args:
            config (dict): Configuration dictionary with keys matching __init__ parameters.
                Must include: encoder_type, encoder_args, encoder_kwargs, downstream_tasks.
                Optional: pooling_layers, aggregate_pooling_type, graph_features_net_type, etc.

        Returns:
            GNNModel: An initialized GNNModel instance.

        Raises:
            RuntimeError: If model creation fails (wraps underlying exceptions).
            jsonschema.ValidationError: If config is invalid.
        """
        try:
            jsonschema.validate(config, cls.schema)
            return cls(
                config["encoder_type"],
                config["downstream_tasks"],
                encoder_args=config.get("encoder_args", None),
                encoder_kwargs=config.get("encoder_kwargs", None),
                pooling_layers=config.get("pooling_layers", None),
                aggregate_pooling_type=config.get("aggregate_pooling_type"),
                aggregate_pooling_args=config.get("aggregate_pooling_args"),
                aggregate_pooling_kwargs=config.get("aggregate_pooling_kwargs"),
                latent_model_type=config.get("latent_model_type"),
                latent_model_args=config.get("latent_model_args"),
                latent_model_kwargs=config.get("latent_model_kwargs"),
                graph_features_net_type=config.get("graph_features_net_type"),
                graph_features_net_args=config.get("graph_features_net_args"),
                graph_features_net_kwargs=config.get("graph_features_net_kwargs"),
                aggregate_graph_features_type=config.get(
                    "aggregate_graph_features_type"
                ),
                aggregate_graph_features_args=config.get(
                    "aggregate_graph_features_args"
                ),
                aggregate_graph_features_kwargs=config.get(
                    "aggregate_graph_features_kwargs"
                ),
                active_tasks=config.get("active_tasks"),
            )
        except Exception as e:
            raise RuntimeError(
                f"Error during creation of GNNModel from config: {e}"
            ) from e

    def save(self, path: str | Path) -> None:
        """Save the model state dictionary to file.

        Args:
            path (str | Path): File path where the model will be saved.
        """

        torch.save(
            self.state_dict(),
            path,
        )

    @classmethod
    def load(
        cls,
        path: str | Path,
        config: Dict[str, Any] | None = None,
        args: Sequence[Any] | None = None,
        kwargs: Dict[str, Any] | None = None,
        device: torch.device = torch.device("cpu"),
    ) -> "GNNModel":
        """Load a GNNModel from a file saved with the save() method that's defined by the provided config.
        It is assumed that the config used to save the model is the same as the one provided here or defines the
        same model architecture. Therefore, configs should always be saved alongside the model weights.

        Args:
            path (str | Path): Path to the saved model file.
            config (Dict[str, Any] | None): Config for building the model
            args (Sequence[Any] | None): Arguments for building the model if config is not supplied
            kwargs (Dict[str, Any] | None): Keyword argumetns for building the model if config is not supplied
            device (torch.device, optional): Device to load the model onto. Defaults to torch.device("cpu").

        Returns:
            GNNModel: Fully initialized model instance with loaded weights.
        """
        if config is not None:
            model = cls.from_config(config).to(device)
        else:
            model = cls(
                *(args if args is not None else []),
                **(kwargs if kwargs is not None else {}),
            )
        model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
        return model

__init__(encoder_type, downstream_tasks, encoder_args=None, encoder_kwargs=None, pooling_layers=None, aggregate_pooling_type=None, aggregate_pooling_args=None, aggregate_pooling_kwargs=None, latent_model_type=None, latent_model_args=None, latent_model_kwargs=None, graph_features_net_type=None, graph_features_net_args=None, graph_features_net_kwargs=None, aggregate_graph_features_type=None, aggregate_graph_features_args=None, aggregate_graph_features_kwargs=None, active_tasks=None)

Initialize GNNModel with encoder, pooling, and downstream task components.

Parameters:

Name Type Description Default
encoder_type type

Class type or torch Module instance for the encoder network (e.g., GNN backbone).

required
encoder_args Sequence[Any]

Positional arguments to pass to encoder_type constructor.

None
encoder_kwargs Dict[str, Any]

Keyword arguments to pass to encoder_type constructor.

None
downstream_tasks Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]]

List of downstream tasks, where each task is specified as [task_type, task_args, task_kwargs].

required
pooling_layers Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]] | None

List of pooling layers, where each layer is specified as [pooling_type, pooling_args, pooling_kwargs]. Defaults to None.

None
aggregate_pooling_type type | Callable | None

Type, Module instance or function for aggregating multiple pooling outputs. Required if pooling_layers is provided. Defaults to None.

None
aggregate_pooling_args Sequence[Any] | None

Positional arguments for aggregate_pooling_type. Defaults to None.

None
aggregate_pooling_kwargs Dict[str, Any] | None

Keyword arguments for aggregate_pooling_type. Defaults to None.

None
latent_model_type type | Module | None

Latent model type. Either this or pooling_layers can be used, not both.

None
latent_model_args Sequence[Any] | None

Latent model args.

None
latent_model_kwargs Dict[str, Any] | None

Latent model kwargs.

None
graph_features_net_type type | None

Network type for processing additional graph-level features. Defaults to None.

None
graph_features_net_args Sequence[Any] | None

Positional arguments for graph_features_net_type. Defaults to None.

None
graph_features_net_kwargs Dict[str, Any] | None

Keyword arguments for graph_features_net_type. Defaults to None.

None
aggregate_graph_features_type type | Callable | None

Type, Module instance or function for combining embeddings with graph features. Required if graph_features_net_type is provided. Defaults to None.

None
aggregate_graph_features_args Sequence[Any] | None

Positional arguments for aggregate_graph_features_type. Defaults to None.

None
aggregate_graph_features_kwargs Dict[str, Any] | None

Keyword arguments for aggregate_graph_features_type. Defaults to None.

None
active_tasks Dict[int, bool] | None

Dictionary mapping task indices to active status. If None, all tasks are active by default. Defaults to None.

None

Raises:

Type Description
ValueError

If downstream_tasks is empty (at least one task required).

ValueError

If pooling_layers provided without aggregate_pooling_type or vice versa.

ValueError

If pooling_layers is empty when provided.

ValueError

If graph_features_net_type provided without aggregate_graph_features_type or vice versa.

ValueError

If pooling_layers and latent_type are given at the same time.

Source code in src/QuantumGrav/gnn_model.py
def __init__(
    self,
    encoder_type: type | torch.nn.Module,
    downstream_tasks: Sequence[
        Tuple[type | torch.nn.Module, Sequence[Any] | None, Dict[str, Any] | None]
    ],
    encoder_args: Sequence[Any] | None = None,
    encoder_kwargs: Dict[str, Any] | None = None,
    pooling_layers: Sequence[
        Tuple[type | torch.nn.Module, Sequence[Any] | None, Dict[str, Any] | None]
    ]
    | None = None,
    aggregate_pooling_type: type | torch.nn.Module | Callable | None = None,
    aggregate_pooling_args: Sequence[Any] | None = None,
    aggregate_pooling_kwargs: Dict[str, Any] | None = None,
    latent_model_type: type | torch.nn.Module | None = None,
    latent_model_args: Sequence[Any] | None = None,
    latent_model_kwargs: Dict[str, Any] | None = None,
    graph_features_net_type: type | torch.nn.Module | None = None,
    graph_features_net_args: Sequence[Any] | None = None,
    graph_features_net_kwargs: Dict[str, Any] | None = None,
    aggregate_graph_features_type: type | torch.nn.Module | Callable | None = None,
    aggregate_graph_features_args: Sequence[Any] | None = None,
    aggregate_graph_features_kwargs: Dict[str, Any] | None = None,
    active_tasks: Dict[int, bool] | None = None,
):
    """Initialize GNNModel with encoder, pooling, and downstream task components.

    Args:
        encoder_type (type): Class type or torch Module instance for the encoder network (e.g., GNN backbone).
        encoder_args (Sequence[Any]): Positional arguments to pass to encoder_type constructor.
        encoder_kwargs (Dict[str, Any]): Keyword arguments to pass to encoder_type constructor.
        downstream_tasks (Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]]): List of downstream tasks,
            where each task is specified as [task_type, task_args, task_kwargs].
        pooling_layers (Sequence[Sequence[type, Sequence[Any], Dict[str, Any]]] | None, optional): List of pooling layers,
            where each layer is specified as [pooling_type, pooling_args, pooling_kwargs]. Defaults to None.
        aggregate_pooling_type (type | Callable | None, optional): Type, Module instance or function for aggregating multiple pooling outputs.
            Required if pooling_layers is provided. Defaults to None.
        aggregate_pooling_args (Sequence[Any] | None, optional): Positional arguments for aggregate_pooling_type. Defaults to None.
        aggregate_pooling_kwargs (Dict[str, Any] | None, optional): Keyword arguments for aggregate_pooling_type. Defaults to None.
        latent_model_type (type | torch.nn.Module | None): Latent model type. Either this or pooling_layers can be used, not both.
        latent_model_args (Sequence[Any] | None, optional): Latent model args.
        latent_model_kwargs (Dict[str, Any] | None, optional): Latent model kwargs.
        graph_features_net_type (type | None, optional): Network type for processing additional graph-level features. Defaults to None.
        graph_features_net_args (Sequence[Any] | None, optional): Positional arguments for graph_features_net_type. Defaults to None.
        graph_features_net_kwargs (Dict[str, Any] | None, optional): Keyword arguments for graph_features_net_type. Defaults to None.
        aggregate_graph_features_type (type | Callable | None, optional): Type, Module instance or function for combining embeddings with graph features.
            Required if graph_features_net_type is provided. Defaults to None.
        aggregate_graph_features_args (Sequence[Any] | None, optional): Positional arguments for aggregate_graph_features_type. Defaults to None.
        aggregate_graph_features_kwargs (Dict[str, Any] | None, optional): Keyword arguments for aggregate_graph_features_type. Defaults to None.
        active_tasks (Dict[int, bool] | None, optional): Dictionary mapping task indices to active status.
            If None, all tasks are active by default. Defaults to None.

    Raises:
        ValueError: If downstream_tasks is empty (at least one task required).
        ValueError: If pooling_layers provided without aggregate_pooling_type or vice versa.
        ValueError: If pooling_layers is empty when provided.
        ValueError: If graph_features_net_type provided without aggregate_graph_features_type or vice versa.
        ValueError: If pooling_layers and latent_type are given at the same time.
    """

    # check consistency
    graph_processors = [graph_features_net_type, aggregate_graph_features_type]
    if any([g is not None for g in graph_processors]) and not all(
        g is not None for g in graph_processors
    ):
        raise ValueError(
            "If graph features are to be used, both a graph features network and an aggregation method must be provided."
        )

    pooling_funcs = [aggregate_pooling_type, pooling_layers]
    self.with_pooling = False
    if any(p is not None for p in pooling_funcs):
        if not all(p is not None for p in pooling_funcs):
            raise ValueError(
                "If pooling layers are to be used, both an aggregate pooling method and pooling layers must be provided."
            )
        else:
            self.with_pooling = True

    self.with_latent = False
    if latent_model_type is not None:
        self.with_latent = True

    if self.with_latent and self.with_pooling:
        raise ValueError(
            "Either latent_model_type or pooling_layers and aggregate_pooling can be given, not both"
        )

    # set up downstream tasks. These are independent of each other, but there must be one at least
    if len(downstream_tasks) == 0:
        raise ValueError("At least one downstream task must be provided.")

    super().__init__()

    self.encoder_args = encoder_args
    self.encoder_kwargs = encoder_kwargs
    self.downstream_task_specs = downstream_tasks
    self.pooling_layer_specs = pooling_layers

    self.aggregate_pooling_type = aggregate_pooling_type
    self.aggregate_pooling_args = aggregate_pooling_args
    self.aggregate_pooling_kwargs = aggregate_pooling_kwargs

    self.graph_features_net_args = graph_features_net_args
    self.graph_features_net_kwargs = graph_features_net_kwargs

    self.aggregate_graph_features_type = aggregate_graph_features_type
    self.aggregate_graph_features_args = aggregate_graph_features_args
    self.aggregate_graph_features_kwargs = aggregate_graph_features_kwargs

    # set up encoder type and downstream tasks
    self.encoder = instantiate_type(encoder_type, encoder_args, encoder_kwargs)

    downstream_task_modules = []
    for task_type, args, kwargs in downstream_tasks:
        downstream_task_modules.append(instantiate_type(task_type, args, kwargs))

    self.downstream_tasks = torch.nn.ModuleList(downstream_task_modules)

    if self.with_pooling:
        if pooling_layers is not None:
            if len(pooling_layers) == 0:
                raise ValueError("At least one pooling layer must be provided.")

            self.pooling_layers = torch.nn.ModuleList(
                [
                    instantiate_type(pl_type, args, kwargs)
                    for pl_type, args, kwargs in pooling_layers
                ]
            )
        else:
            self.pooling_layers = None

        # aggregate pooling layer
        if aggregate_pooling_type is not None:
            self.aggregate_pooling = instantiate_type(
                aggregate_pooling_type,
                aggregate_pooling_args,
                aggregate_pooling_kwargs,
            )
        else:
            self.aggregate_pooling = torch.nn.Identity()  # non-op

    if self.with_latent:
        # alternatively to pooling_layer, set up latent_model
        self.latent_model = instantiate_type(
            latent_model_type, latent_model_args, latent_model_kwargs
        )

    # set up graph features processing if provided
    if graph_features_net_type is not None:
        self.graph_features_net = instantiate_type(
            graph_features_net_type,
            graph_features_net_args,
            graph_features_net_kwargs,
        )
    else:
        self.graph_features_net = torch.nn.Identity()  # non-op

    if aggregate_graph_features_type is not None:
        self.aggregate_graph_features = instantiate_type(
            aggregate_graph_features_type,
            aggregate_graph_features_args,
            aggregate_graph_features_kwargs,
        )
    else:
        self.aggregate_graph_features = torch.nn.Identity()  # non-op

    # active tasks
    if active_tasks:
        if len(active_tasks) != len(self.downstream_tasks) or set(
            active_tasks.keys()
        ) != set(range(len(self.downstream_tasks))):
            raise ValueError(
                "active_tasks keys must match the indices of downstream tasks."
            )
        self.active_tasks: Dict[int, bool] = active_tasks
    else:
        self.active_tasks = {i: True for i in range(0, len(self.downstream_tasks))}

compute_downstream_tasks(x, args=None, kwargs=None)

Compute the outputs of the downstream tasks. Only the active tasks will be computed.

Parameters:

Name Type Description Default
x Tensor

Input embeddings tensor

required
args Sequence[Tuple | Sequence] | None

Arguments for downstream tasks. Defaults to None.

None
kwargs Sequence[Dict[str, Any]] | None

Keyword arguments for downstream tasks. Defaults to None.

None

Returns:

Type Description
Dict[int, Tensor | Collection[Tensor]]

Dict[int, torch.Tensor | Collection[torch.Tensor]]: Outputs of the downstream tasks.

Source code in src/QuantumGrav/gnn_model.py
def compute_downstream_tasks(
    self,
    x: torch.Tensor,
    args: Sequence[Tuple | Sequence] | None = None,
    kwargs: Sequence[Dict[str, Any]] | None = None,
) -> Dict[int, torch.Tensor | Collection[torch.Tensor]]:
    """Compute the outputs of the downstream tasks. Only the active tasks will be computed.

    Args:
        x (torch.Tensor): Input embeddings tensor
        args (Sequence[Tuple | Sequence] | None, optional): Arguments for downstream tasks. Defaults to None.
        kwargs (Sequence[Dict[str, Any]] | None, optional): Keyword arguments for downstream tasks. Defaults to None.

    Returns:
        Dict[int, torch.Tensor | Collection[torch.Tensor]]: Outputs of the downstream tasks.
    """
    d_args = [[] for _ in self.downstream_tasks] if args is None else args
    d_kwargs: Sequence[Dict[str, Any]] = (
        [{} for _ in self.downstream_tasks] if kwargs is None else kwargs
    )
    return {
        i: self.downstream_tasks[i](x, *d_args[i], **d_kwargs[i])
        for i in range(0, len(self.downstream_tasks))
        if self.active_tasks[i]
    }

forward(x, edge_index, batch, graph_features=None, latent_args=None, latent_kwargs=None, downstream_task_args=None, downstream_task_kwargs=None, embedding_kwargs=None)

Forward run of the gnn model with optional graph features. First execute the graph-neural network backbone, then process the graph features, and finally apply the downstream tasks.

Parameters:

Name Type Description Default
x Tensor

Input node features.

required
edge_index Tensor

Graph connectivity information.

required
batch Tensor

Batch vector for pooling.

required
graph_features Tensor | None

Additional graph features. Defaults to None.

None
downstream_task_args Sequence[Tuple | Sequence[Any]] | None

Arguments for downstream tasks. Defaults to None.

None
downstream_task_kwargs Sequence[Dict[str, Any]] | None

Keyword arguments for downstream tasks. Defaults to None.

None
embedding_kwargs dict[Any, Any] | None

Additional arguments for the GCN. Defaults to None.

None

Returns:

Type Description
Dict[int, Tensor]

Dict[int, torch.Tensor]: Raw output of downstream tasks.

Source code in src/QuantumGrav/gnn_model.py
def forward(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    batch: torch.Tensor,
    graph_features: torch.Tensor | None = None,
    latent_args: Sequence[Any] | None = None,
    latent_kwargs: Dict[str, Any] | None = None,
    downstream_task_args: Sequence[Tuple | Sequence[Any]] | None = None,
    downstream_task_kwargs: Sequence[Dict[str, Any]] | None = None,
    embedding_kwargs: Dict[Any, Any] | None = None,
) -> Dict[int, torch.Tensor]:
    """Forward run of the gnn model with optional graph features.
    First execute the graph-neural network backbone, then process the graph features, and finally apply the downstream tasks.

    Args:
        x (torch.Tensor): Input node features.
        edge_index (torch.Tensor): Graph connectivity information.
        batch (torch.Tensor): Batch vector for pooling.
        graph_features (torch.Tensor | None, optional): Additional graph features. Defaults to None.
        downstream_task_args ( Sequence[Tuple | Sequence[Any]] | None, optional): Arguments for downstream tasks. Defaults to None.
        downstream_task_kwargs (Sequence[Dict[str, Any]] | None, optional): Keyword arguments for downstream tasks. Defaults to None.
        embedding_kwargs (dict[Any, Any] | None, optional): Additional arguments for the GCN. Defaults to None.

    Returns:
        Dict[int, torch.Tensor]: Raw output of downstream tasks.
    """
    # apply the GCN backbone to the node features
    embeddings = self.get_embeddings(
        x,
        edge_index,
        batch,
        gcn_kwargs=embedding_kwargs,
        latent_args=latent_args,
        latent_kwargs=latent_kwargs,
    )
    # If we have graph features, we need to process them and concatenate them with the node features
    if graph_features is not None:
        graph_features = self.graph_features_net(graph_features)
        embeddings = self.aggregate_graph_features(embeddings, graph_features)

    # downstream tasks are given out as is, no softmax or other assumptions
    return self.compute_downstream_tasks(
        embeddings,
        args=downstream_task_args,
        kwargs=downstream_task_kwargs,
    )

from_config(config) classmethod

Create a GNNModel instance from a configuration dictionary.

Parameters:

Name Type Description Default
config dict

Configuration dictionary with keys matching init parameters. Must include: encoder_type, encoder_args, encoder_kwargs, downstream_tasks. Optional: pooling_layers, aggregate_pooling_type, graph_features_net_type, etc.

required

Returns:

Name Type Description
GNNModel GNNModel

An initialized GNNModel instance.

Raises:

Type Description
RuntimeError

If model creation fails (wraps underlying exceptions).

ValidationError

If config is invalid.

Source code in src/QuantumGrav/gnn_model.py
@classmethod
def from_config(cls, config: dict) -> "GNNModel":
    """Create a GNNModel instance from a configuration dictionary.

    Args:
        config (dict): Configuration dictionary with keys matching __init__ parameters.
            Must include: encoder_type, encoder_args, encoder_kwargs, downstream_tasks.
            Optional: pooling_layers, aggregate_pooling_type, graph_features_net_type, etc.

    Returns:
        GNNModel: An initialized GNNModel instance.

    Raises:
        RuntimeError: If model creation fails (wraps underlying exceptions).
        jsonschema.ValidationError: If config is invalid.
    """
    try:
        jsonschema.validate(config, cls.schema)
        return cls(
            config["encoder_type"],
            config["downstream_tasks"],
            encoder_args=config.get("encoder_args", None),
            encoder_kwargs=config.get("encoder_kwargs", None),
            pooling_layers=config.get("pooling_layers", None),
            aggregate_pooling_type=config.get("aggregate_pooling_type"),
            aggregate_pooling_args=config.get("aggregate_pooling_args"),
            aggregate_pooling_kwargs=config.get("aggregate_pooling_kwargs"),
            latent_model_type=config.get("latent_model_type"),
            latent_model_args=config.get("latent_model_args"),
            latent_model_kwargs=config.get("latent_model_kwargs"),
            graph_features_net_type=config.get("graph_features_net_type"),
            graph_features_net_args=config.get("graph_features_net_args"),
            graph_features_net_kwargs=config.get("graph_features_net_kwargs"),
            aggregate_graph_features_type=config.get(
                "aggregate_graph_features_type"
            ),
            aggregate_graph_features_args=config.get(
                "aggregate_graph_features_args"
            ),
            aggregate_graph_features_kwargs=config.get(
                "aggregate_graph_features_kwargs"
            ),
            active_tasks=config.get("active_tasks"),
        )
    except Exception as e:
        raise RuntimeError(
            f"Error during creation of GNNModel from config: {e}"
        ) from e

get_embeddings(x, edge_index, batch=None, gcn_kwargs=None, latent_args=None, latent_kwargs=None)

Get embeddings from the GCN model.

Parameters:

Name Type Description Default
x Tensor

Input node features.

required
edge_index Tensor

Graph connectivity information.

required
batch Tensor

Batch vector for pooling.

None
gcn_kwargs dict

Additional arguments for the GCN. Defaults to None.

None

Returns:

Type Description
Tensor

torch.Tensor: Embedding vector for the graph features.

Source code in src/QuantumGrav/gnn_model.py
def get_embeddings(
    self,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    batch: torch.Tensor | None = None,
    gcn_kwargs: Dict[str, Any] | None = None,
    latent_args: Sequence[Any] | None = None,
    latent_kwargs: Dict[str, Any] | None = None,
) -> torch.Tensor:
    """Get embeddings from the GCN model.

    Args:
        x (torch.Tensor): Input node features.
        edge_index (torch.Tensor): Graph connectivity information.
        batch (torch.Tensor): Batch vector for pooling.
        gcn_kwargs (dict, optional): Additional arguments for the GCN. Defaults to None.

    Returns:
        torch.Tensor: Embedding vector for the graph features.
    """
    # apply the GCN backbone to the node features
    embeddings = self.encoder(x, edge_index, **(gcn_kwargs if gcn_kwargs else {}))

    # pool everything together into a single graph representation
    if self.with_pooling:
        if not self.pooling_layers or self.pooling_layers == [None]:
            # No pooling layers provided; pass embeddings directly
            pooled_embeddings = [
                embeddings,
            ]
        else:
            pooled_embeddings = [
                pooling_op(embeddings, batch) if pooling_op else embeddings
                for pooling_op in self.pooling_layers
            ]

        return self.aggregate_pooling(pooled_embeddings)

    elif self.with_latent:
        return self.latent_model(
            embeddings,
            *(latent_args if latent_args is not None else []),
            **(latent_kwargs if latent_kwargs is not None else {}),
        )

load(path, config=None, args=None, kwargs=None, device=torch.device('cpu')) classmethod

Load a GNNModel from a file saved with the save() method that's defined by the provided config. It is assumed that the config used to save the model is the same as the one provided here or defines the same model architecture. Therefore, configs should always be saved alongside the model weights.

Parameters:

Name Type Description Default
path str | Path

Path to the saved model file.

required
config Dict[str, Any] | None

Config for building the model

None
args Sequence[Any] | None

Arguments for building the model if config is not supplied

None
kwargs Dict[str, Any] | None

Keyword argumetns for building the model if config is not supplied

None
device device

Device to load the model onto. Defaults to torch.device("cpu").

device('cpu')

Returns:

Name Type Description
GNNModel GNNModel

Fully initialized model instance with loaded weights.

Source code in src/QuantumGrav/gnn_model.py
@classmethod
def load(
    cls,
    path: str | Path,
    config: Dict[str, Any] | None = None,
    args: Sequence[Any] | None = None,
    kwargs: Dict[str, Any] | None = None,
    device: torch.device = torch.device("cpu"),
) -> "GNNModel":
    """Load a GNNModel from a file saved with the save() method that's defined by the provided config.
    It is assumed that the config used to save the model is the same as the one provided here or defines the
    same model architecture. Therefore, configs should always be saved alongside the model weights.

    Args:
        path (str | Path): Path to the saved model file.
        config (Dict[str, Any] | None): Config for building the model
        args (Sequence[Any] | None): Arguments for building the model if config is not supplied
        kwargs (Dict[str, Any] | None): Keyword argumetns for building the model if config is not supplied
        device (torch.device, optional): Device to load the model onto. Defaults to torch.device("cpu").

    Returns:
        GNNModel: Fully initialized model instance with loaded weights.
    """
    if config is not None:
        model = cls.from_config(config).to(device)
    else:
        model = cls(
            *(args if args is not None else []),
            **(kwargs if kwargs is not None else {}),
        )
    model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
    return model

save(path)

Save the model state dictionary to file.

Parameters:

Name Type Description Default
path str | Path

File path where the model will be saved.

required
Source code in src/QuantumGrav/gnn_model.py
def save(self, path: str | Path) -> None:
    """Save the model state dictionary to file.

    Args:
        path (str | Path): File path where the model will be saved.
    """

    torch.save(
        self.state_dict(),
        path,
    )

set_task_active(key)

Set a downstream task as active.

Parameters:

Name Type Description Default
key int

key (name) of the downstream task to activate.

required
Source code in src/QuantumGrav/gnn_model.py
def set_task_active(self, key: int) -> None:
    """Set a downstream task as active.

    Args:
        key (int): key (name) of the downstream task to activate.
    """
    if key not in self.active_tasks:
        raise KeyError(f"Task {key} not found in active tasks.")
    self.active_tasks[key] = True

set_task_inactive(key)

Set a downstream task as inactive.

Parameters:

Name Type Description Default
key int

key (name) of the downstream task to deactivate.

required
Source code in src/QuantumGrav/gnn_model.py
def set_task_inactive(self, key: int) -> None:
    """Set a downstream task as inactive.

    Args:
        key (int): key (name) of the downstream task to deactivate.
    """

    if key not in self.active_tasks:
        raise KeyError(f"Task {key} not found in active tasks.")
    self.active_tasks[key] = False

ModuleWrapper

Bases: Module

Wrapper to make pooling functions compatible with ModuleList and ModuleDict

Source code in src/QuantumGrav/gnn_model.py
class ModuleWrapper(torch.nn.Module):
    """Wrapper to make pooling functions compatible with ModuleList and ModuleDict"""

    def __init__(self, fn: Callable):
        super().__init__()
        self.fn = fn

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return self.fn(*args, **kwargs)

    def get_fn(self) -> Callable:
        return self.fn

instantiate_type(object_or_type, args, kwargs)

Helper to instantiate a type from args, kwargs or use it directly. When a function is passed, it will be wrapped in a ModuleWrapper instance

Parameters:

Name Type Description Default
object_or_type type | Module

type or object to check and instantiate

required
args Sequence[Any] | None

args to build the object

required
kwargs Dict[str, Any] | None

kwargs to build the object

required

Raises:

Type Description
ValueError

When the type is not a subclass or instance of both torch.nn.Module and QG.base.Configurable

Returns:

Type Description

newly instantiated object of type 'object_or_type' or the passed object

Source code in src/QuantumGrav/gnn_model.py
def instantiate_type(
    object_or_type: type | torch.nn.Module | Callable,
    args: Sequence[Any] | None,
    kwargs: Dict[str, Any] | None,
):
    """Helper to instantiate a type from args, kwargs or use it directly. When a function is passed, it will be wrapped in a `ModuleWrapper` instance

    Args:
        object_or_type (type | torch.nn.Module): type or object to check and instantiate
        args (Sequence[Any] | None): args to build the object
        kwargs (Dict[str, Any] | None): kwargs to build the object

    Raises:
        ValueError: When the type is not a subclass or instance of both torch.nn.Module and QG.base.Configurable

    Returns:
         newly instantiated object of type 'object_or_type' or the passed object
    """

    if isinstance(object_or_type, torch.nn.Module):
        return object_or_type
    elif isclass(object_or_type) and issubclass(object_or_type, torch.nn.Module):
        return object_or_type(*(args if args else []), **(kwargs if kwargs else {}))
    elif callable(object_or_type):
        return ModuleWrapper(object_or_type)
    else:
        raise ValueError(
            f"{object_or_type} must be either a subtype of torch.nn.Module or an instance of such a type or a callable"
        )

Graph Neural network submodels

The submodel classes in this section comprise the graph neural network backbone of a QuantumGrav model.

Graph model block

This submodel is the main part of the graph neural network backbone, composed of a set of GNN layers from pytorch-geometric with dropout and BatchNorm.

GNNBlock

Bases: Module, Configurable

Graph Neural Network Block. Consists of a GNN layer, a normalizer, an activation function, and a residual connection. The gnn-layer is applied first, followed by the normalizer and activation function. The result is then projected from the input dimensions to the output dimensions using a linear layer and added to the original input (residual connection). Finally, dropout is applied for regularization.

Source code in src/QuantumGrav/models/gnn_block.py
class GNNBlock(torch.nn.Module, base.Configurable):
    """Graph Neural Network Block. Consists of a GNN layer, a normalizer, an activation function,
    and a residual connection. The gnn-layer is applied first, followed by the normalizer and activation function. The result is then projected from the input dimensions to the output dimensions using a linear layer and added to the original input (residual connection). Finally, dropout is applied for regularization.
    """

    schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "title": "GNNBlock Configuration",
        "type": "object",
        "properties": {
            "in_dim": {
                "type": "integer",
                "description": "input feature size",
                "minimum": 0,
            },
            "out_dim": {
                "type": "integer",
                "description": "output feature size",
                "minimum": 0,
            },
            "dropout": {
                "type": "number",
                "description": "dropout fraction",
                "minimum": 0.0,
                "maximum": 1.0,
            },
            "with_skip": {
                "type": "boolean",
                "description": "Whether a skip connection should be used or not",
            },
            "gnn_layer_type": {
                "description": "type of the graph convolution layer",
            },
            "gnn_layer_args": {
                "type": "array",
                "description": "Arguments of the gcn layer",
                "items": {},
            },
            "gnn_layer_kwargs": {
                "type": "object",
                "description": "Keyword arguments for the gcn layer",
            },
            "normalizer_type": {
                "description": "type of the normalizer module, e.g. BatchNorm",
            },
            "norm_args": {
                "type": "array",
                "description": "Arguments of the normalization layer",
                "items": {},
            },
            "norm_kwargs": {
                "type": "object",
                "description": "Keyword arguments for the normalization layer",
            },
            "activation_type": {
                "description": "type of the activation function",
            },
            "activation_args": {
                "type": "array",
                "description": "Arguments of the activation layer",
                "items": {},
            },
            "activation_kwargs": {
                "type": "object",
                "description": "Keyword arguments for the activation layer",
            },
            "skip_args": {
                "type": "array",
                "description": "Arguments of the skip connection layer",
                "items": {},
            },
            "skip_kwargs": {
                "type": "object",
                "description": "Keyword arguments for the skip connection layer",
            },
        },
        "required": ["in_dim", "out_dim"],
        "additionalProperties": False,
    }

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        dropout: float = 0.3,
        with_skip: bool = True,
        gnn_layer_type: type[torch.nn.Module] = torch_geometric.nn.conv.GCNConv,
        gnn_layer_args: list[Any] | None = None,
        gnn_layer_kwargs: Dict[str, Any] | None = None,
        normalizer_type: type[torch.nn.Module] = torch.nn.Identity,
        norm_args: list[Any] | None = None,
        norm_kwargs: Dict[str, Any] | None = None,
        activation_type: type[torch.nn.Module] = torch.nn.ReLU,
        activation_args: list[Any] | None = None,
        activation_kwargs: Dict[str, Any] | None = None,
        skip_args: list[Any] | None = None,
        skip_kwargs: Dict[str, Any] | None = None,
    ):
        """Create a GNNBlock instance.

        Args:
            in_dim (int): The dimensions of the input features.
            out_dim (int): The dimensions of the output features.
            dropout (float, optional): The dropout probability. Defaults to 0.3.
            with_skip (bool, optional): Whether to use a skip connection. Defaults to True.

            gnn_layer_type (torch.nn.Module, optional): The type of GNN-layer to use. Defaults to torch_geometric.nn.conv.GCNConv.
            gnn_layer_args (list[Any], optional): Additional arguments for the GNN layer. Defaults to None.
            gnn_layer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

            normalizer (torch.nn.Module, optional): The normalizer layer to use. Defaults to torch.nn.Identity.
            norm_args (list[Any], optional): Additional arguments for the normalizer layer. Defaults to None.
            norm_kwargs (Dict[str, Any], optional): Additional keyword arguments for the normalizer layer. Defaults to None.

            activation (torch.nn.Module, optional): The activation function to use. Defaults to torch.nn.ReLU.
            activation_args (list[Any], optional): Additional arguments for the activation function. Defaults to None.
            activation_kwargs (Dict[str, Any], optional): Additional keyword arguments for the activation function. Defaults to None.

            skip_args (list[Any], optional): Additional arguments for the projection layer. Defaults to None.
            skip_kwargs (Dict[str, Any], optional): Additional keyword arguments for the projection layer. Defaults to None.
        """
        super().__init__()

        # save parameters
        self.dropout_p = dropout
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.with_skip = with_skip
        # save args/kwargs
        self.gnn_layer_args = gnn_layer_args
        self.gnn_layer_kwargs = gnn_layer_kwargs
        self.norm_args = norm_args
        self.norm_kwargs = norm_kwargs
        self.activation_args = activation_args
        self.activation_kwargs = activation_kwargs
        self.skip_args = skip_args
        self.skip_kwargs = skip_kwargs

        # initialize layers
        self.dropout = torch.nn.Dropout(p=dropout, inplace=False)

        self.normalizer = normalizer_type(
            *(norm_args if norm_args is not None else []),
            **(norm_kwargs if norm_kwargs is not None else {}),
        )

        self.activation = activation_type(
            *(activation_args if activation_args is not None else []),
            **(activation_kwargs if activation_kwargs is not None else {}),
        )

        self.conv = gnn_layer_type(
            in_dim,
            out_dim,
            *(gnn_layer_args if gnn_layer_args is not None else []),
            **(gnn_layer_kwargs if gnn_layer_kwargs is not None else {}),
        )

        if self.skip_kwargs is None:
            self.skip_kwargs = {}

        if self.skip_args is None:
            self.skip_args = [in_dim, out_dim]

        if with_skip:
            self.skip = skipconnection.SkipConnection(
                *(self.skip_args if self.skip_args else [in_dim, out_dim]),
                **(self.skip_kwargs if self.skip_kwargs else {}),
            )

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor, **kwargs
    ) -> torch.Tensor:
        """Forward pass for the GNNBlock.
        First apply the graph convolution layer, then normalize and apply the activation function.
        Finally, apply a residual connection and dropout.
        Args:
            x (torch.Tensor): The input node features.
            edge_index (torch.Tensor): The graph connectivity information.
            edge_weight (torch.Tensor, optional): The edge weights. Defaults to None.
            kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

        Returns:
            torch.Tensor: The output node features.
        """

        # convolution, then normalize and apply nonlinearity
        x_res = self.conv(x, edge_index, **kwargs)
        x_res = self.normalizer(x_res)
        x_res = self.activation(x_res)

        # Residual connection
        if self.with_skip:
            x_res = self.skip(x, x_res)

        # Apply dropout as regularization
        x_res = self.dropout(x_res)

        return x_res

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "GNNBlock":
        """Create a GNNBlock from a configuration dictionary.
        When the config does not have 'dropout', it defaults to 0.3.

        Args:
            config (dict[str, Any]): Configuration dictionary containing the parameters for the GNNBlock.

        Returns:
            GNNBlock: An instance of GNNBlock initialized with the provided configuration.
        """
        validate(config, cls.schema)

        try:
            return cls(
                in_dim=config["in_dim"],
                out_dim=config["out_dim"],
                dropout=config.get("dropout", 0.3),
                with_skip=config.get("with_skip", True),
                gnn_layer_type=config["gnn_layer_type"],
                normalizer_type=config["normalizer_type"],
                activation_type=config["activation_type"],
                gnn_layer_args=config.get("gnn_layer_args", []),
                gnn_layer_kwargs=config.get("gnn_layer_kwargs", {}),
                norm_args=config.get("norm_args", []),
                norm_kwargs=config.get("norm_kwargs", {}),
                activation_args=config.get("activation_args", []),
                activation_kwargs=config.get("activation_kwargs", {}),
                skip_args=config.get("skip_args", None),
                skip_kwargs=config.get("skip_kwargs", None),
            )

        except Exception as e:
            raise RuntimeError(f"Error while building GNNBlock from config: {e}") from e

    def to_config(self) -> dict[str, Any]:
        """Convert the GNNBlock instance to a configuration dictionary."""
        config = {
            "in_dim": self.in_dim,
            "out_dim": self.out_dim,
            "dropout": self.dropout.p,
            "with_skip": self.with_skip,
            "gnn_layer_type": f"{type(self.conv).__module__}.{type(self.conv).__name__}",
            "gnn_layer_args": self.gnn_layer_args
            if self.gnn_layer_args is not None
            else [],
            "gnn_layer_kwargs": self.gnn_layer_kwargs
            if self.gnn_layer_kwargs is not None
            else {},
            "normalizer_type": f"{type(self.normalizer).__module__}.{type(self.normalizer).__name__}",
            "norm_args": self.norm_args if self.norm_args is not None else [],
            "norm_kwargs": self.norm_kwargs if self.norm_kwargs is not None else {},
            "activation_type": f"{type(self.activation).__module__}.{type(self.activation).__name__}",
            "activation_args": self.activation_args
            if self.activation_args is not None
            else [],
            "activation_kwargs": self.activation_kwargs
            if self.activation_kwargs is not None
            else {},
            "skip_args": self.skip_args if self.skip_args is not None else [],
            "skip_kwargs": self.skip_kwargs if self.skip_kwargs is not None else {},
        }
        return config

    def save(self, path: str | Path) -> None:
        """Save the model's state to file.

        Args:
            path (str | Path): path to save the model to.
        """

        self_as_cfg = self.to_config()

        torch.save({"config": self_as_cfg, "state_dict": self.state_dict()}, path)

    @classmethod
    def load(
        cls, path: str | Path, device: torch.device = torch.device("cpu")
    ) -> "GNNBlock":
        """Load a mode instance from file

        Args:
            path (str | Path): Path to the file to load.
            device (torch.device): device to put the model to. Defaults to torch.device("cpu")
        Returns:
            GNNBlock: A GNNBlock instance initialized from the data loaded from the file.
        """

        modeldata = torch.load(path, weights_only=False)

        cfg = modeldata["config"]
        cfg["gnn_layer_type"] = utils.import_and_get(cfg["gnn_layer_type"])
        cfg["normalizer_type"] = utils.import_and_get(cfg["normalizer_type"])
        cfg["activation_type"] = utils.import_and_get(cfg["activation_type"])

        model = cls.from_config(cfg).to(device)
        model.load_state_dict(modeldata["state_dict"])
        return model

__init__(in_dim, out_dim, dropout=0.3, with_skip=True, gnn_layer_type=torch_geometric.nn.conv.GCNConv, gnn_layer_args=None, gnn_layer_kwargs=None, normalizer_type=torch.nn.Identity, norm_args=None, norm_kwargs=None, activation_type=torch.nn.ReLU, activation_args=None, activation_kwargs=None, skip_args=None, skip_kwargs=None)

Create a GNNBlock instance.

Parameters:

Name Type Description Default
in_dim int

The dimensions of the input features.

required
out_dim int

The dimensions of the output features.

required
dropout float

The dropout probability. Defaults to 0.3.

0.3
with_skip bool

Whether to use a skip connection. Defaults to True.

True
gnn_layer_type Module

The type of GNN-layer to use. Defaults to torch_geometric.nn.conv.GCNConv.

GCNConv
gnn_layer_args list[Any]

Additional arguments for the GNN layer. Defaults to None.

None
gnn_layer_kwargs Dict[str, Any]

Additional keyword arguments for the GNN layer. Defaults to None.

None
normalizer Module

The normalizer layer to use. Defaults to torch.nn.Identity.

required
norm_args list[Any]

Additional arguments for the normalizer layer. Defaults to None.

None
norm_kwargs Dict[str, Any]

Additional keyword arguments for the normalizer layer. Defaults to None.

None
activation Module

The activation function to use. Defaults to torch.nn.ReLU.

required
activation_args list[Any]

Additional arguments for the activation function. Defaults to None.

None
activation_kwargs Dict[str, Any]

Additional keyword arguments for the activation function. Defaults to None.

None
skip_args list[Any]

Additional arguments for the projection layer. Defaults to None.

None
skip_kwargs Dict[str, Any]

Additional keyword arguments for the projection layer. Defaults to None.

None
Source code in src/QuantumGrav/models/gnn_block.py
def __init__(
    self,
    in_dim: int,
    out_dim: int,
    dropout: float = 0.3,
    with_skip: bool = True,
    gnn_layer_type: type[torch.nn.Module] = torch_geometric.nn.conv.GCNConv,
    gnn_layer_args: list[Any] | None = None,
    gnn_layer_kwargs: Dict[str, Any] | None = None,
    normalizer_type: type[torch.nn.Module] = torch.nn.Identity,
    norm_args: list[Any] | None = None,
    norm_kwargs: Dict[str, Any] | None = None,
    activation_type: type[torch.nn.Module] = torch.nn.ReLU,
    activation_args: list[Any] | None = None,
    activation_kwargs: Dict[str, Any] | None = None,
    skip_args: list[Any] | None = None,
    skip_kwargs: Dict[str, Any] | None = None,
):
    """Create a GNNBlock instance.

    Args:
        in_dim (int): The dimensions of the input features.
        out_dim (int): The dimensions of the output features.
        dropout (float, optional): The dropout probability. Defaults to 0.3.
        with_skip (bool, optional): Whether to use a skip connection. Defaults to True.

        gnn_layer_type (torch.nn.Module, optional): The type of GNN-layer to use. Defaults to torch_geometric.nn.conv.GCNConv.
        gnn_layer_args (list[Any], optional): Additional arguments for the GNN layer. Defaults to None.
        gnn_layer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

        normalizer (torch.nn.Module, optional): The normalizer layer to use. Defaults to torch.nn.Identity.
        norm_args (list[Any], optional): Additional arguments for the normalizer layer. Defaults to None.
        norm_kwargs (Dict[str, Any], optional): Additional keyword arguments for the normalizer layer. Defaults to None.

        activation (torch.nn.Module, optional): The activation function to use. Defaults to torch.nn.ReLU.
        activation_args (list[Any], optional): Additional arguments for the activation function. Defaults to None.
        activation_kwargs (Dict[str, Any], optional): Additional keyword arguments for the activation function. Defaults to None.

        skip_args (list[Any], optional): Additional arguments for the projection layer. Defaults to None.
        skip_kwargs (Dict[str, Any], optional): Additional keyword arguments for the projection layer. Defaults to None.
    """
    super().__init__()

    # save parameters
    self.dropout_p = dropout
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.with_skip = with_skip
    # save args/kwargs
    self.gnn_layer_args = gnn_layer_args
    self.gnn_layer_kwargs = gnn_layer_kwargs
    self.norm_args = norm_args
    self.norm_kwargs = norm_kwargs
    self.activation_args = activation_args
    self.activation_kwargs = activation_kwargs
    self.skip_args = skip_args
    self.skip_kwargs = skip_kwargs

    # initialize layers
    self.dropout = torch.nn.Dropout(p=dropout, inplace=False)

    self.normalizer = normalizer_type(
        *(norm_args if norm_args is not None else []),
        **(norm_kwargs if norm_kwargs is not None else {}),
    )

    self.activation = activation_type(
        *(activation_args if activation_args is not None else []),
        **(activation_kwargs if activation_kwargs is not None else {}),
    )

    self.conv = gnn_layer_type(
        in_dim,
        out_dim,
        *(gnn_layer_args if gnn_layer_args is not None else []),
        **(gnn_layer_kwargs if gnn_layer_kwargs is not None else {}),
    )

    if self.skip_kwargs is None:
        self.skip_kwargs = {}

    if self.skip_args is None:
        self.skip_args = [in_dim, out_dim]

    if with_skip:
        self.skip = skipconnection.SkipConnection(
            *(self.skip_args if self.skip_args else [in_dim, out_dim]),
            **(self.skip_kwargs if self.skip_kwargs else {}),
        )

forward(x, edge_index, **kwargs)

Forward pass for the GNNBlock. First apply the graph convolution layer, then normalize and apply the activation function. Finally, apply a residual connection and dropout. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor): The graph connectivity information. edge_weight (torch.Tensor, optional): The edge weights. Defaults to None. kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

Returns:

Type Description
Tensor

torch.Tensor: The output node features.

Source code in src/QuantumGrav/models/gnn_block.py
def forward(
    self, x: torch.Tensor, edge_index: torch.Tensor, **kwargs
) -> torch.Tensor:
    """Forward pass for the GNNBlock.
    First apply the graph convolution layer, then normalize and apply the activation function.
    Finally, apply a residual connection and dropout.
    Args:
        x (torch.Tensor): The input node features.
        edge_index (torch.Tensor): The graph connectivity information.
        edge_weight (torch.Tensor, optional): The edge weights. Defaults to None.
        kwargs (dict[Any, Any], optional): Additional keyword arguments for the GNN layer. Defaults to None.

    Returns:
        torch.Tensor: The output node features.
    """

    # convolution, then normalize and apply nonlinearity
    x_res = self.conv(x, edge_index, **kwargs)
    x_res = self.normalizer(x_res)
    x_res = self.activation(x_res)

    # Residual connection
    if self.with_skip:
        x_res = self.skip(x, x_res)

    # Apply dropout as regularization
    x_res = self.dropout(x_res)

    return x_res

from_config(config) classmethod

Create a GNNBlock from a configuration dictionary. When the config does not have 'dropout', it defaults to 0.3.

Parameters:

Name Type Description Default
config dict[str, Any]

Configuration dictionary containing the parameters for the GNNBlock.

required

Returns:

Name Type Description
GNNBlock GNNBlock

An instance of GNNBlock initialized with the provided configuration.

Source code in src/QuantumGrav/models/gnn_block.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GNNBlock":
    """Create a GNNBlock from a configuration dictionary.
    When the config does not have 'dropout', it defaults to 0.3.

    Args:
        config (dict[str, Any]): Configuration dictionary containing the parameters for the GNNBlock.

    Returns:
        GNNBlock: An instance of GNNBlock initialized with the provided configuration.
    """
    validate(config, cls.schema)

    try:
        return cls(
            in_dim=config["in_dim"],
            out_dim=config["out_dim"],
            dropout=config.get("dropout", 0.3),
            with_skip=config.get("with_skip", True),
            gnn_layer_type=config["gnn_layer_type"],
            normalizer_type=config["normalizer_type"],
            activation_type=config["activation_type"],
            gnn_layer_args=config.get("gnn_layer_args", []),
            gnn_layer_kwargs=config.get("gnn_layer_kwargs", {}),
            norm_args=config.get("norm_args", []),
            norm_kwargs=config.get("norm_kwargs", {}),
            activation_args=config.get("activation_args", []),
            activation_kwargs=config.get("activation_kwargs", {}),
            skip_args=config.get("skip_args", None),
            skip_kwargs=config.get("skip_kwargs", None),
        )

    except Exception as e:
        raise RuntimeError(f"Error while building GNNBlock from config: {e}") from e

load(path, device=torch.device('cpu')) classmethod

Load a mode instance from file

Parameters:

Name Type Description Default
path str | Path

Path to the file to load.

required
device device

device to put the model to. Defaults to torch.device("cpu")

device('cpu')

Returns: GNNBlock: A GNNBlock instance initialized from the data loaded from the file.

Source code in src/QuantumGrav/models/gnn_block.py
@classmethod
def load(
    cls, path: str | Path, device: torch.device = torch.device("cpu")
) -> "GNNBlock":
    """Load a mode instance from file

    Args:
        path (str | Path): Path to the file to load.
        device (torch.device): device to put the model to. Defaults to torch.device("cpu")
    Returns:
        GNNBlock: A GNNBlock instance initialized from the data loaded from the file.
    """

    modeldata = torch.load(path, weights_only=False)

    cfg = modeldata["config"]
    cfg["gnn_layer_type"] = utils.import_and_get(cfg["gnn_layer_type"])
    cfg["normalizer_type"] = utils.import_and_get(cfg["normalizer_type"])
    cfg["activation_type"] = utils.import_and_get(cfg["activation_type"])

    model = cls.from_config(cfg).to(device)
    model.load_state_dict(modeldata["state_dict"])
    return model

save(path)

Save the model's state to file.

Parameters:

Name Type Description Default
path str | Path

path to save the model to.

required
Source code in src/QuantumGrav/models/gnn_block.py
def save(self, path: str | Path) -> None:
    """Save the model's state to file.

    Args:
        path (str | Path): path to save the model to.
    """

    self_as_cfg = self.to_config()

    torch.save({"config": self_as_cfg, "state_dict": self.state_dict()}, path)

to_config()

Convert the GNNBlock instance to a configuration dictionary.

Source code in src/QuantumGrav/models/gnn_block.py
def to_config(self) -> dict[str, Any]:
    """Convert the GNNBlock instance to a configuration dictionary."""
    config = {
        "in_dim": self.in_dim,
        "out_dim": self.out_dim,
        "dropout": self.dropout.p,
        "with_skip": self.with_skip,
        "gnn_layer_type": f"{type(self.conv).__module__}.{type(self.conv).__name__}",
        "gnn_layer_args": self.gnn_layer_args
        if self.gnn_layer_args is not None
        else [],
        "gnn_layer_kwargs": self.gnn_layer_kwargs
        if self.gnn_layer_kwargs is not None
        else {},
        "normalizer_type": f"{type(self.normalizer).__module__}.{type(self.normalizer).__name__}",
        "norm_args": self.norm_args if self.norm_args is not None else [],
        "norm_kwargs": self.norm_kwargs if self.norm_kwargs is not None else {},
        "activation_type": f"{type(self.activation).__module__}.{type(self.activation).__name__}",
        "activation_args": self.activation_args
        if self.activation_args is not None
        else [],
        "activation_kwargs": self.activation_kwargs
        if self.activation_kwargs is not None
        else {},
        "skip_args": self.skip_args if self.skip_args is not None else [],
        "skip_kwargs": self.skip_kwargs if self.skip_kwargs is not None else {},
    }
    return config

Base class for models composed of linear layers

LinearSequential

Bases: Module, Configurable

This class implements a neural network block consisting of a backbone (a sequence of linear layers with activation functions) and multiple output layers for classification tasks. It supports multi-objective classification by allowing multiple output layers, each corresponding to a different classification task, but can also be used for any other type of sequential processing that involves linear layers.

Source code in src/QuantumGrav/models/linear_sequential.py
class LinearSequential(torch.nn.Module, base.Configurable):
    """This class implements a neural network block consisting of a backbone
    (a sequence of linear layers with activation functions) and multiple
    output layers for classification tasks. It supports multi-objective
    classification by allowing multiple output layers, each corresponding
    to a different classification task, but can also be used for any other type of sequential processing that involves linear layers.
    """

    schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "title": "LinearSequential Configuration",
        "type": "object",
        "properties": {
            "dims": {
                "type": "array",
                "description": "(input_channels, output_channels) for each layer",
                "items": {
                    "type": "array",
                    "items": {"type": "integer"},
                    "minItems": 2,
                    "maxItems": 2,
                },
                "minItems": 1,
            },
            "activations": {
                "type": "array",
                "description": "list of activation function objects, e.g. torch.nn.ReLU",
                "items": {},
            },
            "linear_kwargs": {
                "type": "array",
                "description": "keyword arguments to torch_geometric.nn.dense.Linear",
                "items": {"type": "object"},
            },
            "activation_kwargs": {
                "type": "array",
                "description": "keyword arguments to the activation function objects",
                "items": {"type": "object"},
            },
        },
        "required": ["dims", "activations"],
        "additionalProperties": False,  # Only allow specified properties
    }

    def __init__(
        self,
        dims: list[Sequence[int]],
        activations: list[type[torch.nn.Module]],
        linear_kwargs: list[Dict[Any, Any]] | None = None,
        activation_kwargs: list[Dict[Any, Any]] | None = None,
    ):
        """Create a LinearSequential object containing a sequence of MLPs of type torch_geometric.nn.dense.Linear, interspersed with activation functions. All layers are of type `Linear` with an activation function in between (the backbone) and a set of linear output layers.

        Args:
            input_dim (int): input dimension of the LinearSequential object
            output_dim (int): output dimension for the output layer, i.e., the classification task
            hidden_dims (list[int]): list of hidden dimensions for the backbone
            activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
            backbone_kwargs (list[dict], optional): additional arguments for the backbone layers. Defaults to None.
            output_kwargs (dict, optional): additional keyword arguments for the output layers. Defaults to None.

        Raises:
            ValueError: If hidden_dims contains non-positive integers.
            ValueError: If output_dim is a non-positive integer.
        """
        super().__init__()

        if len(dims) == 0:
            raise ValueError("dims must not be empty")

        if len(dims) != len(activations):
            raise ValueError("dims and activations must have the same length")

        if linear_kwargs is None:
            linear_kwargs = [{} for _ in range(len(dims))]

        if activation_kwargs is None:
            activation_kwargs = [{} for _ in range(len(dims))]

        if len(linear_kwargs) != len(dims):
            raise ValueError("linear_kwargs must have the same length as dims")

        if len(activation_kwargs) != len(dims):
            raise ValueError("activation_kwargs must have the same length as dims")

        # build backbone with Sequential
        layers = []
        for i in range(len(dims)):
            in_dim = dims[i][0]
            out_dim = dims[i][1]
            layers.append(
                torch_geometric.nn.dense.Linear(
                    in_dim,
                    out_dim,
                    **linear_kwargs[i],
                )
            )
            layers.append(
                activations[i](
                    **activation_kwargs[i],
                )
            )

        self.layers = torch.nn.Sequential(*layers)
        self.linear_kwargs = linear_kwargs
        self.activation_kwargs = activation_kwargs

    def forward(
        self,
        x: torch.Tensor,
    ) -> list[torch.Tensor]:
        """Forward pass through the LinearSequential object.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            list[torch.Tensor]: List of output tensors from each classifier layer.
        """
        return self.layers(x)

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "LinearSequential":
        """Create a LinearSequential from a configuration dictionary.

        Args:
            config (Dict[str, Any]): Configuration dictionary containing parameters for the LinearSequential.

        Returns:
            LinearSequential: An instance of LinearSequential initialized with the provided configuration.

        Raises:
            ValueError: If the specified activation function is not registered.
        """
        validate(config, cls.schema)

        try:
            n_layers = len(config["dims"])

            if "linear_kwargs" in config and len(config["linear_kwargs"]) != n_layers:
                raise ValueError("linear_kwargs must match dims length")

            if "activation_kwargs" in config and len(
                config["activation_kwargs"]
            ) != len(config["activations"]):
                raise ValueError("activation_kwargs must match dims length")

            return cls(
                dims=config["dims"],
                activations=config["activations"],
                linear_kwargs=config.get("linear_kwargs", None),
                activation_kwargs=config.get("activation_kwargs", None),
            )
        except Exception as e:
            raise RuntimeError(
                f"Error while building LinearSequential from config: {e}"
            ) from e

    def to_config(self) -> Dict[str, Any]:
        """Build a config file from the current model

        Returns:
            Dict[str, Any]: Model config
        """
        linear_dims = []
        activations = []

        for layer in self.layers:
            if isinstance(layer, torch_geometric.nn.dense.Linear):
                linear_dims.append([layer.in_channels, layer.out_channels])
            elif isinstance(layer, torch.nn.Linear):
                linear_dims.append([layer.in_features, layer.out_features])
            elif isinstance(layer, torch.nn.Module) or callable(layer):
                activations.append(f"{layer.__module__}.{type(layer).__name__}")
            else:
                raise ValueError(f"Unknown layer type: {type(layer)}")

        config = {
            "dims": linear_dims,
            "activations": activations,
            "linear_kwargs": self.linear_kwargs,
            "activation_kwargs": self.activation_kwargs,
        }

        return config

    def save(self, path: str | Path) -> None:
        """Save the model's state to file.

        Args:
            path (str | Path): path to save the model to.
        """

        self_as_config = self.to_config()

        torch.save({"config": self_as_config, "state_dict": self.state_dict()}, path)

    @classmethod
    def load(
        cls, path: str | Path, device: torch.device = torch.device("cpu")
    ) -> "LinearSequential":
        """Load a LinearSequential instance from file

        Args:
            path (str | Path): path to the file to load the model from
            device (torch.device): device to put the model to. Defaults to torch.device("cpu")
        Returns:
            LinearSequential: An instance of LinearSequential initialized from the loaded data.
        """
        cfg = torch.load(path)
        cfg["config"]["activations"] = [
            utils.import_and_get(act) for act in cfg["config"]["activations"]
        ]
        model = cls.from_config(cfg["config"])
        model.load_state_dict(cfg["state_dict"], strict=False)
        model.to(device)

        return model

__init__(dims, activations, linear_kwargs=None, activation_kwargs=None)

Create a LinearSequential object containing a sequence of MLPs of type torch_geometric.nn.dense.Linear, interspersed with activation functions. All layers are of type Linear with an activation function in between (the backbone) and a set of linear output layers.

Parameters:

Name Type Description Default
input_dim int

input dimension of the LinearSequential object

required
output_dim int

output dimension for the output layer, i.e., the classification task

required
hidden_dims list[int]

list of hidden dimensions for the backbone

required
activation type[Module]

activation function to use. Defaults to torch.nn.ReLU.

required
backbone_kwargs list[dict]

additional arguments for the backbone layers. Defaults to None.

required
output_kwargs dict

additional keyword arguments for the output layers. Defaults to None.

required

Raises:

Type Description
ValueError

If hidden_dims contains non-positive integers.

ValueError

If output_dim is a non-positive integer.

Source code in src/QuantumGrav/models/linear_sequential.py
def __init__(
    self,
    dims: list[Sequence[int]],
    activations: list[type[torch.nn.Module]],
    linear_kwargs: list[Dict[Any, Any]] | None = None,
    activation_kwargs: list[Dict[Any, Any]] | None = None,
):
    """Create a LinearSequential object containing a sequence of MLPs of type torch_geometric.nn.dense.Linear, interspersed with activation functions. All layers are of type `Linear` with an activation function in between (the backbone) and a set of linear output layers.

    Args:
        input_dim (int): input dimension of the LinearSequential object
        output_dim (int): output dimension for the output layer, i.e., the classification task
        hidden_dims (list[int]): list of hidden dimensions for the backbone
        activation (type[torch.nn.Module], optional): activation function to use. Defaults to torch.nn.ReLU.
        backbone_kwargs (list[dict], optional): additional arguments for the backbone layers. Defaults to None.
        output_kwargs (dict, optional): additional keyword arguments for the output layers. Defaults to None.

    Raises:
        ValueError: If hidden_dims contains non-positive integers.
        ValueError: If output_dim is a non-positive integer.
    """
    super().__init__()

    if len(dims) == 0:
        raise ValueError("dims must not be empty")

    if len(dims) != len(activations):
        raise ValueError("dims and activations must have the same length")

    if linear_kwargs is None:
        linear_kwargs = [{} for _ in range(len(dims))]

    if activation_kwargs is None:
        activation_kwargs = [{} for _ in range(len(dims))]

    if len(linear_kwargs) != len(dims):
        raise ValueError("linear_kwargs must have the same length as dims")

    if len(activation_kwargs) != len(dims):
        raise ValueError("activation_kwargs must have the same length as dims")

    # build backbone with Sequential
    layers = []
    for i in range(len(dims)):
        in_dim = dims[i][0]
        out_dim = dims[i][1]
        layers.append(
            torch_geometric.nn.dense.Linear(
                in_dim,
                out_dim,
                **linear_kwargs[i],
            )
        )
        layers.append(
            activations[i](
                **activation_kwargs[i],
            )
        )

    self.layers = torch.nn.Sequential(*layers)
    self.linear_kwargs = linear_kwargs
    self.activation_kwargs = activation_kwargs

forward(x)

Forward pass through the LinearSequential object.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
list[Tensor]

list[torch.Tensor]: List of output tensors from each classifier layer.

Source code in src/QuantumGrav/models/linear_sequential.py
def forward(
    self,
    x: torch.Tensor,
) -> list[torch.Tensor]:
    """Forward pass through the LinearSequential object.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        list[torch.Tensor]: List of output tensors from each classifier layer.
    """
    return self.layers(x)

from_config(config) classmethod

Create a LinearSequential from a configuration dictionary.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary containing parameters for the LinearSequential.

required

Returns:

Name Type Description
LinearSequential LinearSequential

An instance of LinearSequential initialized with the provided configuration.

Raises:

Type Description
ValueError

If the specified activation function is not registered.

Source code in src/QuantumGrav/models/linear_sequential.py
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "LinearSequential":
    """Create a LinearSequential from a configuration dictionary.

    Args:
        config (Dict[str, Any]): Configuration dictionary containing parameters for the LinearSequential.

    Returns:
        LinearSequential: An instance of LinearSequential initialized with the provided configuration.

    Raises:
        ValueError: If the specified activation function is not registered.
    """
    validate(config, cls.schema)

    try:
        n_layers = len(config["dims"])

        if "linear_kwargs" in config and len(config["linear_kwargs"]) != n_layers:
            raise ValueError("linear_kwargs must match dims length")

        if "activation_kwargs" in config and len(
            config["activation_kwargs"]
        ) != len(config["activations"]):
            raise ValueError("activation_kwargs must match dims length")

        return cls(
            dims=config["dims"],
            activations=config["activations"],
            linear_kwargs=config.get("linear_kwargs", None),
            activation_kwargs=config.get("activation_kwargs", None),
        )
    except Exception as e:
        raise RuntimeError(
            f"Error while building LinearSequential from config: {e}"
        ) from e

load(path, device=torch.device('cpu')) classmethod

Load a LinearSequential instance from file

Parameters:

Name Type Description Default
path str | Path

path to the file to load the model from

required
device device

device to put the model to. Defaults to torch.device("cpu")

device('cpu')

Returns: LinearSequential: An instance of LinearSequential initialized from the loaded data.

Source code in src/QuantumGrav/models/linear_sequential.py
@classmethod
def load(
    cls, path: str | Path, device: torch.device = torch.device("cpu")
) -> "LinearSequential":
    """Load a LinearSequential instance from file

    Args:
        path (str | Path): path to the file to load the model from
        device (torch.device): device to put the model to. Defaults to torch.device("cpu")
    Returns:
        LinearSequential: An instance of LinearSequential initialized from the loaded data.
    """
    cfg = torch.load(path)
    cfg["config"]["activations"] = [
        utils.import_and_get(act) for act in cfg["config"]["activations"]
    ]
    model = cls.from_config(cfg["config"])
    model.load_state_dict(cfg["state_dict"], strict=False)
    model.to(device)

    return model

save(path)

Save the model's state to file.

Parameters:

Name Type Description Default
path str | Path

path to save the model to.

required
Source code in src/QuantumGrav/models/linear_sequential.py
def save(self, path: str | Path) -> None:
    """Save the model's state to file.

    Args:
        path (str | Path): path to save the model to.
    """

    self_as_config = self.to_config()

    torch.save({"config": self_as_config, "state_dict": self.state_dict()}, path)

to_config()

Build a config file from the current model

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Model config

Source code in src/QuantumGrav/models/linear_sequential.py
def to_config(self) -> Dict[str, Any]:
    """Build a config file from the current model

    Returns:
        Dict[str, Any]: Model config
    """
    linear_dims = []
    activations = []

    for layer in self.layers:
        if isinstance(layer, torch_geometric.nn.dense.Linear):
            linear_dims.append([layer.in_channels, layer.out_channels])
        elif isinstance(layer, torch.nn.Linear):
            linear_dims.append([layer.in_features, layer.out_features])
        elif isinstance(layer, torch.nn.Module) or callable(layer):
            activations.append(f"{layer.__module__}.{type(layer).__name__}")
        else:
            raise ValueError(f"Unknown layer type: {type(layer)}")

    config = {
        "dims": linear_dims,
        "activations": activations,
        "linear_kwargs": self.linear_kwargs,
        "activation_kwargs": self.activation_kwargs,
    }

    return config

Model evaluation

This module provides base classes that take the output of applying the model to a validation or training dataset, and derive useful quantities to evaluate the model quality. These do not do anything useful by default. Rather, you must derive your own class from them that implemements your desired evaluation, e.g., using an F1 score.

DefaultEvaluator

Default evaluator for model evaluation - testing and validation during training

Source code in src/QuantumGrav/evaluate.py
class DefaultEvaluator:
    """Default evaluator for model evaluation - testing and validation during training"""

    def __init__(
        self,
        device: str | torch.device | int,
        criterion: Callable,
        apply_model: Callable | None = None,
    ):
        """Default evaluator for model evaluation.

        Args:
            device (str | torch.device | int): The device to run the evaluation on.
            criterion (Callable): The loss function to use for evaluation.
            apply_model (Callable): A function to apply the model to the data.
        """
        self.criterion = criterion
        self.apply_model = apply_model
        self.device = device
        self.data: pd.DataFrame | list = []
        self.logger = logging.getLogger(__name__)

    def evaluate(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ) -> list[Any]:
        """Evaluate the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to evaluate.
            data_loader (torch_geometric.loader.DataLoader): Data loader for evaluation.

        Returns:
             list[Any]: A list of evaluation results.
        """
        model.eval()
        current_data = []

        with torch.no_grad():
            for i, batch in enumerate(data_loader):
                data = batch.to(self.device)
                if self.apply_model:
                    outputs = self.apply_model(model, data)
                else:
                    outputs = model(data.x, data.edge_index, data.batch)
                loss = self.criterion(outputs, data)
                current_data.append(loss)

        return current_data

    def report(self, data: list | pd.Series | torch.Tensor | np.ndarray) -> None:
        """Report the evaluation results.

        Args:
            data (list | pd.Series | torch.Tensor | np.ndarray): The evaluation results.
        """

        if isinstance(data, torch.Tensor):
            data = data.cpu().numpy()

        if isinstance(data, list):
            for i, d in enumerate(data):
                if isinstance(d, torch.Tensor):
                    data[i] = d.cpu().numpy()

        avg = np.mean(data)
        sigma = np.std(data)
        self.logger.info(f"Average loss: {avg}, Standard deviation: {sigma}")

        if isinstance(self.data, list):
            self.data.append((avg, sigma))
        else:
            self.data = pd.concat(
                [
                    self.data,
                    pd.DataFrame({"loss": avg, "std": sigma}, index=[0]),
                ],
                axis=0,
                ignore_index=True,
            )

__init__(device, criterion, apply_model=None)

Default evaluator for model evaluation.

Parameters:

Name Type Description Default
device str | device | int

The device to run the evaluation on.

required
criterion Callable

The loss function to use for evaluation.

required
apply_model Callable

A function to apply the model to the data.

None
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self,
    device: str | torch.device | int,
    criterion: Callable,
    apply_model: Callable | None = None,
):
    """Default evaluator for model evaluation.

    Args:
        device (str | torch.device | int): The device to run the evaluation on.
        criterion (Callable): The loss function to use for evaluation.
        apply_model (Callable): A function to apply the model to the data.
    """
    self.criterion = criterion
    self.apply_model = apply_model
    self.device = device
    self.data: pd.DataFrame | list = []
    self.logger = logging.getLogger(__name__)

evaluate(model, data_loader)

Evaluate the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to evaluate.

required
data_loader DataLoader

Data loader for evaluation.

required

Returns:

Type Description
list[Any]

list[Any]: A list of evaluation results.

Source code in src/QuantumGrav/evaluate.py
def evaluate(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
) -> list[Any]:
    """Evaluate the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to evaluate.
        data_loader (torch_geometric.loader.DataLoader): Data loader for evaluation.

    Returns:
         list[Any]: A list of evaluation results.
    """
    model.eval()
    current_data = []

    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            data = batch.to(self.device)
            if self.apply_model:
                outputs = self.apply_model(model, data)
            else:
                outputs = model(data.x, data.edge_index, data.batch)
            loss = self.criterion(outputs, data)
            current_data.append(loss)

    return current_data

report(data)

Report the evaluation results.

Parameters:

Name Type Description Default
data list | Series | Tensor | ndarray

The evaluation results.

required
Source code in src/QuantumGrav/evaluate.py
def report(self, data: list | pd.Series | torch.Tensor | np.ndarray) -> None:
    """Report the evaluation results.

    Args:
        data (list | pd.Series | torch.Tensor | np.ndarray): The evaluation results.
    """

    if isinstance(data, torch.Tensor):
        data = data.cpu().numpy()

    if isinstance(data, list):
        for i, d in enumerate(data):
            if isinstance(d, torch.Tensor):
                data[i] = d.cpu().numpy()

    avg = np.mean(data)
    sigma = np.std(data)
    self.logger.info(f"Average loss: {avg}, Standard deviation: {sigma}")

    if isinstance(self.data, list):
        self.data.append((avg, sigma))
    else:
        self.data = pd.concat(
            [
                self.data,
                pd.DataFrame({"loss": avg, "std": sigma}, index=[0]),
            ],
            axis=0,
            ignore_index=True,
        )

DefaultTester

Bases: DefaultEvaluator

Default tester for model testing.

Parameters:

Name Type Description Default
DefaultEvaluator Class

Inherits from DefaultEvaluator and provides functionality for validating models

required

using a specified criterion and optional model application function.

Source code in src/QuantumGrav/evaluate.py
class DefaultTester(DefaultEvaluator):
    """Default tester for model testing.

    Args:
        DefaultEvaluator (Class): Inherits from DefaultEvaluator and provides functionality for validating models
    using a specified criterion and optional model application function.
    """

    def __init__(
        self,
        device: str | torch.device | int,
        criterion: Callable,
        apply_model: Callable | None = None,
    ):
        """Default tester for model testing.

        Args:
            device (str | torch.device | int,): The device to run the testing on.
            criterion (Callable): The loss function to use for testing.
            apply_model (Callable): A function to apply the model to the data.
        """
        super().__init__(device, criterion, apply_model)

    def test(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ) -> list[Any]:
        """Test the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to test.
            data_loader (torch_geometric.loader.DataLoader): Data loader for testing.

        Returns:
            list[Any]: A list of testing results.
        """
        return self.evaluate(model, data_loader)

__init__(device, criterion, apply_model=None)

Default tester for model testing.

Parameters:

Name Type Description Default
device (str | device | int,)

The device to run the testing on.

required
criterion Callable

The loss function to use for testing.

required
apply_model Callable

A function to apply the model to the data.

None
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self,
    device: str | torch.device | int,
    criterion: Callable,
    apply_model: Callable | None = None,
):
    """Default tester for model testing.

    Args:
        device (str | torch.device | int,): The device to run the testing on.
        criterion (Callable): The loss function to use for testing.
        apply_model (Callable): A function to apply the model to the data.
    """
    super().__init__(device, criterion, apply_model)

test(model, data_loader)

Test the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to test.

required
data_loader DataLoader

Data loader for testing.

required

Returns:

Type Description
list[Any]

list[Any]: A list of testing results.

Source code in src/QuantumGrav/evaluate.py
def test(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
) -> list[Any]:
    """Test the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to test.
        data_loader (torch_geometric.loader.DataLoader): Data loader for testing.

    Returns:
        list[Any]: A list of testing results.
    """
    return self.evaluate(model, data_loader)

DefaultValidator

Bases: DefaultEvaluator

Default validator for model validation.

Parameters:

Name Type Description Default
DefaultEvaluator Class

Inherits from DefaultEvaluator and provides functionality for validating models

required

using a specified criterion and optional model application function.

Source code in src/QuantumGrav/evaluate.py
class DefaultValidator(DefaultEvaluator):
    """Default validator for model validation.

    Args:
        DefaultEvaluator (Class): Inherits from DefaultEvaluator and provides functionality for validating models
    using a specified criterion and optional model application function.
    """

    def __init__(
        self,
        device: str | torch.device | int,
        criterion: Callable,
        apply_model: Callable | None = None,
    ):
        """Default validator for model validation.

        Args:
            device (str | torch.device | int,): The device to run the validation on.
            criterion (Callable): The loss function to use for validation.
            apply_model (Callable | None, optional): A function to apply the model to the data. Defaults to None.
        """
        super().__init__(device, criterion, apply_model)

    def validate(
        self,
        model: torch.nn.Module,
        data_loader: torch_geometric.loader.DataLoader,  # type: ignore
    ) -> list[Any]:
        """Validate the model on the given data loader.

        Args:
            model (torch.nn.Module): Model to validate.
            data_loader (torch_geometric.loader.DataLoader): Data loader for validation.
        Returns:
            list[Any]: A list of validation results.
        """
        return self.evaluate(model, data_loader)

__init__(device, criterion, apply_model=None)

Default validator for model validation.

Parameters:

Name Type Description Default
device (str | device | int,)

The device to run the validation on.

required
criterion Callable

The loss function to use for validation.

required
apply_model Callable | None

A function to apply the model to the data. Defaults to None.

None
Source code in src/QuantumGrav/evaluate.py
def __init__(
    self,
    device: str | torch.device | int,
    criterion: Callable,
    apply_model: Callable | None = None,
):
    """Default validator for model validation.

    Args:
        device (str | torch.device | int,): The device to run the validation on.
        criterion (Callable): The loss function to use for validation.
        apply_model (Callable | None, optional): A function to apply the model to the data. Defaults to None.
    """
    super().__init__(device, criterion, apply_model)

validate(model, data_loader)

Validate the model on the given data loader.

Parameters:

Name Type Description Default
model Module

Model to validate.

required
data_loader DataLoader

Data loader for validation.

required

Returns: list[Any]: A list of validation results.

Source code in src/QuantumGrav/evaluate.py
def validate(
    self,
    model: torch.nn.Module,
    data_loader: torch_geometric.loader.DataLoader,  # type: ignore
) -> list[Any]:
    """Validate the model on the given data loader.

    Args:
        model (torch.nn.Module): Model to validate.
        data_loader (torch_geometric.loader.DataLoader): Data loader for validation.
    Returns:
        list[Any]: A list of validation results.
    """
    return self.evaluate(model, data_loader)

Datasets

The package supports three kinds of datasets with a common baseclass QGDatasetBase. For the basics of how those work, check out the pytorch-geometric documentation of dataset

These are: - QGDataset: A dataset that relies on an on-disk storage of the processed data. It lazily loads csets from disk when needed. - QGDatasetInMemory: A dataset that holds the entire processed dataset in memory at once. - QGDatasetOnthefly: This dataset does not hold anything on disk or in memory, but creates the data on demand from some supplied Julia code.

Dataset base class

QGDatasetBase

Mixin class that provides common functionality for the dataset classes. Works only for file-based datasets. Provides methods for processing data.

Source code in src/QuantumGrav/dataset_base.py
class QGDatasetBase:
    """Mixin class that provides common functionality for the dataset classes. Works only for file-based datasets. Provides methods for processing data."""

    def __init__(
        self,
        input: list[str | Path],
        output: str | Path,
        reader: Callable[
            [zarr.Group, torch.dtype, torch.dtype, bool],
            list[Data],
        ]
        | None = None,
        float_type: torch.dtype = torch.float32,
        int_type: torch.dtype = torch.int64,
        validate_data: bool = True,
        n_processes: int = 1,
        chunksize: int = 1000,
        preprocess: bool = False,
    ):
        """Initialize a QGDatasetBase instance. This class is designed to provide some common functionality that can be used by downstream datasets built on top of torch dataset classes. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

        Args:
            input (list[str  |  Path] : The list of input files for the dataset, or a callable that generates a set of input files.
            output (str | Path): The output directory where processed data will be stored.
            reader (Callable[[zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to load data from a file. Defaults to None.
            float_type (torch.dtype, optional): The data type to use for floating point values. Defaults to torch.float32.
            int_type (torch.dtype, optional): The data type to use for integer values. Defaults to torch.int64.
            validate_data (bool, optional): Whether to validate the data after loading. Defaults to True.
            n_processes (int, optional): The number of processes to use for parallel processing of read data. Defaults to 1.
            chunksize (int, optional): The size of the chunks to process in parallel. Defaults to 1000.
            preprocess: (bool, optional): Whether datapreprocessing should happen and the results be stored on disk. Defaults to False.

        Raises:
            ValueError: If one of the input data files does not exist
            ValueError: If the metadata retrieval function is invalid.
            FileNotFoundError: If an input file does not exist.
        """
        if reader is None:
            raise ValueError("A reader function must be provided to load the data.")

        self.input = input
        for file in self.input:
            if Path(file).exists() is False:
                raise FileNotFoundError(f"Input file {file} does not exist.")

        self.output = output
        self.data_reader = reader
        self.metadata = {}
        self.float_type = float_type
        self.int_type = int_type
        self.validate_data = validate_data
        self.n_processes = n_processes
        self.chunksize = chunksize
        self.preprocess = preprocess

        # get the number of samples in the dataset
        self._num_samples = 0
        num_samples_per_file = []
        for filepath in self.input:
            if not Path(filepath).exists():
                raise FileNotFoundError(f"Input file {filepath} does not exist.")
            n = self._get_num_samples_per_file(filepath)
            num_samples_per_file.append(n)
            self._num_samples += n

        self._num_samples_per_file = np.stack(num_samples_per_file, axis=0)

        # ensure the input is a list of paths
        if Path(self.processed_dir).exists():
            with open(Path(self.processed_dir) / "metadata.yaml", "r") as f:
                self.metadata = yaml.load(f, Loader=yaml.FullLoader)
        elif preprocess:
            Path(self.processed_dir).mkdir(parents=True, exist_ok=True)
            self.metadata = {
                "num_samples": int(self._num_samples),
                "input": [str(Path(f).resolve().absolute()) for f in self.input],
                "output": str(Path(self.output).resolve().absolute()),
                "float_type": str(self.float_type),
                "int_type": str(self.int_type),
                "validate_data": self.validate_data,
                "n_processes": self.n_processes,
                "chunksize": self.chunksize,
                "preprocess": self.preprocess,
            }

            with open(Path(self.processed_dir) / "metadata.yaml", "w") as f:
                yaml.dump(self.metadata, f)
        else:
            # do nothing here b/c this branch doesn't need any action
            pass

    def _get_num_samples_per_file(self, filepath: str | Path) -> int | np.ndarray:
        """Get the number of samples in a given file.

        Args:
            filepath (str | Path): The path to the file.

        Raises:
            ValueError: If the file is not a valid Zarr file.

        Returns:
            int: The number of samples in the file.
        """

        # try to find the sample number from a dedicated dataset
        def try_find_numsamples(f):
            s = None
            for name in ["num_causal_sets", "num_samples"]:
                if name in f:
                    s = f[name]
                    break
            return s

        # ... if that fails, we try to read it from any scalar dataset.
        # ... if we can´t because they are of unequal sizes, we return None
        # ... to indicate an unresolvable state
        def fallback(f) -> int | None:
            # find scalar datasets and use their sizes to determine size
            shapes = [f[k].shape[0] for k in f.keys() if len(f[k].shape) == 1]
            max_shape = max(shapes)
            min_shape = min(shapes)
            if max_shape != min_shape:
                return None
            else:
                return max_shape

        # same logic for Zarr
        try:
            group = zarr.open_group(
                zarr.storage.LocalStore(filepath, read_only=True),
                path="",
                mode="r",
            )
            # note that fallback returns an int directly,
            # while for try_find_numsamples we need to index into the result
            s = try_find_numsamples(group)
            if s is not None:
                return s[0]
            else:
                s = fallback(group)
                if s is not None:
                    return s
                else:
                    raise RuntimeError("Unable to determine number of samples.")
        except Exception:
            # we need an extra fallback for zarr b/c Julia Zarr and python Zarr
            # can differ in layout - Julia Zarr does not have to have a group
            try:
                store = zarr.storage.LocalStore(filepath, read_only=True)
                arr = zarr.open_array(store, path="adjacency_matrix")
                s = max(arr.shape)
                return s
            except Exception:
                raise

    @property
    def processed_dir(self) -> str:
        """Get the path to the processed directory.

        Returns:
            str: The path to the processed directory, or None if it doesn't exist.
        """
        processed_path = Path(self.output).resolve().absolute() / "processed"
        return str(processed_path)

    @property
    def raw_file_names(self) -> list[str]:
        """Get the raw file paths from the input list.

        Returns:
            list[str]: A list of raw file paths.
        """
        suf = ".zarr"
        return [str(Path(f).name) for f in self.input if Path(f).suffix == suf]

    @property
    def processed_file_names(self) -> list[str]:
        """Get a list of processed files in the processed directory.

        Returns:
            list[str]: A list of processed file paths, excluding JSON files.
        """
        if not Path(self.processed_dir).exists():
            return []

        return [
            str(f.name)
            for f in Path(self.processed_dir).iterdir()
            if f.is_file() and f.suffix == ".pt" and "data" in f.name
        ]

    def process_chunk(
        self,
        store: zarr.storage.LocalStore,
        start: int,
        pre_transform: Callable[[Data | Collection], Data] | None = None,
        pre_filter: Callable[[Data | Collection], bool] | None = None,
    ) -> Sequence[Data]:
        """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

        Args:
            store (zarr.storage.LocalStore): local zarr storage
            start (int): start index
            pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

        Returns:
            list[Data]: The processed data or None if the chunk is empty.
        """
        N = self._get_num_samples_per_file(store.root)
        rootgroup = zarr.open_group(store.root)

        def process_item(i: int):
            item = self.data_reader(
                rootgroup,
                i,
                self.float_type,
                self.int_type,
                self.validate_data,
            )
            if pre_filter is not None and not pre_filter(item):
                return None
            if pre_transform is not None:
                return pre_transform(item)
            return item

        if self.n_processes > 1:
            results = Parallel(n_jobs=self.n_processes)(
                delayed(process_item)(i)
                for i in range(start, min(start + self.chunksize, N))
            )
        else:
            results = [
                process_item(i) for i in range(start, min(start + self.chunksize, N))
            ]

        return [res for res in results if res is not None]

processed_dir property

Get the path to the processed directory.

Returns:

Name Type Description
str str

The path to the processed directory, or None if it doesn't exist.

processed_file_names property

Get a list of processed files in the processed directory.

Returns:

Type Description
list[str]

list[str]: A list of processed file paths, excluding JSON files.

raw_file_names property

Get the raw file paths from the input list.

Returns:

Type Description
list[str]

list[str]: A list of raw file paths.

__init__(input, output, reader=None, float_type=torch.float32, int_type=torch.int64, validate_data=True, n_processes=1, chunksize=1000, preprocess=False)

Initialize a QGDatasetBase instance. This class is designed to provide some common functionality that can be used by downstream datasets built on top of torch dataset classes. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

Parameters:

Name Type Description Default
input (list[str | Path]

The list of input files for the dataset, or a callable that generates a set of input files.

required
output str | Path

The output directory where processed data will be stored.

required
reader Callable[[Group, dtype, dtype, bool], list[Data]] | None

A function to load data from a file. Defaults to None.

None
float_type dtype

The data type to use for floating point values. Defaults to torch.float32.

float32
int_type dtype

The data type to use for integer values. Defaults to torch.int64.

int64
validate_data bool

Whether to validate the data after loading. Defaults to True.

True
n_processes int

The number of processes to use for parallel processing of read data. Defaults to 1.

1
chunksize int

The size of the chunks to process in parallel. Defaults to 1000.

1000
preprocess bool

(bool, optional): Whether datapreprocessing should happen and the results be stored on disk. Defaults to False.

False

Raises:

Type Description
ValueError

If one of the input data files does not exist

ValueError

If the metadata retrieval function is invalid.

FileNotFoundError

If an input file does not exist.

Source code in src/QuantumGrav/dataset_base.py
def __init__(
    self,
    input: list[str | Path],
    output: str | Path,
    reader: Callable[
        [zarr.Group, torch.dtype, torch.dtype, bool],
        list[Data],
    ]
    | None = None,
    float_type: torch.dtype = torch.float32,
    int_type: torch.dtype = torch.int64,
    validate_data: bool = True,
    n_processes: int = 1,
    chunksize: int = 1000,
    preprocess: bool = False,
):
    """Initialize a QGDatasetBase instance. This class is designed to provide some common functionality that can be used by downstream datasets built on top of torch dataset classes. It is not to be instantiated directly, but rather used as a mixin for other dataset classes.

    Args:
        input (list[str  |  Path] : The list of input files for the dataset, or a callable that generates a set of input files.
        output (str | Path): The output directory where processed data will be stored.
        reader (Callable[[zarr.Group, torch.dtype, torch.dtype, bool], list[Data]] | None, optional): A function to load data from a file. Defaults to None.
        float_type (torch.dtype, optional): The data type to use for floating point values. Defaults to torch.float32.
        int_type (torch.dtype, optional): The data type to use for integer values. Defaults to torch.int64.
        validate_data (bool, optional): Whether to validate the data after loading. Defaults to True.
        n_processes (int, optional): The number of processes to use for parallel processing of read data. Defaults to 1.
        chunksize (int, optional): The size of the chunks to process in parallel. Defaults to 1000.
        preprocess: (bool, optional): Whether datapreprocessing should happen and the results be stored on disk. Defaults to False.

    Raises:
        ValueError: If one of the input data files does not exist
        ValueError: If the metadata retrieval function is invalid.
        FileNotFoundError: If an input file does not exist.
    """
    if reader is None:
        raise ValueError("A reader function must be provided to load the data.")

    self.input = input
    for file in self.input:
        if Path(file).exists() is False:
            raise FileNotFoundError(f"Input file {file} does not exist.")

    self.output = output
    self.data_reader = reader
    self.metadata = {}
    self.float_type = float_type
    self.int_type = int_type
    self.validate_data = validate_data
    self.n_processes = n_processes
    self.chunksize = chunksize
    self.preprocess = preprocess

    # get the number of samples in the dataset
    self._num_samples = 0
    num_samples_per_file = []
    for filepath in self.input:
        if not Path(filepath).exists():
            raise FileNotFoundError(f"Input file {filepath} does not exist.")
        n = self._get_num_samples_per_file(filepath)
        num_samples_per_file.append(n)
        self._num_samples += n

    self._num_samples_per_file = np.stack(num_samples_per_file, axis=0)

    # ensure the input is a list of paths
    if Path(self.processed_dir).exists():
        with open(Path(self.processed_dir) / "metadata.yaml", "r") as f:
            self.metadata = yaml.load(f, Loader=yaml.FullLoader)
    elif preprocess:
        Path(self.processed_dir).mkdir(parents=True, exist_ok=True)
        self.metadata = {
            "num_samples": int(self._num_samples),
            "input": [str(Path(f).resolve().absolute()) for f in self.input],
            "output": str(Path(self.output).resolve().absolute()),
            "float_type": str(self.float_type),
            "int_type": str(self.int_type),
            "validate_data": self.validate_data,
            "n_processes": self.n_processes,
            "chunksize": self.chunksize,
            "preprocess": self.preprocess,
        }

        with open(Path(self.processed_dir) / "metadata.yaml", "w") as f:
            yaml.dump(self.metadata, f)
    else:
        # do nothing here b/c this branch doesn't need any action
        pass

process_chunk(store, start, pre_transform=None, pre_filter=None)

Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

Parameters:

Name Type Description Default
store LocalStore

local zarr storage

required
start int

start index

required
pre_transform Callable[[Data], Data] | None

Transformation that adds additional features to the data. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

A function that filters the data. Defaults to None.

None

Returns:

Type Description
Sequence[Data]

list[Data]: The processed data or None if the chunk is empty.

Source code in src/QuantumGrav/dataset_base.py
def process_chunk(
    self,
    store: zarr.storage.LocalStore,
    start: int,
    pre_transform: Callable[[Data | Collection], Data] | None = None,
    pre_filter: Callable[[Data | Collection], bool] | None = None,
) -> Sequence[Data]:
    """Process a chunk of data from the raw file. This method is intended to be used in the data loading pipeline to read a chunk of data, apply transformations, and filter the read data, and thus should not be called directly.

    Args:
        store (zarr.storage.LocalStore): local zarr storage
        start (int): start index
        pre_transform (Callable[[Data], Data] | None, optional): Transformation that adds additional features to the data. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): A function that filters the data. Defaults to None.

    Returns:
        list[Data]: The processed data or None if the chunk is empty.
    """
    N = self._get_num_samples_per_file(store.root)
    rootgroup = zarr.open_group(store.root)

    def process_item(i: int):
        item = self.data_reader(
            rootgroup,
            i,
            self.float_type,
            self.int_type,
            self.validate_data,
        )
        if pre_filter is not None and not pre_filter(item):
            return None
        if pre_transform is not None:
            return pre_transform(item)
        return item

    if self.n_processes > 1:
        results = Parallel(n_jobs=self.n_processes)(
            delayed(process_item)(i)
            for i in range(start, min(start + self.chunksize, N))
        )
    else:
        results = [
            process_item(i) for i in range(start, min(start + self.chunksize, N))
        ]

    return [res for res in results if res is not None]

Dataset loading data from disk

QGDataset

Bases: QGDatasetBase, Dataset

A dataset class for QuantumGrav data that is designed to handle large datasets stored on disk. This class provides methods for loading, processing, and writing data that are common to both in-memory and on-disk datasets.

Source code in src/QuantumGrav/dataset_ondisk.py
class QGDataset(QGDatasetBase, Dataset):
    """A dataset class for QuantumGrav data that is designed to handle large datasets stored on disk. This class provides methods for loading, processing, and writing data that are common to both in-memory and on-disk datasets."""

    def __init__(
        self,
        input: list[str | Path],
        output: str | Path,
        reader: Callable[
            [zarr.Group, int, torch.dtype, torch.dtype, bool], Collection[Any]
        ]
        | None = None,
        float_type: torch.dtype = torch.float32,
        int_type: torch.dtype = torch.int64,
        validate_data: bool = True,
        chunksize: int = 1000,
        n_processes: int = 1,
        # dataset properties
        transform: Callable[[Data | Collection[Any]], Data] | None = None,
        pre_transform: Callable[[Data | Collection[Any]], Data] | None = None,
        pre_filter: Callable[[Data | Collection[Any]], bool] | None = None,
    ):
        """Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk. When there is no pre_transform and no pre_filter is given, the system will not create a `processed` directory.

        Args:
            input (list[str  |  Path] | Callable[[Any], dict]): List of input zarr file paths.
            output (str | Path): Output directory where processed data will be stored.
            reader (Callable[[zarr.Group, int], list[Data]] | None, optional): Function to read data from the zarr file. Defaults to None.
            float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
            int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
            validate_data (bool, optional): Whether to validate the data. Defaults to True.
            chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
            n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
            transform (Callable[[Data], Data] | None, optional): Function to transform the data. Defaults to None.
            pre_transform (Callable[[Data], Data] | None, optional): Function to pre-transform the data. Defaults to None.
            pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data. Defaults to None.
        """
        preprocess = pre_transform is not None or pre_filter is not None
        self.stores = {}

        QGDatasetBase.__init__(
            self,
            input,
            output,
            reader=reader,
            float_type=float_type,
            int_type=int_type,
            validate_data=validate_data,
            chunksize=chunksize,
            n_processes=n_processes,
            preprocess=preprocess,
        )

        Dataset.__init__(
            self,
            root=output,
            transform=transform,
            pre_transform=pre_transform,
            pre_filter=pre_filter,
        )

    def write_data(self, data: list[Data], idx: int) -> int:
        """Write the processed data to disk using `torch.save`. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

        Args:
            data (list[Data]): The list of Data objects to write to disk.
            idx (int): The index to use for naming the files.
        """
        if not Path(self.processed_dir).exists():
            Path(self.processed_dir).mkdir(parents=True, exist_ok=True)

        for d in data:
            if d is not None:
                file_path = Path(self.processed_dir) / f"data_{idx}.pt"
                torch.save(d, file_path)
                idx += 1
        return idx

    def process(self) -> None:
        """Process the dataset from the read rawdata into its final form."""
        # process data files
        k = 0  # index to create the filenames for the processed data
        if self.pre_filter is None and self.pre_transform is None:
            return

        for file in self.input:
            N = self._get_num_samples_per_file(Path(file).resolve().absolute())

            num_chunks = N // self.chunksize

            raw_file = zarr.storage.LocalStore(
                str(Path(file).resolve().absolute()), read_only=True
            )

            for i in range(0, num_chunks * self.chunksize, self.chunksize):
                data = self.process_chunk(
                    raw_file,
                    i,
                    pre_transform=self.pre_transform,
                    pre_filter=self.pre_filter,
                )

                k = self.write_data(data, k)

            # final chunk processing
            data = self.process_chunk(
                raw_file,
                num_chunks * self.chunksize,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            k = self.write_data(data, k)

            raw_file.close()

    def _get_store_group(
        self, file: Path | str
    ) -> Tuple[zarr.storage.LocalStore, zarr.Group]:
        """Get a requested open store and add it to an internal cache if not open yet.

        Args:
            file (Path | str): filepath to store

        Returns:
            Tuple[zarr.storage.LocalStore, zarr.Group]: tuple containing the opened store and its root group
        """
        if file not in self.stores:
            store = zarr.storage.LocalStore(file, read_only=True)
            rootgroup = zarr.open_group(store.root)
            self.stores[file] = (store, rootgroup)

        return self.stores[file]

    def close(self) -> None:
        "Close all open zarr stores."
        for store, _ in self.stores.values():
            store.close()
        self.stores.clear()

    def __del__(self):
        "Cleanup on deletion."
        self.close()

    def map_index(self, idx: int) -> Tuple[str | Path, int]:
        """Map a global index to a specific file and local index within that file.

        Args:
            idx (int): The global index to map.

        Raises:
            RuntimeError: If the index cannot be mapped to any file.

        Returns:
            Tuple[str | Path, int]: The file and local index corresponding to the global index.
        """
        original_index = idx
        final_file: Path | str | None = None

        for size, dfile in zip(self._num_samples_per_file, self.input):
            if idx < size:
                final_file = dfile
                break
            else:
                idx -= size

        if final_file is None:
            raise RuntimeError(
                f"Error, index {original_index} could not be found in the supplied data files of size {self._num_samples_per_file} with total size {self._num_samples}"
            )
        return final_file, idx

    def get(self, idx: int) -> Data:
        """Get a single data sample by index."""
        if self._num_samples is None:
            raise ValueError("Dataset has not been processed yet.")

        # Load the data from the processed files
        if self.preprocess:
            datapoint = torch.load(
                Path(self.processed_dir) / f"data_{idx}.pt", weights_only=False
            )
            if self.transform is not None:
                datapoint = self.transform(datapoint)
        else:
            # TODO: this is inefficient, but it's the only robust way I could find
            dfile, idx = self.map_index(idx)
            _, rootgroup = self._get_store_group(dfile)
            datapoint = self.data_reader(
                rootgroup,
                idx,
                self.float_type,
                self.int_type,
                self.validate_data,
            )

        return datapoint

    def __getitem__(
        self, idx: int | Sequence[int]
    ) -> Data | Sequence[Data] | Collection[Any]:
        """_summary_

        Args:
            idx (int | Sequence[int]): _description_

        Returns:
            Data | Sequence[Data] | Collection[Any]: _description_
        """
        if isinstance(idx, int):
            return self.get(idx)
        else:
            return [self.get(i) for i in idx]

    def len(self) -> int:
        """Get the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        if self.preprocess:
            return len(self.processed_file_names)
        else:
            return self._num_samples

__del__()

Cleanup on deletion.

Source code in src/QuantumGrav/dataset_ondisk.py
def __del__(self):
    "Cleanup on deletion."
    self.close()

__getitem__(idx)

summary

Parameters:

Name Type Description Default
idx int | Sequence[int]

description

required

Returns:

Type Description
Data | Sequence[Data] | Collection[Any]

Data | Sequence[Data] | Collection[Any]: description

Source code in src/QuantumGrav/dataset_ondisk.py
def __getitem__(
    self, idx: int | Sequence[int]
) -> Data | Sequence[Data] | Collection[Any]:
    """_summary_

    Args:
        idx (int | Sequence[int]): _description_

    Returns:
        Data | Sequence[Data] | Collection[Any]: _description_
    """
    if isinstance(idx, int):
        return self.get(idx)
    else:
        return [self.get(i) for i in idx]

__init__(input, output, reader=None, float_type=torch.float32, int_type=torch.int64, validate_data=True, chunksize=1000, n_processes=1, transform=None, pre_transform=None, pre_filter=None)

Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk. When there is no pre_transform and no pre_filter is given, the system will not create a processed directory.

Parameters:

Name Type Description Default
input list[str | Path] | Callable[[Any], dict]

List of input zarr file paths.

required
output str | Path

Output directory where processed data will be stored.

required
reader Callable[[Group, int], list[Data]] | None

Function to read data from the zarr file. Defaults to None.

None
float_type dtype

Data type for float tensors. Defaults to torch.float32.

float32
int_type dtype

Data type for int tensors. Defaults to torch.int64.

int64
validate_data bool

Whether to validate the data. Defaults to True.

True
chunksize int

Size of data chunks to process at once. Defaults to 1000.

1000
n_processes int

Number of processes to use for data loading. Defaults to 1.

1
transform Callable[[Data], Data] | None

Function to transform the data. Defaults to None.

None
pre_transform Callable[[Data], Data] | None

Function to pre-transform the data. Defaults to None.

None
pre_filter Callable[[Data], bool] | None

Function to pre-filter the data. Defaults to None.

None
Source code in src/QuantumGrav/dataset_ondisk.py
def __init__(
    self,
    input: list[str | Path],
    output: str | Path,
    reader: Callable[
        [zarr.Group, int, torch.dtype, torch.dtype, bool], Collection[Any]
    ]
    | None = None,
    float_type: torch.dtype = torch.float32,
    int_type: torch.dtype = torch.int64,
    validate_data: bool = True,
    chunksize: int = 1000,
    n_processes: int = 1,
    # dataset properties
    transform: Callable[[Data | Collection[Any]], Data] | None = None,
    pre_transform: Callable[[Data | Collection[Any]], Data] | None = None,
    pre_filter: Callable[[Data | Collection[Any]], bool] | None = None,
):
    """Create a new QGDataset instance. This class is designed to handle the loading, processing, and writing of QuantumGrav datasets that are stored on disk. When there is no pre_transform and no pre_filter is given, the system will not create a `processed` directory.

    Args:
        input (list[str  |  Path] | Callable[[Any], dict]): List of input zarr file paths.
        output (str | Path): Output directory where processed data will be stored.
        reader (Callable[[zarr.Group, int], list[Data]] | None, optional): Function to read data from the zarr file. Defaults to None.
        float_type (torch.dtype, optional): Data type for float tensors. Defaults to torch.float32.
        int_type (torch.dtype, optional): Data type for int tensors. Defaults to torch.int64.
        validate_data (bool, optional): Whether to validate the data. Defaults to True.
        chunksize (int, optional): Size of data chunks to process at once. Defaults to 1000.
        n_processes (int, optional): Number of processes to use for data loading. Defaults to 1.
        transform (Callable[[Data], Data] | None, optional): Function to transform the data. Defaults to None.
        pre_transform (Callable[[Data], Data] | None, optional): Function to pre-transform the data. Defaults to None.
        pre_filter (Callable[[Data], bool] | None, optional): Function to pre-filter the data. Defaults to None.
    """
    preprocess = pre_transform is not None or pre_filter is not None
    self.stores = {}

    QGDatasetBase.__init__(
        self,
        input,
        output,
        reader=reader,
        float_type=float_type,
        int_type=int_type,
        validate_data=validate_data,
        chunksize=chunksize,
        n_processes=n_processes,
        preprocess=preprocess,
    )

    Dataset.__init__(
        self,
        root=output,
        transform=transform,
        pre_transform=pre_transform,
        pre_filter=pre_filter,
    )

close()

Close all open zarr stores.

Source code in src/QuantumGrav/dataset_ondisk.py
def close(self) -> None:
    "Close all open zarr stores."
    for store, _ in self.stores.values():
        store.close()
    self.stores.clear()

get(idx)

Get a single data sample by index.

Source code in src/QuantumGrav/dataset_ondisk.py
def get(self, idx: int) -> Data:
    """Get a single data sample by index."""
    if self._num_samples is None:
        raise ValueError("Dataset has not been processed yet.")

    # Load the data from the processed files
    if self.preprocess:
        datapoint = torch.load(
            Path(self.processed_dir) / f"data_{idx}.pt", weights_only=False
        )
        if self.transform is not None:
            datapoint = self.transform(datapoint)
    else:
        # TODO: this is inefficient, but it's the only robust way I could find
        dfile, idx = self.map_index(idx)
        _, rootgroup = self._get_store_group(dfile)
        datapoint = self.data_reader(
            rootgroup,
            idx,
            self.float_type,
            self.int_type,
            self.validate_data,
        )

    return datapoint

len()

Get the number of samples in the dataset.

Returns:

Name Type Description
int int

The number of samples in the dataset.

Source code in src/QuantumGrav/dataset_ondisk.py
def len(self) -> int:
    """Get the number of samples in the dataset.

    Returns:
        int: The number of samples in the dataset.
    """
    if self.preprocess:
        return len(self.processed_file_names)
    else:
        return self._num_samples

map_index(idx)

Map a global index to a specific file and local index within that file.

Parameters:

Name Type Description Default
idx int

The global index to map.

required

Raises:

Type Description
RuntimeError

If the index cannot be mapped to any file.

Returns:

Type Description
Tuple[str | Path, int]

Tuple[str | Path, int]: The file and local index corresponding to the global index.

Source code in src/QuantumGrav/dataset_ondisk.py
def map_index(self, idx: int) -> Tuple[str | Path, int]:
    """Map a global index to a specific file and local index within that file.

    Args:
        idx (int): The global index to map.

    Raises:
        RuntimeError: If the index cannot be mapped to any file.

    Returns:
        Tuple[str | Path, int]: The file and local index corresponding to the global index.
    """
    original_index = idx
    final_file: Path | str | None = None

    for size, dfile in zip(self._num_samples_per_file, self.input):
        if idx < size:
            final_file = dfile
            break
        else:
            idx -= size

    if final_file is None:
        raise RuntimeError(
            f"Error, index {original_index} could not be found in the supplied data files of size {self._num_samples_per_file} with total size {self._num_samples}"
        )
    return final_file, idx

process()

Process the dataset from the read rawdata into its final form.

Source code in src/QuantumGrav/dataset_ondisk.py
def process(self) -> None:
    """Process the dataset from the read rawdata into its final form."""
    # process data files
    k = 0  # index to create the filenames for the processed data
    if self.pre_filter is None and self.pre_transform is None:
        return

    for file in self.input:
        N = self._get_num_samples_per_file(Path(file).resolve().absolute())

        num_chunks = N // self.chunksize

        raw_file = zarr.storage.LocalStore(
            str(Path(file).resolve().absolute()), read_only=True
        )

        for i in range(0, num_chunks * self.chunksize, self.chunksize):
            data = self.process_chunk(
                raw_file,
                i,
                pre_transform=self.pre_transform,
                pre_filter=self.pre_filter,
            )

            k = self.write_data(data, k)

        # final chunk processing
        data = self.process_chunk(
            raw_file,
            num_chunks * self.chunksize,
            pre_transform=self.pre_transform,
            pre_filter=self.pre_filter,
        )

        k = self.write_data(data, k)

        raw_file.close()

write_data(data, idx)

Write the processed data to disk using torch.save. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

Parameters:

Name Type Description Default
data list[Data]

The list of Data objects to write to disk.

required
idx int

The index to use for naming the files.

required
Source code in src/QuantumGrav/dataset_ondisk.py
def write_data(self, data: list[Data], idx: int) -> int:
    """Write the processed data to disk using `torch.save`. This is a default implementation that can be overridden by subclasses, and is intended to be used in the data loading pipeline. Thus, is not intended to be called directly.

    Args:
        data (list[Data]): The list of Data objects to write to disk.
        idx (int): The index to use for naming the files.
    """
    if not Path(self.processed_dir).exists():
        Path(self.processed_dir).mkdir(parents=True, exist_ok=True)

    for d in data:
        if d is not None:
            file_path = Path(self.processed_dir) / f"data_{idx}.pt"
            torch.save(d, file_path)
            idx += 1
    return idx

Julia-Python integration

This class provides a bridge to some user-supplied Julia code and converts its output into something Python can work with.

JuliaWorker

This class runs a given Julia callable object from a given Julia code file. It additionally imports the QuantumGrav julia module and installs given dependencies if provided. After creation, the wrapped julia callable can be called via the call method of this calls. Warning: This class requires the juliacall package to be installed in the Python environment. Warning: This class is in early development and may change in the future, be slow, or otherwise not ready for high performance production use.

Source code in src/QuantumGrav/julia_worker.py
class JuliaWorker:
    """This class runs a given Julia callable object from a given Julia code file. It additionally imports the QuantumGrav julia module and installs given dependencies if provided. After creation, the wrapped julia callable can be called via the __call__ method of this calls.
    **Warning**: This class requires the juliacall package to be installed in the Python environment.
    **Warning**: This class is in early development and may change in the future, be slow, or otherwise not ready for high performance production use.
    """

    jl_constructor_name = None

    def __init__(
        self,
        jl_kwargs: dict[str, Any] | None = None,
        jl_code_path: str | Path | None = None,
        jl_constructor_name: str | None = None,
        jl_base_module_path: str | Path | None = None,
        jl_dependencies: list[str] | None = None,
    ):
        """Initializes the JuliaWorker with the given parameters.

        Args:
            jl_kwargs (dict[str, Any] | None, optional): Keyword arguments to pass to the Julia callable object constructor. Defaults to None.
            jl_code_path (str | Path | None, optional): Path to the Julia code file in which the callable object is defined. Defaults to None.
            jl_constructor_name (str | None, optional): Name of the Julia constructor function. Defaults to None.
            jl_base_module_path (str | Path | None, optional): Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default `using QuantumGrav` import. Defaults to None.
            jl_dependencies (list[str] | None, optional): List of Julia package dependencies. Defaults to None. Will be installed via `Pkg.add` if provided upon first call.

        Raises:
            ValueError: If the Julia function name is not provided.
            ValueError: If the Julia code path is not provided.
            FileNotFoundError: If the Julia code path does not exist.
            NotImplementedError: If the base module path is not provided.
            RuntimeError: If there is an error loading the base module.
            RuntimeError: If there is an error loading Julia dependencies.
            RuntimeError: If there is an error loading the Julia code.
        """

        # we test for a bunch of needed args first
        if jl_constructor_name is None:
            raise ValueError("Julia function name must be provided.")

        if jl_code_path is None:
            raise ValueError("Julia code path must be provided.")

        jl_code_path = Path(jl_code_path).resolve().absolute()
        if not jl_code_path.exists():
            raise FileNotFoundError(f"Julia code path {jl_code_path} does not exist.")

        self.jl_constructor_name = jl_constructor_name
        self.jl_module_name = "QuantumGravPy2Jl"  # the module name is hardcoded here

        # try to initialize the new Julia module, then later do every julia call through this module
        try:
            self.jl_module = jcall.newmodule(self.jl_module_name)

        except jcall.JuliaError as e:
            raise RuntimeError(f"Error creating new julia module: {e}") from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while creating Julia module {self.jl_module_name}: {e}"
            ) from e

        # add base module for dependencies if exists
        if jl_base_module_path is not None:
            jl_base_module_path = Path(jl_base_module_path).resolve().absolute()
            try:
                self.jl_module.seval(
                    f'using Pkg; Pkg.develop(path="{str(jl_base_module_path)}")'
                )  # only for now -> get from package index later
            except jcall.JuliaError as e:
                raise RuntimeError(
                    f"Error loading base module {str(jl_base_module_path)}: {e}"
                ) from e
            except Exception as e:
                raise RuntimeError(
                    f"Unexpected exception while initializing julia base module: {e}"
                ) from e

        try:
            # add dependencies if provided
            if jl_dependencies is not None:
                for dep in jl_dependencies:
                    self.jl_module.seval(f'using Pkg; Pkg.add("{dep}")')
        except jcall.JuliaError as e:
            raise RuntimeError(f"Error processing Julia dependencies: {e}") from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while processing Julia dependencies: {e}"
            ) from e

        try:
            # load the julia data generation julia code
            self.jl_module.seval(f'push!(LOAD_PATH, "{jl_code_path}")')
            self.jl_module.seval("using QuantumGrav")
            self.jl_module.seval(f'include("{jl_code_path}")')
            constructor_name = getattr(self.jl_module, jl_constructor_name)
            self.jl_generator = constructor_name(jl_kwargs)
        except jcall.JuliaError as e:
            raise RuntimeError(
                f"Error evaluating Julia code to activate base module: {e}"
            ) from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while loading Julia base module: {e}"
            ) from e

    def __call__(self, *args, **kwargs) -> Any:
        """Calls the wrapped Julia generator with the given arguments.

        Raises:
            RuntimeError: If the Julia module is not initialized.
        Args:
            *args: Positional arguments to pass to the Julia generator.
            **kwargs: Keyword arguments to pass to the Julia generator.
        Returns:
            Any: The raw data generated by the Julia generator.
        """
        if self.jl_module is None:
            raise RuntimeError("Julia module is not initialized.")
        raw_data = self.jl_generator(*args, **kwargs)
        return raw_data

__call__(*args, **kwargs)

Calls the wrapped Julia generator with the given arguments.

Raises:

Type Description
RuntimeError

If the Julia module is not initialized.

Args: args: Positional arguments to pass to the Julia generator. *kwargs: Keyword arguments to pass to the Julia generator. Returns: Any: The raw data generated by the Julia generator.

Source code in src/QuantumGrav/julia_worker.py
def __call__(self, *args, **kwargs) -> Any:
    """Calls the wrapped Julia generator with the given arguments.

    Raises:
        RuntimeError: If the Julia module is not initialized.
    Args:
        *args: Positional arguments to pass to the Julia generator.
        **kwargs: Keyword arguments to pass to the Julia generator.
    Returns:
        Any: The raw data generated by the Julia generator.
    """
    if self.jl_module is None:
        raise RuntimeError("Julia module is not initialized.")
    raw_data = self.jl_generator(*args, **kwargs)
    return raw_data

__init__(jl_kwargs=None, jl_code_path=None, jl_constructor_name=None, jl_base_module_path=None, jl_dependencies=None)

Initializes the JuliaWorker with the given parameters.

Parameters:

Name Type Description Default
jl_kwargs dict[str, Any] | None

Keyword arguments to pass to the Julia callable object constructor. Defaults to None.

None
jl_code_path str | Path | None

Path to the Julia code file in which the callable object is defined. Defaults to None.

None
jl_constructor_name str | None

Name of the Julia constructor function. Defaults to None.

None
jl_base_module_path str | Path | None

Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default using QuantumGrav import. Defaults to None.

None
jl_dependencies list[str] | None

List of Julia package dependencies. Defaults to None. Will be installed via Pkg.add if provided upon first call.

None

Raises:

Type Description
ValueError

If the Julia function name is not provided.

ValueError

If the Julia code path is not provided.

FileNotFoundError

If the Julia code path does not exist.

NotImplementedError

If the base module path is not provided.

RuntimeError

If there is an error loading the base module.

RuntimeError

If there is an error loading Julia dependencies.

RuntimeError

If there is an error loading the Julia code.

Source code in src/QuantumGrav/julia_worker.py
def __init__(
    self,
    jl_kwargs: dict[str, Any] | None = None,
    jl_code_path: str | Path | None = None,
    jl_constructor_name: str | None = None,
    jl_base_module_path: str | Path | None = None,
    jl_dependencies: list[str] | None = None,
):
    """Initializes the JuliaWorker with the given parameters.

    Args:
        jl_kwargs (dict[str, Any] | None, optional): Keyword arguments to pass to the Julia callable object constructor. Defaults to None.
        jl_code_path (str | Path | None, optional): Path to the Julia code file in which the callable object is defined. Defaults to None.
        jl_constructor_name (str | None, optional): Name of the Julia constructor function. Defaults to None.
        jl_base_module_path (str | Path | None, optional): Path to the base Julia module 'QuantumGrav.jl'. If not given, tries to load it via a default `using QuantumGrav` import. Defaults to None.
        jl_dependencies (list[str] | None, optional): List of Julia package dependencies. Defaults to None. Will be installed via `Pkg.add` if provided upon first call.

    Raises:
        ValueError: If the Julia function name is not provided.
        ValueError: If the Julia code path is not provided.
        FileNotFoundError: If the Julia code path does not exist.
        NotImplementedError: If the base module path is not provided.
        RuntimeError: If there is an error loading the base module.
        RuntimeError: If there is an error loading Julia dependencies.
        RuntimeError: If there is an error loading the Julia code.
    """

    # we test for a bunch of needed args first
    if jl_constructor_name is None:
        raise ValueError("Julia function name must be provided.")

    if jl_code_path is None:
        raise ValueError("Julia code path must be provided.")

    jl_code_path = Path(jl_code_path).resolve().absolute()
    if not jl_code_path.exists():
        raise FileNotFoundError(f"Julia code path {jl_code_path} does not exist.")

    self.jl_constructor_name = jl_constructor_name
    self.jl_module_name = "QuantumGravPy2Jl"  # the module name is hardcoded here

    # try to initialize the new Julia module, then later do every julia call through this module
    try:
        self.jl_module = jcall.newmodule(self.jl_module_name)

    except jcall.JuliaError as e:
        raise RuntimeError(f"Error creating new julia module: {e}") from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while creating Julia module {self.jl_module_name}: {e}"
        ) from e

    # add base module for dependencies if exists
    if jl_base_module_path is not None:
        jl_base_module_path = Path(jl_base_module_path).resolve().absolute()
        try:
            self.jl_module.seval(
                f'using Pkg; Pkg.develop(path="{str(jl_base_module_path)}")'
            )  # only for now -> get from package index later
        except jcall.JuliaError as e:
            raise RuntimeError(
                f"Error loading base module {str(jl_base_module_path)}: {e}"
            ) from e
        except Exception as e:
            raise RuntimeError(
                f"Unexpected exception while initializing julia base module: {e}"
            ) from e

    try:
        # add dependencies if provided
        if jl_dependencies is not None:
            for dep in jl_dependencies:
                self.jl_module.seval(f'using Pkg; Pkg.add("{dep}")')
    except jcall.JuliaError as e:
        raise RuntimeError(f"Error processing Julia dependencies: {e}") from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while processing Julia dependencies: {e}"
        ) from e

    try:
        # load the julia data generation julia code
        self.jl_module.seval(f'push!(LOAD_PATH, "{jl_code_path}")')
        self.jl_module.seval("using QuantumGrav")
        self.jl_module.seval(f'include("{jl_code_path}")')
        constructor_name = getattr(self.jl_module, jl_constructor_name)
        self.jl_generator = constructor_name(jl_kwargs)
    except jcall.JuliaError as e:
        raise RuntimeError(
            f"Error evaluating Julia code to activate base module: {e}"
        ) from e
    except Exception as e:
        raise RuntimeError(
            f"Unexpected exception while loading Julia base module: {e}"
        ) from e

Model training

This consists of two classes, one which provides the basic training functionality - Trainer, and a class derived from this, TrainerDDP, which provides functionality for distributed data parallel training.

Trainer

This class provides wrapper functions for setting up a model and for training and evaluating it. The basic concept is that everything is defined in a yaml file and handed to this class together with evaluator classes. After construction, the train and test functions will take care of the training and testing of the model.

Trainer

Bases: Configurable

Trainer class for training and evaluating GNN models.

Source code in src/QuantumGrav/train.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
class Trainer(base.Configurable):
    """Trainer class for training and evaluating GNN models."""

    schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "title": "Model trainer class Configuration",
        "type": "object",
        "definitions": {
            "constructor": {
                "type": "object",
                "description": "Python constructor spec: type(*args, **kwargs)",
                "properties": {
                    "type": {
                        "description": "Fully-qualified import path or name of the type/callable to initialize",
                    },
                    "args": {
                        "type": "array",
                        "description": "Positional arguments for constructor",
                        "items": {},
                    },
                    "kwargs": {
                        "type": "object",
                        "description": "Keyword arguments for constructor",
                        "additionalProperties": {},
                    },
                },
                "required": ["type"],
                "additionalProperties": False,
            }
        },
        "properties": {
            "name": {
                "type": "string",
                "description": "Name of the training run",
            },
            "log_level": {
                "description": "Optional logging level (int or string, e.g. INFO)",
                "anyOf": [{"type": "integer"}, {"type": "string"}],
            },
            "training": {
                "type": "object",
                "description": "Training configuration",
                "properties": {
                    "seed": {"type": "integer", "description": "Random seed"},
                    "device": {
                        "type": "string",
                        "description": "Torch device string, e.g. 'cpu', 'cuda', 'cuda:0'",
                    },
                    "path": {
                        "type": "string",
                        "description": "Output directory for run artifacts and checkpoints",
                    },
                    "num_epochs": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Number of training epochs",
                    },
                    "batch_size": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Training DataLoader batch size",
                    },
                    "optimizer_type": {
                        "description": "Optimizer type name, e.g. 'torch.optim.Adam' or 'torch.optim.SGD'",
                    },
                    "optimizer_args": {
                        "type": "array",
                        "description": "Arguments for optimizer",
                        "items": {},
                    },
                    "optimizer_kwargs": {
                        "type": "object",
                        "description": "Optimizer keyword arguments",
                        "additionalProperties": {},
                    },
                    "lr_scheduler_type": {
                        "description": "type of the learning rate scheduler",
                    },
                    "lr_scheduler_args": {
                        "type": "array",
                        "description": "arguments to construct the learning rate scheduler",
                        "items": {},
                    },
                    "lr_scheduler_kwargs": {
                        "type": "object",
                        "description": "keyword arguments for the construction of learning rate scheduler",
                        "additionalProperties": {},
                    },
                    "num_workers": {
                        "type": "integer",
                        "minimum": 0,
                        "description": "DataLoader workers for training",
                    },
                    "pin_memory": {
                        "type": "boolean",
                        "description": "Pin GPU memory in DataLoader",
                    },
                    "drop_last": {
                        "type": "boolean",
                        "description": "Drop last incomplete batch",
                    },
                    "prefetch_factor": {
                        "type": ["integer", "null"],
                        "minimum": 2,
                        "description": "Prefetch samples per worker (None or >=2)",
                    },
                    "shuffle": {
                        "type": "boolean",
                        "description": "Shuffle training dataset",
                    },
                    "checkpoint_at": {
                        "type": ["integer", "null"],
                        "minimum": 1,
                        "description": "Checkpoint every N epochs (or None to disable)",
                    },
                },
                "required": [
                    "seed",
                    "device",
                    "path",
                    "num_epochs",
                    "batch_size",
                    "optimizer_type",
                    "optimizer_args",
                    "optimizer_kwargs",
                    "num_workers",
                    "drop_last",
                    "checkpoint_at",
                ],
                "additionalProperties": True,
            },
            "data": {
                "type": "object",
                "description": "Dataset configuration",
                "properties": {
                    "pre_transform": {
                        "description": "Name of the python object to use for the pre-transform function to use. Must refer to a callable"
                    },
                    "transform": {
                        "description": "Name of the python object to use for the transform function to use. Must refer to a callable"
                    },
                    "pre_filter": {
                        "description": "Name of the python object to use for the pre_filter function to use. Must refer to a callable"
                    },
                    "reader": {
                        "description": "Name  of the python object to read raw data from file. Must be callable",
                    },
                    "files": {
                        "type": "array",
                        "description": "list of zarr stores to get data from",
                        "minItems": 1,
                        "items": {
                            "type": "string",
                            "description": "zarr file names to read data from",
                        },
                    },
                    "output": {
                        "type": "string",
                        "description": "path to store preprocessed data at.",
                    },
                    "validate_data": {
                        "type": "boolean",
                        "description": "Whether to validate the transformed data objects or not",
                    },
                    "n_processes": {
                        "type": "integer",
                        "description": "number of processes to use for preprocessing the dataset",
                        "minimum": 0,
                    },
                    "chunksize": {
                        "type": "integer",
                        "description": "Number of datapoints to process at once during preprocessing",
                    },
                    "shuffle": {
                        "type": "boolean",
                        "description": "Whether to shuffle the dataset or not",
                    },
                    "subset": {
                        "type": "number",
                        "description": "Fraction of the dataset to use. Full dataset is used when not given",
                    },
                    "split": {
                        "type": "array",
                        "description": "Split ratios of the dataset",
                        "items": {
                            "type": "number",
                            "minItems": 3,
                            "maxItems": 3,
                        },
                    },
                },
                "required": ["output", "files", "reader"],
                "additionalProperties": False,
            },
            "model": {
                "description": "Model config: either constructor triple or full GNNModel schema",
                "anyOf": [
                    {"$ref": "#/definitions/constructor"},
                    gnn_model.GNNModel.schema,
                ],
            },
            "validation": {
                "type": "object",
                "description": "Model validation configuration",
                "properties": {
                    "batch_size": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Validation DataLoader batch size",
                    },
                    "num_workers": {
                        "type": "integer",
                        "minimum": 0,
                        "description": "DataLoader workers",
                    },
                    "pin_memory": {
                        "type": "boolean",
                        "description": "Pin GPU memory in DataLoader",
                    },
                    "drop_last": {
                        "type": "boolean",
                        "description": "Drop last incomplete batch",
                    },
                    "prefetch_factor": {
                        "type": ["integer", "null"],
                        "minimum": 2,
                        "description": "Prefetch samples per worker (None or >=2)",
                    },
                    "shuffle": {
                        "type": "boolean",
                        "description": "Shuffle validation dataset",
                    },
                    "validator": {
                        "$ref": "#/definitions/constructor",
                        "description": "Validator constructor spec: provides type, args, kwargs",
                    },
                },
                "required": ["batch_size"],
                "additionalProperties": True,
            },
            "testing": {
                "type": "object",
                "description": "Configuration for model testing after training",
                "properties": {
                    "batch_size": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Test DataLoader batch size",
                    },
                    "num_workers": {
                        "type": "integer",
                        "minimum": 0,
                        "description": "DataLoader workers",
                    },
                    "pin_memory": {
                        "type": "boolean",
                        "description": "Pin GPU memory in DataLoader",
                    },
                    "drop_last": {
                        "type": "boolean",
                        "description": "Drop last incomplete batch",
                    },
                    "prefetch_factor": {
                        "type": ["integer", "null"],
                        "minimum": 2,
                        "description": "Prefetch samples per worker (None or >=2)",
                    },
                    "shuffle": {
                        "type": "boolean",
                        "description": "Shuffle test dataset",
                    },
                    "tester": {
                        "$ref": "#/definitions/constructor",
                        "description": "Tester constructor spec: provides type, args, kwargs",
                    },
                },
                "required": ["batch_size"],
                "additionalProperties": True,
            },
            "early_stopping": {
                "$ref": "#/definitions/constructor",
                "description": "Early stopping constructor spec: provides type, args, kwargs",
            },
            "apply_model": {
                "description": "Optional method to call the model on data. Useful when using optional signatures for instance "
            },
            "criterion": {
                "description": "The loss function used for training as a python type"
            },
        },
        "required": [
            "training",
            "model",
            "validation",
            "testing",
            "criterion",
        ],
        "additionalProperties": True,
    }

    def __init__(
        self,
        config: Dict[str, Any],
        # training and evaluation functions
    ):
        """Initialize the trainer.

        Args:
            config (dict[str, Any]): The configuration dictionary.

        Raises:
            ValueError: If the configuration is invalid.
        """

        jsonschema.validate(instance=config, schema=self.schema)
        self.config = config
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(config.get("log_level", logging.INFO))
        self.logger.info("Initializing Trainer instance")

        # functions for executing training and evaluation
        self.criterion = config["criterion"]
        self.apply_model = config.get("apply_model")
        self.seed = config["training"]["seed"]
        self.device = torch.device(config["training"]["device"])

        self.nprng = np.random.default_rng(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.seed)

        # parameters for finding out which model is best
        self.best_score = None
        self.best_epoch = 0
        self.epoch = 0

        # date and time of run:
        run_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        self.data_path = (
            Path(self.config["training"]["path"])
            / f"{config.get('name', 'run')}_{run_date}"
        )

        # set up paths for storing model snapshots and data
        if not self.data_path.exists():
            self.data_path.mkdir(parents=True)
        self.logger.info(f"Data path set to: {self.data_path}")

        self.checkpoint_path = self.data_path / "model_checkpoints"
        self.checkpoint_at = config["training"].get("checkpoint_at", None)
        self.latest_checkpoint = None

        # model and optimizer initialization placeholders
        self.model = None
        self.optimizer = None

        # early stopping and evaluation functors
        try:
            self.early_stopping = early_stopping.DefaultEarlyStopping.from_config(
                config["early_stopping"]
            )
        except Exception as e:
            self.logger.debug(
                f"from_config failed for early stopping, using direct instantiation: {e}"
            )
            self.early_stopping = config["early_stopping"]["type"](
                *config["early_stopping"]["args"], **config["early_stopping"]["kwargs"]
            )

        try:
            self.validator = evaluate.DefaultValidator.from_config(
                config["validation"]["validator"]
            )
        except Exception as e:
            self.logger.debug(
                f"from_config failed for validator, using direct instantiation: {e}"
            )
            self.validator = config["validation"]["validator"]["type"](
                *config["validation"]["validator"]["args"],
                **config["validation"]["validator"]["kwargs"],
            )

        try:
            self.tester = evaluate.DefaultTester.from_config(
                config["testing"]["tester"]
            )
        except Exception as e:
            self.logger.debug(
                f"from_config failed for tester, using direct instantiation: {e}"
            )
            self.tester = config["testing"]["tester"]["type"](
                *config["testing"]["tester"]["args"],
                **config["testing"]["tester"]["kwargs"],
            )

        with open(self.data_path / "config.yaml", "w") as f:
            yaml.dump(self.config, f)

        self.logger.info("Trainer initialized")
        self.logger.debug(f"Configuration: {self.config}")

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "Trainer":
        """Create a Trainer instance from a configuration dictionary.

        Args:
            config (Dict[str, Any]): The configuration dictionary.
        """
        return cls(
            config=config,
        )

    def initialize_model(self) -> Any:
        """Initialize the model for training.

        Returns:
            Any: The initialized model.
        """
        if hasattr(self, "model") and self.model is not None:
            self.logger.warning(
                "Model is already initialized. This will replace it with a new instance"
            )

        try:
            self.model = gnn_model.GNNModel.from_config(self.config["model"]).to(
                self.device
            )

        except Exception:
            self.logger.debug(
                "from_config for  model initialization failed, using direct initialization instead"
            )
            self.model = self.config["model"]["type"](
                *self.config["model"]["args"], **self.config["model"]["kwargs"]
            ).to(self.device)

        self.logger.info("Model initialized to device: {}".format(self.device))
        return self.model

    def initialize_lr_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler | None:
        """Initialize the learning rate scheduler for training.

        Raises:
            RuntimeError: If the optimizer is not initialized.

        Returns:
            torch.optim.lr_scheduler._LRScheduler: The initialized learning rate scheduler.
        """
        if self.config["training"].get("lr_scheduler_type") is None:
            self.logger.info("No learning rate scheduler specified in config.")
            return None
        else:
            if not hasattr(self, "optimizer") or self.optimizer is None:
                raise RuntimeError(
                    "Optimizer must be initialized before initializing learning rate scheduler."
                )

            try:
                self.lr_scheduler = self.config["training"].get("lr_scheduler_type")(
                    self.optimizer,
                    *self.config["training"].get("lr_scheduler_args", []),
                    **self.config["training"].get("lr_scheduler_kwargs", {}),
                )
                self.logger.info("Learning rate scheduler initialized.")
                return self.lr_scheduler
            except Exception as e:
                self.logger.error(f"Error initializing learning rate scheduler: {e}")
                raise e

    def initialize_optimizer(self) -> torch.optim.Optimizer | None:
        """Initialize the optimizer for training.

        Raises:
            RuntimeError: If the model is not initialized.

        Returns:
            torch.optim.Optimizer: The initialized optimizer.
        """

        if not hasattr(self, "model") or self.model is None:
            raise RuntimeError(
                "Model must be initialized before initializing optimizer."
            )

        if hasattr(self, "optimizer") and self.optimizer is not None:
            self.logger.warning(
                "Optimizer is already initialized. This will replace it with a new instance"
            )

        try:
            optimizer = self.config["training"].get("optimizer_type", torch.optim.Adam)(
                self.model.parameters(),
                *self.config["training"].get("optimizer_args", []),
                **self.config["training"].get("optimizer_kwargs", {}),
            )
            self.optimizer = optimizer
            self.logger.info("Optimizer initialized")
        except Exception as e:
            self.logger.error(f"Error initializing optimizer: {e}")
            raise e
        return self.optimizer

    def prepare_dataset(
        self,
        dataset: Dataset | None = None,
        split: list[float] = [0.8, 0.1, 0.1],
        train_dataset: torch.utils.data.Subset | None = None,
        val_dataset: torch.utils.data.Subset | None = None,
        test_dataset: torch.utils.data.Subset | None = None,
    ) -> Tuple[Dataset, Dataset, Dataset]:
        """Set up the split for training, validation, and testing datasets.

        Args:
            dataset (Dataset | None, optional): Dataset to be split. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
            split (list[float], optional): split ratios for train, validation, and test datasets. Defaults to [0.8, 0.1, 0.1].
            train_dataset (torch.utils.data.Subset | None, optional): Training subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
            val_dataset (torch.utils.data.Subset | None, optional): Validation subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
            test_dataset (torch.utils.data.Subset | None, optional): Testing subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

        Raises:
            ValueError: If providing train, val, or test datasets, the full dataset must not be provided.
            ValueError: If split ratios are not summing up to 1
            ValueError: If train size is 0
            ValueError: If validation size is 0
            ValueError: If test size is 0

        Returns:
            Tuple[Dataset, Dataset, Dataset]: train, validation, and test datasets.
        """
        if dataset is not None and (
            train_dataset is not None
            or val_dataset is not None
            or test_dataset is not None
        ):
            raise ValueError(
                "If providing train, val, or test datasets, the full dataset must not be provided."
            )

        if dataset is None:
            cfg = self.config["data"]
            dataset = dataset_ondisk.QGDataset(
                cfg["files"],
                cfg["output"],
                cfg["reader"],
                float_type=cfg.get("float_type", torch.float32),
                int_type=cfg.get("int_type", torch.int32),
                validate_data=cfg.get("validate_data", True),
                chunksize=cfg.get("chunksize", 1),
                n_processes=cfg.get("n_processes", 1),
                transform=cfg.get("transform"),
                pre_transform=cfg.get("pre_transform"),
                pre_filter=cfg.get("pre_filter"),
            )

            if cfg.get("subset"):
                num_points = ceil(len(dataset) * cfg["subset"])
                dataset = dataset.index_select(
                    self.nprng.integers(0, len(dataset), size=num_points).tolist()
                )

            if cfg.get("shuffle"):
                dataset.shuffle()

        if train_dataset is None and val_dataset is None and test_dataset is None:
            split = self.config.get("data", {}).get("split", split)
            if not np.isclose(
                np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False
            ):
                raise ValueError(
                    f"Split ratios must sum to 1.0. Provided split: {split}"
                )

            train_size = ceil(len(dataset) * split[0])
            val_size = floor(len(dataset) * split[1])
            test_size = len(dataset) - train_size - val_size

            if train_size == 0:
                raise ValueError("train size cannot be 0")

            if val_size == 0:
                raise ValueError("validation size cannot be 0")

            if test_size == 0:
                raise ValueError("test size cannot be 0")

            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
                dataset, [train_size, val_size, test_size]
            )

        return train_dataset, val_dataset, test_dataset

    def prepare_dataloaders(
        self,
        dataset: Dataset | None = None,
        split: list[float] = [0.8, 0.1, 0.1],
        train_dataset: torch.utils.data.Subset | None = None,
        val_dataset: torch.utils.data.Subset | None = None,
        test_dataset: torch.utils.data.Subset | None = None,
        training_sampler: torch.utils.data.Sampler | None = None,
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Prepare the data loaders for training, validation, and testing.

        Args:
            dataset (Dataset): The dataset to prepare.
            split (list[float], optional): The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].
            training_sampler (torch.utils.data.Sampler, optional): The sampler for the training data loader. Defaults to None.

        Returns:
            Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.
        """
        self.train_dataset, self.val_dataset, self.test_dataset = self.prepare_dataset(
            dataset=dataset,
            split=split,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
        )
        train_loader = DataLoader(
            self.train_dataset,  # type: ignore
            batch_size=self.config["training"]["batch_size"],
            num_workers=self.config["training"].get("num_workers", 0),
            pin_memory=self.config["training"].get("pin_memory", True),
            drop_last=self.config["training"].get("drop_last", False),
            prefetch_factor=self.config["training"].get("prefetch_factor", None),
            shuffle=self.config["training"].get("shuffle", True),
            sampler=training_sampler,
        )

        val_loader = DataLoader(
            self.val_dataset,  # type: ignore
            batch_size=self.config["validation"]["batch_size"],
            num_workers=self.config["validation"].get("num_workers", 0),
            pin_memory=self.config["validation"].get("pin_memory", True),
            drop_last=self.config["validation"].get("drop_last", False),
            prefetch_factor=self.config["validation"].get("prefetch_factor", None),
            shuffle=self.config["validation"].get("shuffle", True),
        )

        test_loader = DataLoader(
            self.test_dataset,  # type: ignore
            batch_size=self.config["testing"]["batch_size"],
            num_workers=self.config["testing"].get("num_workers", 0),
            pin_memory=self.config["testing"].get("pin_memory", True),
            drop_last=self.config["testing"].get("drop_last", False),
            prefetch_factor=self.config["testing"].get("prefetch_factor", None),
            shuffle=self.config["testing"].get("shuffle", True),
        )

        if dataset is not None:
            self.logger.info(
                f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
            )
        return train_loader, val_loader, test_loader

    # training helper functions
    def _evaluate_batch(
        self,
        model: torch.nn.Module,
        data: Data,
    ) -> torch.Tensor | Collection[torch.Tensor]:
        """Evaluate a single batch of data using the model.

        Args:
            model (torch.nn.Module): The model to evaluate.
            data (Data): The input data for the model.

        Returns:
            torch.Tensor | Collection[torch.Tensor]: The output of the model.
        """
        self.logger.debug(f"  Evaluating batch on device: {self.device}")
        if self.apply_model:
            outputs = self.apply_model(model, data)
        else:
            outputs = model(data.x, data.edge_index, data.batch)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return outputs

    def _run_train_epoch(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        train_loader: DataLoader,
    ) -> torch.Tensor:
        """Run a single training epoch.

        Args:
            model (torch.nn.Module): The model to train.
            optimizer (torch.optim.Optimizer): The optimizer for the model.
            train_loader (DataLoader): The data loader for the training set.
        Raises:
            RuntimeError: If the model is not initialized.
            RuntimeError: If the optimizer is not initialized.

        Returns:
            torch.Tensor: The training loss for each batch stored in a torch.Tensor
        """

        if model is None:
            raise RuntimeError("Model must be initialized before training.")

        if optimizer is None:
            raise RuntimeError("Optimizer must be initialized before training.")

        losses = torch.zeros(len(train_loader), dtype=torch.float32, device=self.device)
        self.logger.info(f"  Starting training epoch {self.epoch}")
        # training run
        for i, batch in enumerate(
            tqdm.tqdm(train_loader, desc=f"Training Epoch {self.epoch}")
        ):
            self.logger.debug(f"    Moving batch {i} to device: {self.device}")
            optimizer.zero_grad()

            data = batch.to(self.device)
            outputs = self._evaluate_batch(model, data)

            self.logger.debug("    Computing loss")
            loss = self.criterion(outputs, data, self)

            self.logger.debug(f"    Backpropagating loss: {loss.item()}")
            loss.backward()

            optimizer.step()

            losses[i] = loss

        if hasattr(self, "lr_scheduler") and self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return losses

    def _check_model_status(self, eval_data: pd.DataFrame) -> bool:
        """Check the status of the model during training.

        Args:
            eval_data (pd.DataFrame): The evaluation data from the training epoch.

        Returns:
            bool: Whether the training should stop early.
        """
        if self.model is None:
            raise ValueError("Model must be initialized before saving checkpoints")

        if (
            self.checkpoint_at is not None
            and self.epoch % self.checkpoint_at == 0
            and self.epoch > 0
        ):
            self.save_checkpoint()

        if self.early_stopping is not None:
            if self.early_stopping(eval_data):
                self.logger.debug(f"Early stopping at epoch {self.epoch}.")
                self.save_checkpoint(name_addition=f"_{self.epoch}_early_stopping")
                return True

            if self.early_stopping.found_better_model:
                self.logger.debug(f"Found better model at epoch {self.epoch}.")
                self.save_checkpoint(name_addition=f"_{self.epoch}_current_best")
                # not returning true because this is not the end of training

        return False

    def run_training(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        trial: optuna.trial.Trial | None = None,
    ) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
        """Run the training process.

        Args:
            train_loader (DataLoader): The data loader for the training set.
            val_loader (DataLoader): The data loader for the validation set.
            trial (optuna.trial.Trial | None, optional): An Optuna trial
                for hyperparameter tuning. Defaults to None.

        Returns:
            Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
        """
        self.logger.info("Starting training process.")
        # training loop
        num_epochs = self.config["training"]["num_epochs"]

        self.initialize_model()

        self.initialize_optimizer()

        self.initialize_lr_scheduler()

        total_training_data = torch.zeros(num_epochs, 2, dtype=torch.float32)

        for epoch in range(0, num_epochs):
            self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
            self.model.train()

            epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)

            # collect mean and std for each epoch
            total_training_data[epoch, :] = torch.Tensor(
                [epoch_data.mean(dim=0).item(), epoch_data.std(dim=0).item()]
            )

            self.logger.info(
                f"  Completed epoch {epoch}. training loss: {total_training_data[epoch, 0]:.8f} +/- {total_training_data[epoch, 1]:.8f}."
            )

            # evaluation run on validation set
            if self.validator is not None:
                validation_result = self.validator.validate(self.model, val_loader)
                self.validator.report(validation_result)

                # integrate Optuna here for hyperparameter tuning
                if trial is not None:
                    avg_sigma_loss = self.validator.data[self.epoch]
                    avg_loss = avg_sigma_loss[0]
                    trial.report(avg_loss, self.epoch)

                    # Handle pruning based on the intermediate value.
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

            should_stop = self._check_model_status(
                self.validator.data if self.validator else total_training_data,
            )
            if should_stop:
                self.logger.info("Stopping training early.")
                break
            self.epoch += 1

        self.logger.info("Training process completed.")
        self.logger.info("Saving model")

        outpath = self.data_path / f"final_model_epoch={self.epoch}.pt"
        self.model.save(outpath)

        return total_training_data, self.validator.data if self.validator else []

    def run_test(
        self, test_loader: DataLoader, model_name_addition: str = "current_best.pt"
    ) -> Collection[Any]:
        """Run testing phase.

        Args:
            test_loader (DataLoader): The data loader for the test set.
            model_name_addition (str): An optional string to append to the checkpoint filename.
        Raises:
            RuntimeError: If the model is not initialized.
            RuntimeError: If the test data is not available.

        Returns:
            Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.
        """
        self.logger.info("Starting testing process.")
        # get the best model again

        saved_models = [
            f
            for f in Path(self.checkpoint_path).iterdir()
            if f.is_file() and model_name_addition in str(f)
        ]

        if len(saved_models) == 0:
            raise RuntimeError(
                f"No model with the name addition '{model_name_addition}' found, did training work?"
            )

        # get the latest of the best models
        best_of_the_best = max(saved_models, key=lambda f: f.stat().st_mtime)

        self.logger.info(f"loading best model found: {str(best_of_the_best)}")

        self.model = gnn_model.GNNModel.load(
            best_of_the_best, self.config["model"], device=self.device
        )
        self.model.eval()
        if self.tester is None:
            raise RuntimeError("Tester must be initialized before testing.")
        test_result = self.tester.test(self.model, test_loader)
        self.tester.report(test_result)
        self.logger.info("Testing process completed.")
        self.save_checkpoint(name_addition="best_model_found")
        return self.tester.data

    def save_checkpoint(self, name_addition: str = ""):
        """Save model checkpoint.

        Raises:
            ValueError: If the model is not initialized.
            ValueError: If the model configuration does not contain 'name'.
            ValueError: If the training configuration does not contain 'checkpoint_path'.
        """
        self.logger.info(
            f"Saving checkpoint for model at epoch {self.epoch} to {self.checkpoint_path}"
        )
        outpath = self.checkpoint_path / f"model_{name_addition}.pt"

        if outpath.exists() is False:
            outpath.parent.mkdir(parents=True, exist_ok=True)
            self.logger.debug(f"Created directory {outpath.parent} for checkpoint.")

        self.latest_checkpoint = outpath
        self.model.save(outpath)

    def load_checkpoint(self, name_addition: str = "") -> None:
        """Load model checkpoint to the device given

        Args:
            name_addition (str): An optional string to append to the checkpoint filename.

        Raises:
            RuntimeError: If the model is not initialized.
        """

        if self.model is None:
            raise RuntimeError("Model must be initialized before loading checkpoint.")

        if Path(self.checkpoint_path).exists() is False:
            raise RuntimeError("Checkpoint path does not exist.")

        self.logger.info(
            "available checkpoints: %s", list(Path(self.checkpoint_path).iterdir())
        )

        loadpath = Path(self.checkpoint_path) / f"model_{name_addition}.pt"

        if not loadpath.exists():
            raise FileNotFoundError(f"Checkpoint file {loadpath} does not exist.")

        self.model = gnn_model.GNNModel.load(
            loadpath,
            self.config["model"],
        )

__init__(config)

Initialize the trainer.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary.

required

Raises:

Type Description
ValueError

If the configuration is invalid.

Source code in src/QuantumGrav/train.py
def __init__(
    self,
    config: Dict[str, Any],
    # training and evaluation functions
):
    """Initialize the trainer.

    Args:
        config (dict[str, Any]): The configuration dictionary.

    Raises:
        ValueError: If the configuration is invalid.
    """

    jsonschema.validate(instance=config, schema=self.schema)
    self.config = config
    self.logger = logging.getLogger(__name__)
    self.logger.setLevel(config.get("log_level", logging.INFO))
    self.logger.info("Initializing Trainer instance")

    # functions for executing training and evaluation
    self.criterion = config["criterion"]
    self.apply_model = config.get("apply_model")
    self.seed = config["training"]["seed"]
    self.device = torch.device(config["training"]["device"])

    self.nprng = np.random.default_rng(self.seed)
    torch.manual_seed(self.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(self.seed)

    # parameters for finding out which model is best
    self.best_score = None
    self.best_epoch = 0
    self.epoch = 0

    # date and time of run:
    run_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    self.data_path = (
        Path(self.config["training"]["path"])
        / f"{config.get('name', 'run')}_{run_date}"
    )

    # set up paths for storing model snapshots and data
    if not self.data_path.exists():
        self.data_path.mkdir(parents=True)
    self.logger.info(f"Data path set to: {self.data_path}")

    self.checkpoint_path = self.data_path / "model_checkpoints"
    self.checkpoint_at = config["training"].get("checkpoint_at", None)
    self.latest_checkpoint = None

    # model and optimizer initialization placeholders
    self.model = None
    self.optimizer = None

    # early stopping and evaluation functors
    try:
        self.early_stopping = early_stopping.DefaultEarlyStopping.from_config(
            config["early_stopping"]
        )
    except Exception as e:
        self.logger.debug(
            f"from_config failed for early stopping, using direct instantiation: {e}"
        )
        self.early_stopping = config["early_stopping"]["type"](
            *config["early_stopping"]["args"], **config["early_stopping"]["kwargs"]
        )

    try:
        self.validator = evaluate.DefaultValidator.from_config(
            config["validation"]["validator"]
        )
    except Exception as e:
        self.logger.debug(
            f"from_config failed for validator, using direct instantiation: {e}"
        )
        self.validator = config["validation"]["validator"]["type"](
            *config["validation"]["validator"]["args"],
            **config["validation"]["validator"]["kwargs"],
        )

    try:
        self.tester = evaluate.DefaultTester.from_config(
            config["testing"]["tester"]
        )
    except Exception as e:
        self.logger.debug(
            f"from_config failed for tester, using direct instantiation: {e}"
        )
        self.tester = config["testing"]["tester"]["type"](
            *config["testing"]["tester"]["args"],
            **config["testing"]["tester"]["kwargs"],
        )

    with open(self.data_path / "config.yaml", "w") as f:
        yaml.dump(self.config, f)

    self.logger.info("Trainer initialized")
    self.logger.debug(f"Configuration: {self.config}")

from_config(config) classmethod

Create a Trainer instance from a configuration dictionary.

Parameters:

Name Type Description Default
config Dict[str, Any]

The configuration dictionary.

required
Source code in src/QuantumGrav/train.py
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Trainer":
    """Create a Trainer instance from a configuration dictionary.

    Args:
        config (Dict[str, Any]): The configuration dictionary.
    """
    return cls(
        config=config,
    )

initialize_lr_scheduler()

Initialize the learning rate scheduler for training.

Raises:

Type Description
RuntimeError

If the optimizer is not initialized.

Returns:

Type Description
_LRScheduler | None

torch.optim.lr_scheduler._LRScheduler: The initialized learning rate scheduler.

Source code in src/QuantumGrav/train.py
def initialize_lr_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler | None:
    """Initialize the learning rate scheduler for training.

    Raises:
        RuntimeError: If the optimizer is not initialized.

    Returns:
        torch.optim.lr_scheduler._LRScheduler: The initialized learning rate scheduler.
    """
    if self.config["training"].get("lr_scheduler_type") is None:
        self.logger.info("No learning rate scheduler specified in config.")
        return None
    else:
        if not hasattr(self, "optimizer") or self.optimizer is None:
            raise RuntimeError(
                "Optimizer must be initialized before initializing learning rate scheduler."
            )

        try:
            self.lr_scheduler = self.config["training"].get("lr_scheduler_type")(
                self.optimizer,
                *self.config["training"].get("lr_scheduler_args", []),
                **self.config["training"].get("lr_scheduler_kwargs", {}),
            )
            self.logger.info("Learning rate scheduler initialized.")
            return self.lr_scheduler
        except Exception as e:
            self.logger.error(f"Error initializing learning rate scheduler: {e}")
            raise e

initialize_model()

Initialize the model for training.

Returns:

Name Type Description
Any Any

The initialized model.

Source code in src/QuantumGrav/train.py
def initialize_model(self) -> Any:
    """Initialize the model for training.

    Returns:
        Any: The initialized model.
    """
    if hasattr(self, "model") and self.model is not None:
        self.logger.warning(
            "Model is already initialized. This will replace it with a new instance"
        )

    try:
        self.model = gnn_model.GNNModel.from_config(self.config["model"]).to(
            self.device
        )

    except Exception:
        self.logger.debug(
            "from_config for  model initialization failed, using direct initialization instead"
        )
        self.model = self.config["model"]["type"](
            *self.config["model"]["args"], **self.config["model"]["kwargs"]
        ).to(self.device)

    self.logger.info("Model initialized to device: {}".format(self.device))
    return self.model

initialize_optimizer()

Initialize the optimizer for training.

Raises:

Type Description
RuntimeError

If the model is not initialized.

Returns:

Type Description
Optimizer | None

torch.optim.Optimizer: The initialized optimizer.

Source code in src/QuantumGrav/train.py
def initialize_optimizer(self) -> torch.optim.Optimizer | None:
    """Initialize the optimizer for training.

    Raises:
        RuntimeError: If the model is not initialized.

    Returns:
        torch.optim.Optimizer: The initialized optimizer.
    """

    if not hasattr(self, "model") or self.model is None:
        raise RuntimeError(
            "Model must be initialized before initializing optimizer."
        )

    if hasattr(self, "optimizer") and self.optimizer is not None:
        self.logger.warning(
            "Optimizer is already initialized. This will replace it with a new instance"
        )

    try:
        optimizer = self.config["training"].get("optimizer_type", torch.optim.Adam)(
            self.model.parameters(),
            *self.config["training"].get("optimizer_args", []),
            **self.config["training"].get("optimizer_kwargs", {}),
        )
        self.optimizer = optimizer
        self.logger.info("Optimizer initialized")
    except Exception as e:
        self.logger.error(f"Error initializing optimizer: {e}")
        raise e
    return self.optimizer

load_checkpoint(name_addition='')

Load model checkpoint to the device given

Parameters:

Name Type Description Default
name_addition str

An optional string to append to the checkpoint filename.

''

Raises:

Type Description
RuntimeError

If the model is not initialized.

Source code in src/QuantumGrav/train.py
def load_checkpoint(self, name_addition: str = "") -> None:
    """Load model checkpoint to the device given

    Args:
        name_addition (str): An optional string to append to the checkpoint filename.

    Raises:
        RuntimeError: If the model is not initialized.
    """

    if self.model is None:
        raise RuntimeError("Model must be initialized before loading checkpoint.")

    if Path(self.checkpoint_path).exists() is False:
        raise RuntimeError("Checkpoint path does not exist.")

    self.logger.info(
        "available checkpoints: %s", list(Path(self.checkpoint_path).iterdir())
    )

    loadpath = Path(self.checkpoint_path) / f"model_{name_addition}.pt"

    if not loadpath.exists():
        raise FileNotFoundError(f"Checkpoint file {loadpath} does not exist.")

    self.model = gnn_model.GNNModel.load(
        loadpath,
        self.config["model"],
    )

prepare_dataloaders(dataset=None, split=[0.8, 0.1, 0.1], train_dataset=None, val_dataset=None, test_dataset=None, training_sampler=None)

Prepare the data loaders for training, validation, and testing.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to prepare.

None
split list[float]

The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].

[0.8, 0.1, 0.1]
training_sampler Sampler

The sampler for the training data loader. Defaults to None.

None

Returns:

Type Description
Tuple[DataLoader, DataLoader, DataLoader]

Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.

Source code in src/QuantumGrav/train.py
def prepare_dataloaders(
    self,
    dataset: Dataset | None = None,
    split: list[float] = [0.8, 0.1, 0.1],
    train_dataset: torch.utils.data.Subset | None = None,
    val_dataset: torch.utils.data.Subset | None = None,
    test_dataset: torch.utils.data.Subset | None = None,
    training_sampler: torch.utils.data.Sampler | None = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Prepare the data loaders for training, validation, and testing.

    Args:
        dataset (Dataset): The dataset to prepare.
        split (list[float], optional): The split ratios for training, validation, and test sets. Defaults to [0.8, 0.1, 0.1].
        training_sampler (torch.utils.data.Sampler, optional): The sampler for the training data loader. Defaults to None.

    Returns:
        Tuple[DataLoader, DataLoader, DataLoader]: The data loaders for training, validation, and testing.
    """
    self.train_dataset, self.val_dataset, self.test_dataset = self.prepare_dataset(
        dataset=dataset,
        split=split,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
    )
    train_loader = DataLoader(
        self.train_dataset,  # type: ignore
        batch_size=self.config["training"]["batch_size"],
        num_workers=self.config["training"].get("num_workers", 0),
        pin_memory=self.config["training"].get("pin_memory", True),
        drop_last=self.config["training"].get("drop_last", False),
        prefetch_factor=self.config["training"].get("prefetch_factor", None),
        shuffle=self.config["training"].get("shuffle", True),
        sampler=training_sampler,
    )

    val_loader = DataLoader(
        self.val_dataset,  # type: ignore
        batch_size=self.config["validation"]["batch_size"],
        num_workers=self.config["validation"].get("num_workers", 0),
        pin_memory=self.config["validation"].get("pin_memory", True),
        drop_last=self.config["validation"].get("drop_last", False),
        prefetch_factor=self.config["validation"].get("prefetch_factor", None),
        shuffle=self.config["validation"].get("shuffle", True),
    )

    test_loader = DataLoader(
        self.test_dataset,  # type: ignore
        batch_size=self.config["testing"]["batch_size"],
        num_workers=self.config["testing"].get("num_workers", 0),
        pin_memory=self.config["testing"].get("pin_memory", True),
        drop_last=self.config["testing"].get("drop_last", False),
        prefetch_factor=self.config["testing"].get("prefetch_factor", None),
        shuffle=self.config["testing"].get("shuffle", True),
    )

    if dataset is not None:
        self.logger.info(
            f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
        )
    return train_loader, val_loader, test_loader

prepare_dataset(dataset=None, split=[0.8, 0.1, 0.1], train_dataset=None, val_dataset=None, test_dataset=None)

Set up the split for training, validation, and testing datasets.

Parameters:

Name Type Description Default
dataset Dataset | None

Dataset to be split. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

None
split list[float]

split ratios for train, validation, and test datasets. Defaults to [0.8, 0.1, 0.1].

[0.8, 0.1, 0.1]
train_dataset Subset | None

Training subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

None
val_dataset Subset | None

Validation subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

None
test_dataset Subset | None

Testing subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

None

Raises:

Type Description
ValueError

If providing train, val, or test datasets, the full dataset must not be provided.

ValueError

If split ratios are not summing up to 1

ValueError

If train size is 0

ValueError

If validation size is 0

ValueError

If test size is 0

Returns:

Type Description
Tuple[Dataset, Dataset, Dataset]

Tuple[Dataset, Dataset, Dataset]: train, validation, and test datasets.

Source code in src/QuantumGrav/train.py
def prepare_dataset(
    self,
    dataset: Dataset | None = None,
    split: list[float] = [0.8, 0.1, 0.1],
    train_dataset: torch.utils.data.Subset | None = None,
    val_dataset: torch.utils.data.Subset | None = None,
    test_dataset: torch.utils.data.Subset | None = None,
) -> Tuple[Dataset, Dataset, Dataset]:
    """Set up the split for training, validation, and testing datasets.

    Args:
        dataset (Dataset | None, optional): Dataset to be split. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
        split (list[float], optional): split ratios for train, validation, and test datasets. Defaults to [0.8, 0.1, 0.1].
        train_dataset (torch.utils.data.Subset | None, optional): Training subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
        val_dataset (torch.utils.data.Subset | None, optional): Validation subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.
        test_dataset (torch.utils.data.Subset | None, optional): Testing subset of the dataset. Only one of dataset, train_dataset, val_dataset, test_dataset should be provided. Defaults to None.

    Raises:
        ValueError: If providing train, val, or test datasets, the full dataset must not be provided.
        ValueError: If split ratios are not summing up to 1
        ValueError: If train size is 0
        ValueError: If validation size is 0
        ValueError: If test size is 0

    Returns:
        Tuple[Dataset, Dataset, Dataset]: train, validation, and test datasets.
    """
    if dataset is not None and (
        train_dataset is not None
        or val_dataset is not None
        or test_dataset is not None
    ):
        raise ValueError(
            "If providing train, val, or test datasets, the full dataset must not be provided."
        )

    if dataset is None:
        cfg = self.config["data"]
        dataset = dataset_ondisk.QGDataset(
            cfg["files"],
            cfg["output"],
            cfg["reader"],
            float_type=cfg.get("float_type", torch.float32),
            int_type=cfg.get("int_type", torch.int32),
            validate_data=cfg.get("validate_data", True),
            chunksize=cfg.get("chunksize", 1),
            n_processes=cfg.get("n_processes", 1),
            transform=cfg.get("transform"),
            pre_transform=cfg.get("pre_transform"),
            pre_filter=cfg.get("pre_filter"),
        )

        if cfg.get("subset"):
            num_points = ceil(len(dataset) * cfg["subset"])
            dataset = dataset.index_select(
                self.nprng.integers(0, len(dataset), size=num_points).tolist()
            )

        if cfg.get("shuffle"):
            dataset.shuffle()

    if train_dataset is None and val_dataset is None and test_dataset is None:
        split = self.config.get("data", {}).get("split", split)
        if not np.isclose(
            np.sum(split), 1.0, rtol=1e-05, atol=1e-08, equal_nan=False
        ):
            raise ValueError(
                f"Split ratios must sum to 1.0. Provided split: {split}"
            )

        train_size = ceil(len(dataset) * split[0])
        val_size = floor(len(dataset) * split[1])
        test_size = len(dataset) - train_size - val_size

        if train_size == 0:
            raise ValueError("train size cannot be 0")

        if val_size == 0:
            raise ValueError("validation size cannot be 0")

        if test_size == 0:
            raise ValueError("test size cannot be 0")

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size, test_size]
        )

    return train_dataset, val_dataset, test_dataset

run_test(test_loader, model_name_addition='current_best.pt')

Run testing phase.

Parameters:

Name Type Description Default
test_loader DataLoader

The data loader for the test set.

required
model_name_addition str

An optional string to append to the checkpoint filename.

'current_best.pt'

Raises: RuntimeError: If the model is not initialized. RuntimeError: If the test data is not available.

Returns:

Type Description
Collection[Any]

Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.

Source code in src/QuantumGrav/train.py
def run_test(
    self, test_loader: DataLoader, model_name_addition: str = "current_best.pt"
) -> Collection[Any]:
    """Run testing phase.

    Args:
        test_loader (DataLoader): The data loader for the test set.
        model_name_addition (str): An optional string to append to the checkpoint filename.
    Raises:
        RuntimeError: If the model is not initialized.
        RuntimeError: If the test data is not available.

    Returns:
        Collection[Any]: A collection of test results that can be scalars, tensors, lists, dictionaries or any other data type that the tester might return.
    """
    self.logger.info("Starting testing process.")
    # get the best model again

    saved_models = [
        f
        for f in Path(self.checkpoint_path).iterdir()
        if f.is_file() and model_name_addition in str(f)
    ]

    if len(saved_models) == 0:
        raise RuntimeError(
            f"No model with the name addition '{model_name_addition}' found, did training work?"
        )

    # get the latest of the best models
    best_of_the_best = max(saved_models, key=lambda f: f.stat().st_mtime)

    self.logger.info(f"loading best model found: {str(best_of_the_best)}")

    self.model = gnn_model.GNNModel.load(
        best_of_the_best, self.config["model"], device=self.device
    )
    self.model.eval()
    if self.tester is None:
        raise RuntimeError("Tester must be initialized before testing.")
    test_result = self.tester.test(self.model, test_loader)
    self.tester.report(test_result)
    self.logger.info("Testing process completed.")
    self.save_checkpoint(name_addition="best_model_found")
    return self.tester.data

run_training(train_loader, val_loader, trial=None)

Run the training process.

Parameters:

Name Type Description Default
train_loader DataLoader

The data loader for the training set.

required
val_loader DataLoader

The data loader for the validation set.

required
trial Trial | None

An Optuna trial for hyperparameter tuning. Defaults to None.

None

Returns:

Type Description
Tuple[Tensor | Collection[Any], Tensor | Collection[Any]]

Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.

Source code in src/QuantumGrav/train.py
def run_training(
    self,
    train_loader: DataLoader,
    val_loader: DataLoader,
    trial: optuna.trial.Trial | None = None,
) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
    """Run the training process.

    Args:
        train_loader (DataLoader): The data loader for the training set.
        val_loader (DataLoader): The data loader for the validation set.
        trial (optuna.trial.Trial | None, optional): An Optuna trial
            for hyperparameter tuning. Defaults to None.

    Returns:
        Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
    """
    self.logger.info("Starting training process.")
    # training loop
    num_epochs = self.config["training"]["num_epochs"]

    self.initialize_model()

    self.initialize_optimizer()

    self.initialize_lr_scheduler()

    total_training_data = torch.zeros(num_epochs, 2, dtype=torch.float32)

    for epoch in range(0, num_epochs):
        self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
        self.model.train()

        epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)

        # collect mean and std for each epoch
        total_training_data[epoch, :] = torch.Tensor(
            [epoch_data.mean(dim=0).item(), epoch_data.std(dim=0).item()]
        )

        self.logger.info(
            f"  Completed epoch {epoch}. training loss: {total_training_data[epoch, 0]:.8f} +/- {total_training_data[epoch, 1]:.8f}."
        )

        # evaluation run on validation set
        if self.validator is not None:
            validation_result = self.validator.validate(self.model, val_loader)
            self.validator.report(validation_result)

            # integrate Optuna here for hyperparameter tuning
            if trial is not None:
                avg_sigma_loss = self.validator.data[self.epoch]
                avg_loss = avg_sigma_loss[0]
                trial.report(avg_loss, self.epoch)

                # Handle pruning based on the intermediate value.
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()

        should_stop = self._check_model_status(
            self.validator.data if self.validator else total_training_data,
        )
        if should_stop:
            self.logger.info("Stopping training early.")
            break
        self.epoch += 1

    self.logger.info("Training process completed.")
    self.logger.info("Saving model")

    outpath = self.data_path / f"final_model_epoch={self.epoch}.pt"
    self.model.save(outpath)

    return total_training_data, self.validator.data if self.validator else []

save_checkpoint(name_addition='')

Save model checkpoint.

Raises:

Type Description
ValueError

If the model is not initialized.

ValueError

If the model configuration does not contain 'name'.

ValueError

If the training configuration does not contain 'checkpoint_path'.

Source code in src/QuantumGrav/train.py
def save_checkpoint(self, name_addition: str = ""):
    """Save model checkpoint.

    Raises:
        ValueError: If the model is not initialized.
        ValueError: If the model configuration does not contain 'name'.
        ValueError: If the training configuration does not contain 'checkpoint_path'.
    """
    self.logger.info(
        f"Saving checkpoint for model at epoch {self.epoch} to {self.checkpoint_path}"
    )
    outpath = self.checkpoint_path / f"model_{name_addition}.pt"

    if outpath.exists() is False:
        outpath.parent.mkdir(parents=True, exist_ok=True)
        self.logger.debug(f"Created directory {outpath.parent} for checkpoint.")

    self.latest_checkpoint = outpath
    self.model.save(outpath)

Distributed data parallel Trainer class

This is based on this part of the pytorch documentation and is untested at the time of writing.

TrainerDDP

Bases: Trainer

Distributed Data Parallel (DDP) Trainer for training GNN models across multiple processes.

Source code in src/QuantumGrav/train_ddp.py
class TrainerDDP(train.Trainer):
    """Distributed Data Parallel (DDP) Trainer for training GNN models across multiple processes."""

    schema = {
        "$schema": "http://json-schema.org/draft-07/schema#",
        "title": "Model trainer class Configuration",
        "type": "object",
        "definitions": train.Trainer.schema.get("definitions", {}),
        "properties": {
            "parallel": {
                "type": "object",
                "properties": {
                    "world_size": {"type": "integer", "minimum": 1},
                    "output_device": {"type": ["integer", "null"]},
                    "find_unused_parameters": {"type": "boolean"},
                    "rank": {"type": "integer", "minimum": 0},
                    "master_addr": {"type": "string"},
                    "master_port": {"type": "string"},
                },
                "required": ["world_size"],
            },
            "log_level": train.Trainer.schema["properties"]["log_level"],
            "data": train.Trainer.schema["properties"]["data"],
            "training": train.Trainer.schema["properties"]["training"],
            "model": train.Trainer.schema["properties"]["model"],
            "validation": train.Trainer.schema["properties"]["validation"],
            "testing": train.Trainer.schema["properties"]["testing"],
            "early_stopping": train.Trainer.schema["properties"]["early_stopping"],
            "apply_model": train.Trainer.schema["properties"]["apply_model"],
            "criterion": train.Trainer.schema["properties"]["criterion"],
        },
        "required": [
            "parallel",
            "training",
            "model",
            "validation",
            "testing",
            "criterion",
        ],
        "additionalProperties": False,
    }

    def __init__(
        self,
        rank: int,
        config: dict[str, Any],
        # training and evaluation functions
    ):
        """Initialize the distributed data parallel (DDP) trainer.

        Args:
            rank (int): The rank of the current process.
            config (dict[str, Any]): The configuration dictionary.
            criterion (Callable): The loss function.
            apply_model (Callable | None, optional): The function to apply the model. Defaults to None.
            early_stopping (Callable[[list[dict[str, Any]]], bool] | None, optional): The early stopping function. Defaults to None.
            validator (DefaultValidator | None, optional): The validator for model evaluation. Defaults to None.
            tester (DefaultTester | None, optional): The tester for model testing. Defaults to None.

        Raises:
            ValueError: If the configuration is invalid.
        """
        jsonschema.validate(instance=config, schema=self.schema)

        super().__init__(
            config,
        )

        self.config = config  # keep the full config including parallel section

        # initialize the systems differently on each process/rank
        torch.manual_seed(self.seed + rank)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.seed + rank)

        if torch.cuda.is_available() and config["training"]["device"] != "cpu":
            torch.cuda.set_device(rank)
            self.device = torch.device(f"cuda:{rank}")
        else:
            self.device = torch.device("cpu")

        self.rank = rank
        self.world_size = config["parallel"]["world_size"]
        self.logger.info("Initialized DDP trainer")

    def initialize_model(self) -> DDP:
        """Initialize the model for training.

        Returns:
            DDP: The initialized model.
        """
        model = gnn_model.GNNModel.from_config(self.config["model"])

        if self.device.type == "cpu" or (
            isinstance(self.device, torch.device) and self.device.type == "cpu"
        ):
            d_id = None
            o_id = None
        else:
            d_id = [
                self.device,
            ]
            o_id = self.config["parallel"].get("output_device", None)
        model = DDP(
            model,
            device_ids=d_id,
            output_device=o_id,
            find_unused_parameters=self.config["parallel"].get(
                "find_unused_parameters", False
            ),
        )
        self.model = model.to(self.device, non_blocking=True)
        self.logger.info(f"Model initialized on device: {self.device}")
        return self.model

    def prepare_dataloaders(
        self,
        dataset: Dataset | None = None,
        split: list[float] = [0.8, 0.1, 0.1],
        train_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
        val_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
        test_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
        training_sampler: torch.utils.data.Sampler | None = None,
    ) -> Tuple[
        DataLoader,
        DataLoader,
        DataLoader,
    ]:
        """Prepare dataloader for distributed training.

        Args:
            dataset (Dataset | None, optional): Dataset to use. Defaults to None.
            split (list[float], optional): Splits into train, validation and test datasets. Defaults to [0.8, 0.1, 0.1].
            train_dataset (torch.utils.data.Subset | None, optional): Training dataset. Only used when Dataset is None. Defaults to None.
            val_dataset (torch.utils.data.Subset | None, optional): Validation dataset. Only used when Dataset is None.. Defaults to None.
            test_dataset (torch.utils.data.Subset | None, optional): Test dataset. Only used when Dataset is None.. Defaults to None.
            training_sampler (torch.utils.data.Sampler | None, optional): Ignored here. Defaults to None.

        Returns:
            Tuple[ DataLoader, DataLoader, DataLoader, ]: Train, validation and test dataloaders
        """
        self.train_dataset, self.val_dataset, self.test_dataset = self.prepare_dataset(
            dataset=dataset,
            split=split,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            test_dataset=test_dataset,
        )

        # samplers are needed to distribute the data across processes in such a way that each process gets a unique subset of the data
        self.train_sampler = torch.utils.data.DistributedSampler(
            self.train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True,
        )

        self.val_sampler = torch.utils.data.DistributedSampler(
            self.val_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )

        self.test_sampler = torch.utils.data.DistributedSampler(
            self.test_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )

        # make the data loaders
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config["training"]["batch_size"],
            sampler=self.train_sampler,
            num_workers=self.config["training"].get("num_workers", 0),
            pin_memory=self.config["training"].get("pin_memory", True),
            drop_last=self.config["training"].get("drop_last", False),
            prefetch_factor=self.config["training"].get("prefetch_factor", None),
        )

        val_loader = DataLoader(
            self.val_dataset,
            sampler=self.val_sampler,
            batch_size=self.config["validation"]["batch_size"],
            num_workers=self.config["validation"].get("num_workers", 0),
            pin_memory=self.config["validation"].get("pin_memory", True),
            drop_last=self.config["validation"].get("drop_last", False),
            prefetch_factor=self.config["validation"].get("prefetch_factor", None),
        )

        test_loader = DataLoader(
            self.test_dataset,
            sampler=self.test_sampler,
            batch_size=self.config["testing"]["batch_size"],
            num_workers=self.config["testing"].get("num_workers", 0),
            pin_memory=self.config["testing"].get("pin_memory", True),
            drop_last=self.config["testing"].get("drop_last", False),
            prefetch_factor=self.config["testing"].get("prefetch_factor", None),
        )
        self.logger.info(
            f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
        )
        return train_loader, val_loader, test_loader

    def _check_model_status(self, eval_data: pd.DataFrame | list[torch.Tensor]) -> bool:
        """Check the status of the model during evaluation.

        Args:
            eval_data (pd.DataFrame): The evaluation data to check.

        Returns:
            bool: Whether the model training should stop.
        """
        should_stop = False
        if self.rank == 0:
            should_stop = super()._check_model_status(eval_data)
        return should_stop

    def save_checkpoint(self, name_addition: str = ""):
        """Save model checkpoint.

        Raises:
            ValueError: If the model is not initialized.
            ValueError: If the model configuration does not contain 'name'.
            ValueError: If the training configuration does not contain 'checkpoint_path'.
        """
        # TODO: check if this works really - it should save the best model that is there
        if self.rank == 0:
            if self.model is None:
                raise ValueError("Model must be initialized before saving checkpoint.")

            self.logger.info(
                f"Saving checkpoint for model model at epoch {self.epoch} to {self.checkpoint_path}"
            )
            outpath = self.checkpoint_path / f"model_{name_addition}.pt"

            if outpath.exists() is False:
                outpath.parent.mkdir(parents=True, exist_ok=True)
                self.logger.info(f"Created directory {outpath.parent} for checkpoint.")

            self.latest_checkpoint = outpath
            torch.save(self.model, outpath)

    def run_training(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        trial: optuna.trial.Trial | None = None,
    ) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
        """
        Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

        Args:
            train_loader (DataLoader): The training data loader.
            val_loader (DataLoader): The validation data loader.
            trial (optuna.trial.Trial | None, optional): An Optuna trial for hyperparameter optimization.
                Defaults to None.

        Returns:
            Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
        """

        self.model = self.initialize_model()
        self.optimizer = self.initialize_optimizer()

        num_epochs = self.config["training"]["num_epochs"]
        self.logger.info("Starting training process.")

        total_training_data = []
        all_training_data: list[Any] = [None for _ in range(self.world_size)]
        all_validation_data: list[Any] = [None for _ in range(self.world_size)]
        for _ in range(0, num_epochs):
            self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
            self.model.train()
            if train_loader.sampler:
                train_loader.sampler.set_epoch(self.epoch)
            epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)
            total_training_data.append(epoch_data)

            # evaluation run on validation set
            self.model.eval()
            if self.validator is not None:
                validation_result = self.validator.validate(self.model, val_loader)
                if self.rank == 0:
                    self.validator.report(validation_result)

                    # integrate Optuna here for hyperparameter tuning
                    if trial is not None:
                        avg_sigma_loss = self.validator.data[self.epoch]
                        avg_loss = avg_sigma_loss[0]
                        trial.report(avg_loss, self.epoch)

                        # Handle pruning based on the intermediate value.
                        if trial.should_prune():
                            raise optuna.exceptions.TrialPruned()

            dist.barrier()  # Ensure all processes have completed the epoch before checking status
            should_stop = self._check_model_status(
                self.validator.data if self.validator else total_training_data,
            )

            object_list = [should_stop]

            should_stop = dist.broadcast_object_list(
                object_list, src=0, device=self.device
            )
            should_stop = object_list[0]

            if should_stop:
                break

            self.epoch += 1

        dist.barrier()
        dist.all_gather_object(all_training_data, total_training_data)
        dist.all_gather_object(
            all_validation_data, self.validator.data if self.validator else []
        )
        self.logger.info("Training process completed.")

        return all_training_data, all_validation_data

__init__(rank, config)

Initialize the distributed data parallel (DDP) trainer.

Parameters:

Name Type Description Default
rank int

The rank of the current process.

required
config dict[str, Any]

The configuration dictionary.

required
criterion Callable

The loss function.

required
apply_model Callable | None

The function to apply the model. Defaults to None.

required
early_stopping Callable[[list[dict[str, Any]]], bool] | None

The early stopping function. Defaults to None.

required
validator DefaultValidator | None

The validator for model evaluation. Defaults to None.

required
tester DefaultTester | None

The tester for model testing. Defaults to None.

required

Raises:

Type Description
ValueError

If the configuration is invalid.

Source code in src/QuantumGrav/train_ddp.py
def __init__(
    self,
    rank: int,
    config: dict[str, Any],
    # training and evaluation functions
):
    """Initialize the distributed data parallel (DDP) trainer.

    Args:
        rank (int): The rank of the current process.
        config (dict[str, Any]): The configuration dictionary.
        criterion (Callable): The loss function.
        apply_model (Callable | None, optional): The function to apply the model. Defaults to None.
        early_stopping (Callable[[list[dict[str, Any]]], bool] | None, optional): The early stopping function. Defaults to None.
        validator (DefaultValidator | None, optional): The validator for model evaluation. Defaults to None.
        tester (DefaultTester | None, optional): The tester for model testing. Defaults to None.

    Raises:
        ValueError: If the configuration is invalid.
    """
    jsonschema.validate(instance=config, schema=self.schema)

    super().__init__(
        config,
    )

    self.config = config  # keep the full config including parallel section

    # initialize the systems differently on each process/rank
    torch.manual_seed(self.seed + rank)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(self.seed + rank)

    if torch.cuda.is_available() and config["training"]["device"] != "cpu":
        torch.cuda.set_device(rank)
        self.device = torch.device(f"cuda:{rank}")
    else:
        self.device = torch.device("cpu")

    self.rank = rank
    self.world_size = config["parallel"]["world_size"]
    self.logger.info("Initialized DDP trainer")

initialize_model()

Initialize the model for training.

Returns:

Name Type Description
DDP DistributedDataParallel

The initialized model.

Source code in src/QuantumGrav/train_ddp.py
def initialize_model(self) -> DDP:
    """Initialize the model for training.

    Returns:
        DDP: The initialized model.
    """
    model = gnn_model.GNNModel.from_config(self.config["model"])

    if self.device.type == "cpu" or (
        isinstance(self.device, torch.device) and self.device.type == "cpu"
    ):
        d_id = None
        o_id = None
    else:
        d_id = [
            self.device,
        ]
        o_id = self.config["parallel"].get("output_device", None)
    model = DDP(
        model,
        device_ids=d_id,
        output_device=o_id,
        find_unused_parameters=self.config["parallel"].get(
            "find_unused_parameters", False
        ),
    )
    self.model = model.to(self.device, non_blocking=True)
    self.logger.info(f"Model initialized on device: {self.device}")
    return self.model

prepare_dataloaders(dataset=None, split=[0.8, 0.1, 0.1], train_dataset=None, val_dataset=None, test_dataset=None, training_sampler=None)

Prepare dataloader for distributed training.

Parameters:

Name Type Description Default
dataset Dataset | None

Dataset to use. Defaults to None.

None
split list[float]

Splits into train, validation and test datasets. Defaults to [0.8, 0.1, 0.1].

[0.8, 0.1, 0.1]
train_dataset Subset | None

Training dataset. Only used when Dataset is None. Defaults to None.

None
val_dataset Subset | None

Validation dataset. Only used when Dataset is None.. Defaults to None.

None
test_dataset Subset | None

Test dataset. Only used when Dataset is None.. Defaults to None.

None
training_sampler Sampler | None

Ignored here. Defaults to None.

None

Returns:

Type Description
Tuple[DataLoader, DataLoader, DataLoader]

Tuple[ DataLoader, DataLoader, DataLoader, ]: Train, validation and test dataloaders

Source code in src/QuantumGrav/train_ddp.py
def prepare_dataloaders(
    self,
    dataset: Dataset | None = None,
    split: list[float] = [0.8, 0.1, 0.1],
    train_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
    val_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
    test_dataset: torch.utils.data.Dataset | torch.utils.data.Subset | None = None,
    training_sampler: torch.utils.data.Sampler | None = None,
) -> Tuple[
    DataLoader,
    DataLoader,
    DataLoader,
]:
    """Prepare dataloader for distributed training.

    Args:
        dataset (Dataset | None, optional): Dataset to use. Defaults to None.
        split (list[float], optional): Splits into train, validation and test datasets. Defaults to [0.8, 0.1, 0.1].
        train_dataset (torch.utils.data.Subset | None, optional): Training dataset. Only used when Dataset is None. Defaults to None.
        val_dataset (torch.utils.data.Subset | None, optional): Validation dataset. Only used when Dataset is None.. Defaults to None.
        test_dataset (torch.utils.data.Subset | None, optional): Test dataset. Only used when Dataset is None.. Defaults to None.
        training_sampler (torch.utils.data.Sampler | None, optional): Ignored here. Defaults to None.

    Returns:
        Tuple[ DataLoader, DataLoader, DataLoader, ]: Train, validation and test dataloaders
    """
    self.train_dataset, self.val_dataset, self.test_dataset = self.prepare_dataset(
        dataset=dataset,
        split=split,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
    )

    # samplers are needed to distribute the data across processes in such a way that each process gets a unique subset of the data
    self.train_sampler = torch.utils.data.DistributedSampler(
        self.train_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=True,
    )

    self.val_sampler = torch.utils.data.DistributedSampler(
        self.val_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=False,
    )

    self.test_sampler = torch.utils.data.DistributedSampler(
        self.test_dataset,
        num_replicas=self.world_size,
        rank=self.rank,
        shuffle=False,
    )

    # make the data loaders
    train_loader = DataLoader(
        self.train_dataset,
        batch_size=self.config["training"]["batch_size"],
        sampler=self.train_sampler,
        num_workers=self.config["training"].get("num_workers", 0),
        pin_memory=self.config["training"].get("pin_memory", True),
        drop_last=self.config["training"].get("drop_last", False),
        prefetch_factor=self.config["training"].get("prefetch_factor", None),
    )

    val_loader = DataLoader(
        self.val_dataset,
        sampler=self.val_sampler,
        batch_size=self.config["validation"]["batch_size"],
        num_workers=self.config["validation"].get("num_workers", 0),
        pin_memory=self.config["validation"].get("pin_memory", True),
        drop_last=self.config["validation"].get("drop_last", False),
        prefetch_factor=self.config["validation"].get("prefetch_factor", None),
    )

    test_loader = DataLoader(
        self.test_dataset,
        sampler=self.test_sampler,
        batch_size=self.config["testing"]["batch_size"],
        num_workers=self.config["testing"].get("num_workers", 0),
        pin_memory=self.config["testing"].get("pin_memory", True),
        drop_last=self.config["testing"].get("drop_last", False),
        prefetch_factor=self.config["testing"].get("prefetch_factor", None),
    )
    self.logger.info(
        f"Data loaders prepared with splits: {split} and dataset sizes: {len(self.train_dataset)}, {len(self.val_dataset)}, {len(self.test_dataset)}"
    )
    return train_loader, val_loader, test_loader

run_training(train_loader, val_loader, trial=None)

Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

Parameters:

Name Type Description Default
train_loader DataLoader

The training data loader.

required
val_loader DataLoader

The validation data loader.

required
trial Trial | None

An Optuna trial for hyperparameter optimization. Defaults to None.

None

Returns:

Type Description
Tuple[Tensor | Collection[Any], Tensor | Collection[Any]]

Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.

Source code in src/QuantumGrav/train_ddp.py
def run_training(
    self,
    train_loader: DataLoader,
    val_loader: DataLoader,
    trial: optuna.trial.Trial | None = None,
) -> Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]:
    """
    Run the training loop for the distributed model. This will synchronize for validation. No testing is performed in this function. The model will only be checkpointed and early stopped on the 'rank' 0 process.

    Args:
        train_loader (DataLoader): The training data loader.
        val_loader (DataLoader): The validation data loader.
        trial (optuna.trial.Trial | None, optional): An Optuna trial for hyperparameter optimization.
            Defaults to None.

    Returns:
        Tuple[torch.Tensor | Collection[Any], torch.Tensor | Collection[Any]]: The training and validation results.
    """

    self.model = self.initialize_model()
    self.optimizer = self.initialize_optimizer()

    num_epochs = self.config["training"]["num_epochs"]
    self.logger.info("Starting training process.")

    total_training_data = []
    all_training_data: list[Any] = [None for _ in range(self.world_size)]
    all_validation_data: list[Any] = [None for _ in range(self.world_size)]
    for _ in range(0, num_epochs):
        self.logger.info(f"  Current epoch: {self.epoch}/{num_epochs}")
        self.model.train()
        if train_loader.sampler:
            train_loader.sampler.set_epoch(self.epoch)
        epoch_data = self._run_train_epoch(self.model, self.optimizer, train_loader)
        total_training_data.append(epoch_data)

        # evaluation run on validation set
        self.model.eval()
        if self.validator is not None:
            validation_result = self.validator.validate(self.model, val_loader)
            if self.rank == 0:
                self.validator.report(validation_result)

                # integrate Optuna here for hyperparameter tuning
                if trial is not None:
                    avg_sigma_loss = self.validator.data[self.epoch]
                    avg_loss = avg_sigma_loss[0]
                    trial.report(avg_loss, self.epoch)

                    # Handle pruning based on the intermediate value.
                    if trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

        dist.barrier()  # Ensure all processes have completed the epoch before checking status
        should_stop = self._check_model_status(
            self.validator.data if self.validator else total_training_data,
        )

        object_list = [should_stop]

        should_stop = dist.broadcast_object_list(
            object_list, src=0, device=self.device
        )
        should_stop = object_list[0]

        if should_stop:
            break

        self.epoch += 1

    dist.barrier()
    dist.all_gather_object(all_training_data, total_training_data)
    dist.all_gather_object(
        all_validation_data, self.validator.data if self.validator else []
    )
    self.logger.info("Training process completed.")

    return all_training_data, all_validation_data

save_checkpoint(name_addition='')

Save model checkpoint.

Raises:

Type Description
ValueError

If the model is not initialized.

ValueError

If the model configuration does not contain 'name'.

ValueError

If the training configuration does not contain 'checkpoint_path'.

Source code in src/QuantumGrav/train_ddp.py
def save_checkpoint(self, name_addition: str = ""):
    """Save model checkpoint.

    Raises:
        ValueError: If the model is not initialized.
        ValueError: If the model configuration does not contain 'name'.
        ValueError: If the training configuration does not contain 'checkpoint_path'.
    """
    # TODO: check if this works really - it should save the best model that is there
    if self.rank == 0:
        if self.model is None:
            raise ValueError("Model must be initialized before saving checkpoint.")

        self.logger.info(
            f"Saving checkpoint for model model at epoch {self.epoch} to {self.checkpoint_path}"
        )
        outpath = self.checkpoint_path / f"model_{name_addition}.pt"

        if outpath.exists() is False:
            outpath.parent.mkdir(parents=True, exist_ok=True)
            self.logger.info(f"Created directory {outpath.parent} for checkpoint.")

        self.latest_checkpoint = outpath
        torch.save(self.model, outpath)

cleanup_ddp()

Clean up the distributed process group.

Source code in src/QuantumGrav/train_ddp.py
def cleanup_ddp() -> None:
    """Clean up the distributed process group."""
    if dist.is_initialized():
        dist.destroy_process_group()
        os.environ.pop("MASTER_ADDR", None)
        os.environ.pop("MASTER_PORT", None)

initialize_ddp(rank, worldsize, master_addr='localhost', master_port='12345', backend='nccl')

Initialize the distributed process group. This assumes one process per GPU.

Parameters:

Name Type Description Default
rank int

The rank of the current process.

required
worldsize int

The total number of processes.

required
master_addr str

The address of the master process. Defaults to "localhost". This needs to be the ip of the master node if you are running on a cluster.

'localhost'
master_port str

The port of the master process. Defaults to "12345". Choose a high port if you are running multiple jobs on the same machine to avoid conflicts. If running on a cluster, this should be the port that the master node is listening on.

'12345'
backend str

The backend to use for distributed training. Defaults to "nccl".

'nccl'

Raises:

Type Description
RuntimeError

If the environment variables MASTER_ADDR and MASTER_PORT are already set.

Source code in src/QuantumGrav/train_ddp.py
def initialize_ddp(
    rank: int,
    worldsize: int,
    master_addr: str = "localhost",
    master_port: str = "12345",
    backend: str = "nccl",
) -> None:
    """Initialize the distributed process group. This assumes one process per GPU.

    Args:
        rank (int): The rank of the current process.
        worldsize (int): The total number of processes.
        master_addr (str, optional): The address of the master process. Defaults to "localhost". This needs to be the ip of the master node if you are running on a cluster.
        master_port (str, optional): The port of the master process. Defaults to "12345". Choose a high port if you are running multiple jobs on the same machine to avoid conflicts. If running on a cluster, this should be the port that the master node is listening on.
        backend (str, optional): The backend to use for distributed training. Defaults to "nccl".

    Raises:
        RuntimeError: If the environment variables MASTER_ADDR and MASTER_PORT are already set.
    """
    if dist.is_initialized():
        raise RuntimeError("The distributed process group is already initialized.")
    else:
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = master_port
        dist.init_process_group(backend=backend, rank=rank, world_size=worldsize)

Utilities

General utilities that are used throughout this package.

assign_at_path(cfg, path, value)

Assign a value to a key in a nested dictionary 'dict'. The path to follow through this nested structure is given by 'path'.

Parameters:

Name Type Description Default
cfg dict

The configuration dictionary to modify.

required
path Sequence[Any]

The path to the key to modify as a list of nodes to traverse.

required
value Any

The value to assign to the key.

required
Source code in src/QuantumGrav/utils.py
def assign_at_path(cfg: dict, path: Sequence[Any], value: Any) -> None:
    """Assign a value to a key in a nested dictionary 'dict'. The path to follow through this nested structure is given by 'path'.

    Args:
        cfg (dict): The configuration dictionary to modify.
        path (Sequence[Any]): The path to the key to modify as a list of nodes to traverse.
        value (Any): The value to assign to the key.
    """
    for p in path[:-1]:
        cfg = cfg[p]
    cfg[path[-1]] = value

get_at_path(cfg, path, default=None)

Get the value at a key in a nested dictionary. The path to follow through this nested structure is given by 'path'.

Parameters:

Name Type Description Default
cfg dict

The configuration dictionary to modify.

required
path Sequence[Any]

The path to the key to get as a list of nodes to traverse.

required

Returns:

Name Type Description
Any Any

The value at the specified key, or None if not found.

Source code in src/QuantumGrav/utils.py
def get_at_path(cfg: dict, path: Sequence[Any], default: Any = None) -> Any:
    """Get the value at a key in a nested dictionary. The path to follow through this nested structure is given by 'path'.

    Args:
        cfg (dict): The configuration dictionary to modify.
        path (Sequence[Any]): The path to the key to get as a list of nodes to traverse.

    Returns:
        Any: The value at the specified key, or None if not found.
    """
    for p in path[:-1]:
        cfg = cfg[p]

    return cfg.get(path[-1], default)

import_and_get(importpath)

Import a module and get an object from it.

Parameters:

Name Type Description Default
importpath str

The import path of the object to get.

required

Returns:

Name Type Description
Any Any

The name as imported from the module.

Raises:

Type Description
KeyError

When the module indicated by the path is not found

KeyError

When the object name indidcated by the path is not found in the module

Source code in src/QuantumGrav/utils.py
def import_and_get(importpath: str) -> Any:
    """Import a module and get an object from it.

    Args:
        importpath (str): The import path of the object to get.

    Returns:
        Any: The name as imported from the module.

    Raises:
        KeyError: When the module indicated by the path is not found
        KeyError: When the object name indidcated by the path is not found in the module
    """
    parts = importpath.split(".")
    module_name = ".".join(parts[:-1])
    object_name = parts[-1]

    try:
        module = importlib.import_module(module_name)
    except Exception as e:
        raise KeyError(f"Importing module {module_name} unsuccessful") from e
    try:
        return getattr(module, object_name)
    except Exception as e:
        raise KeyError(f"Could not load name {object_name} from {module_name}") from e