You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
NNGT/setup.py

111 lines
2.6 KiB
Python

# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2015-2023 Tanguy Fardet
# SPDX-License-Identifier: CC0-1.0
# setup.py
import os
import os.path as op
import platform
import re
import sys
import sysconfig
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
import numpy as np
from Cython.Build import cythonize
# ------------------ #
# Paths and platform #
# ------------------ #
# OS name: Linux/Darwin (Mac)/Windows
os_name = platform.system()
# OpenMP
omp_lib = [] if os_name == "Windows" else ["gomp"]
omp_pos = sys.argv.index("--omp") if "--omp" in sys.argv else -1
omp_lib_dir = "/usr/lib" if omp_pos == -1 else sys.argv[omp_pos + 1]
dirname = os.path.join(".", "nngt/generation/")
# ------------------------ #
# Compiling OMP algorithms #
# ------------------------ #
# compiler options
copt = {
'msvc': ['/openmp', '/O2', '/fp:precise', '/permissive-', '/Zc:twoPhase-'],
'unix': [
'-std=c++11', '-Wno-cpp', '-Wno-unused-function', '-fopenmp',
'-ffast-math', '-msse', '-ftree-vectorize', '-O2', '-g',
]
}
lopt = {
'unix': ['-fopenmp'],
'clang': ['-fopenmp'],
}
extensions = Extension(
"nngt.generation.cconnect",
sources=[
os.path.join(dirname, "cconnect.pyx"),
os.path.join(dirname, "func_connect.cpp")
],
extra_compile_args=[],
language="c++",
include_dirs=[dirname, np.get_include()],
libraries=omp_lib,
library_dirs=[dirname, omp_lib_dir]
)
class build_ext(_build_ext):
def initialize_options(self):
super().initialize_options()
if self.distribution.ext_modules is None:
self.distribution.ext_modules = []
self.distribution.ext_modules = cythonize(extensions)
def build_extensions(self):
# only Unix compilers and their ports have `compiler_so`
c = getattr(self.compiler, 'compiler_so', None)
if c is None:
c = os.environ.get('CC', os.environ.get('CXX'))
else:
c = c[0]
if c is None:
c = sysconfig.get_config_var('CC')
if re.match(r"gcc|g\+\+|mingw|clang", c):
c = "unix"
elif "msvc" in c:
c = "msvc"
for e in self.distribution.ext_modules:
e.extra_link_args.extend(lopt.get(c, []))
e.extra_compile_args.extend(copt.get(c, []))
try:
self.compiler.compiler_so.remove("-Wstrict-prototypes")
self.compiler.compiler_so.remove("-O3")
except:
pass
super().build_extensions()
if __name__ == "__main__":
setup(cmdclass={"build_ext": build_ext}, ext_modules=cythonize(extensions))