124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
# testing/test_mpi.py
|
|
|
|
"""
|
|
Test the main methods of the :mod:`~nngt.generation` module.
|
|
"""
|
|
|
|
import os
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
import nngt
|
|
from nngt.analysis import *
|
|
from nngt.lib.connect_tools import _compute_connections
|
|
|
|
from base_test import TestBasis, XmlHandler, network_dir
|
|
from test_generation import _distance_rule_theo, _distance_rule_exp
|
|
from tools_testing import foreach_graph
|
|
|
|
|
|
if os.environ.get("MPI"):
|
|
nngt.set_config("mpi", True)
|
|
|
|
|
|
# -------- #
|
|
# Test MPI #
|
|
# -------- #
|
|
|
|
class TestMPI(TestBasis):
|
|
|
|
'''
|
|
Class testing the main methods of the :mod:`~nngt.generation` module.
|
|
'''
|
|
|
|
theo_prop = {
|
|
"distance_rule": _distance_rule_theo,
|
|
}
|
|
|
|
exp_prop = {
|
|
"distance_rule": _distance_rule_exp,
|
|
}
|
|
|
|
tolerance = 0.08
|
|
|
|
@property
|
|
def test_name(self):
|
|
return "test_mpi"
|
|
|
|
@unittest.skipIf(not nngt.get_config('mpi'), "Not using MPI.")
|
|
def gen_graph(self, graph_name):
|
|
di_instructions = self.parser.get_graph_options(graph_name)
|
|
graph = nngt.generate(di_instructions)
|
|
if nngt.on_master_process():
|
|
graph.set_name(graph_name)
|
|
return graph, di_instructions
|
|
|
|
@foreach_graph
|
|
def test_model_properties(self, graph, instructions, **kwargs):
|
|
'''
|
|
When generating graphs from on of the preconfigured models, check that
|
|
the expected properties are indeed obtained.
|
|
'''
|
|
if nngt.get_config("backend") != "nngt" and nngt.on_master_process():
|
|
graph_type = instructions["graph_type"]
|
|
ref_result = self.theo_prop[graph_type](instructions)
|
|
computed_result = self.exp_prop[graph_type](graph, instructions)
|
|
|
|
if graph_type == 'distance_rule':
|
|
# average degree
|
|
self.assertTrue(
|
|
ref_result[0] == computed_result[0],
|
|
"Avg. deg. for graph {} failed:\nref = {} vs exp {}\
|
|
".format(graph.name, ref_result[0], computed_result[0]))
|
|
|
|
# average error on distance distribution
|
|
sqd = np.square(
|
|
np.subtract(ref_result[1:], computed_result[1:]))
|
|
avg_sqd = sqd / np.square(computed_result[1:])
|
|
err = np.sqrt(avg_sqd).mean()
|
|
tolerance = (self.tolerance if instructions['rule'] == 'lin'
|
|
else 0.25)
|
|
|
|
self.assertTrue(err <= tolerance,
|
|
"Distance distribution for graph {} failed:\nerr = {} > {}\
|
|
".format(graph.name, err, tolerance))
|
|
elif nngt.get_config("backend") == "nngt":
|
|
from mpi4py import MPI
|
|
comm = MPI.COMM_WORLD
|
|
num_proc = comm.Get_size()
|
|
graph_type = instructions["graph_type"]
|
|
ref_result = self.theo_prop[graph_type](instructions)
|
|
computed_result = self.exp_prop[graph_type](graph, instructions)
|
|
if graph_type == 'distance_rule':
|
|
# average degree
|
|
self.assertTrue(
|
|
ref_result[0] == computed_result[0] * num_proc,
|
|
"Avg. deg. for graph {} failed:\nref = {} vs exp {}\
|
|
".format(graph.name, ref_result[0], computed_result[0]))
|
|
# average error on distance distribution
|
|
sqd = np.square(
|
|
np.subtract(ref_result[1:], computed_result[1:]))
|
|
avg_sqd = sqd / np.square(computed_result[1:])
|
|
err = np.sqrt(avg_sqd).mean()
|
|
tolerance = (self.tolerance if instructions['rule'] == 'lin'
|
|
else 0.25)
|
|
self.assertTrue(err <= tolerance,
|
|
"Distance distribution for graph {} failed:\nerr = {} > {}\
|
|
".format(graph.name, err, tolerance))
|
|
|
|
|
|
|
|
# ---------- #
|
|
# Test suite #
|
|
# ---------- #
|
|
|
|
if nngt.get_config('mpi'):
|
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestMPI)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|