Skip to content

Commit de59d3e

Browse files
committed
Switch D to BF16, also need to gate for macOS 14.
1 parent fe99c48 commit de59d3e

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

lib/nnc/mfa/v2/AttentionDescriptor.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
265265
// unrolled (head dimension vastly exceeds head block dimension).
266266
if (lowPrecisionIntermediates) {
267267
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
268-
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16;
268+
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::BF16;
269269
} else {
270270
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
271271
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
@@ -356,7 +356,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
356356
// The register precision of L/D only counts for backward key-value.
357357
if (lowPrecisionIntermediates) {
358358
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
359-
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
359+
registerPrecisions[AttentionOperand::D] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
360360
} else {
361361
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
362362
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
@@ -383,7 +383,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
383383
registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32;
384384
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16;
385385
registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32;
386-
registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32;
386+
registerPrecisions[AttentionOperand::dS] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
387387
} else {
388388
registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32;
389389
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;

lib/nnc/mfa/v2/AttentionKernel.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,24 @@ unsigned short AttentionKernel::createThreadgroupMemoryAllocation() const noexce
395395
std::string AttentionKernel::createSource() const noexcept {
396396
CodeWriter source;
397397

398-
bool injectBF16Methods = (memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dV] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dK] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dQ] == GEMMOperandPrecision::BF16);
398+
bool injectBF16Methods = false;
399+
switch (type.value) {
400+
case AttentionKernelType::forward:
401+
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16)) {
402+
injectBF16Methods = true;
403+
}
404+
break;
405+
case AttentionKernelType::backwardQuery:
406+
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dQ] == GEMMOperandPrecision::BF16)) {
407+
injectBF16Methods = true;
408+
}
409+
break;
410+
case AttentionKernelType::backwardKeyValue:
411+
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dV] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dK] == GEMMOperandPrecision::BF16)) {
412+
injectBF16Methods = true;
413+
}
414+
break;
415+
}
399416

400417
// Inject the contents of the headers.
401418
source += createMetalSimdgroupEvent() + "\n";

0 commit comments

Comments
 (0)