Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support shared encoding defined with linear layout #5720

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Jan 27, 2025

This PR enables basic support for linear layout used as a shared encoding.

This change is needed to support custom memory layout introduced in #4984

This PR enables basic support for linear layout used as a shared encoding.
@binarman
Copy link
Contributor Author

+cc @lezcano

AMD team wants to experiment with non standard swizzling patterns(for example #4984), so I had an idea to use LinearEncodingAttr as a memory layout. While implementation I've noticed that it inherits DistributedEncoding.

What do you think about renaming LinearEncodingAttr to DistributedLinearEncodingAttr and introduction of new MemoryLinearEncodingAttr encoding?

@masahi
Copy link
Contributor

masahi commented Jan 27, 2025

I'm interested in this discussion as well. The NV lowering path heavily relies on the legacy encoding of the SMEM layout, so going all in on using LL for SMEM at the IR level is difficult. Moreover, how are you going to query swizzling properties of SMEM represented only via LL?

@Jokeren
Copy link
Contributor

Jokeren commented Jan 27, 2025

I'm interested in this discussion as well. The NV lowering path heavily relies on the legacy encoding of the SMEM layout, so going all in on using LL for SMEM at the IR level is difficult. Moreover, how are you going to query swizzling properties of SMEM represented only via LL?

We can query the first few bases of the offset dimension to see if they are contiguous and check if the rest bases do not overlap with them.

There's indeed an algorithm we plan to implement for the ldmatrix path that won't rely on the standard shared encoding

@Jokeren
Copy link
Contributor

Jokeren commented Jan 27, 2025

AMD path might be much actually simpler since it doesn't slice shared memory and doesn't have ldmatrix/stmatrix instructions.

The only concern is that checking if base names have "offset" to determine a shared encoding isn't a solid solution.

@lezcano
Copy link
Contributor

lezcano commented Jan 28, 2025

@masahi just bumped into this quite recently as well, and as @Jokeren has mentioned, we'll start using new shmem layouts sooner than later.

The issue with SharedLayouts is that they don't have an API like the one DistributedLayouts have. They just have very few attributes that are rather unique to their own structure. In general, a characterisation of the shared memory layout that we care about may be given by is a LinearLayout for which al its bases have at most 2 bits equal to one (popc(b) <= 2). Also, the only case in which we have broadcasting (a basis that is a zero) is the mxfp4 in BW case (we don't model it with a basis equal to zero yet, but we'll do so in the future). In all the other cases these LLs are bijective maps onto their domain.

Taking this structure into account, the tricky part here is to create an API that's returns the relevant properties we need to work with this layout. This would be the equivalent to

SmallVector<unsigned> basesPerDimImpl(const LinearLayout::BasesT &namedBases,
StringAttr dimName, size_t rank,
bool skipBroadcast = true) {
const auto &bases = namedBases.find(dimName)->second;
if (bases.empty()) {
return SmallVector<unsigned>(rank, 1);
}
SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
// Skip a basis if it's broadcasting (all zeros)
// e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout)
if (it != basis.end()) {
nonZeroIdx = it - basis.begin();
ret[nonZeroIdx] *= 2;
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
ret[nonZeroIdx] *= 2;
}
}
return ret;
}
SmallVector<unsigned>
LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const {
auto ll = getLinearLayout();
auto rank = ll.getNumOutDims();
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
}
SmallVector<unsigned>
LinearEncodingAttr::orderPerDim(StringAttr dimName,
ArrayRef<unsigned> defaultOrder) const {
auto ll = getLinearLayout();
const auto &bases = ll.getBases().find(dimName)->second;
llvm::SetVector<unsigned> order;
auto nonZero = [](auto val) { return val != 0; };
for (const auto &basis : bases) {
// Bases can have one or zero non-zero elements
// Skip a basis if it's broadcasting (all zeros)
// e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout)
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it != basis.end()) {
auto i = it - basis.begin();
order.insert(i);
}
}
// If any dim is missing, we add them in the defaultOrder
for (auto i : defaultOrder) {
order.insert(i);
}
return SmallVector<unsigned>(order.begin(), order.end());
}
// [Note. Divergence of methods wrt. legacy layouts]
// For smaller shapes where the CTATile is larger than the output
// tensor, some methods return different values than the legacy layouts. I think
// this is benign tho. An example: what is the the vector of `warpsPerCTA` if
// all the warps hold the same data? I think it should be [1, 1], even if we
// have 4 warps. But perhaps for this we have to add some masking in some
// places... We'll see
SmallVector<unsigned> LinearEncodingAttr::getRepOrder() const {
// This is not correct, but:
// - It happens to agree in most places with the legacy layout
// - getRepOrder does not make sense for LinearEncodingAttr as it already has
// the same shape as the tensor that uses it
return getOrder();
}
SmallVector<unsigned> LinearEncodingAttr::getCTAsPerCGA() const {
// CTAs are split into an identity part (SplitNum) and a broadcast part
return basesPerDim(StringAttr::get(getContext(), "block"),
/*skipBroadcast=*/false);
}
SmallVector<unsigned> LinearEncodingAttr::getCTAOrder() const {
return orderPerDim(StringAttr::get(getContext(), "block"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getCTASplitNum() const {
return basesPerDim(StringAttr::get(getContext(), "block"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpsPerCTA() const {
return basesPerDim(StringAttr::get(getContext(), "warp"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpOrder() const {
return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getThreadsPerWarp() const {
return basesPerDim(StringAttr::get(getContext(), "lane"));
}
SmallVector<unsigned> LinearEncodingAttr::getThreadOrder() const {
return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
auto rank = getRepOrder().size();
auto ll = getLinearLayout();
auto ctx = getContext();
auto kRegister = StringAttr::get(ctx, "register");
// We canonicalize on the spot, as if we use CGAs the regs are not in
// canonical form The order is [reg, lane, warp, rep, block], so we first
// remove the blocks
llvm::SmallVector<unsigned> ctaShape;
for (auto [shape, cgaNum] :
llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) {
ctaShape.push_back(shape / cgaNum);
}
LinearLayout::BasesT bases = ll.getBases();
llvm::SetVector<unsigned> reverseRepOrder;
auto nonZero = [](auto val) { return val != 0; };
auto &registers = bases[StringAttr::get(ctx, "register")];
while (!registers.empty()) {
auto &basis = registers.back();
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// If there's broadcasting (base == zeros) there are no more reps
if (it == basis.end()) {
break;
}
auto dim = it - basis.begin();
reverseRepOrder.insert(dim);
// As soon as we stop finding reps, we stop
if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) {
break;
}
ctaShape[dim] /= 2;
registers.pop_back();
}
return basesPerDimImpl(bases, kRegister, rank);
}
SmallVector<unsigned> LinearEncodingAttr::getOrder() const {
auto rank = getLinearLayout().getNumOutDims();
SmallVector<unsigned> order(rank);
// Choose [rank-1, rank-2, ... 0] as the default order in case
// there are dims that do not move in the register
// This order is as good as any really
std::iota(order.rbegin(), order.rend(), 0);
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}
LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ll = getLinearLayout();
auto canonicalDims = llvm::to_vector(ll.getOutDimNames());
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
llvm::SmallVector<StringAttr> permutedDims;
for (auto dim : getRepOrder()) {
permutedDims.push_back(canonicalDims[dim]);
namedShape[canonicalDims[dim]] = shape[dim];
}
ll = ll.transposeOuts(permutedDims);
ll = ensureLayoutNotSmallerThan(ll, namedShape);
ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false);
ll = ll.transposeOuts(canonicalDims);
return ll;
}
SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto scaledLayout = get(getContext(), toLinearLayout(shape));
auto kRegister = StringAttr::get(getContext(), "register");
return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false);
}

This API is going to be very different to the one from LinearEncodingAttr, so they should be separate, and I expect it to be also a bit different to the one in SharedEncoding. As such, it would be best to create a parent class (LinearSharedEncodingAttr?) that just holds a LinearLayout and make SharedEncoding inherit from it. From here, the way forward would not be simply to implement getVec, getPerPhase, getMaxPhase etc, but actually change passes, aux functions, etc to support LinearSharedEncodingAttr in its full generality (i.e., make them to work with an arbitrary invertible (probably easiest not to think about the mxfp4 case for now) linear layout with popc(b) in (1, 2). You'll have to come up with your own API for this, which may be tricky.

This is not the easiest task, but it's what we'll need to tackle if we want to support generic swizzled layouts. If this is a bit too much in one go, that's fine. Another way to go is to define another subclass of SharedEncodings that are less general than the one I described above, and start implementing that one and adding support across the codebase for it. That is what @masahi is set to do with mxfp4 for BW.

I'll probably make some strides towards tackling the general case next month anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants