Skip to content

[jax] [package list]

jax/0.4.1-cpeCray-22.08-rocm5.3 (jax-0.4.1-cpeCray-22.08-rocm5.3.eb)

This software is archived in the LUMI-EasyBuild-contrib GitHub repository as easybuild/easyconfigs/__archive__/j/jax/jax-0.4.1-cpeCray-22.08-rocm5.3.eb. The corresponding module would be jax/0.4.1-cpeCray-22.08-rocm5.3.

# Created for LUMI by Orian Louant
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.1'
versionsuffix= '-rocm5.3'

local_craypython_version = '3.9.12.1'
local_rocm_version       = '5.3.3' 
local_wheel_version      = '0.37.1'

local_craypython_maj_min = '.'.join( local_craypython_version.split('.')[:2] )

homepage = 'https://pypi.python.org/pypi/jax'

whatis = ['Description: JAX is Autograd and XLA, brought together for high-performance numerical computing.']

description = """
JAX is Autograd and XLA, brought together for high-performance numerical
computing.

JAX provides a familiar NumPy-style API for ease of adoption by researchers and
engineers. It includes composable function transformations for compilation,
batching, automatic differentiation, and parallelization. The same code executes
on multiple backends, including CPU, GPU, & TPU.
"""

docurls = ['https://jax.readthedocs.io/en/latest/']
software_license_urls = ['https://github.com/google/jax/blob/main/LICENSE']

toolchain = {'name': 'cpeCray', 'version': '22.08'}

builddependencies = [
    ('wheel', local_wheel_version, '-cray-python%s' % local_craypython_maj_min),
]

dependencies = [
    ('cray-python/%s' % local_craypython_version, EXTERNAL_MODULE),
    ('rocm/%s' % local_rocm_version,              EXTERNAL_MODULE),
]

# downloading TensorFlow tarball to avoid that Bazel downloads it during the build
# note: while it's considered good to have the exact same commit as used in WORKSPACE
#       in practice, we have to pick the commit of the AMD repo that synced the changes
#       comming from the official TF repo (ususally synced after a few days)
local_tf_commit = '463c9aeff2f52e7988c869b15e1c3876d311b578'
local_tf_dir = 'tensorflow-upstream-%s' % local_tf_commit
local_tf_builddir = '%(builddir)s/' + local_tf_dir

# replace remote TensorFlow repository with the local one from EB
local_jax_prebuildopts = "sed -i -f jaxlib_local-tensorflow-repo.sed WORKSPACE && "
local_jax_prebuildopts += "sed -i 's|EB_TF_REPOPATH|%s|' WORKSPACE && " % local_tf_builddir

use_pip = True

default_easyblock = 'PythonPackage'
default_component_specs = {
    'sources': [SOURCE_TAR_GZ],
    'source_urls': [PYPI_SOURCE],
    'start_dir': '%(name)s-%(version)s',
    'use_pip': True,
    'sanity_pip_check': True,
    'download_dep_fail': True,
}

components = [
    ('absl-py', '1.4.0', {
        'options': {'modulename': 'absl'},
        'checksums': ['d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d'],
    }),
    ('jaxlib', version, {
        'sources': [
            '%(name)s-v%(version)s.tar.gz',
            {
                'download_filename': '%s.tar.gz' % local_tf_commit,
                'filename': 'tensorflow-upstream-%s.tar.gz' % local_tf_commit,
            }
        ],
        'source_urls': [
            'https://github.com/google/jax/archive/',
            'https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/archive/'
        ],
        'patches': [
            ('jaxlib_local-tensorflow-repo.sed', '.'),
        ],
        'checksums': [
            # jaxlib-v0.4.1.tar.gz
            'a001de25e0b7ca5847dcbbcf4c31a8e354c0b6cdaa970b7cc4aeea027fd638d7',
            # tensorflow-upstream-463c9aeff2f52e7988c869b15e1c3876d311b578.tar.gz
            '4cae7f07e5c5d30af78e029a753ab077789ef121c6bea6d3499d15f4ae943a46',
            # jaxlib_local-tensorflow-repo.sed
            'abb5c3b97f4e317bce9f22ed3eeea3b9715365818d8b50720d937e2d41d5c4e5',
        ],
        'start_dir'      : 'jax-jaxlib-v%(version)s',
        'prebuildopts'   : local_jax_prebuildopts,
        'use_mkl_dnn'    : False,
        'binutils_root'  : '/opt/cray/pe/cce/14.0.2/binutils/x86_64',
    }),
]

exts_list = [
    ('opt_einsum', '3.3.0', {
        'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'],
    }),
    ('flit_core', '3.8.0', {
        'checksum' : ['b305b30c99526df5e63d6022dd2310a0a941a187bd3884f4c8ef0418df6c39f3'],
    }),
    ('etils', '0.8.0', {
        'checksums': ['d1d5af7bd9c784a273c4e1eccfaa8feaca5e0481a08717b5313fa231da22a903'],
    }),
    ('typing-extensions', '4.4.0', {
       'sources'  : ['typing_extensions-4.4.0.tar.gz'],  
       'checksums': ['1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa'],
    }),
    (name, version, {
        'source_tmpl' : '%(name)s-v%(version)s.tar.gz',
        'source_urls' : ['https://github.com/google/jax/archive/'],
        'checksums'   : ['27d74d95f84af0303813456900e5b7604e5025742adff1f09df16e508424a8a0'],
    }),
]

# Because of missing cloudpickle required by dask in cray install... 
sanity_pip_check = False

moduleclass = 'tools'

[jax] [package list]