Skip to content

Commit 97c0893

Browse files
committed
Add AMD GPU test for ray clusters
1 parent b612ce3 commit 97c0893

6 files changed

+97
-52
lines changed

Diff for: go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ require (
1111
github.com/openshift/api v0.0.0-20240904015708-69df64132c91
1212
github.com/openshift/client-go v0.0.0-20240904130219-3795e907a202
1313
github.com/project-codeflare/appwrapper v1.0.4
14-
github.com/project-codeflare/codeflare-common v0.0.0-20250306164418-eb812487be82
14+
github.com/project-codeflare/codeflare-common v0.0.0-20250317102908-1c124db97844
1515
github.com/ray-project/kuberay/ray-operator v1.2.2
1616
go.uber.org/zap v1.27.0
1717
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56

Diff for: go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
225225
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
226226
github.com/project-codeflare/appwrapper v1.0.4 h1:364zQLX0tsi4LvBBYNKZL7PPbNWPbVU7vK6+/kVV/FQ=
227227
github.com/project-codeflare/appwrapper v1.0.4/go.mod h1:A1b6bMFNMX5Btv3ckgeuAHVVZzp1G30pSBe6BE/xJWE=
228-
github.com/project-codeflare/codeflare-common v0.0.0-20250306164418-eb812487be82 h1:cL1K2+r1lJVwBkhXiVFr2A9DphnylJmilYDIqg/W62M=
229-
github.com/project-codeflare/codeflare-common v0.0.0-20250306164418-eb812487be82/go.mod h1:DPSv5khRiRDFUD43SF8da+MrVQTWmxNhuKJmwSLOyO0=
228+
github.com/project-codeflare/codeflare-common v0.0.0-20250317102908-1c124db97844 h1:hEjZ2pV4Fp81wytijJZ7uHWovKIqirVBA/t1F5hIrbA=
229+
github.com/project-codeflare/codeflare-common v0.0.0-20250317102908-1c124db97844/go.mod h1:DPSv5khRiRDFUD43SF8da+MrVQTWmxNhuKJmwSLOyO0=
230230
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
231231
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
232232
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=

Diff for: test/e2e/deployment_appwrapper_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestDeploymentAppWrapper(t *testing.T) {
4545
defer func() {
4646
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
4747
}()
48-
clusterQueue := createClusterQueue(test, resourceFlavor, 0)
48+
clusterQueue := createClusterQueue(test, resourceFlavor, 0, CPU)
4949
defer func() {
5050
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
5151
}()

Diff for: test/e2e/job_appwrapper_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func TestBatchJobAppWrapper(t *testing.T) {
4343
defer func() {
4444
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
4545
}()
46-
clusterQueue := createClusterQueue(test, resourceFlavor, 0)
46+
clusterQueue := createClusterQueue(test, resourceFlavor, 0, CPU)
4747
defer func() {
4848
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
4949
}()

Diff for: test/e2e/mnist_pytorch_appwrapper_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ import (
3232
)
3333

3434
func TestMnistPyTorchAppWrapperCpu(t *testing.T) {
35-
runMnistPyTorchAppWrapper(t, "cpu", 0)
35+
runMnistPyTorchAppWrapper(t, CPU, 0)
3636
}
3737

3838
func TestMnistPyTorchAppWrapperGpu(t *testing.T) {
39-
runMnistPyTorchAppWrapper(t, "gpu", 1)
39+
runMnistPyTorchAppWrapper(t, NVIDIA, 1)
4040
}
4141

4242
// Trains the MNIST dataset as a batch Job in an AppWrapper, and asserts successful completion of the training job.
43-
func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus int) {
43+
func runMnistPyTorchAppWrapper(t *testing.T, accelerator Accelerator, numberOfGpus int) {
4444
test := With(t)
4545

4646
// Create a namespace
@@ -51,7 +51,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
5151
defer func() {
5252
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
5353
}()
54-
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus)
54+
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus, accelerator)
5555
defer func() {
5656
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
5757
}()
@@ -109,7 +109,7 @@ func runMnistPyTorchAppWrapper(t *testing.T, accelerator string, numberOfGpus in
109109
{Name: "MNIST_DATASET_URL", Value: GetMnistDatasetURL()},
110110
{Name: "PIP_INDEX_URL", Value: GetPipIndexURL()},
111111
{Name: "PIP_TRUSTED_HOST", Value: GetPipTrustedHost()},
112-
{Name: "ACCELERATOR", Value: accelerator},
112+
{Name: "ACCELERATOR", Value: accelerator.Type},
113113
},
114114
Command: []string{"/bin/sh", "-c", "pip install -r /test/requirements.txt && torchrun /test/mnist.py"},
115115
VolumeMounts: []corev1.VolumeMount{

Diff for: test/e2e/mnist_rayjob_raycluster_test.go

+87-42
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222
"net/http"
2323
"net/url"
24+
"strings"
2425
"testing"
2526

2627
. "github.com/onsi/gomega"
@@ -40,14 +41,18 @@ import (
4041
// directly managed by Kueue, and asserts successful completion of the training job.
4142

4243
func TestMnistRayJobRayClusterCpu(t *testing.T) {
43-
runMnistRayJobRayCluster(t, "cpu", 0)
44+
runMnistRayJobRayCluster(t, CPU, 0, GetRayImage())
4445
}
4546

46-
func TestMnistRayJobRayClusterGpu(t *testing.T) {
47-
runMnistRayJobRayCluster(t, "gpu", 1)
47+
func TestMnistRayJobRayClusterCudaGpu(t *testing.T) {
48+
runMnistRayJobRayCluster(t, NVIDIA, 1, GetRayImage())
4849
}
4950

50-
func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int) {
51+
func TestMnistRayJobRayClusterROCmGpu(t *testing.T) {
52+
runMnistRayJobRayCluster(t, AMD, 1, GetRayROCmImage())
53+
}
54+
55+
func runMnistRayJobRayCluster(t *testing.T, accelerator Accelerator, numberOfGpus int, rayImage string) {
5156
test := With(t)
5257

5358
// Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -58,11 +63,12 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
5863
defer func() {
5964
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
6065
}()
61-
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus)
66+
67+
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus, accelerator)
6268
defer func() {
6369
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
6470
}()
65-
CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
71+
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
6672

6773
// Create MNIST training script
6874
mnist := constructMNISTConfigMap(test, namespace)
@@ -71,7 +77,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
7177
test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name)
7278

7379
// Create RayCluster and assign it to the localqueue
74-
rayCluster := constructRayCluster(test, namespace, mnist, numberOfGpus)
80+
rayCluster := constructRayCluster(test, namespace, localQueue.Name, mnist, numberOfGpus, accelerator, rayImage)
7581
rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Create(test.Ctx(), rayCluster, metav1.CreateOptions{})
7682
test.Expect(err).NotTo(HaveOccurred())
7783
test.T().Logf("Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name)
@@ -81,7 +87,7 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
8187
Should(WithTransform(RayClusterState, Equal(rayv1.Ready)))
8288

8389
// Create RayJob
84-
rayJob := constructRayJob(test, namespace, rayCluster, accelerator, numberOfGpus)
90+
rayJob := constructRayJob(test, namespace, rayCluster, accelerator, numberOfGpus, rayImage)
8591
rayJob, err = test.Client().Ray().RayV1().RayJobs(namespace.Name).Create(test.Ctx(), rayJob, metav1.CreateOptions{})
8692
test.Expect(err).NotTo(HaveOccurred())
8793
test.T().Logf("Created RayJob %s/%s successfully", rayJob.Namespace, rayJob.Name)
@@ -110,15 +116,19 @@ func runMnistRayJobRayCluster(t *testing.T, accelerator string, numberOfGpus int
110116
}
111117

112118
func TestMnistRayJobRayClusterAppWrapperCpu(t *testing.T) {
113-
runMnistRayJobRayClusterAppWrapper(t, "cpu", 0)
119+
runMnistRayJobRayClusterAppWrapper(t, CPU, 0, GetRayImage())
120+
}
121+
122+
func TestMnistRayJobRayClusterAppWrapperCudaGpu(t *testing.T) {
123+
runMnistRayJobRayClusterAppWrapper(t, NVIDIA, 1, GetRayImage())
114124
}
115125

116-
func TestMnistRayJobRayClusterAppWrapperGpu(t *testing.T) {
117-
runMnistRayJobRayClusterAppWrapper(t, "gpu", 1)
126+
func TestMnistRayJobRayClusterAppWrapperROCmGpu(t *testing.T) {
127+
runMnistRayJobRayClusterAppWrapper(t, AMD, 1, GetRayROCmImage())
118128
}
119129

120130
// Same as TestMNISTRayJobRayCluster, except the RayCluster is wrapped in an AppWrapper
121-
func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, numberOfGpus int) {
131+
func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator Accelerator, numberOfGpus int, rayImage string) {
122132
test := With(t)
123133

124134
// Create a static namespace to ensure a consistent Ray Dashboard hostname entry in /etc/hosts before executing the test.
@@ -129,7 +139,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
129139
defer func() {
130140
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
131141
}()
132-
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus)
142+
clusterQueue := createClusterQueue(test, resourceFlavor, numberOfGpus, accelerator)
133143
defer func() {
134144
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
135145
}()
@@ -142,7 +152,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
142152
test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name)
143153

144154
// Create RayCluster, wrap in AppWrapper and assign to localqueue
145-
rayCluster := constructRayCluster(test, namespace, mnist, numberOfGpus)
155+
rayCluster := constructRayCluster(test, namespace, localQueue.Name, mnist, numberOfGpus, accelerator, rayImage)
146156
raw := Raw(test, rayCluster)
147157
raw = RemoveCreationTimestamp(test, raw)
148158

@@ -183,7 +193,7 @@ func runMnistRayJobRayClusterAppWrapper(t *testing.T, accelerator string, number
183193
Should(WithTransform(RayClusterState, Equal(rayv1.Ready)))
184194

185195
// Create RayJob
186-
rayJob := constructRayJob(test, namespace, rayCluster, accelerator, numberOfGpus)
196+
rayJob := constructRayJob(test, namespace, rayCluster, accelerator, numberOfGpus, rayImage)
187197
rayJob, err = test.Client().Ray().RayV1().RayJobs(namespace.Name).Create(test.Ctx(), rayJob, metav1.CreateOptions{})
188198
test.Expect(err).NotTo(HaveOccurred())
189199
test.T().Logf("Created RayJob %s/%s successfully", rayJob.Namespace, rayJob.Name)
@@ -223,11 +233,11 @@ func TestRayClusterImagePullSecret(t *testing.T) {
223233
defer func() {
224234
_ = test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
225235
}()
226-
clusterQueue := createClusterQueue(test, resourceFlavor, 0)
236+
clusterQueue := createClusterQueue(test, resourceFlavor, 0, CPU)
227237
defer func() {
228238
_ = test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
229239
}()
230-
CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
240+
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
231241

232242
// Create MNIST training script
233243
mnist := constructMNISTConfigMap(test, namespace)
@@ -236,7 +246,7 @@ func TestRayClusterImagePullSecret(t *testing.T) {
236246
test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name)
237247

238248
// Create RayCluster with imagePullSecret and assign it to the localqueue
239-
rayCluster := constructRayCluster(test, namespace, mnist, 0)
249+
rayCluster := constructRayCluster(test, namespace, localQueue.Name, mnist, 0, CPU, GetRayImage())
240250
rayCluster.Spec.HeadGroupSpec.Template.Spec.ImagePullSecrets = []corev1.LocalObjectReference{{Name: "custom-pull-secret"}}
241251
rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Create(test.Ctx(), rayCluster, metav1.CreateOptions{})
242252
test.Expect(err).NotTo(HaveOccurred())
@@ -266,15 +276,18 @@ func constructMNISTConfigMap(test Test, namespace *corev1.Namespace) *corev1.Con
266276
}
267277
}
268278

269-
func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.ConfigMap, numberOfGpus int) *rayv1.RayCluster {
270-
return &rayv1.RayCluster{
279+
func constructRayCluster(_ Test, namespace *corev1.Namespace, localQueueName string, mnist *corev1.ConfigMap, numberOfGpus int, accelerator Accelerator, rayImage string) *rayv1.RayCluster {
280+
raycluster := rayv1.RayCluster{
271281
TypeMeta: metav1.TypeMeta{
272282
APIVersion: rayv1.GroupVersion.String(),
273283
Kind: "RayCluster",
274284
},
275285
ObjectMeta: metav1.ObjectMeta{
276286
Name: "raycluster",
277287
Namespace: namespace.Name,
288+
Labels: map[string]string{
289+
"kueue.x-k8s.io/queue-name": localQueueName,
290+
},
278291
},
279292
Spec: rayv1.RayClusterSpec{
280293
RayVersion: GetRayVersion(),
@@ -287,7 +300,7 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
287300
Containers: []corev1.Container{
288301
{
289302
Name: "ray-head",
290-
Image: GetRayImage(),
303+
Image: rayImage,
291304
Ports: []corev1.ContainerPort{
292305
{
293306
ContainerPort: 6379,
@@ -342,7 +355,7 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
342355
Containers: []corev1.Container{
343356
{
344357
Name: "ray-worker",
345-
Image: GetRayImage(),
358+
Image: rayImage,
346359
Lifecycle: &corev1.Lifecycle{
347360
PreStop: &corev1.LifecycleHandler{
348361
Exec: &corev1.ExecAction{
@@ -352,14 +365,14 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
352365
},
353366
Resources: corev1.ResourceRequirements{
354367
Requests: corev1.ResourceList{
355-
corev1.ResourceCPU: resource.MustParse("250m"),
356-
corev1.ResourceMemory: resource.MustParse("1G"),
357-
"nvidia.com/gpu": resource.MustParse(fmt.Sprint(numberOfGpus)),
368+
corev1.ResourceCPU: resource.MustParse("250m"),
369+
corev1.ResourceMemory: resource.MustParse("1G"),
370+
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse(fmt.Sprint(numberOfGpus)),
358371
},
359372
Limits: corev1.ResourceList{
360-
corev1.ResourceCPU: resource.MustParse("2"),
361-
corev1.ResourceMemory: resource.MustParse("4G"),
362-
"nvidia.com/gpu": resource.MustParse(fmt.Sprint(numberOfGpus)),
373+
corev1.ResourceCPU: resource.MustParse("2"),
374+
corev1.ResourceMemory: resource.MustParse("4G"),
375+
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse(fmt.Sprint(numberOfGpus)),
363376
},
364377
},
365378
VolumeMounts: []corev1.VolumeMount{
@@ -388,9 +401,37 @@ func constructRayCluster(_ Test, namespace *corev1.Namespace, mnist *corev1.Conf
388401
},
389402
},
390403
}
404+
405+
if accelerator.ResourceLabel == "amd.com/gpu" {
406+
// Remove the nvidia.com/gpu resource
407+
delete(raycluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests, corev1.ResourceName("nvidia.com/gpu"))
408+
delete(raycluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits, corev1.ResourceName("nvidia.com/gpu"))
409+
410+
// update with amd.com/gpu resource
411+
raycluster.Spec.WorkerGroupSpecs[0].Template.Spec.Tolerations[0].Key = "amd.com/gpu"
412+
raycluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Requests[corev1.ResourceName("amd.com/gpu")] = resource.MustParse(fmt.Sprint(numberOfGpus))
413+
raycluster.Spec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName("amd.com/gpu")] = resource.MustParse(fmt.Sprint(numberOfGpus))
414+
}
415+
416+
return &raycluster
391417
}
392418

393-
func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayCluster, accelerator string, numberOfGpus int) *rayv1.RayJob {
419+
func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayCluster, accelerator Accelerator, numberOfGpus int, rayImage string) *rayv1.RayJob {
420+
pipPackages := []string{
421+
"pytorch_lightning==2.4.0",
422+
"torchmetrics==1.6.0",
423+
"torchvision==0.19.1",
424+
}
425+
426+
// Append AMD-specific packages
427+
if accelerator.ResourceLabel == "amd.com/gpu" {
428+
pipPackages = append(pipPackages,
429+
"--extra-index-url https://download.pytorch.org/whl/rocm6.1",
430+
"torch==2.4.1+rocm6.1",
431+
)
432+
}
433+
434+
// Construct RayJob with the final pip list
394435
return &rayv1.RayJob{
395436
TypeMeta: metav1.TypeMeta{
396437
APIVersion: rayv1.GroupVersion.String(),
@@ -402,17 +443,15 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
402443
},
403444
Spec: rayv1.RayJobSpec{
404445
Entrypoint: "python /home/ray/jobs/mnist.py",
405-
RuntimeEnvYAML: `
406-
pip:
407-
- pytorch_lightning==2.4.0
408-
- torchmetrics==1.6.0
409-
- torchvision==0.20.1
410-
env_vars:
411-
MNIST_DATASET_URL: "` + GetMnistDatasetURL() + `"
412-
PIP_INDEX_URL: "` + GetPipIndexURL() + `"
413-
PIP_TRUSTED_HOST: "` + GetPipTrustedHost() + `"
414-
ACCELERATOR: "` + accelerator + `"
415-
`,
446+
RuntimeEnvYAML: fmt.Sprintf(`
447+
pip:
448+
- %s
449+
env_vars:
450+
MNIST_DATASET_URL: "%s"
451+
PIP_INDEX_URL: "%s"
452+
PIP_TRUSTED_HOST: "%s"
453+
ACCELERATOR: "%s"
454+
`, strings.Join(pipPackages, "\n - "), GetMnistDatasetURL(), GetPipIndexURL(), GetPipTrustedHost(), accelerator.Type),
416455
ClusterSelector: map[string]string{
417456
RayJobDefaultClusterSelectorKey: rayCluster.Name,
418457
},
@@ -422,7 +461,7 @@ func constructRayJob(_ Test, namespace *corev1.Namespace, rayCluster *rayv1.RayC
422461
RestartPolicy: corev1.RestartPolicyNever,
423462
Containers: []corev1.Container{
424463
{
425-
Image: GetRayImage(),
464+
Image: rayImage,
426465
Name: "rayjob-submitter-pod",
427466
},
428467
},
@@ -477,7 +516,7 @@ func getRayDashboardURL(test Test, namespace, rayClusterName string) string {
477516
}
478517

479518
// Create ClusterQueue
480-
func createClusterQueue(test Test, resourceFlavor *v1beta1.ResourceFlavor, numberOfGpus int) *v1beta1.ClusterQueue {
519+
func createClusterQueue(test Test, resourceFlavor *v1beta1.ResourceFlavor, numberOfGpus int, accelerator Accelerator) *v1beta1.ClusterQueue {
481520
cqSpec := v1beta1.ClusterQueueSpec{
482521
NamespaceSelector: &metav1.LabelSelector{},
483522
ResourceGroups: []v1beta1.ResourceGroup{
@@ -505,5 +544,11 @@ func createClusterQueue(test Test, resourceFlavor *v1beta1.ResourceFlavor, numbe
505544
},
506545
},
507546
}
547+
548+
if accelerator.ResourceLabel == "amd.com/gpu" {
549+
cqSpec.ResourceGroups[0].CoveredResources[2] = corev1.ResourceName(accelerator.ResourceLabel)
550+
cqSpec.ResourceGroups[0].Flavors[0].Resources[2].Name = corev1.ResourceName(accelerator.ResourceLabel)
551+
}
552+
508553
return CreateKueueClusterQueue(test, cqSpec)
509554
}

0 commit comments

Comments
 (0)