Skip to content

API Reference

Computation graph module

This module defines the ComputationGraph class. This class represents a model as a directed acyclic graph (DAG) that executes a series of interdependent tasks which together represent the run of a given heiplanet model and manages the setup and execution of such a graph.

ComputationGraph

A class to represent a computation DAG that executes a series of tasks which together represent the run of a given heiplanet model. These models are defined as combinations of functions known to the class. Modules are a loose collection of functions that are registered with the class and combined into a computational graph to create a functional system. Therefore, functions are registered as either part of a module or as utility functions, e.g., if they are used by multiple modules. The computational graph is built from these functions and executed in via dask tasks to allow for parallel, lazy execution and efficient resource management. Computations can be combined freely from the functions registered with different modules.

Attributes:

Name Type Description
modules dict[str, Any]

A dictionary of modules, where each module is a module object imported from a given path.

module_functions dict[str, dict[str, Callable]]

A dictionary mapping module names to dictionaries of function names and their corresponding callable objects.

task_graph dict[str, Delayed]

A dictionary representing the Dask computational graph, where each node is a dask.delayed object.

config dict[str, Any]

A configuration dictionary for the computation, the computational graph structure.

sink_node Delayed | None

The sink node of the computational graph, which is the final node that triggers the execution of the entire computation.

Source code in src/heiplanet_models/computation_graph.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class ComputationGraph:
    """A class to represent a computation DAG that executes a series of tasks which together represent the run of a given heiplanet model. These models are defined as combinations of functions known to the class.
    Modules are a loose collection of functions that are registered with the class and combined into a computational graph to create a functional system. Therefore, functions are registered as either part of a module or as utility functions, e.g., if they are used by multiple modules. The computational graph is built from these functions and executed in via dask tasks to allow for parallel, lazy execution and efficient resource management. Computations can be combined freely from the functions registered with different modules.

    Attributes:
        modules (dict[str, Any]): A dictionary of modules, where each module is a module object imported from a given path.
        module_functions (dict[str, dict[str, Callable]]): A dictionary mapping module names to dictionaries of function names and their corresponding callable objects.
        task_graph (dict[str, Delayed]): A dictionary representing the Dask computational graph, where each node is a dask.delayed object.
        config (dict[str, Any]): A configuration dictionary for the computation, the computational graph structure.
        sink_node (Delayed | None): The sink node of the computational graph, which is the final node that triggers the execution of the entire computation.
    """

    module_functions: dict[str, dict[str, Callable]] = {}
    config: dict[str, Any] = None  # Configuration for the computation
    sink_node: Delayed | None = None  # The sink node of the computational graph
    task_graph: dict[str, Delayed] | None = None
    sink_node_name: str | None = None
    scheduler: str | None = None  # The Dask scheduler to use for execution
    default_modules: set[str] = {
        "utils",
        "Jmodel",
    }  # README: we need a better way to manage default modules

    def __init__(self, config: dict[str, Any]):
        """Initialize the computation graph from the given configuration.
        This method verifies the configuration, loads the necessary modules, retrieves the functions from the modules, builds the computational graph, and sets the Dask scheduler.

        Args:
            config (dict[str, Any]): The configuration dictionary.

        Raises:
            ValueError: If the configuration is invalid.
        """
        config_valid, msg = self._verify_config(config)

        if not config_valid:
            raise ValueError(f"Configuration verification failed: {msg}")

        self.config = config

        self.logger = logging.getLogger("ComputationGraph")
        self.logger.setLevel(
            logging.DEBUG
            if "log_level" not in config["execution"]
            else config["execution"]["log_level"]
        )
        # load needed code.
        self.module_functions = self._get_functions_from_module(config)

        # build the computational graph and find the sink node which we use to execute the graph
        self.sink_node_name = self._find_sink_node(config)
        self.task_graph = self._build_dag(config)
        self.sink_node = self.task_graph[self.sink_node_name]

        # set the dask scheduler
        self.scheduler = config["execution"]["scheduler"]

    def _get_functions_from_module(
        self, config: dict[str, Any]
    ) -> dict[str, dict[str, Callable]]:
        """Find all functions in the given modules and return a dictionary mapping module names to dictionaries of function names and their corresponding callable objects.

        Args:
            config (dict[str, Any]): A configuration dictionary containing the computational graph structure in which function names and modules they live in are defined.

        Returns:
            dict[str, dict[str, Callable]]: A dictionary mapping module names to dictionaries of function names and their corresponding callable objects.
        """
        module_functions = {}
        for name, spec in config["graph"].items():
            module_path = Path(spec["module"]).resolve().absolute()
            module_name = module_path.stem
            if module_name in self.default_modules:
                continue
            function_name = spec["function"]
            try:
                func = utils.load_name_from_module(
                    module_name=module_name,
                    file_path=module_path,
                    name=function_name,
                )
            except Exception as e:
                raise RuntimeError(
                    f"Failed to load function '{function_name}' from module '{module_name}': {e}"
                ) from e

            if module_name not in module_functions:
                module_functions[module_name] = {}
            module_functions[module_name][function_name] = func

        # add the default modules and utility functions needed
        # README: this needs to be generalized later when we have a more stable
        # way of handling model code
        module_functions["Jmodel"] = {}
        module_functions["utils"] = {}

        for module in [utils, Jmodel]:
            for name, obj in inspect.getmembers(module, inspect.isfunction):
                if obj.__module__ == module.__name__ and name[0] != "_":
                    module_functions[module.__name__.split(".")[-1]][name] = obj

        return module_functions

    def _find_sink_node(self, config: dict[str, Any]) -> str:
        """Find the sink node in the computational graph.

        Args:
            config (dict[str, Any]): Configuration dictionary containing the computational graph structure.

        Raises:
            ValueError: If multiple sink nodes are found.
            ValueError: If no sink node is found.

        Returns:
            str: The name of the sink node.
        """
        all_inputs = []

        for node in config["graph"].values():
            if "input" in node and isinstance(node["input"], list):
                all_inputs.extend(node["input"])
        all_inputs = set(all_inputs)

        sink_node = None
        for node_name, _ in config["graph"].items():
            if node_name not in all_inputs:
                if sink_node is not None:
                    raise ValueError(
                        f"Multiple sink nodes found in the computational graph: {sink_node}, {node_name}."
                    )
                sink_node = node_name

        if sink_node is None:
            raise ValueError("No sink node found in the computational graph.")

        self.logger.debug(f"Sink node found: {sink_node}")
        return sink_node

    def _build_dag(self, config: dict[str, Any]) -> dict[str, Delayed]:
        """Build the Dask computational graph from the configuration.
            The returned executable graph is a dictionary mapping node names to Dask delayed objects, and the returned sink node is the final node in the graph that triggers the execution of the entire computation. There must only be one sink node in the graph.
            If there are multiple sink nodes, an error is raised.
        Args:
            config (dict[str, Any]): Configuration dictionary containing the computational graph structure as name: {
                "function": function_name,
                "input": [input_node_name1, input_node_name2, ...],
                "args": [args],
                "kwargs": {kwargs},
                "module": module_name
            }

        Raises:
            ValueError: A module is not found in the registered modules.
            ValueError:  A function is not found in the module.

        Returns:
            tuple[dict[str, Delayed], Delayed]: A tuple containing the Dask computational graph and the sink node.

        """

        # recursive function that progressively works it's way up the computational graph from the sink node until it reaches source nodes that have no dependencies, and then down again to build the computation tasks in the graph from sources to sink.
        def build_node(node_name, delayed_tasks):
            """Build a single node in the Dask computational graph."""
            self.logger.debug(f" Building input node: {node_name}")

            if node_name in delayed_tasks:
                self.logger.debug(f" Node already defined: {node_name}")
                return

            node_spec = config["graph"][node_name]
            input_nodes = node_spec["input"]
            args = node_spec["args"] if node_spec["args"] is not None else []
            kwargs = node_spec["kwargs"] if node_spec["kwargs"] is not None else {}

            # move up the DAG to build the input nodes first. Since none are defined initially, this will eventually reach the source node.
            for input_node in input_nodes:
                build_node(input_node, delayed_tasks)

            # when the input nodes are all defined, we can create the task
            # for the source node there are no input nodes, so we can create the task immediately
            # Check if all input nodes are already defined
            if (
                all(input_node in delayed_tasks for input_node in input_nodes)
                or len(input_nodes) == 0
            ):
                self.logger.debug(f" Creating task for node: {node_name}")
                # All input nodes are already defined, so we can create the task
                module_name = str(Path(node_spec["module"]).stem)
                func = self.module_functions[module_name][node_spec["function"]]
                delayed_tasks[node_name] = dask.delayed(func)(
                    *[delayed_tasks[input_node] for input_node in input_nodes],
                    *args,
                    **kwargs,
                )
            # no else because on the way down the graph we will never encounter this case.

        delayed_tasks = {}
        build_node(self.sink_node_name, delayed_tasks)
        self.logger.debug(f"build_graph: {delayed_tasks.keys()}")
        return delayed_tasks

    def _verify_computation_config(self, config: dict[str, Any]) -> tuple[bool, str]:
        """Verify the configuration of the computational graph.

        Args:
            config (dict[str, Any]): The configuration dictionary.

        Returns:
            bool, str: A tuple containing a boolean indicating whether the configuration is valid and an error message if it is not.
        """
        # verify the computation structure.
        for node, value in config.items():
            # verify that the node is a dict
            if value is None or not isinstance(value, dict):
                return False, f"Node {node} is not a dict."

            # all nodes that define a computation node must have the name of the function to call, a list of nodes that need to run before this one and that are used as inputs, as well as additional arguments and keyword arguments that are passed to the function. The latter two can be empty or None, but the keys must be present to make this choice explicit and distinguish it from having forgotten to specify them.
            if any(
                key not in value
                for key in ["function", "input", "args", "kwargs", "module"]
            ):
                return (
                    False,
                    f"Node {node} is missing required keys. Required keys are 'function', 'input', 'args', 'kwargs', and 'module'.",
                )

            # check that the module path exists and is a valid file
            if (
                str(Path(value["module"]).stem) not in self.default_modules
                and Path.exists(Path(value["module"]).resolve().absolute()) is False
            ):
                module_name = value["module"]
                return (
                    False,
                    f"Module {module_name} for node {node} at path {Path(value['module']).resolve().absolute()} does not exist.",
                )

            # the input nodes must be a list of names of other nodes
            if not isinstance(value["input"], list):
                return False, f"input nodes for node {node} must be a list"

            # the input nodes must be explicitly specified and must be present # in the graph somewhere, otherwise we cannot resolve them.
            for input_node in value["input"]:
                if input_node not in config:
                    return (
                        False,
                        f"input node {input_node} of node {node} not found in graph",
                    )

            # the positional arguments and keyword arguments must be a list and a dict, respectively, or None
            if not isinstance(value["args"], list) and value["args"] is not None:
                return False, f"arguments for node {node} must be a list"

            if not isinstance(value["kwargs"], dict) and value["kwargs"] is not None:
                return False, f"keyword arguments for node {node} must be a dict"

        return True, "Configuration is valid."

    def _verify_config(self, config: dict[str, Any]) -> tuple[bool, str]:
        """Verify the configuration dictionary.

        Args:
            config (dict[str, Any]): The configuration dictionary.

        Returns:
            bool, str: A tuple containing a boolean indicating whether the configuration is valid and an error message if it is not.
        """

        # verify the structure of the configuration file. Checks that all needed nodes are present and of the right type and within allowed parameters

        # verify the high-level structure of the configuration
        needed_high_level_keys = ["graph", "execution"]
        if not all(key in config for key in needed_high_level_keys):
            return (
                False,
                f"Configuration is missing required keys. Required keys are {needed_high_level_keys}.",
            )

        # we need to have a dask scheduler defined in the execution section...
        if "scheduler" not in config["execution"]:
            return False, "Execution configuration is missing 'scheduler' key."

        # ... and it must be one of those that are supported by dask
        if config["execution"]["scheduler"] not in [
            "synchronous",
            "threads",
            "multiprocessing",
            "distributed",
        ]:
            scheduler = config["execution"]["scheduler"]
            return (
                False,
                f"Unsupported scheduler: {scheduler}. Supported schedulers are 'synchronous', 'threads', 'multiprocessing', or 'distributed'.",
            )

        return self._verify_computation_config(config["graph"])

    def execute(self, client: distributed.client.Client = None):
        """Executes the computational graph.

        Args:
            client (distributed.client.Client, optional): The client to use for execution if the computation should be executed on a cluster. If None, will use the local machine. Defaults to None. For more on how to use the client, see https://distributed.dask.org/en/stable/client.html.

        Raises:
            ValueError: If the sink node is not defined.
            ValueError: If the scheduler is not defined.

        Returns:
            Any: The result of the computation.
        """

        return self.sink_node.compute(scheduler=self.scheduler, client=client)

    def visualize(self, filename: str):
        """Visualizes the computational graph.

        Raises:
            ValueError: If the sink node is not defined.

        Returns:
            Any: The visualization of the sink node as returned by the Delayed.visualize method.
        """
        if self.sink_node is None:
            raise ValueError("Sink node is not defined. Cannot visualize the graph.")
        return self.sink_node.visualize(
            filename=str(filename),
            optimize_graph=False,  # Shows the full graph structure
            rankdir="TB",  # Top to bottom layout
        )

    @classmethod
    def from_config(cls, path_to_config: str | Path) -> "ComputationGraph":
        """Creates a `ComputationGraph` instance from a configuration dictionary read from a json file."""
        with open(path_to_config, "r") as f:
            config = json.load(f)
        return cls(config)

__init__(config)

Initialize the computation graph from the given configuration. This method verifies the configuration, loads the necessary modules, retrieves the functions from the modules, builds the computational graph, and sets the Dask scheduler.

Parameters:

Name Type Description Default
config dict[str, Any]

The configuration dictionary.

required

Raises:

Type Description
ValueError

If the configuration is invalid.

Source code in src/heiplanet_models/computation_graph.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(self, config: dict[str, Any]):
    """Initialize the computation graph from the given configuration.
    This method verifies the configuration, loads the necessary modules, retrieves the functions from the modules, builds the computational graph, and sets the Dask scheduler.

    Args:
        config (dict[str, Any]): The configuration dictionary.

    Raises:
        ValueError: If the configuration is invalid.
    """
    config_valid, msg = self._verify_config(config)

    if not config_valid:
        raise ValueError(f"Configuration verification failed: {msg}")

    self.config = config

    self.logger = logging.getLogger("ComputationGraph")
    self.logger.setLevel(
        logging.DEBUG
        if "log_level" not in config["execution"]
        else config["execution"]["log_level"]
    )
    # load needed code.
    self.module_functions = self._get_functions_from_module(config)

    # build the computational graph and find the sink node which we use to execute the graph
    self.sink_node_name = self._find_sink_node(config)
    self.task_graph = self._build_dag(config)
    self.sink_node = self.task_graph[self.sink_node_name]

    # set the dask scheduler
    self.scheduler = config["execution"]["scheduler"]

execute(client=None)

Executes the computational graph.

Parameters:

Name Type Description Default
client Client

The client to use for execution if the computation should be executed on a cluster. If None, will use the local machine. Defaults to None. For more on how to use the client, see https://distributed.dask.org/en/stable/client.html.

None

Raises:

Type Description
ValueError

If the sink node is not defined.

ValueError

If the scheduler is not defined.

Returns:

Name Type Description
Any

The result of the computation.

Source code in src/heiplanet_models/computation_graph.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def execute(self, client: distributed.client.Client = None):
    """Executes the computational graph.

    Args:
        client (distributed.client.Client, optional): The client to use for execution if the computation should be executed on a cluster. If None, will use the local machine. Defaults to None. For more on how to use the client, see https://distributed.dask.org/en/stable/client.html.

    Raises:
        ValueError: If the sink node is not defined.
        ValueError: If the scheduler is not defined.

    Returns:
        Any: The result of the computation.
    """

    return self.sink_node.compute(scheduler=self.scheduler, client=client)

from_config(path_to_config) classmethod

Creates a ComputationGraph instance from a configuration dictionary read from a json file.

Source code in src/heiplanet_models/computation_graph.py
355
356
357
358
359
360
@classmethod
def from_config(cls, path_to_config: str | Path) -> "ComputationGraph":
    """Creates a `ComputationGraph` instance from a configuration dictionary read from a json file."""
    with open(path_to_config, "r") as f:
        config = json.load(f)
    return cls(config)

visualize(filename)

Visualizes the computational graph.

Raises:

Type Description
ValueError

If the sink node is not defined.

Returns:

Name Type Description
Any

The visualization of the sink node as returned by the Delayed.visualize method.

Source code in src/heiplanet_models/computation_graph.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def visualize(self, filename: str):
    """Visualizes the computational graph.

    Raises:
        ValueError: If the sink node is not defined.

    Returns:
        Any: The visualization of the sink node as returned by the Delayed.visualize method.
    """
    if self.sink_node is None:
        raise ValueError("Sink node is not defined. Cannot visualize the graph.")
    return self.sink_node.visualize(
        filename=str(filename),
        optimize_graph=False,  # Shows the full graph structure
        rankdir="TB",  # Top to bottom layout
    )

Jmodel model implementation

read_default_config()

Reads the default configuration for the JModel from a JSON file.

Returns:

Type Description
dict[str, str | int64 | None]

dict[str, str | np.int64 | None]: A dictionary containing the default configuration.

Source code in src/heiplanet_models/Jmodel.py
31
32
33
34
35
36
37
38
39
40
def read_default_config() -> dict[str, str | np.int64 | None]:
    """Reads the default configuration for the JModel from a JSON file.

    Returns:
        dict[str, str | np.int64 | None]: A dictionary containing the default configuration.
    """
    config_path = Path(__file__).parent / "config_Jmodel.json"
    with open(config_path, "r") as f:
        config = json.load(f)
    return config

read_input_data(model_data)

Read input data from given source 'model_data.input'

Parameters:

Name Type Description Default
model_data JModelData

Data class containing the model configuration and input data path.

required

Returns:

Type Description
Dataset

xr.Dataset: xarray dataset containing the input data for the model.

Source code in src/heiplanet_models/Jmodel.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def read_input_data(model_data: JModelData) -> xr.Dataset:
    """Read input data from given source 'model_data.input'

    Args:
        model_data (JModelData): Data class containing the model configuration and input data path.

    Returns:
        xr.Dataset: xarray dataset containing the input data for the model.
    """

    # nothing done here yet
    data = xr.open_dataset(
        model_data.input, chunks=None if model_data.run_mode == "forbidden" else "auto"
    )

    if data is None:
        raise ValueError("Input data source is not defined in the configuration.")

    # ensure the data has a coordinate reference system (CRS)
    data = detect_csr(data)

    # read the grid data if we want to crop the data
    if all(
        [
            model_data.grid_data_baseurl is not None,
            model_data.nuts_level is not None,
            model_data.resolution is not None,
            model_data.year is not None,
        ]
    ):
        grid_data = read_geodata(
            base_url=model_data.grid_data_baseurl,
            nuts_level=model_data.nuts_level,
            resolution=model_data.resolution,
            year=model_data.year,
            url=lambda base_url, resolution, year, nuts_level: f"{base_url}/geojson/NUTS_RG_{resolution}_{year}_4326_LEVL_{nuts_level}.geojson",
        )

        if grid_data.crs != data.rio.crs:
            raise ValueError(
                f"Coordinate reference system mismatch: Grid data CRS {grid_data.crs} does not match input data CRS {data.rio.crs}."
            )

        # crop the data to the grid. This will remove the pixels outside the grid area
        data = data.rio.clip(
            grid_data.geometry.values,
            grid_data.crs,
            drop=True,  # Drop pixels outside the clipping area
        )

    if model_data.run_mode == "forbidden":
        # run synchronously on one cpu
        return data.compute()
    else:
        return data

run_model(model_data, data)

Runs the JModel with the provided input data. Applies the R0 interpolation based on temperature values from the stored R0 data and returns a new dataset or dataframe with the R0 data.

Parameters:

Name Type Description Default
model_data JModelData

description

required
data Dataset | DataFrame

description

required

Returns:

Type Description
Dataset | DataFrame

xr.Dataset | pd.DataFrame: A dataset or dataframe with the incoming R0 data interpolated based on the temperature values at each grid point.

Source code in src/heiplanet_models/Jmodel.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def run_model(
    model_data: JModelData, data: xr.Dataset | pd.DataFrame
) -> xr.Dataset | pd.DataFrame:
    """Runs the JModel with the provided input data. Applies the R0 interpolation based on temperature values from the stored R0 data and returns a new dataset or dataframe with the R0 data.

    Args:
        model_data (JModelData): _description_
        data (xr.Dataset | pd.DataFrame): _description_

    Returns:
        xr.Dataset | pd.DataFrame: A dataset or dataframe with the incoming R0 data interpolated based on the temperature values at each grid point.
    """
    r0_map = xr.apply_ufunc(
        lambda t: _interpolate_r0(
            t,
            model_data.r0_data,
            model_data.min_temp,
            model_data.max_temp,
        ),
        data[model_data.temp_colname],
        input_core_dims=[[]],
        output_core_dims=[[]],
        dask=model_data.run_mode,
        keep_attrs=True,
    ).rename(model_data.out_colname)

    return r0_map

setup_modeldata(input=None, output=None, r0_path=None, run_mode='forbidden', grid_data_baseurl=None, nuts_level=None, resolution=None, year=None, temp_colname='t2m', out_colname='R0')

Initializes the JModel with the given configuration.

Parameters:

Name Type Description Default
input str | None

Path to the input data file.

None
output str | None

Path to the output data file.

None
r0_path str | None

Path to the R0 data file.

None
run_mode str

Dask run mode used by xarray, default is "forbidden".

'forbidden'
grid_data_baseurl str | None

Base URL for the grid data.

None
nuts_level int | None

NUTS level for the model, default is None

None
resolution str | None

Resolution for the NUTS data, default is None.

None
year int | None

Year for the model, default is None.

None
temp_colname str

Name of the temperature column in the input data, default is "t2m".

't2m'
out_colname str

Name of the output column for R0 data, default is "R0".

'R0'
Source code in src/heiplanet_models/Jmodel.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def setup_modeldata(
    input: str | None = None,
    output: str | None = None,
    r0_path: str | None = None,
    run_mode: str = "forbidden",
    grid_data_baseurl: str | None = None,
    nuts_level: int | None = None,
    resolution: str | None = None,
    year: int | None = None,
    temp_colname: str = "t2m",
    out_colname: str = "R0",
) -> JModelData:
    """Initializes the JModel with the given configuration.

    Args:
        input (str | None): Path to the input data file.
        output (str | None): Path to the output data file.
        r0_path (str | None): Path to the R0 data file.
        run_mode (str): Dask run mode used by xarray, default is "forbidden".
        grid_data_baseurl (str | None): Base URL for the grid data.
        nuts_level (int | None): NUTS level for the model, default is None
        resolution (str | None): Resolution for the NUTS data, default is None.
        year (int | None): Year for the model, default is None.
        temp_colname (str): Name of the temperature column in the input data, default is "t2m".
        out_colname (str): Name of the output column for R0 data, default is "R0".

    """

    # set up plumbing for the model
    if run_mode not in ["forbidden", "parallelized"]:
        raise ValueError(
            f"Invalid run mode: {run_mode}. Supported modes are 'forbidden', 'parallelized'. For the meaning of these modes, see the documentation. of xarray.apply_ufunc"
        )

    # set data paths and get r0 data
    if input is None:
        raise ValueError("Input data path must be provided in the configuration.")

    if output is None:
        raise ValueError("Output data path must be provided in the configuration.")

    # read R0 data from the given path
    if r0_path is None:
        raise ValueError("R0 data path must be provided in the configuration.")
    else:
        r0_data = pd.read_csv(r0_path)
    step_temp = (
        r0_data.Temperature[1] - r0_data.Temperature[0]
    )  # assume uniform step size
    min_temp = r0_data.Temperature.min()
    max_temp = r0_data.Temperature.max()

    if any(
        [
            grid_data_baseurl is None,
            nuts_level is None,
            resolution is None,
            year is None,
        ]
    ) and not all(
        [
            grid_data_baseurl is None,
            nuts_level is None,
            resolution is None,
            year is None,
        ]
    ):
        raise ValueError(
            "Grid data configuration is incomplete. Please provide all parameters: grid_data_baseurl, nuts_level, resolution, and year, or do not set any to have them all set to 'None'."
        )
    else:
        # don“t do anything here, because None indicates the grid data is not used
        pass

    return JModelData(
        name="JModel",
        input=input,
        output=output,
        run_mode=run_mode,
        r0_data=r0_data,
        min_temp=min_temp,
        max_temp=max_temp,
        step=step_temp,
        temp_colname=temp_colname,
        out_colname=out_colname,
        grid_data_baseurl=grid_data_baseurl,
        nuts_level=nuts_level,
        resolution=resolution,
        year=year,
    )

Utilities used throughout the code

detect_csr(data)

Detects and sets the coordinate reference system (CRS) for an xarray dataset. Uses rioxarray to handle the CRS. If the crs is not defined, it checks if the coordinates match the expected ranges for EPSG:4326 (standard lat/lon coordinates).

Parameters:

Name Type Description Default
data Dataset

xarray dataset to check and set the CRS for. typically these are era5 data or other climate data which often do not come with a given crs. Currently, this only supports the

required
EPSG.4326 standard lat/lon coordinates, which are defined as follows
required
- Longitude

-180 to 180 degrees

required
- Latitude

-90 to 90 degrees

required

Raises:

Type Description
ValueError

When the CRS is not defined and the coordinates do not match the expected ranges for EPSG:4326.

Returns:

Type Description
Dataset

xr.Dataset: dataset with the CRS set to EPSG:4326 if it was not already defined and the coordinates match the expected ranges.

Source code in src/heiplanet_models/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def detect_csr(data: xr.Dataset) -> xr.Dataset:
    """Detects and sets the coordinate reference system (CRS) for an xarray dataset. Uses rioxarray to handle the CRS. If the crs is not defined, it checks if the coordinates match the expected ranges for EPSG:4326 (standard lat/lon coordinates).

    Args:
        data (xr.Dataset): xarray dataset to check and set the CRS for. typically these are era5 data or other climate data which often do not come with a given crs. Currently, this only supports the
        EPSG.4326 standard lat/lon coordinates, which are defined as follows:
        - Longitude: -180 to 180 degrees
        - Latitude: -90 to 90 degrees
        The spatial coordinates of the dataset must be called 'latitude' and 'longitude'.

    Raises:
        ValueError: When the CRS is not defined and the coordinates do not match the expected ranges for EPSG:4326.

    Returns:
        xr.Dataset: dataset with the CRS set to EPSG:4326 if it was not already defined and the coordinates match the expected ranges.

    """

    # this currently only detects EPSG:4326 standard lat/lon coordinates
    if (
        -181.0 < data.longitude.min().values < -179.0
        and 179.0 < data.longitude.max().values < 181.0
        and -91.0 < data.latitude.min().values < -89.0
        and 89.0 < data.latitude.max().values < 91.0
    ):
        data = data.rio.write_crs("EPSG:4326")
    else:
        raise ValueError(
            "Coordinate reference system (CRS) is not defined and coordinates do not match expected ranges for EPSG:4326."
        )
    return data

load_module(module_name, file_path)

load_module Load a python module from 'path' with alias 'alias'

Parameters:

Name Type Description Default
module_name str

module alias.

required
file_path str

Path to load the module from

required

Returns:

Name Type Description
module

Python module that has been loaded

Source code in src/heiplanet_models/utils.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def load_module(module_name: str, file_path: str):
    """
    load_module Load a python module from 'path' with alias 'alias'

    Args:
        module_name (str): module alias.
        file_path (str): Path to load the module from

    Returns:
        module: Python module that has been loaded
    """
    try:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
    except Exception as e:
        raise RuntimeError(f"Error in loading module {file_path}") from e
    return module

load_name_from_module(module_name, file_path, name)

load_name_from_module Load a python module from 'path' with alias 'alias'

Parameters:

Name Type Description Default
module_name str

module alias.

required
file_path str

Path to load the module from

required
name str

name to import

required

Returns: module: Python module that has been loaded

Source code in src/heiplanet_models/utils.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
def load_name_from_module(module_name: str, file_path: str, name: str):
    """
    load_name_from_module Load a python module from 'path' with alias 'alias'

    Args:
        module_name (str): module alias.
        file_path (str): Path to load the module from
        name (str): name to import
    Returns:
        module: Python module that has been loaded
    """
    module = load_module(module_name, file_path)
    return getattr(module, name)

read_geodata(nuts_level=3, year=2024, resolution='10M', base_url='https://gisco-services.ec.europa.eu/distribution/v2/nuts', url=lambda base_url, resolution, year, nuts_level: f'{base_url}/geojson/NUTS_RG_{resolution}_{year}_4326_LEVL_{nuts_level}.geojson')

load Eurostat NUTS geospatial data from the Eurostat service.

Parameters:

Name Type Description Default
nuts_level int

nuts administrative region level. Defaults to 3.

3
year int

year to load data for. Defaults to 2024.

2024
resolution str

resolution of the map. Resolution of the geospatial data. One of

'10M'
"60" (1

60million),

required
"20" (1

20million)

required
"10" (1

10million)

required
"03" (1

3million) or

required
"01" (1

1million).

required
base_url str

description. Defaults to "https://gisco-services.ec.europa.eu/distribution/v2/nuts".

'https://gisco-services.ec.europa.eu/distribution/v2/nuts'
url callable

builds the full url from the arguments passed to the function.must have the signature url(base_url, resolution, year, nuts_level).

lambda base_url, resolution, year, nuts_level: f'{base_url}/geojson/NUTS_RG_{resolution}_{year}_4326_LEVL_{nuts_level}.geojson'

Returns:

Type Description

geopandas.dataframe: Dataframe containing the NUTS geospatial data.

Source code in src/heiplanet_models/utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def read_geodata(
    nuts_level: int = 3,
    year: int = 2024,
    resolution: str = "10M",
    base_url: str = "https://gisco-services.ec.europa.eu/distribution/v2/nuts",
    url: callable = lambda base_url, resolution, year, nuts_level: f"{base_url}/geojson/NUTS_RG_{resolution}_{year}_4326_LEVL_{nuts_level}.geojson",
):
    """load Eurostat NUTS geospatial data from the Eurostat service.

    Args:
        nuts_level (int, optional): nuts administrative region level. Defaults to 3.
        year (int, optional): year to load data for. Defaults to 2024.
        resolution (str, optional): resolution of the map. Resolution of the geospatial data. One of
        "60" (1:60million),
        "20" (1:20million)
        "10" (1:10million)
        "03" (1:3million) or
        "01" (1:1million).
        Defaults to "10M".
        base_url (str, optional): _description_. Defaults to "https://gisco-services.ec.europa.eu/distribution/v2/nuts".
        url (callable, optional): builds the full url from the arguments passed to the function.must have the signature url(base_url, resolution, year, nuts_level).

    Returns:
        geopandas.dataframe: Dataframe containing the NUTS geospatial data.
    """
    url_str = url(
        nuts_level=nuts_level, year=year, resolution=resolution, base_url=base_url
    )

    try:
        nuts_data = gpd.read_file(url_str)
        return nuts_data
    except Exception as e:
        raise RuntimeError(f"Failed to download from {url_str}: {e}")

validate_spatial_alignment(arr1, arr2)

Validates that two xarray DataArrays have aligned spatial coordinates.

Parameters:

Name Type Description Default
arr1 DataArray

The first DataArray.

required
arr2 DataArray

The second DataArray.

required

Raises:

Type Description
ValueError

If the 'latitude' or 'longitude' coordinates do not match or if the coordinates are missing.

Source code in src/heiplanet_models/utils.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def validate_spatial_alignment(arr1: xr.DataArray, arr2: xr.DataArray) -> None:
    """Validates that two xarray DataArrays have aligned spatial coordinates.

    Args:
        arr1 (xr.DataArray): The first DataArray.
        arr2 (xr.DataArray): The second DataArray.

    Raises:
        ValueError: If the 'latitude' or 'longitude' coordinates do not match
                    or if the coordinates are missing.
    """
    # Check latitude
    try:
        if not np.array_equal(arr1.latitude.values, arr2.latitude.values):
            raise ValueError(
                "Spatial coordinate 'latitude' of input arrays must be aligned."
            )
    except AttributeError:
        raise ValueError("Input DataArrays must have a 'latitude' coordinate.")

    # Check longitude
    try:
        if not np.array_equal(arr1.longitude.values, arr2.longitude.values):
            raise ValueError(
                "Spatial coordinate 'longitude' of input arrays must be aligned."
            )
    except AttributeError:
        raise ValueError("Input DataArrays must have a 'longitude' coordinate.")