Skip to content

Commit a7a9099

Browse files
committed
Add 64 bit atomics
1 parent 6040820 commit a7a9099

File tree

14 files changed

+327
-7
lines changed

14 files changed

+327
-7
lines changed

Diff for: CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ By @cwfitzgerald in [#5325](https://github.com/gfx-rs/wgpu/pull/5325).
109109

110110
By @ErichDonGubler in [#5146](https://github.com/gfx-rs/wgpu/pull/5146), [#5046](https://github.com/gfx-rs/wgpu/pull/5046).
111111
- Signed and unsigned 64 bit integer support in shaders. By @rodolphito and @cwfitzgerald in [#5154](https://github.com/gfx-rs/wgpu/pull/5154)
112+
- 64 bit integer atomic support in shaders. By @rodolphito and @JMS55 in [#5381](https://github.com/gfx-rs/wgpu/pull/5381)
112113
- `wgpu::Instance` can now report which `wgpu::Backends` are available based on the build configuration. By @wumpf [#5167](https://github.com/gfx-rs/wgpu/pull/5167)
113114
```diff
114115
-wgpu::Instance::any_backend_feature_enabled()

Diff for: naga/src/valid/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bitflags::bitflags! {
7777
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7878
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
7979
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
80-
pub struct Capabilities: u16 {
80+
pub struct Capabilities: u32 {
8181
/// Support for [`AddressSpace:PushConstant`].
8282
const PUSH_CONSTANT = 0x1;
8383
/// Float values with width = 8.
@@ -110,6 +110,8 @@ bitflags::bitflags! {
110110
const CUBE_ARRAY_TEXTURES = 0x4000;
111111
/// Support for 64-bit signed and unsigned integers.
112112
const SHADER_INT64 = 0x8000;
113+
/// Support for 64-bit signed and unsigned integers.
114+
const SHADER_INT64_ATOMIC = 0x10000;
113115
}
114116
}
115117

Diff for: naga/src/valid/type.rs

+19-6
Original file line numberDiff line numberDiff line change
@@ -353,16 +353,29 @@ impl super::Validator {
353353
)
354354
}
355355
Ti::Atomic(crate::Scalar { kind, width }) => {
356-
let good = match kind {
356+
match kind {
357357
crate::ScalarKind::Bool
358358
| crate::ScalarKind::Float
359359
| crate::ScalarKind::AbstractInt
360-
| crate::ScalarKind::AbstractFloat => false,
361-
crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
360+
| crate::ScalarKind::AbstractFloat => {
361+
return Err(TypeError::InvalidAtomicWidth(kind, width))
362+
}
363+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
364+
if width == 8 {
365+
if self
366+
.capabilities
367+
.contains(Capabilities::SHADER_INT64_ATOMIC)
368+
{
369+
} else {
370+
return Err(TypeError::MissingCapability(
371+
Capabilities::SHADER_INT64_ATOMIC,
372+
));
373+
}
374+
} else if width != 4 {
375+
return Err(TypeError::InvalidAtomicWidth(kind, width));
376+
}
377+
}
362378
};
363-
if !good {
364-
return Err(TypeError::InvalidAtomicWidth(kind, width));
365-
}
366379
TypeInfo::new(
367380
TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
368381
Alignment::from_width(width),

Diff for: naga/tests/in/atomicCompareExchange-int64.param.ron

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
(
2+
god_mode: true,
3+
spv: (
4+
version: (1, 0),
5+
capabilities: [ Int64, Int64ImageEXT, Int64Atomics ],
6+
),
7+
hlsl: (
8+
shader_model: V6_0,
9+
binding_map: {},
10+
fake_missing_bindings: true,
11+
special_constants_binding: Some((space: 1, register: 0)),
12+
push_constants_target: Some((space: 0, register: 0)),
13+
zero_initialize_workgroup_memory: true,
14+
),
15+
msl: (
16+
lang_version: (3, 1),
17+
per_entry_point_map: {},
18+
inline_samplers: [],
19+
spirv_cross_compatibility: false,
20+
fake_missing_bindings: true,
21+
zero_initialize_workgroup_memory: true,
22+
),
23+
)

Diff for: naga/tests/in/atomicCompareExchange-int64.wgsl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
const SIZE: u64 = 128u;
2+
3+
@group(0) @binding(0)
4+
var<storage,read_write> arr_i64: array<atomic<i64>, SIZE>;
5+
@group(0) @binding(1)
6+
var<storage,read_write> arr_u64: array<atomic<u64>, SIZE>;
7+
8+
@compute @workgroup_size(1)
9+
fn test_atomic_compare_exchange_i64() {
10+
for(var i = 0u; i < SIZE; i++) {
11+
var old = atomicLoad(&arr_i64[i]);
12+
var exchanged = false;
13+
while(!exchanged) {
14+
let new_ = bitcast<i64>(bitcast<f32>(old) + 1.0);
15+
let result = atomicCompareExchangeWeak(&arr_i64[i], old, new_);
16+
old = result.old_value;
17+
exchanged = result.exchanged;
18+
}
19+
}
20+
}
21+
22+
@compute @workgroup_size(1)
23+
fn test_atomic_compare_exchange_u64() {
24+
for(var i = 0u; i < SIZE; i++) {
25+
var old = atomicLoad(&arr_u64[i]);
26+
var exchanged = false;
27+
while(!exchanged) {
28+
let new_ = bitcast<u64>(bitcast<f32>(old) + 1.0);
29+
let result = atomicCompareExchangeWeak(&arr_u64[i], old, new_);
30+
old = result.old_value;
31+
exchanged = result.exchanged;
32+
}
33+
}
34+
}

Diff for: naga/tests/in/atomicOps-int64.param.ron

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
(
2+
god_mode: true,
3+
spv: (
4+
version: (1, 0),
5+
capabilities: [ Int64, Int64ImageEXT, Int64Atomics ],
6+
),
7+
hlsl: (
8+
shader_model: V6_0,
9+
binding_map: {},
10+
fake_missing_bindings: true,
11+
special_constants_binding: Some((space: 1, register: 0)),
12+
push_constants_target: Some((space: 0, register: 0)),
13+
zero_initialize_workgroup_memory: true,
14+
),
15+
msl: (
16+
lang_version: (3, 1),
17+
per_entry_point_map: {},
18+
inline_samplers: [],
19+
spirv_cross_compatibility: false,
20+
fake_missing_bindings: true,
21+
zero_initialize_workgroup_memory: true,
22+
),
23+
)

Diff for: naga/tests/in/atomicOps-int64.wgsl

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// This test covers the cross product of:
2+
//
3+
// * All atomic operations.
4+
// * On all applicable scopes (storage read-write, workgroup).
5+
// * For all shapes of modeling atomic data.
6+
7+
struct Struct {
8+
atomic_scalar: atomic<u64>,
9+
atomic_arr: array<atomic<i64>, 2>,
10+
}
11+
12+
@group(0) @binding(0)
13+
var<storage, read_write> storage_atomic_scalar: atomic<u64>;
14+
@group(0) @binding(1)
15+
var<storage, read_write> storage_atomic_arr: array<atomic<i64>, 2>;
16+
@group(0) @binding(2)
17+
var<storage, read_write> storage_struct: Struct;
18+
19+
var<workgroup> workgroup_atomic_scalar: atomic<u64>;
20+
var<workgroup> workgroup_atomic_arr: array<atomic<i64>, 2>;
21+
var<workgroup> workgroup_struct: Struct;
22+
23+
@compute
24+
@workgroup_size(2)
25+
fn cs_main(@builtin(local_invocation_id) id: vec3<u64>) {
26+
atomicStore(&storage_atomic_scalar, 1lu);
27+
atomicStore(&storage_atomic_arr[1], 1li);
28+
atomicStore(&storage_struct.atomic_scalar, 1lu);
29+
atomicStore(&storage_struct.atomic_arr[1], 1li);
30+
atomicStore(&workgroup_atomic_scalar, 1lu);
31+
atomicStore(&workgroup_atomic_arr[1], 1li);
32+
atomicStore(&workgroup_struct.atomic_scalar, 1lu);
33+
atomicStore(&workgroup_struct.atomic_arr[1], 1li);
34+
35+
workgroupBarrier();
36+
37+
let l0 = atomicLoad(&storage_atomic_scalar);
38+
let l1 = atomicLoad(&storage_atomic_arr[1]);
39+
let l2 = atomicLoad(&storage_struct.atomic_scalar);
40+
let l3 = atomicLoad(&storage_struct.atomic_arr[1]);
41+
let l4 = atomicLoad(&workgroup_atomic_scalar);
42+
let l5 = atomicLoad(&workgroup_atomic_arr[1]);
43+
let l6 = atomicLoad(&workgroup_struct.atomic_scalar);
44+
let l7 = atomicLoad(&workgroup_struct.atomic_arr[1]);
45+
46+
workgroupBarrier();
47+
48+
atomicAdd(&storage_atomic_scalar, 1lu);
49+
atomicAdd(&storage_atomic_arr[1], 1li);
50+
atomicAdd(&storage_struct.atomic_scalar, 1lu);
51+
atomicAdd(&storage_struct.atomic_arr[1], 1li);
52+
atomicAdd(&workgroup_atomic_scalar, 1lu);
53+
atomicAdd(&workgroup_atomic_arr[1], 1li);
54+
atomicAdd(&workgroup_struct.atomic_scalar, 1lu);
55+
atomicAdd(&workgroup_struct.atomic_arr[1], 1li);
56+
57+
workgroupBarrier();
58+
59+
atomicSub(&storage_atomic_scalar, 1lu);
60+
atomicSub(&storage_atomic_arr[1], 1li);
61+
atomicSub(&storage_struct.atomic_scalar, 1lu);
62+
atomicSub(&storage_struct.atomic_arr[1], 1li);
63+
atomicSub(&workgroup_atomic_scalar, 1lu);
64+
atomicSub(&workgroup_atomic_arr[1], 1li);
65+
atomicSub(&workgroup_struct.atomic_scalar, 1lu);
66+
atomicSub(&workgroup_struct.atomic_arr[1], 1li);
67+
68+
workgroupBarrier();
69+
70+
atomicMax(&storage_atomic_scalar, 1lu);
71+
atomicMax(&storage_atomic_arr[1], 1li);
72+
atomicMax(&storage_struct.atomic_scalar, 1lu);
73+
atomicMax(&storage_struct.atomic_arr[1], 1li);
74+
atomicMax(&workgroup_atomic_scalar, 1lu);
75+
atomicMax(&workgroup_atomic_arr[1], 1li);
76+
atomicMax(&workgroup_struct.atomic_scalar, 1lu);
77+
atomicMax(&workgroup_struct.atomic_arr[1], 1li);
78+
79+
workgroupBarrier();
80+
81+
atomicMin(&storage_atomic_scalar, 1lu);
82+
atomicMin(&storage_atomic_arr[1], 1li);
83+
atomicMin(&storage_struct.atomic_scalar, 1lu);
84+
atomicMin(&storage_struct.atomic_arr[1], 1li);
85+
atomicMin(&workgroup_atomic_scalar, 1lu);
86+
atomicMin(&workgroup_atomic_arr[1], 1li);
87+
atomicMin(&workgroup_struct.atomic_scalar, 1lu);
88+
atomicMin(&workgroup_struct.atomic_arr[1], 1li);
89+
90+
workgroupBarrier();
91+
92+
atomicAnd(&storage_atomic_scalar, 1lu);
93+
atomicAnd(&storage_atomic_arr[1], 1li);
94+
atomicAnd(&storage_struct.atomic_scalar, 1lu);
95+
atomicAnd(&storage_struct.atomic_arr[1], 1li);
96+
atomicAnd(&workgroup_atomic_scalar, 1lu);
97+
atomicAnd(&workgroup_atomic_arr[1], 1li);
98+
atomicAnd(&workgroup_struct.atomic_scalar, 1lu);
99+
atomicAnd(&workgroup_struct.atomic_arr[1], 1li);
100+
101+
workgroupBarrier();
102+
103+
atomicOr(&storage_atomic_scalar, 1lu);
104+
atomicOr(&storage_atomic_arr[1], 1li);
105+
atomicOr(&storage_struct.atomic_scalar, 1lu);
106+
atomicOr(&storage_struct.atomic_arr[1], 1li);
107+
atomicOr(&workgroup_atomic_scalar, 1lu);
108+
atomicOr(&workgroup_atomic_arr[1], 1li);
109+
atomicOr(&workgroup_struct.atomic_scalar, 1lu);
110+
atomicOr(&workgroup_struct.atomic_arr[1], 1li);
111+
112+
workgroupBarrier();
113+
114+
atomicXor(&storage_atomic_scalar, 1lu);
115+
atomicXor(&storage_atomic_arr[1], 1li);
116+
atomicXor(&storage_struct.atomic_scalar, 1lu);
117+
atomicXor(&storage_struct.atomic_arr[1], 1li);
118+
atomicXor(&workgroup_atomic_scalar, 1lu);
119+
atomicXor(&workgroup_atomic_arr[1], 1li);
120+
atomicXor(&workgroup_struct.atomic_scalar, 1lu);
121+
atomicXor(&workgroup_struct.atomic_arr[1], 1li);
122+
123+
atomicExchange(&storage_atomic_scalar, 1lu);
124+
atomicExchange(&storage_atomic_arr[1], 1li);
125+
atomicExchange(&storage_struct.atomic_scalar, 1lu);
126+
atomicExchange(&storage_struct.atomic_arr[1], 1li);
127+
atomicExchange(&workgroup_atomic_scalar, 1lu);
128+
atomicExchange(&workgroup_atomic_arr[1], 1li);
129+
atomicExchange(&workgroup_struct.atomic_scalar, 1lu);
130+
atomicExchange(&workgroup_struct.atomic_arr[1], 1li);
131+
132+
// // TODO: https://github.com/gpuweb/gpuweb/issues/2021
133+
// atomicCompareExchangeWeak(&storage_atomic_scalar, 1lu);
134+
// atomicCompareExchangeWeak(&storage_atomic_arr[1], 1li);
135+
// atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1lu);
136+
// atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1li);
137+
// atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1lu);
138+
// atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1li);
139+
// atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1lu);
140+
// atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1li);
141+
}

Diff for: naga/tests/in/int64.param.ron

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
god_mode: true,
33
spv: (
44
version: (1, 0),
5+
capabilities: [ Int64 ],
56
),
67
hlsl: (
78
shader_model: V6_0,

Diff for: wgpu-core/src/device/resource.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,10 @@ impl<A: HalApi> Device<A> {
15151515
Caps::SHADER_INT64,
15161516
self.features.contains(wgt::Features::SHADER_INT64),
15171517
);
1518+
caps.set(
1519+
Caps::SHADER_INT64_ATOMIC,
1520+
self.features.contains(wgt::Features::SHADER_INT64_ATOMIC),
1521+
);
15181522
caps.set(
15191523
Caps::MULTISAMPLED_SHADING,
15201524
self.downlevel

Diff for: wgpu-hal/src/dx12/adapter.rs

+17
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,23 @@ impl super::Adapter {
311311
};
312312
features.set(wgt::Features::SHADER_INT64, int64_shader_ops_supported);
313313

314+
let atomic_int64_on_typed_resource_supported = {
315+
let mut features9: crate::dx12::types::D3D12_FEATURE_DATA_D3D12_OPTIONS9 =
316+
unsafe { mem::zeroed() };
317+
let hr = unsafe {
318+
device.CheckFeatureSupport(
319+
37, // D3D12_FEATURE_D3D12_OPTIONS9
320+
&mut features9 as *mut _ as *mut _,
321+
mem::size_of::<crate::dx12::types::D3D12_FEATURE_DATA_D3D12_OPTIONS9>() as _,
322+
)
323+
};
324+
hr == 0 && features9.AtomicInt64OnTypedResourceSupported != 0
325+
};
326+
features.set(
327+
wgt::Features::SHADER_INT64_ATOMIC,
328+
atomic_int64_on_typed_resource_supported,
329+
);
330+
314331
// float32-filterable should always be available on d3d12
315332
features.set(wgt::Features::FLOAT32_FILTERABLE, true);
316333

Diff for: wgpu-hal/src/dx12/types.rs

+18
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ winapi::ENUM! {
3232
}
3333
}
3434

35+
winapi::ENUM! {
36+
enum D3D12_WAVE_MMA_TIER {
37+
D3D12_WAVE_MMA_TIER_NOT_SUPPORTED = 0,
38+
D3D12_WAVE_MMA_TIER_1_0 = 10,
39+
}
40+
}
41+
3542
winapi::STRUCT! {
3643
struct D3D12_FEATURE_DATA_D3D12_OPTIONS3 {
3744
CopyQueueTimestampQueriesSupported: winapi::shared::minwindef::BOOL,
@@ -41,3 +48,14 @@ winapi::STRUCT! {
4148
BarycentricsSupported: winapi::shared::minwindef::BOOL,
4249
}
4350
}
51+
52+
winapi::STRUCT! {
53+
struct D3D12_FEATURE_DATA_D3D12_OPTIONS9 {
54+
MeshShaderPipelineStatsSupported: winapi::shared::minwindef::BOOL,
55+
MeshShaderSupportsFullRangeRenderTargetArrayIndex: winapi::shared::minwindef::BOOL,
56+
AtomicInt64OnTypedResourceSupported: winapi::shared::minwindef::BOOL,
57+
AtomicInt64OnGroupSharedSupported: winapi::shared::minwindef::BOOL,
58+
DerivativesInMeshAndAmplificationShadersSupported: winapi::shared::minwindef::BOOL,
59+
WaveMMATier: D3D12_WAVE_MMA_TIER,
60+
}
61+
}

Diff for: wgpu-hal/src/metal/adapter.rs

+4
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,10 @@ impl super::PrivateCapabilities {
882882
F::SHADER_INT64,
883883
self.msl_version >= MTLLanguageVersion::V2_3,
884884
);
885+
features.set(
886+
F::SHADER_INT64_ATOMIC,
887+
self.msl_version >= MTLLanguageVersion::V3_1,
888+
);
885889

886890
features.set(
887891
F::ADDRESS_MODE_CLAMP_TO_BORDER,

0 commit comments

Comments
 (0)