@@ -98,11 +98,11 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
98
98
pool->drain ();
99
99
auto kernel = pipelineValue->kernel ;
100
100
auto pipeline = pipelineValue->pipeline ;
101
- // Allocate a new command.
101
+ // Allocate a new command.
102
102
auto encoder = command_batch->startCommand ();
103
103
encoder->setComputePipelineState (pipeline.get ());
104
104
encoder->setThreadgroupMemoryLength (kernel->threadgroupMemoryAllocation , 0 );
105
-
105
+
106
106
// Bind the function arguments.
107
107
encoder->useResource (tensors[0 ], MTL::ResourceUsageRead);
108
108
encoder->useResource (tensors[1 ], MTL::ResourceUsageRead);
@@ -146,17 +146,37 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
146
146
encoder->setBuffer (scratch, 0 , AttentionOperand (AttentionOperand::L).bufferIndex ());
147
147
}
148
148
}
149
-
149
+
150
150
MTL::Size gridSize
151
- (ceilDivide (int64_t (hash.R ), kernel->blockDimensions [0 ]),
152
- hash.Hq ,
153
- attentionDesc.batchDimension );
151
+ (ceilDivide (int64_t (hash.R ), kernel->blockDimensions [0 ]), 1 , 1 );
154
152
MTL::Size groupSize
155
153
(int64_t (kernel->threadgroupSize ), 1 , 1 );
156
-
157
- // Dispatch the required number of threads.
158
- encoder->dispatchThreadgroups (gridSize, groupSize);
159
-
154
+
155
+ const size_t bytesPerElement = attentionDesc.lowPrecisionInputs ? sizeof (uint16_t ) : sizeof (float );
156
+ for (int i = 0 ; i < attentionDesc.batchDimension ; i++) {
157
+ for (int j = 0 ; j < hash.Hq ; j++) {
158
+ encoder->setBufferOffset (tensor_offsets[0 ] + bytesPerElement * (i * hash.R * hash.D * hash.Hq + j * hash.D ), AttentionOperand (AttentionOperand::Q).bufferIndex ());
159
+ encoder->setBufferOffset (tensor_offsets[1 ] + bytesPerElement * (i * hash.C * hash.D * hash.Hk + j * hash.D ), AttentionOperand (AttentionOperand::K).bufferIndex ());
160
+ encoder->setBufferOffset (tensor_offsets[2 ] + bytesPerElement * (i * hash.C * hash.D * hash.Hk + j * hash.D ), AttentionOperand (AttentionOperand::V).bufferIndex ());
161
+ if (attentionDesc.lowPrecisionInputs ) {
162
+ encoder->setBufferOffset (sizeof (float ) * (i * hash.R * hash.D * hash.Hq + j * hash.D ), AttentionOperand (AttentionOperand::O).bufferIndex ());
163
+ if (tensors[5 ]) {
164
+ encoder->setBufferOffset (tensor_offsets[5 ] + sizeof (float ) * (i * hash.R * hash.Hq + j * hash.R ), AttentionOperand (AttentionOperand::L).bufferIndex ());
165
+ } else {
166
+ encoder->setBufferOffset (sizeof (float ) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension + sizeof (float ) * (i * hash.R * hash.Hq + j * hash.R ), AttentionOperand (AttentionOperand::L).bufferIndex ());
167
+ }
168
+ } else {
169
+ encoder->setBufferOffset (tensor_offsets[3 ] + sizeof (float ) * (i * hash.R * hash.D * hash.Hq + j * hash.D ), AttentionOperand (AttentionOperand::O).bufferIndex ());
170
+ if (tensors[5 ]) {
171
+ encoder->setBufferOffset (tensor_offsets[5 ] + sizeof (float ) * (i * hash.R * hash.Hq + j * hash.R ), AttentionOperand (AttentionOperand::L).bufferIndex ());
172
+ } else {
173
+ encoder->setBufferOffset (sizeof (float ) * (i * hash.R * hash.Hq + j * hash.R ), AttentionOperand (AttentionOperand::L).bufferIndex ());
174
+ }
175
+ }
176
+ // Dispatch the required number of threads.
177
+ encoder->dispatchThreadgroups (gridSize, groupSize);
178
+ }
179
+ }
160
180
// Finish the command.
161
181
command_batch->finishCommand (encoder);
162
182
if (attentionDesc.lowPrecisionInputs ) {
0 commit comments