Skip to content

Commit c0c5879

Browse files
committed
count_ones: fix bit_reverse, must be u32-only in vulkan
1 parent 0759f28 commit c0c5879

File tree

1 file changed

+73
-5
lines changed

1 file changed

+73
-5
lines changed

crates/rustc_codegen_spirv/src/builder/intrinsics.rs

+73-5
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
219219
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true),
220220

221221
sym::ctpop => self.count_ones(args[0].immediate()),
222-
sym::bitreverse => self
223-
.emit()
224-
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self))
225-
.unwrap()
226-
.with_type(args[0].immediate().ty),
222+
sym::bitreverse => self.bit_reverse(args[0].immediate()),
227223
sym::bswap => {
228224
// https://github.com/KhronosGroup/SPIRV-LLVM/pull/221/files
229225
// TODO: Definitely add tests to make sure this impl is right.
@@ -418,6 +414,78 @@ impl Builder<'_, '_> {
418414
_ => self.fatal("count_ones on a non-integer type"),
419415
}
420416
}
417+
pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue {
418+
let ty = arg.ty;
419+
match self.cx.lookup_type(ty) {
420+
SpirvType::Integer(bits, signed) => {
421+
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
422+
let uint = SpirvType::Integer(bits, false).def(self.span(), self);
423+
424+
match (bits, signed) {
425+
(8 | 16, signed) => {
426+
let arg = arg.def(self);
427+
let arg = if signed {
428+
self.emit().bitcast(uint, None, arg).unwrap()
429+
} else {
430+
arg
431+
};
432+
let arg = self.emit().u_convert(u32, None, arg).unwrap();
433+
434+
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
435+
let shift = self.constant_u32(self.span(), 32 - bits).def(self);
436+
let reverse = self.emit().shift_right_logical(u32, None, reverse, shift).unwrap();
437+
let reverse = self.emit().u_convert(uint, None, reverse).unwrap();
438+
if signed {
439+
self.emit().bitcast(ty, None, reverse).unwrap()
440+
} else {
441+
reverse
442+
}
443+
}
444+
(32, false) => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(),
445+
(32, true) => {
446+
let arg = self.emit().bitcast(u32, None, arg.def(self)).unwrap();
447+
let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
448+
self.emit().bitcast(ty, None, reverse).unwrap()
449+
},
450+
(64, signed) => {
451+
let u32_32 = self.constant_u32(self.span(), 32).def(self);
452+
let arg = arg.def(self);
453+
let lower = self.emit().s_convert(u32, None, arg).unwrap();
454+
let higher = self
455+
.emit()
456+
.shift_left_logical(ty, None, arg, u32_32)
457+
.unwrap();
458+
let higher = self.emit().s_convert(u32, None, higher).unwrap();
459+
460+
// note that higher and lower have swapped
461+
let higher_bits = self.emit().bit_reverse(u32, None, lower).unwrap();
462+
let lower_bits = self.emit().bit_reverse(u32, None, higher).unwrap();
463+
464+
let higher_bits = self.emit().u_convert(uint, None, higher_bits).unwrap();
465+
let shift = self.constant_u32(self.span(), 32).def(self);
466+
let higher_bits = self.emit().shift_right_logical(uint, None, higher_bits, shift).unwrap();
467+
let lower_bits = self.emit().u_convert(uint, None, lower_bits).unwrap();
468+
469+
let result = self.emit().bitwise_or(ty, None, lower_bits, higher_bits).unwrap();
470+
if signed {
471+
self.emit().bitcast(ty, None, result).unwrap()
472+
} else {
473+
result
474+
}
475+
}
476+
_ => {
477+
let undef = self.undef(ty).def(self);
478+
self.zombie(undef, &format!(
479+
"counting leading / trailing zeros on unsupported {ty:?} bit integer type"
480+
));
481+
undef
482+
}
483+
}
484+
.with_type(ty)
485+
}
486+
_ => self.fatal("count_ones on a non-integer type"),
487+
}
488+
}
421489

422490
pub fn count_leading_trailing_zeros(
423491
&self,

0 commit comments

Comments
 (0)