Skip to content

Commit 3ed9334

Browse files
DwarKapexyhtang
andauthored
Add MJX docker container build and small perf test (#497)
Adding support for MJX to JAX-Toolbox --------- Co-authored-by: Yu-Hang 'Maxin' Tang <[email protected]>
1 parent 5535cbf commit 3ed9334

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

.github/container/Dockerfile.mjx

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# syntax=docker/dockerfile:1-labs
2+
3+
ARG BASE_IMAGE=ghcr.io/nvidia/jax:mealkit
4+
ARG SRC_PATH_MUJOCO=/opt/mujoco
5+
6+
###############################################################################
7+
## Download source and add auxiliary scripts
8+
###############################################################################
9+
10+
FROM ${BASE_IMAGE} as mealkit
11+
ARG SRC_PATH_MUJOCO
12+
RUN <<"EOF" bash -ex
13+
get-source.sh -l mujoco -m ${MANIFEST_FILE} -b $(dirname ${SRC_PATH_MUJOCO})
14+
echo "-f https://py.mujoco.org/" >> /opt/pip-tools.d/requirements-mjx.in
15+
echo "-e file://${SRC_PATH_MUJOCO}/mjx" >> /opt/pip-tools.d/requirements-mjx.in
16+
EOF
17+
18+
19+
###############################################################################
20+
## Install accumulated packages from the base image and the previous stage
21+
###############################################################################
22+
23+
FROM mealkit as final
24+
25+
RUN pip-finalize.sh

.github/container/manifest.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,9 @@ haliax:
126126
url: https://github.com/stanford-crfm/haliax.git
127127
tracking_ref: main
128128
latest_verified_commit: 0f29c95eea05ed9e2d9d01c7ae48f4231cf1a57d
129-
mode: git-clone
129+
mode: git-clone
130+
mujoco:
131+
url: https://github.com/google-deepmind/mujoco.git
132+
tracking_ref: main
133+
latest_verified_commit: 4f53d9a0d7bde4b9a69994d79449dfd57a04c305
134+
mode: git-clone
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
name: MJX build
2+
run-name: MJX build - (${{ github.event_name == 'workflow_run' && format('nightly {0}', github.event.workflow_run.created_at) || github.event_name }})
3+
4+
on:
5+
push:
6+
schedule:
7+
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC
8+
workflow_dispatch:
9+
inputs:
10+
BASE_IMAGE_AMD64:
11+
type: string
12+
description: 'JAX mealkit AMD64 imagebuilt by NVIDIA/JAX-Toolbox'
13+
default: ''
14+
required: false
15+
BASE_IMAGE_ARM64:
16+
type: string
17+
description: 'JAX mealkit AMD64 imagebuilt by NVIDIA/JAX-Toolbox'
18+
default: ''
19+
required: false
20+
PUBLISH:
21+
type: boolean
22+
description: Publish dated images and update the 'latest' tag?
23+
default: false
24+
required: false
25+
26+
27+
env:
28+
DOCKER_REGISTRY: ghcr.io/nvidia
29+
DEFAULT_BASE_IMAGE: ghcr.io/nvidia/jax:mealkit-2024-01-26
30+
31+
32+
permissions:
33+
contents: read # to fetch code
34+
actions: write # to cancel previous workflows
35+
packages: write # to upload container
36+
37+
jobs:
38+
39+
metadata:
40+
runs-on: ubuntu-22.04
41+
outputs:
42+
PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }}
43+
BASE_IMAGE_AMD64: ${{ steps.base-image.outputs.BASE_IMAGE_AMD64 }}
44+
BASE_IMAGE_ARM64: ${{ steps.base-image.outputs.BASE_IMAGE_ARM64 }}
45+
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }}
46+
47+
steps:
48+
49+
- name: Cancel workflow if upstream workflow did not success
50+
if: ${{ steps.if-upstream-failed.outputs.UPSTREAM_FAILED == 'true' }}
51+
run: |
52+
echo "Upstream workflow failed, cancelling this workflow"
53+
curl -X POST -H "Authorization: token ${{ github.token }}" \
54+
-H "Accept: application/vnd.github.v3+json" \
55+
"https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/cancel"
56+
cat # blocks execution in case workflow cancellation takes time
57+
58+
- name: Determine if the resulting container should be 'published'
59+
id: if-publish
60+
shell: bash -x -e {0}
61+
run:
62+
# A container should be published if:
63+
# 1) the workflow is triggered by workflow_dispatch and the PUBLISH input is true, or
64+
# 2) the workflow is triggered by workflow_run (i.e., a nightly build)
65+
echo "PUBLISH=${{ github.event_name == 'workflow_run' || (github.event_name == 'workflow_dispatch' && inputs.PUBLISH) }}" >> $GITHUB_OUTPUT
66+
67+
- name: Set build date
68+
id: date
69+
shell: bash -x -e {0}
70+
run: |
71+
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d')
72+
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT
73+
74+
- name: Set base image
75+
id: base-image
76+
shell: bash -x -e {0}
77+
run: |
78+
if [[ -z "${{ inputs.BASE_IMAGE }}" ]]; then
79+
BASE_IMAGE_AMD64=${{ env.DEFAULT_BASE_IMAGE }}
80+
BASE_IMAGE_ARM64=${{ env.DEFAULT_BASE_IMAGE }}
81+
else
82+
BASE_IMAGE_AMD64=${{ inputs.BASE_IMAGE_AMD64 }}
83+
BASE_IMAGE_ARM64=${{ inputs.BASE_IMAGE_ARM64 }}
84+
fi
85+
echo "BASE_IMAGE_AMD64=${BASE_IMAGE_AMD64}" >> $GITHUB_OUTPUT
86+
echo "BASE_IMAGE_ARM64=${BASE_IMAGE_ARM64}" >> $GITHUB_OUTPUT
87+
88+
amd64:
89+
needs: metadata
90+
uses: ./.github/workflows/_build.yaml
91+
with:
92+
ARCHITECTURE: amd64
93+
ARTIFACT_NAME: artifact-mjx-build
94+
BADGE_FILENAME: badge-mjx-build
95+
BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_AMD64 }}
96+
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
97+
CONTAINER_NAME: mjx
98+
DOCKERFILE: .github/container/Dockerfile.mjx
99+
secrets: inherit
100+
101+
arm64:
102+
needs: metadata
103+
uses: ./.github/workflows/_build.yaml
104+
with:
105+
ARCHITECTURE: arm64
106+
ARTIFACT_NAME: artifact-mjx-build
107+
BADGE_FILENAME: badge-mjx-build
108+
BASE_IMAGE: ${{ needs.metadata.outputs.BASE_IMAGE_ARM64 }}
109+
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
110+
CONTAINER_NAME: mjx
111+
DOCKERFILE: .github/container/Dockerfile.mjx
112+
secrets: inherit
113+
114+
publish-mealkit:
115+
needs: [metadata, amd64, arm64]
116+
if: false
117+
#if: needs.metadata.outputs.PUBLISH == 'true'
118+
uses: ./.github/workflows/_publish_container.yaml
119+
with:
120+
SOURCE_IMAGE: |
121+
${{ needs.amd64.outputs.DOCKER_TAG_MEALKIT }}
122+
${{ needs.arm64.outputs.DOCKER_TAG_MEALKIT }}
123+
TARGET_IMAGE: jax
124+
TARGET_TAGS: |
125+
type=raw,value=mjx-mealkit,priority=500
126+
type=raw,value=mjx-mealkit-${{ needs.metadata.outputs.BUILD_DATE }},priority=500
127+
128+
publish-final:
129+
needs: [metadata, amd64, arm64]
130+
if: false
131+
#if: needs.metadata.outputs.PUBLISH == 'true'
132+
uses: ./.github/workflows/_publish_container.yaml
133+
with:
134+
SOURCE_IMAGE: |
135+
${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}
136+
${{ needs.arm64.outputs.DOCKER_TAG_FINAL }}
137+
TARGET_IMAGE: jax
138+
TARGET_TAGS: |
139+
type=raw,value=mjx-latest,priority=1000
140+
type=raw,value=mjx-nightly-${{ needs.metadata.outputs.BUILD_DATE }},priority=900
141+
142+
# small perf tests
143+
runner:
144+
uses: ./.github/workflows/_runner_ondemand_slurm.yaml
145+
with:
146+
NAME: "A100-${{ github.run_id }}"
147+
LABELS: "A100:${{ github.run_id }}"
148+
TIME: "01:00:00"
149+
secrets: inherit
150+
151+
mjx-unit-test:
152+
needs: amd64
153+
strategy:
154+
fail-fast: false
155+
matrix:
156+
GPU_ARCH: [A100]
157+
# ensures A100 job lands on dedicated runner for this particular job
158+
runs-on: [self-hosted, "${{ matrix.GPU_ARCH == 'A100' && format('{0}:{1}', matrix.GPU_ARCH, github.run_id) || matrix.GPU_ARCH }}"]
159+
steps:
160+
- name: Print environment variables
161+
run: env
162+
163+
- name: Print GPU information
164+
run: nvidia-smi
165+
166+
- name: Check out repository
167+
uses: actions/checkout@v3
168+
169+
- name: Login to GitHub Container Registry
170+
uses: docker/login-action@v2
171+
with:
172+
registry: ghcr.io
173+
username: ${{ github.repository_owner }}
174+
password: ${{ secrets.GITHUB_TOKEN }}
175+
176+
- name: Pull MJX image
177+
shell: bash -x -e {0}
178+
run: |
179+
docker pull ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }}
180+
181+
- name: MJX speed test
182+
shell: bash -x -e {0}
183+
continue-on-error: true
184+
run: |
185+
docker run --gpus=all --shm-size=1g ${{ needs.amd64.outputs.DOCKER_TAG_FINAL }} bash -ec "mjx-testspeed --mjcf=humanoid/humanoid.xml --batch_size=8192 --unroll=4 --output=tsv" | tee -a test-mjx.log
186+
187+
- name: Save perf to summary
188+
shell: bash -x -e {0}
189+
continue-on-error: true
190+
run: |
191+
SUMMARY_PATTERN="^mjx-testspeed"
192+
SUMMARY=$(cat test-mjx.log | grep "$SUMMARY_PATTERN")
193+
echo "${SUMMARY}" | tee -a $GITHUB_STEP_SUMMARY
194+
195+
- name: Upload artifacts
196+
uses: actions/upload-artifact@v3
197+
with:
198+
name: ${{ env.DEFAULT_ARTIFACT_NAME }}-${{ matrix.GPU_ARCH }}
199+
path: |
200+
test-mjx.log

0 commit comments

Comments
 (0)