From d541bd2eaba7029bc3dd33a5d2a8ebe7bab682da Mon Sep 17 00:00:00 2001 From: Vladislav Kozlov Date: Mon, 28 Oct 2024 16:22:29 -0700 Subject: [PATCH] Customize NCCL for base container --- .github/container/Dockerfile.base | 12 ++++++++ .github/container/install-nccl.sh | 24 ++++++++++++---- .github/workflows/_build_base.yaml | 45 +++++++++++++++++++++++++++++- .github/workflows/_ci.yaml | 6 ++++ .github/workflows/ci.yaml | 7 +++++ 5 files changed, 87 insertions(+), 7 deletions(-) diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 9f0851897..573e585cd 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -29,6 +29,18 @@ FROM ${BASE_IMAGE} ARG GIT_USER_EMAIL ARG GIT_USER_NAME ARG CLANG_VERSION +ARG JAX_NCCL_VERSION +ARG JAX_LIBNCCL_PACKAGE + +############################################################################### +## Update NCCL version env variables +############################################################################### + +ENV NV_LIBNCCL_DEV_PACKAGE=${NV_LIBNCCL_DEV_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE} +ENV NV_LIBNCCL_DEV_PACKAGE_VERSION=${JAX_NCCL_VERSION} +ENV NCCL_VERSION=${JAX_NCCL_VERSION} +ENV NV_LIBNCCL_PACKAGE=${NV_LIBNCCL_PACKAGE_NAME}=${JAX_LIBNCCL_PACKAGE} +ENV NV_LIBNCCL_PACKAGE_VERSION=${JAX_NCCL_VERSION} ############################################################################### ## Install Python and essential tools diff --git a/.github/container/install-nccl.sh b/.github/container/install-nccl.sh index 892bf2d4d..cca8c8adf 100755 --- a/.github/container/install-nccl.sh +++ b/.github/container/install-nccl.sh @@ -5,10 +5,19 @@ set -ex -o pipefail export DEBIAN_FRONTEND=noninteractive export TZ=America/Los_Angeles -# If NCCL is already installed, don't reinstall it. Print a message and exit -if dpkg -s libnccl2 libnccl-dev &> /dev/null; then - echo "NCCL is already installed. Skipping installation." +# Try to get NCCL_VERSION of installed libnccl-dev +if [[ -z $NCCL_VERSION ]]; then + NCCL_VERSION=$(dpkg -s libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1 | tr "+" "\n" | head -1) +fi + +# Skip NCCL installation if both JAX_NCCL_VERSION (user defined) and +# NCCL_VERSION (defined in nvidia/cuda containers) are unset. +# This case means that the base container is built from a custom image with +# a custom network communicator or unset NCCL_VERSION env variable. +if [[ -z $JAX_NCCL_VERSION && -z $NCCL_VERSION ]]; then + echo "Skip NCCL installation" else + JAX_NCCL_VERSION=${JAX_NCCL_VERSION:-$NCCL_VERSION} apt-get update # Extract CUDA version from `nvcc --version` output line @@ -18,14 +27,14 @@ else # Find latest NCCL version compatible with existing CUDA by matching # ${cuda_version} in the package version string - libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1) - libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(.*+cuda${cuda_version}\)$/\1/p" | head -n 1) + libnccl2_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1) + libnccl_dev_version=$(apt-cache show libnccl-dev | sed -n "s/^Version: \(${JAX_NCCL_VERSION}.*+cuda.*\)$/\1/p" | head -n 1) if [[ -z "${libnccl2_version}" || -z "${libnccl_dev_version}" ]]; then echo "Could not find compatible NCCL version for CUDA ${cuda_version}" exit 1 fi - apt-get install -y \ + apt-get install -y --allow-change-held-packages \ libnccl2=${libnccl2_version} \ libnccl-dev=${libnccl_dev_version} @@ -33,6 +42,9 @@ else rm -rf /var/lib/apt/lists/* fi +# Smoke test of installed NCCL packages +dpkg -s libnccl2 libnccl-dev + # Create a prefix with include/ and lib/ directories containing symlinks to the NCCL # version installed at the system level; this is useful to pass to XLA to avoid it # fetching its own copy. diff --git a/.github/workflows/_build_base.yaml b/.github/workflows/_build_base.yaml index b575ec14b..07004dabc 100644 --- a/.github/workflows/_build_base.yaml +++ b/.github/workflows/_build_base.yaml @@ -42,6 +42,11 @@ on: description: Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch default: '' required: false + JAX_LIBNCCL_PACKAGE: + type: string + description: NCCL lib package version to be installed (in the format `2.19.3-1+cuda12.3`) + default: '' + required: false outputs: DOCKER_TAG: description: "Tag of the image built" @@ -56,8 +61,44 @@ permissions: packages: write # to upload container jobs: + nccl-version: + runs-on: ubuntu-22.04 + outputs: + JAX_NCCL_VERSION: ${{ steps.get-nccl-version.outputs.JAX_NCCL_VERSION }} + JAX_LIBNCCL_PACKAGE: ${{ steps.get-nccl-version.outputs.JAX_LIBNCCL_PACKAGE }} + steps: + - name: Print environment variables + run: env + + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v4 + + - name: Get NCCL version + id: get-nccl-version + shell: bash -x -e {0} + run: | + JAX_LIBNCCL_PACKAGE=${{ inputs.JAX_LIBNCCL_PACKAGE }} + if [[ -z $JAX_LIBNCCL_PACKAGE ]]; then + BASE_IMAGE=${{ inputs.BASE_IMAGE }} + if [[ $BASE_IMAGE == latest ]]; then + BASE_IMAGE=$(cat .github/container/Dockerfile.base | sed -n "s/^ARG BASE_IMAGE=\(.*\)$/\1/p") + fi + # try to get NCCL version from provided BASE_IMAGE of x86-arch + if [[ -z "$BASE_IMAGE" ]]; then + echo "Need to pass non-empty BASE_IMAGE variable" + exit 1 + fi + source .github/workflows/scripts/get_remote_env.sh + JAX_LIBNCCL_PACKAGE=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NV_LIBNCCL_PACKAGE') + JAX_NCCL_VERSION=$(get_remote_env ${BASE_IMAGE} linux amd64 | jq -r '.[]' | egrep '^NCCL_VERSION=' | cut -d= -f2-) + else + JAX_NCCL_VERSION=$(echo $JAX_LIBNCCL_PACKAGE | cut -d= -f2 | cut -d+ -f1) + fi + echo "JAX_NCCL_VERSION=$JAX_NCCL_VERSION" >> $GITHUB_OUTPUT + echo "JAX_LIBNCCL_PACKAGE=$JAX_LIBNCCL_PACKAGE" >> $GITHUB_OUTPUT build-base: + needs: nccl-version runs-on: [self-hosted, "${{ inputs.ARCHITECTURE }}", small] env: BADGE_FILENAME_FULL: ${{ inputs.BADGE_FILENAME }}-${{ inputs.ARCHITECTURE }}.json @@ -133,7 +174,9 @@ jobs: GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }} BUILD_DATE=${{ inputs.BUILD_DATE }} ${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }} - + JAX_NCCL_VERSION=${{ needs.nccl-version.outputs.JAX_NCCL_VERSION }} + JAX_LIBNCCL_PACKAGE=${{ needs.nccl-version.outputs.JAX_LIBNCCL_PACKAGE }} + - name: Generate sitrep if: "!cancelled()" shell: bash -x -e {0} diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 17ad5f06b..2c0756cbc 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -26,6 +26,11 @@ on: description: 'A JSON object containing git url+refs for softwares to be built' required: false default: '{}' + JAX_LIBNCCL_PACKAGE: + type: string + description: NCCL version to be installed (for example, `2.20.3-1+cuda12.4`) + default: '' + required: false outputs: DOCKER_TAGS: description: 'JSON object containing tags of all docker images built' @@ -45,6 +50,7 @@ jobs: BASE_IMAGE: ${{ inputs.CUDA_IMAGE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }} + JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }} secrets: inherit build-jax: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0c3c8bdb0..22d123c0c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -40,6 +40,11 @@ on: PACKAGEāˆŠ{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive) default: '' required: false + JAX_LIBNCCL_PACKAGE: + type: string + description: NCCL version to be installed (for example, 2.20.3-1+cuda12.4) + default: '' + required: false concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -197,6 +202,7 @@ jobs: CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} + JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }} secrets: inherit arm64: @@ -208,6 +214,7 @@ jobs: CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} + JAX_LIBNCCL_PACKAGE: ${{ inputs.JAX_LIBNCCL_PACKAGE }} secrets: inherit # Only merge if everything succeeds