Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for leading_zeros, trailing_zeros and fix count_ones #213

Merged
merged 15 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 3 additions & 30 deletions crates/rustc_codegen_spirv/src/builder/ext_inst.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::custom_insts;
use rspirv::dr::Operand;
use rspirv::spirv::{GLOp, Word};
use rspirv::{dr::Operand, spirv::Capability};

const GLSL_STD_450: &str = "GLSL.std.450";

Expand All @@ -13,7 +13,6 @@ pub struct ExtInst {
custom: Option<Word>,

glsl: Option<Word>,
integer_functions_2_intel: bool,
}

impl ExtInst {
Expand All @@ -38,32 +37,11 @@ impl ExtInst {
id
}
}

pub fn require_integer_functions_2_intel(&mut self, bx: &Builder<'_, '_>, to_zombie: Word) {
if !self.integer_functions_2_intel {
self.integer_functions_2_intel = true;
if !bx
.builder
.has_capability(Capability::IntegerFunctions2INTEL)
{
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
}
if !bx
.builder
.has_extension(bx.sym.spv_intel_shader_integer_functions2)
{
bx.zombie(
to_zombie,
"extension SPV_INTEL_shader_integer_functions2 is required",
);
}
}
}
}

impl<'a, 'tcx> Builder<'a, 'tcx> {
pub fn custom_inst(
&mut self,
&self,
result_type: Word,
inst: custom_insts::CustomInst<Operand>,
) -> SpirvValue {
Expand All @@ -80,12 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.with_type(result_type)
}

pub fn gl_op(
&mut self,
op: GLOp,
result_type: Word,
args: impl AsRef<[SpirvValue]>,
) -> SpirvValue {
pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
let args = args.as_ref();
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
self.emit()
Expand Down
283 changes: 245 additions & 38 deletions crates/rustc_codegen_spirv/src/builder/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,46 +211,15 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
self.rotate(val, shift, is_left)
}

// TODO: Do we want to manually implement these instead of using intel instructions?
sym::ctlz | sym::ctlz_nonzero => {
let result = self
.emit()
.u_count_leading_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
}
sym::cttz | sym::cttz_nonzero => {
let result = self
.emit()
.u_count_trailing_zeros_intel(
args[0].immediate().ty,
None,
args[0].immediate().def(self),
)
.unwrap();
self.ext_inst
.borrow_mut()
.require_integer_functions_2_intel(self, result);
result.with_type(args[0].immediate().ty)
sym::ctlz => self.count_leading_trailing_zeros(args[0].immediate(), false, false),
sym::ctlz_nonzero => {
self.count_leading_trailing_zeros(args[0].immediate(), false, true)
}
sym::cttz => self.count_leading_trailing_zeros(args[0].immediate(), true, false),
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true),

sym::ctpop => self
.emit()
.bit_count(args[0].immediate().ty, None, args[0].immediate().def(self))
.unwrap()
.with_type(args[0].immediate().ty),
sym::bitreverse => self
.emit()
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self))
.unwrap()
.with_type(args[0].immediate().ty),
sym::ctpop => self.count_ones(args[0].immediate()),
sym::bitreverse => self.bit_reverse(args[0].immediate()),
sym::bswap => {
// https://github.com/KhronosGroup/SPIRV-LLVM/pull/221/files
// TODO: Definitely add tests to make sure this impl is right.
Expand Down Expand Up @@ -398,6 +367,244 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
}

impl Builder<'_, '_> {
pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, false) => {
let u32 = SpirvType::Integer(32, false).def(self.span(), self);

match bits {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a style thing, but I generally feel it is clearer to matched on signed as well, it makes all the cases clearer, similar to what you did in bitcast.

8 | 16 => {
let arg = arg.def(self);
let arg = self.emit().u_convert(u32, None, arg).unwrap();
self.emit().bit_count(u32, None, arg).unwrap()
}
32 => self.emit().bit_count(u32, None, arg.def(self)).unwrap(),
64 => {
let u32_32 = self.constant_u32(self.span(), 32).def(self);
let arg = arg.def(self);
let lower = self.emit().u_convert(u32, None, arg).unwrap();
let higher = self
.emit()
.shift_right_logical(ty, None, arg, u32_32)
.unwrap();
let higher = self.emit().u_convert(u32, None, higher).unwrap();

let lower_bits = self.emit().bit_count(u32, None, lower).unwrap();
let higher_bits = self.emit().bit_count(u32, None, higher).unwrap();
self.emit()
.i_add(u32, None, lower_bits, higher_bits)
.unwrap()
}
_ => {
let undef = self.undef(ty).def(self);
self.zombie(
undef,
&format!("count_ones() on unsupported {ty:?} bit integer type"),
);
undef
}
}
.with_type(u32)
}
_ => self.fatal(format!(
"count_ones() expected an unsigned integer type, got {:?}",
self.cx.lookup_type(ty)
)),
}
}

pub fn bit_reverse(&self, arg: SpirvValue) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, false) => {
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
let uint = SpirvType::Integer(bits, false).def(self.span(), self);

match bits {
8 | 16 => {
let arg = arg.def(self);
let arg = self.emit().u_convert(u32, None, arg).unwrap();

let reverse = self.emit().bit_reverse(u32, None, arg).unwrap();
let shift = self.constant_u32(self.span(), 32 - bits).def(self);
let reverse = self
.emit()
.shift_right_logical(u32, None, reverse, shift)
.unwrap();
self.emit().u_convert(uint, None, reverse).unwrap()
}
32 => self.emit().bit_reverse(u32, None, arg.def(self)).unwrap(),
64 => {
let u32_32 = self.constant_u32(self.span(), 32).def(self);
let arg = arg.def(self);
let lower = self.emit().u_convert(u32, None, arg).unwrap();
let higher = self
.emit()
.shift_right_logical(ty, None, arg, u32_32)
.unwrap();
let higher = self.emit().u_convert(u32, None, higher).unwrap();

// note that higher and lower have swapped
let higher_bits = self.emit().bit_reverse(u32, None, lower).unwrap();
let lower_bits = self.emit().bit_reverse(u32, None, higher).unwrap();

let higher_bits = self.emit().u_convert(uint, None, higher_bits).unwrap();
let higher_bits = self
.emit()
.shift_left_logical(uint, None, higher_bits, u32_32)
.unwrap();
let lower_bits = self.emit().u_convert(uint, None, lower_bits).unwrap();

self.emit()
.bitwise_or(ty, None, lower_bits, higher_bits)
.unwrap()
}
_ => {
let undef = self.undef(ty).def(self);
self.zombie(
undef,
&format!("bit_reverse() on unsupported {ty:?} bit integer type"),
);
undef
}
}
.with_type(ty)
}
_ => self.fatal(format!(
"bit_reverse() expected an unsigned integer type, got {:?}",
self.cx.lookup_type(ty)
)),
}
}

pub fn count_leading_trailing_zeros(
&self,
arg: SpirvValue,
trailing: bool,
non_zero: bool,
) -> SpirvValue {
let ty = arg.ty;
match self.cx.lookup_type(ty) {
SpirvType::Integer(bits, false) => {
let bool = SpirvType::Bool.def(self.span(), self);
let u32 = SpirvType::Integer(32, false).def(self.span(), self);

let glsl = self.ext_inst.borrow_mut().import_glsl(self);
let find_xsb = |arg| {
if trailing {
self.emit()
.ext_inst(u32, None, glsl, GLOp::FindILsb as u32, [Operand::IdRef(
arg,
)])
.unwrap()
} else {
// rust is always unsigned, so FindUMsb
let msb_bit = self
.emit()
.ext_inst(u32, None, glsl, GLOp::FindUMsb as u32, [Operand::IdRef(
arg,
)])
.unwrap();
// the glsl op returns the Msb bit, not the amount of leading zeros of this u32
// leading zeros = 31 - Msb bit
let u32_31 = self.constant_u32(self.span(), 31).def(self);
self.emit().i_sub(u32, None, u32_31, msb_bit).unwrap()
}
};

let converted = match bits {
8 | 16 => {
if trailing {
let arg = self.emit().u_convert(u32, None, arg.def(self)).unwrap();
find_xsb(arg)
} else {
let arg = arg.def(self);
let arg = self.emit().u_convert(u32, None, arg).unwrap();
let xsb = find_xsb(arg);
let subtrahend = self.constant_u32(self.span(), 32 - bits).def(self);
self.emit().i_sub(u32, None, xsb, subtrahend).unwrap()
}
}
32 => find_xsb(arg.def(self)),
64 => {
let u32_0 = self.constant_int(u32, 0).def(self);
let u32_32 = self.constant_u32(self.span(), 32).def(self);

let arg = arg.def(self);
let lower = self.emit().u_convert(u32, None, arg).unwrap();
let higher = self
.emit()
.shift_right_logical(ty, None, arg, u32_32)
.unwrap();
let higher = self.emit().u_convert(u32, None, higher).unwrap();

let lower_bits = find_xsb(lower);
let higher_bits = find_xsb(higher);

if trailing {
let use_lower = self.emit().i_equal(bool, None, higher, u32_0).unwrap();
let lower_bits =
self.emit().i_add(u32, None, lower_bits, u32_32).unwrap();
self.emit()
.select(u32, None, use_lower, lower_bits, higher_bits)
.unwrap()
} else {
let use_higher = self.emit().i_equal(bool, None, lower, u32_0).unwrap();
let higher_bits =
self.emit().i_add(u32, None, higher_bits, u32_32).unwrap();
self.emit()
.select(u32, None, use_higher, higher_bits, lower_bits)
.unwrap()
}
}
_ => {
let undef = self.undef(ty).def(self);
self.zombie(undef, &format!(
"count_leading_trailing_zeros() on unsupported {ty:?} bit integer type"
));
undef
}
};

if non_zero {
converted
} else {
let int_0 = self.constant_int(ty, 0).def(self);
let int_bits = self.constant_int(u32, bits as u128).def(self);
let is_0 = self
.emit()
.i_equal(bool, None, arg.def(self), int_0)
.unwrap();
self.emit()
.select(u32, None, is_0, int_bits, converted)
.unwrap()
}
.with_type(u32)
}
SpirvType::Integer(bits, true) => {
// rustc wants `[i8,i16,i32,i64]::leading_zeros()` with `non_zero: true` for some reason. I do not know
// how these are reachable, marking them as zombies makes none of our compiletests fail.
let unsigned = SpirvType::Integer(bits, false).def(self.span(), self);
let arg = self
.emit()
.bitcast(unsigned, None, arg.def(self))
.unwrap()
.with_type(unsigned);
let result = self.count_leading_trailing_zeros(arg, trailing, non_zero);
self.emit()
.bitcast(ty, None, result.def(self))
.unwrap()
.with_type(ty)
}
e => {
self.fatal(format!(
"count_leading_trailing_zeros(trailing: {trailing}, non_zero: {non_zero}) expected an integer type, got {e:?}",
));
}
}
}

pub fn abort_with_kind_and_message_debug_printf(
&mut self,
kind: &str,
Expand Down
Loading