NNGT/testing/test_random.py

86 lines
2.0 KiB
Python

# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: GPL-3.0-or-later
# testing/test_random.py
# This file is part of the NNGT module
# Distributed as a free software, in the hope that it will be useful, under the
# terms of the GNU General Public License.
"""
Test the random seeding and generation
"""
import os
import numpy as np
import pytest
import nngt
import nngt.generation as ng
if os.environ.get("MPI"):
nngt.set_config("mpi", True)
if os.environ.get("OMP"):
nngt.set_config("omp", int(os.environ["OMP"]))
def get_num_seeds():
if nngt.get_config("multithreading"):
return nngt.get_config("omp")
return nngt.num_mpi_processes()
# ----- #
# Tests #
# ----- #
def test_random_seeded():
num_seeds = get_num_seeds()
# test equality of graph generated with same seeds
nngt.seed(msd=0, seeds=[i for i in range(1, num_seeds + 1)])
g1 = ng.gaussian_degree(10, 1, nodes=100)
nngt.seed(msd=0, seeds=[i for i in range(1, num_seeds + 1)])
g2 = ng.gaussian_degree(10, 1, nodes=100)
assert np.all(g1.edges_array == g2.edges_array)
# check that subsequent graphs are different
g3 = ng.gaussian_degree(10, 1, nodes=100)
# with mpi onnon-distributed backends, test only on master process
if nngt.get_config("backend") == "nngt" or nngt.on_master_process():
if g3.edge_nb() == g2.edge_nb():
assert np.any(g2.edges_array != g3.edges_array)
def test_random_unseeded():
num_seeds = get_num_seeds()
nngt.seed(msd=42)
# check that seeds generated by first call indeed allow to reproduce the
# graph in later calls
g1 = ng.gaussian_degree(10, 1, nodes=100)
seeds = nngt.get_config("seeds")
nngt.seed(msd=42, seeds=seeds)
g2 = ng.gaussian_degree(10, 1, nodes=100)
assert np.all(g1.edges_array == g2.edges_array)
# --------- #
# Run tests #
# --------- #
if __name__ == "__main__":
test_random_seeded()
test_random_unseeded()