Skip to content

Commit

Permalink
Cleaner handling of numpy-based extensions in setup.py
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#853

Differential Revision: D17147879

Pulled By: myleott

fbshipit-source-id: b1f5e838533de62ade52fa82112ea5308734c70f
  • Loading branch information
myleott authored and facebook-github-bot committed Aug 31, 2019
1 parent 746e59a commit 8d4588b
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import sys


Expand All @@ -23,6 +22,23 @@
extra_compile_args = ['-std=c++11', '-O3']


class NumpyExtension(Extension):
"""Source: https://stackoverflow.com/a/54128391"""

def __init__(self, *args, **kwargs):
self.__include_dirs = []
super().__init__(*args, **kwargs)

@property
def include_dirs(self):
import numpy
return self.__include_dirs + [numpy.get_include()]

@include_dirs.setter
def include_dirs(self, dirs):
self.__include_dirs = dirs


extensions = [
Extension(
'fairseq.libbleu',
Expand All @@ -32,13 +48,13 @@
],
extra_compile_args=extra_compile_args,
),
Extension(
NumpyExtension(
'fairseq.data.data_utils_fast',
sources=['fairseq/data/data_utils_fast.pyx'],
language='c++',
extra_compile_args=extra_compile_args,
),
Extension(
NumpyExtension(
'fairseq.data.token_block_utils_fast',
sources=['fairseq/data/token_block_utils_fast.pyx'],
language='c++',
Expand All @@ -47,15 +63,6 @@
]


class CustomBuildExtCommand(build_ext):
"""Source: https://stackoverflow.com/a/42163080"""
def run(self):
# Import numpy here, only when headers are needed
import numpy
self.include_dirs.append(numpy.get_include())
super().run()


setup(
name='fairseq',
version='0.8.0',
Expand All @@ -71,7 +78,6 @@ def run(self):
long_description=readme,
long_description_content_type='text/markdown',
setup_requires=[
'numpy',
'cython',
'numpy',
'setuptools>=18.0',
Expand Down Expand Up @@ -99,6 +105,5 @@ def run(self):
'fairseq-validate = fairseq_cli.validate:cli_main',
],
},
cmdclass={'build_ext': CustomBuildExtCommand},
zip_safe=False,
)

0 comments on commit 8d4588b

Please sign in to comment.