Skip to content

Commit d541bd2

Browse files
committed
Customize NCCL for base container
1 parent 277b9ef commit d541bd2

File tree

5 files changed

+87
-7
lines changed

5 files changed

+87
-7
lines changed

.github/container/Dockerfile.base

+12
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ FROM ${BASE_IMAGE}
2929
ARG GIT_USER_EMAIL
3030
ARG GIT_USER_NAME
3131
ARG CLANG_VERSION
32+
ARG JAX_NCCL_VERSION
33+
ARG JAX_LIBNCCL_PACKAGE
34+
35+
###############################################################################
36+
## Update NCCL version env variables
37+
###############################################################################
38+
39+
ENV NV_LIBNCCL_DEV_PACKAGE=${NV_LIBNCCL_DEV_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
40+
ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=${JAX_NCCL_VERSION}
41+
ENV NCCL_VERSION=${JAX_NCCL_VERSION}
42+
ENV NV_LIBNCCL_PACKAGE=${NV_LIBNCCL_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE}
43+
ENV NV_LIBNCCL_PACKAGE_VERSION=${JAX_NCCL_VERSION}
3244

3345
###############################################################################
3446
## Install Python and essential tools

.github/container/install-nccl.sh

+18-6
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@ set -ex -o pipefail
55
export DEBIAN_FRONTEND=noninteractive
66
export TZ=America/Los_Angeles
77

8-
# If NCCL is already installed, don't reinstall it. Print a message and exit
9-
if dpkg -s libnccl2 libnccl-dev &> /dev/null; then
10-
echo "NCCL is already installed. Skipping installation."
8+
# Try to get NCCL_VERSION of installed libnccl-dev
9+
if [[ -z $NCCL_VERSION ]]; then
10+
NCCL_VERSION=$(dpkg -s libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1 | tr "+" "\n" | head -1)
11+
fi
12+
13+
# Skip NCCL installation if both JAX_NCCL_VERSION (user defined) and
14+
# NCCL_VERSION (defined in nvidia/cuda containers) are unset.
15+
# This case means that the base container is built from a custom image with
16+
# a custom network communicator or unset NCCL_VERSION env variable.
17+
if [[ -z $JAX_NCCL_VERSION && -z $NCCL_VERSION ]]; then
18+
echo "Skip NCCL installation"
1119
else
20+
JAX_NCCL_VERSION=${JAX_NCCL_VERSION:-$NCCL_VERSION}
1221
apt-get update
1322

1423
# Extract CUDA version from `nvcc --version` output line
@@ -18,21 +27,24 @@ else
1827

1928
# Find latest NCCL version compatible with existing CUDA by matching
2029
# ${cuda_version} in the package version string
21-
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
22-
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1)
30+
libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
31+
libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1)
2332
if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then
2433
echo "Could not find compatible NCCL version for CUDA ${cuda_version}"
2534
exit 1
2635
fi
2736

28-
apt-get install -y \
37+
apt-get install -y --allow-change-held-packages \
2938
libnccl2=${libnccl2_version} \
3039
libnccl-dev=${libnccl_dev_version}
3140

3241
apt-get clean
3342
rm -rf /var/lib/apt/lists/*
3443
fi
3544

45+
# Smoke test of installed NCCL packages
46+
dpkg -s libnccl2 libnccl-dev
47+
3648
# Create a prefix with include/ and lib/ directories containing symlinks to the NCCL
3749
# version installed at the system level; this is useful to pass to XLA to avoid it
3850
# fetching its own copy.

.github/workflows/_build_base.yaml

+44-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ on:
4242
description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch
4343
default: ''
4444
required: false
45+
JAX_LIBNCCL_PACKAGE:
46+
type: string
47+
description: NCCL lib package version to be installed (in the format `2.19.3-1+cuda12.3`)
48+
default: ''
49+
required: false
4550
outputs:
4651
DOCKER_TAG:
4752
description: "Tag of the image built"
@@ -56,8 +61,44 @@ permissions:
5661
packages: write # to upload container
5762

5863
jobs:
64+
nccl-version:
65+
runs-on: ubuntu-22.04
66+
outputs:
67+
JAX_NCCL_VERSION: ${{ steps.get-nccl-version.outputs.JAX_NCCL_VERSION }}
68+
JAX_LIBNCCL_PACKAGE: ${{ steps.get-nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}
69+
steps:
70+
- name: Print environment variables
71+
run: env
72+
73+
- name: Check out the repository under ${GITHUB_WORKSPACE}
74+
uses: actions/checkout@v4
75+
76+
- name: Get NCCL version
77+
id: get-nccl-version
78+
shell: bash -x -e {0}
79+
run: |
80+
JAX_LIBNCCL_PACKAGE=${{ inputs.JAX_LIBNCCL_PACKAGE }}
81+
if [[ -z $JAX_LIBNCCL_PACKAGE ]]; then
82+
BASE_IMAGE=${{ inputs.BASE_IMAGE }}
83+
if [[ $BASE_IMAGE == latest ]]; then
84+
BASE_IMAGE=$(cat .github/container/Dockerfile.base | sed -n "s/^ARG BASE_IMAGE=\(.*\)$/\1/p")
85+
fi
86+
# try to get NCCL version from provided BASE_IMAGE of x86-arch
87+
if [[ -z "$BASE_IMAGE" ]]; then
88+
echo "Need to pass non-empty BASE_IMAGE variable"
89+
exit 1
90+
fi
91+
source .github/workflows/scripts/get_remote_env.sh
92+
JAX_LIBNCCL_PACKAGE=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NV_LIBNCCL_PACKAGE')
93+
JAX_NCCL_VERSION=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NCCL_VERSION=' | cut -d= -f2-)
94+
else
95+
JAX_NCCL_VERSION=$(echo $JAX_LIBNCCL_PACKAGE | cut -d= -f2 | cut -d+ -f1)
96+
fi
97+
echo "JAX_NCCL_VERSION=$JAX_NCCL_VERSION" >> $GITHUB_OUTPUT
98+
echo "JAX_LIBNCCL_PACKAGE=$JAX_LIBNCCL_PACKAGE" >> $GITHUB_OUTPUT
5999
60100
build-base:
101+
needs: nccl-version
61102
runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small]
62103
env:
63104
BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json
@@ -133,7 +174,9 @@ jobs:
133174
GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }}
134175
BUILD_DATE=${{ inputs.BUILD_DATE }}
135176
${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }}
136-
177+
JAX_NCCL_VERSION=${{ needs.nccl-version.outputs.JAX_NCCL_VERSION }}
178+
JAX_LIBNCCL_PACKAGE=${{ needs.nccl-version.outputs.JAX_LIBNCCL_PACKAGE }}
179+
137180
- name: Generate sitrep
138181
if: "!cancelled()"
139182
shell: bash -x -e {0}

.github/workflows/_ci.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ on:
2626
description: 'A JSON object containing git url+refs for softwares to be built'
2727
required: false
2828
default: '{}'
29+
JAX_LIBNCCL_PACKAGE:
30+
type: string
31+
description: NCCL version to be installed (for example, `2.20.3-1+cuda12.4`)
32+
default: ''
33+
required: false
2934
outputs:
3035
DOCKER_TAGS:
3136
description: 'JSON object containing tags of all docker images built'
@@ -45,6 +50,7 @@ jobs:
4550
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
4651
BUILD_DATE: ${{ inputs.BUILD_DATE }}
4752
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
53+
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
4854
secrets: inherit
4955

5056
build-jax:

.github/workflows/ci.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ on:
4040
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive)
4141
default: ''
4242
required: false
43+
JAX_LIBNCCL_PACKAGE:
44+
type: string
45+
description: NCCL version to be installed (for example, 2.20.3-1+cuda12.4)
46+
default: ''
47+
required: false
4348

4449
concurrency:
4550
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
@@ -197,6 +202,7 @@ jobs:
197202
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
198203
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
199204
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
205+
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
200206
secrets: inherit
201207

202208
arm64:
@@ -208,6 +214,7 @@ jobs:
208214
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
209215
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
210216
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
217+
JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }}
211218
secrets: inherit
212219

213220
# Only merge if everything succeeds

0 commit comments

Comments
 (0)