@@ -265,7 +265,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
265
265
// unrolled (head dimension vastly exceeds head block dimension).
266
266
if (lowPrecisionIntermediates) {
267
267
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
268
- memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision:: BF16;
268
+ memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::BF16;
269
269
} else {
270
270
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
271
271
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
@@ -356,7 +356,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
356
356
// The register precision of L/D only counts for backward key-value.
357
357
if (lowPrecisionIntermediates) {
358
358
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
359
- registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
359
+ registerPrecisions[AttentionOperand::D] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
360
360
} else {
361
361
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
362
362
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
@@ -383,7 +383,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
383
383
registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32;
384
384
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16;
385
385
registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32;
386
- registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32;
386
+ registerPrecisions[AttentionOperand::dS] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
387
387
} else {
388
388
registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32;
389
389
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;
0 commit comments