diff --git a/apps/typegpu-docs/src/pages/benchmark/benchmark-app.tsx b/apps/typegpu-docs/src/pages/benchmark/benchmark-app.tsx index fd71a53de..bc301f7a4 100644 --- a/apps/typegpu-docs/src/pages/benchmark/benchmark-app.tsx +++ b/apps/typegpu-docs/src/pages/benchmark/benchmark-app.tsx @@ -79,7 +79,109 @@ async function runBench(params: BenchParameterSet): Promise { root.device.queue.writeBuffer(root.unwrap(buffer), 0, data); root.destroy(); - }); + }) + .add('mass boid transfer (partial write)', async () => { + const root = await tgpu.init(); + + const Boid = d.struct({ + pos: d.vec3f, + vel: d.vec3f, + }); + + const BoidArray = d.arrayOf(Boid, amountOfBoids); + + const buffer = root.createBuffer(BoidArray); + + const randomBoid = Math.floor(Math.random() * amountOfBoids); + + buffer.writePartial([ + { + idx: randomBoid, + value: { pos: d.vec3f(1, 2, 3), vel: d.vec3f(4, 5, 6) }, + }, + ]); + + root.destroy(); + }) + .add( + 'mass boid transfer (partial write 20% of the buffer - not contiguous)', + async () => { + const root = await tgpu.init(); + + const Boid = d.struct({ + pos: d.vec3f, + vel: d.vec3f, + }); + + const BoidArray = d.arrayOf(Boid, amountOfBoids); + + const buffer = root.createBuffer(BoidArray); + + const writes = Array.from({ length: amountOfBoids }) + .map((_, i) => i) + .filter((i) => i % 5 === 0) + .map((i) => ({ + idx: i, + value: { pos: d.vec3f(1, 2, 3), vel: d.vec3f(4, 5, 6) }, + })); + + buffer.writePartial(writes); + + root.destroy(); + }, + ) + .add( + 'mass boid transfer (partial write 20% of the buffer, contiguous)', + async () => { + const root = await tgpu.init(); + + const Boid = d.struct({ + pos: d.vec3f, + vel: d.vec3f, + }); + + const BoidArray = d.arrayOf(Boid, amountOfBoids); + + const buffer = root.createBuffer(BoidArray); + + const writes = Array.from({ length: amountOfBoids / 5 }) + .map((_, i) => i) + .map((i) => ({ + idx: i, + value: { pos: d.vec3f(1, 2, 3), vel: d.vec3f(4, 5, 6) }, + })); + + buffer.writePartial(writes); + + root.destroy(); + }, + ) + .add( + 'mass boid transfer (partial write 100% of the buffer - contiguous (duh))', + async () => { + const root = await tgpu.init(); + + const Boid = d.struct({ + pos: d.vec3f, + vel: d.vec3f, + }); + + const BoidArray = d.arrayOf(Boid, amountOfBoids); + + const buffer = root.createBuffer(BoidArray); + + const writes = Array.from({ length: amountOfBoids }) + .map((_, i) => i) + .map((i) => ({ + idx: i, + value: { pos: d.vec3f(1, 2, 3), vel: d.vec3f(4, 5, 6) }, + })); + + buffer.writePartial(writes); + + root.destroy(); + }, + ); await bench.run(); diff --git a/packages/typegpu/src/core/buffer/buffer.ts b/packages/typegpu/src/core/buffer/buffer.ts index 6b8a0120c..d6dad1ade 100644 --- a/packages/typegpu/src/core/buffer/buffer.ts +++ b/packages/typegpu/src/core/buffer/buffer.ts @@ -2,11 +2,12 @@ import { BufferReader, BufferWriter } from 'typed-binary'; import { isWgslData } from '../../data'; import { readData, writeData } from '../../data/dataIO'; import type { AnyData } from '../../data/dataTypes'; +import { getWriteInstructions } from '../../data/partialIO'; import { sizeOf } from '../../data/sizeOf'; import type { BaseData, WgslTypeLiteral } from '../../data/wgslTypes'; import type { Storage } from '../../extension'; import type { TgpuNamable } from '../../namable'; -import type { Infer } from '../../shared/repr'; +import type { Infer, InferPartial } from '../../shared/repr'; import type { MemIdentity } from '../../shared/repr'; import type { UnionToIntersection } from '../../shared/utilityTypes'; import { isGPUBuffer } from '../../types'; @@ -81,6 +82,7 @@ export interface TgpuBuffer extends TgpuNamable { as>(usage: T): UsageTypeToBufferUsage[T]; write(data: Infer): void; + writePartial(data: InferPartial): void; copyFrom(srcBuffer: TgpuBuffer>): void; read(): Promise>; destroy(): void; @@ -253,6 +255,32 @@ class TgpuBufferImpl implements TgpuBuffer { device.queue.writeBuffer(gpuBuffer, 0, hostBuffer, 0, size); } + public writePartial(data: InferPartial): void { + const gpuBuffer = this.buffer; + const device = this._group.device; + + const instructions = getWriteInstructions(this.dataType, data); + + if (gpuBuffer.mapState === 'mapped') { + const mappedRange = gpuBuffer.getMappedRange(); + const mappedView = new Uint8Array(mappedRange); + + for (const instruction of instructions) { + mappedView.set(instruction.data, instruction.data.byteOffset); + } + } else { + for (const instruction of instructions) { + device.queue.writeBuffer( + gpuBuffer, + instruction.data.byteOffset, + instruction.data, + 0, + instruction.data.byteLength, + ); + } + } + } + copyFrom(srcBuffer: TgpuBuffer>): void { if (this.buffer.mapState === 'mapped') { throw new Error('Cannot copy to a mapped buffer.'); diff --git a/packages/typegpu/src/data/array.ts b/packages/typegpu/src/data/array.ts index b8f331c39..db8364298 100644 --- a/packages/typegpu/src/data/array.ts +++ b/packages/typegpu/src/data/array.ts @@ -1,4 +1,4 @@ -import type { Infer, MemIdentity } from '../shared/repr'; +import type { Infer, InferPartial, MemIdentity } from '../shared/repr'; import { sizeOf } from './sizeOf'; import type { AnyWgslData, BaseData, WgslArray } from './wgslTypes'; @@ -33,6 +33,11 @@ class WgslArrayImpl implements WgslArray { /** Type-token, not available at runtime */ public readonly '~repr'!: Infer[]; /** Type-token, not available at runtime */ + public readonly '~reprPartial'!: { + idx: number; + value: InferPartial; + }[]; + /** Type-token, not available at runtime */ public readonly '~memIdent'!: WgslArray>; constructor( diff --git a/packages/typegpu/src/data/dataTypes.ts b/packages/typegpu/src/data/dataTypes.ts index 648cb6a56..0b7c424e1 100644 --- a/packages/typegpu/src/data/dataTypes.ts +++ b/packages/typegpu/src/data/dataTypes.ts @@ -1,5 +1,10 @@ import type { TgpuNamable } from '../namable'; -import type { Infer, InferRecord } from '../shared/repr'; +import type { + Infer, + InferPartial, + InferPartialRecord, + InferRecord, +} from '../shared/repr'; import { vertexFormats } from '../shared/vertexFormat'; import type { PackedData } from './vertexFormatData'; import * as wgsl from './wgslTypes'; @@ -17,6 +22,7 @@ export interface Disarray { readonly elementCount: number; readonly elementType: TElement; readonly '~repr': Infer[]; + readonly '~reprPartial': { idx: number; value: InferPartial }[]; } /** @@ -34,6 +40,7 @@ export interface Unstruct< readonly type: 'unstruct'; readonly propTypes: TProps; readonly '~repr': InferRecord; + readonly '~reprPartial': Partial>; } export interface LooseDecorated< diff --git a/packages/typegpu/src/data/disarray.ts b/packages/typegpu/src/data/disarray.ts index fe7ebe28f..9420d1a2b 100644 --- a/packages/typegpu/src/data/disarray.ts +++ b/packages/typegpu/src/data/disarray.ts @@ -1,4 +1,4 @@ -import type { Infer } from '../shared/repr'; +import type { Infer, InferPartial } from '../shared/repr'; import type { AnyData, Disarray } from './dataTypes'; import type { Exotic } from './exotic'; @@ -37,6 +37,11 @@ class DisarrayImpl implements Disarray { public readonly type = 'disarray'; /** Type-token, not available at runtime */ public readonly '~repr'!: Infer[]; + /** Type-token, not available at runtime */ + public readonly '~reprPartial'!: { + idx: number; + value: InferPartial; + }[]; constructor( public readonly elementType: TElement, diff --git a/packages/typegpu/src/data/offsets.ts b/packages/typegpu/src/data/offsets.ts new file mode 100644 index 000000000..aabeb6a1c --- /dev/null +++ b/packages/typegpu/src/data/offsets.ts @@ -0,0 +1,69 @@ +import { Measurer } from 'typed-binary'; +import { roundUp } from '../mathUtils'; +import alignIO from './alignIO'; +import { alignmentOf, customAlignmentOf } from './alignmentOf'; +import { type Unstruct, isUnstruct } from './dataTypes'; +import { sizeOf } from './sizeOf'; +import type { AnyWgslStruct, WgslStruct } from './struct'; +import type { BaseData } from './wgslTypes'; + +export interface OffsetInfo { + offset: number; + size: number; + padding?: number | undefined; +} + +const cachedOffsets = new WeakMap< + AnyWgslStruct | Unstruct, + Record +>(); + +export function offsetsForProps>( + struct: WgslStruct | Unstruct, +): Record { + const cached = cachedOffsets.get(struct); + if (cached) { + return cached as Record; + } + + const measurer = new Measurer(); + const offsets = {} as Record; + let lastEntry: OffsetInfo | undefined = undefined; + + for (const key in struct.propTypes) { + const prop = struct.propTypes[key]; + if (prop === undefined) { + throw new Error(`Property ${key} is undefined in struct`); + } + + const beforeAlignment = measurer.size; + + alignIO( + measurer, + isUnstruct(struct) ? customAlignmentOf(prop) : alignmentOf(prop), + ); + + if (lastEntry) { + lastEntry.padding = measurer.size - beforeAlignment; + } + + const propSize = sizeOf(prop); + offsets[key] = { offset: measurer.size, size: propSize }; + lastEntry = offsets[key]; + measurer.add(propSize); + } + + if (lastEntry) { + lastEntry.padding = + roundUp(sizeOf(struct), alignmentOf(struct)) - measurer.size; + } + + cachedOffsets.set( + struct as + | WgslStruct> + | Unstruct>, + offsets, + ); + + return offsets; +} diff --git a/packages/typegpu/src/data/partialIO.ts b/packages/typegpu/src/data/partialIO.ts new file mode 100644 index 000000000..412803679 --- /dev/null +++ b/packages/typegpu/src/data/partialIO.ts @@ -0,0 +1,132 @@ +import { BufferWriter } from 'typed-binary'; +import { roundUp } from '../mathUtils'; +import type { Infer, InferPartial } from '../shared/repr'; +import { alignmentOf } from './alignmentOf'; +import { writeData } from './dataIO'; +import { isDisarray, isUnstruct } from './dataTypes'; +import { offsetsForProps } from './offsets'; +import { sizeOf } from './sizeOf'; +import type * as wgsl from './wgslTypes'; +import { isWgslArray, isWgslStruct } from './wgslTypes'; + +export interface WriteInstruction { + data: Uint8Array; +} + +export function getWriteInstructions( + schema: TData, + data: InferPartial, +): WriteInstruction[] { + const totalSize = sizeOf(schema); + if (totalSize === 0 || data === undefined || data === null) { + return []; + } + + const bigBuffer = new ArrayBuffer(totalSize); + const writer = new BufferWriter(bigBuffer); + + const segments: Array<{ + start: number; + end: number; + padding?: number | undefined; + }> = []; + + function gatherAndWrite( + node: T, + partialValue: InferPartial | undefined, + offset: number, + padding?: number | undefined, + ) { + if (partialValue === undefined || partialValue === null) { + return; + } + + if (isWgslStruct(node) || isUnstruct(node)) { + const propOffsets = offsetsForProps(node); + + for (const [key, propOffset] of Object.entries(propOffsets)) { + const subSchema = node.propTypes[key]; + if (!subSchema) { + continue; + } + + const childValue = partialValue[key as keyof typeof partialValue]; + if (childValue !== undefined) { + gatherAndWrite( + subSchema, + childValue, + offset + propOffset.offset, + propOffset.padding ?? padding, + ); + } + } + } else if (isWgslArray(node) || isDisarray(node)) { + const arrSchema = node; + const elementSize = roundUp( + sizeOf(arrSchema.elementType), + alignmentOf(arrSchema.elementType), + ); + + if (!Array.isArray(partialValue)) { + throw new Error('Partial value for array must be an array'); + } + const arrayPartialValue = partialValue as InferPartial; + + arrayPartialValue.sort((a, b) => a.idx - b.idx); + + for (const { idx, value } of arrayPartialValue) { + gatherAndWrite( + arrSchema.elementType, + value, + offset + idx * elementSize, + elementSize - sizeOf(arrSchema.elementType), + ); + } + } else { + const leafSize = sizeOf(node); + writer.seekTo(offset); + writeData(writer, node, partialValue as Infer); + + segments.push({ start: offset, end: offset + leafSize, padding }); + } + } + + gatherAndWrite(schema, data, 0); + + if (segments.length === 0) { + return []; + } + + const instructions: WriteInstruction[] = []; + let current = segments[0]; + + for (let i = 1; i < segments.length; i++) { + const next = segments[i]; + if (!next || !current) { + throw new Error('Internal error: missing segment'); + } + if (next.start === current.end + (current.padding ?? 0)) { + current.end = next.end; + current.padding = next.padding; + } else { + instructions.push({ + data: new Uint8Array( + bigBuffer, + current.start, + current.end - current.start, + ), + }); + current = next; + } + } + + if (!current) { + throw new Error('Internal error: missing segment'); + } + + instructions.push({ + data: new Uint8Array(bigBuffer, current.start, current.end - current.start), + }); + + return instructions; +} diff --git a/packages/typegpu/src/data/struct.ts b/packages/typegpu/src/data/struct.ts index 036da40eb..f77e065da 100644 --- a/packages/typegpu/src/data/struct.ts +++ b/packages/typegpu/src/data/struct.ts @@ -1,5 +1,9 @@ import type { TgpuNamable } from '../namable'; -import type { InferRecord, MemIdentityRecord } from '../shared/repr'; +import type { + InferPartialRecord, + InferRecord, + MemIdentityRecord, +} from '../shared/repr'; import type { Prettify } from '../shared/utilityTypes'; import type { ExoticRecord } from './exotic'; import type { AnyWgslData, BaseData } from './wgslTypes'; @@ -26,6 +30,8 @@ export interface WgslStruct< readonly '~repr': InferRecord; /** Type-token, not available at runtime */ readonly '~memIdent': WgslStruct>; + /** Type-token, not available at runtime */ + readonly '~reprPartial': Partial>; } // biome-ignore lint/suspicious/noExplicitAny: > diff --git a/packages/typegpu/src/data/unstruct.ts b/packages/typegpu/src/data/unstruct.ts index fbeb92e2d..aff8ce33a 100644 --- a/packages/typegpu/src/data/unstruct.ts +++ b/packages/typegpu/src/data/unstruct.ts @@ -1,4 +1,4 @@ -import type { InferRecord } from '../shared/repr'; +import type { InferPartialRecord, InferRecord } from '../shared/repr'; import type { Unstruct } from './dataTypes'; import type { ExoticRecord } from './exotic'; import type { BaseData } from './wgslTypes'; @@ -42,6 +42,8 @@ class UnstructImpl> public readonly type = 'unstruct'; /** Type-token, not available at runtime */ public readonly '~repr'!: InferRecord; + /** Type-token, not available at runtime */ + public readonly '~reprPartial'!: Partial>; constructor(public readonly propTypes: TProps) {} diff --git a/packages/typegpu/src/data/wgslTypes.ts b/packages/typegpu/src/data/wgslTypes.ts index ab183f3b0..8dc335fa1 100644 --- a/packages/typegpu/src/data/wgslTypes.ts +++ b/packages/typegpu/src/data/wgslTypes.ts @@ -1,4 +1,4 @@ -import type { Infer, MemIdentity } from '../shared/repr'; +import type { Infer, InferPartial, MemIdentity } from '../shared/repr'; import type { AnyWgslStruct, WgslStruct } from './struct'; type DecoratedLocation = Decorated[]>; @@ -815,6 +815,7 @@ export interface WgslArray { readonly elementType: TElement; /** Type-token, not available at runtime */ readonly '~repr': Infer[]; + readonly '~reprPartial': { idx: number; value: InferPartial }[]; readonly '~memIdent': WgslArray>; } diff --git a/packages/typegpu/src/shared/repr.ts b/packages/typegpu/src/shared/repr.ts index 758e7b711..82518f2ae 100644 --- a/packages/typegpu/src/shared/repr.ts +++ b/packages/typegpu/src/shared/repr.ts @@ -7,11 +7,24 @@ import type { AnyData } from '../data/dataTypes'; * type B = Infer> // => number[] */ export type Infer = T extends { readonly '~repr': infer TRepr } ? TRepr : T; +export type InferPartial = T extends { readonly '~reprPartial': infer TRepr } + ? TRepr + : T extends { readonly '~repr': infer TRepr } + ? TRepr | undefined + : T extends Record + ? InferPartialRecord + : T; export type InferRecord> = { [Key in keyof T]: Infer; }; +export type InferPartialRecord< + T extends Record, +> = { + [Key in keyof T]: InferPartial; +}; + export type MemIdentity = T extends { readonly '~memIdent': infer TMemIdent extends AnyData; } diff --git a/packages/typegpu/src/shared/utilityTypes.ts b/packages/typegpu/src/shared/utilityTypes.ts index 3cd1f7e3f..f50035a5a 100644 --- a/packages/typegpu/src/shared/utilityTypes.ts +++ b/packages/typegpu/src/shared/utilityTypes.ts @@ -31,6 +31,17 @@ export type Mutable = { -readonly [P in keyof T]: T[P]; }; +/** + * Any typed array + */ +export type TypedArray = + | Uint8Array + | Uint16Array + | Uint32Array + | Int32Array + | Float32Array + | Float64Array; + export function assertExhaustive(x: never, location: string): never { throw new Error(`Failed to handle ${x} at ${location}`); } diff --git a/packages/typegpu/tests/buffer.test.ts b/packages/typegpu/tests/buffer.test.ts index 500ffdf87..8e2a81ec3 100644 --- a/packages/typegpu/tests/buffer.test.ts +++ b/packages/typegpu/tests/buffer.test.ts @@ -1,7 +1,27 @@ import { describe, expect } from 'vitest'; import * as d from '../src/data'; +import type { TypedArray } from '../src/shared/utilityTypes'; import { it } from './utils/extendedIt'; +function toUint8Array(...arrays: Array): Uint8Array { + let totalByteLength = 0; + for (const arr of arrays) { + totalByteLength += arr.byteLength; + } + + const merged = new Uint8Array(totalByteLength); + let offset = 0; + for (const arr of arrays) { + merged.set( + new Uint8Array(arr.buffer, arr.byteOffset, arr.byteLength), + offset, + ); + offset += arr.byteLength; + } + + return merged; +} + describe('TgpuBuffer', () => { it('should properly write to buffer', ({ root, device }) => { const buffer = root.createBuffer(d.u32); @@ -182,6 +202,133 @@ describe('TgpuBuffer', () => { }).toThrow(); }); + it('should allow for partial writes', ({ root, device }) => { + const buffer = root.createBuffer(d.struct({ a: d.u32, b: d.u32 })); + + buffer.writePartial({ a: 3 }); + + const rawBuffer = root.unwrap(buffer); + expect(rawBuffer).toBeDefined(); + + expect(device.mock.queue.writeBuffer.mock.calls).toEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + ]); + + buffer.writePartial({ b: 4 }); + + expect(device.mock.queue.writeBuffer.mock.calls).toEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + [rawBuffer, 4, toUint8Array(new Uint32Array([4])), 0, 4], + ]); + + buffer.writePartial({ a: 5, b: 6 }); // should merge the writes + + expect(device.mock.queue.writeBuffer.mock.calls).toEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + [rawBuffer, 4, toUint8Array(new Uint32Array([4])), 0, 4], + [rawBuffer, 0, toUint8Array(new Uint32Array([5, 6])), 0, 8], + ]); + }); + + it('should allow for partial writes with complex data', ({ + root, + device, + }) => { + const buffer = root.createBuffer( + d.struct({ + a: d.u32, + b: d.struct({ c: d.vec2f }), + d: d.arrayOf(d.u32, 3), + }), + ); + + buffer.writePartial({ a: 3 }); + + const rawBuffer = root.unwrap(buffer); + expect(rawBuffer).toBeDefined(); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + ]); + + buffer.writePartial({ b: { c: d.vec2f(1, 2) } }); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + [rawBuffer, 8, toUint8Array(new Float32Array([1, 2])), 0, 8], + ]); + + buffer.writePartial({ + d: [ + { idx: 0, value: 1 }, + { idx: 2, value: 3 }, + ], + }); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + [rawBuffer, 8, toUint8Array(new Float32Array([1, 2])), 0, 8], + [rawBuffer, 16, toUint8Array(new Uint32Array([1])), 0, 4], + [rawBuffer, 24, toUint8Array(new Uint32Array([3])), 0, 4], + ]); + + buffer.writePartial({ + b: { c: d.vec2f(3, 4) }, + d: [ + { idx: 0, value: 2 }, + { idx: 1, value: 3 }, + ], + }); // should merge the writes + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 0, toUint8Array(new Uint32Array([3])), 0, 4], + [rawBuffer, 8, toUint8Array(new Float32Array([1, 2])), 0, 8], + [rawBuffer, 16, toUint8Array(new Uint32Array([1])), 0, 4], + [rawBuffer, 24, toUint8Array(new Uint32Array([3])), 0, 4], + [ + rawBuffer, + 8, + toUint8Array(new Float32Array([3, 4]), new Uint32Array([2, 3])), + 0, + 16, + ], + ]); + }); + + it('should allow for partial writes with loose data', ({ root, device }) => { + const buffer = root.createBuffer( + d.unstruct({ + a: d.disarrayOf(d.unorm16x2, 4), + b: d.snorm8x2, + c: d.unstruct({ d: d.u32 }), + }), + ); + + buffer.writePartial({ a: [{ idx: 2, value: d.vec2f(0.5, 0.5) }] }); + + const rawBuffer = root.unwrap(buffer); + expect(rawBuffer).toBeDefined(); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 8, new Uint8Array([-1, 127, -1, 127]), 0, 4], + ]); + + buffer.writePartial({ b: d.vec2f(-0.5, 0.5) }); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 8, new Uint8Array([-1, 127, -1, 127]), 0, 4], + [rawBuffer, 16, new Uint8Array([64, -65]), 0, 2], + ]); + + buffer.writePartial({ c: { d: 3 } }); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawBuffer, 8, new Uint8Array([-1, 127, -1, 127]), 0, 4], + [rawBuffer, 16, new Uint8Array([64, -65]), 0, 2], + [rawBuffer, 18, new Uint8Array([3, 0, 0, 0]), 0, 4], + ]); + }); + it('should be able to copy from a buffer identical on the byte level', ({ root, }) => { diff --git a/packages/typegpu/tests/partialIo.test.ts b/packages/typegpu/tests/partialIo.test.ts new file mode 100644 index 000000000..ea0bb7612 --- /dev/null +++ b/packages/typegpu/tests/partialIo.test.ts @@ -0,0 +1,301 @@ +import { describe, expect } from 'vitest'; +import * as d from '../src/data'; +import { offsetsForProps } from '../src/data/offsets'; +import { + type WriteInstruction, + getWriteInstructions, +} from '../src/data/partialIO'; +import type { TypedArray } from '../src/shared/utilityTypes'; +import { it } from './utils/extendedIt'; + +function expectInstruction( + instruction: WriteInstruction, + { + start, + length, + expectedData, + }: { + start: number; + length: number; + expectedData: TypedArray | TypedArray[]; + }, +): void { + expect(instruction.data.byteOffset).toBe(start); + expect(instruction.data.byteLength).toBe(length); + + const dataArrays = Array.isArray(expectedData) + ? expectedData + : [expectedData]; + + const totalByteLength = dataArrays.reduce( + (acc, arr) => acc + arr.byteLength, + 0, + ); + + const mergedExpected = new Uint8Array(totalByteLength); + let offset = 0; + for (const arr of dataArrays) { + mergedExpected.set( + new Uint8Array(arr.buffer, arr.byteOffset, arr.byteLength), + offset, + ); + offset += arr.byteLength; + } + + expect(instruction.data).toHaveLength(totalByteLength); + expect(instruction.data).toEqual(mergedExpected); +} + +describe('offsetsForProps', () => { + it('should return correct offsets for props', () => { + const struct = d.struct({ + a: d.u32, + b: d.vec3f, + c: d.struct({ d: d.u32 }), + }); + + const offsets = offsetsForProps(struct); + expect(offsets).toStrictEqual({ + a: { offset: 0, size: 4, padding: 12 }, + b: { offset: 16, size: 12, padding: 0 }, + c: { offset: 28, size: 4, padding: 0 }, + }); + }); + + it('should return correct offsets for props with arrays', () => { + const struct = d.struct({ + a: d.u32, + b: d.arrayOf(d.vec3f, 4), + c: d.struct({ d: d.u32 }), + }); + + const offsets = offsetsForProps(struct); + expect(offsets).toStrictEqual({ + a: { offset: 0, size: 4, padding: 12 }, + b: { offset: 16, size: 64, padding: 0 }, + c: { offset: 80, size: 4, padding: 12 }, + }); + }); + + it('should return correct offsets for deeply nested structs', () => { + const One = d.struct({ + a: d.u32, + b: d.vec3f, + }); + + const Two = d.struct({ + c: d.arrayOf(One, 3), + d: d.vec4u, + }); + + const Three = d.struct({ + e: One, + f: d.arrayOf(Two, 2), + }); + + const offsets = offsetsForProps(Three); + + expect(offsets).toStrictEqual({ + e: { offset: 0, size: 32, padding: 0 }, + f: { offset: 32, size: 224, padding: 0 }, + }); + + const oneOffsets = offsetsForProps(One); + + expect(oneOffsets).toStrictEqual({ + a: { offset: 0, size: 4, padding: 12 }, + b: { offset: 16, size: 12, padding: 4 }, + }); + + const twoOffsets = offsetsForProps(Two); + + expect(twoOffsets).toStrictEqual({ + c: { offset: 0, size: 96, padding: 0 }, + d: { offset: 96, size: 16, padding: 0 }, + }); + }); +}); + +describe('getWriteInstructions', () => { + it('should return correct write instructions for simple data', () => { + const instructions = getWriteInstructions(d.u32, 3) as [WriteInstruction]; + + expect(instructions).toHaveLength(1); + + expectInstruction(instructions[0], { + start: 0, + length: 4, + expectedData: new Uint32Array([3]), + }); + }); + + it('should return correct write instructions for props', () => { + const struct = d.struct({ + a: d.u32, + b: d.vec3f, + c: d.struct({ d: d.u32 }), + }); + + const data = { + b: d.vec3f(1, 2, 3), + a: 3, + c: { d: 4 }, + }; + + const instructions = getWriteInstructions(struct, data) as [ + WriteInstruction, + ]; + expect(instructions).toHaveLength(1); + + expectInstruction(instructions[0], { + start: 0, + length: 32, + expectedData: [ + new Uint32Array([3, 0, 0, 0]), + new Float32Array([1, 2, 3]), + new Uint32Array([4]), + ], + }); + }); + + it('should return correct write instructions for props with arrays', () => { + const struct = d.struct({ + a: d.u32, + b: d.arrayOf(d.vec3f, 4), + c: d.struct({ d: d.u32 }), + }); + + const data = { + a: 3, + c: { d: 4 }, + b: [ + { idx: 1, value: d.vec3f(4, 5, 6) }, + { idx: 0, value: d.vec3f(1, 2, 3) }, + { idx: 3, value: d.vec3f(10, 11, 12) }, + { idx: 2, value: d.vec3f(7, 8, 9) }, + ], + }; + + const instructions = getWriteInstructions(struct, data) as [ + WriteInstruction, + ]; + expect(instructions).toHaveLength(1); + + expectInstruction(instructions[0], { + start: 0, + length: 84, + expectedData: [ + new Uint32Array([3, 0, 0, 0]), + new Float32Array([1, 2, 3, 0]), + new Float32Array([4, 5, 6, 0]), + new Float32Array([7, 8, 9, 0]), + new Float32Array([10, 11, 12, 0]), + new Uint32Array([4]), + ], + }); + }); + + it('should return correct write instructions for props with arrays and missing data', () => { + const struct = d.struct({ + a: d.u32, + b: d.arrayOf(d.vec3f, 4), + c: d.struct({ d: d.u32 }), + }); + + const data = { + b: [ + { idx: 2, value: d.vec3f(7, 8, 9) }, + { idx: 0, value: d.vec3f(1, 2, 3) }, + { idx: 3, value: d.vec3f(10, 11, 12) }, + ], + c: { d: 4 }, + }; + + const instructions = getWriteInstructions(struct, data) as [ + WriteInstruction, + WriteInstruction, + ]; + expect(instructions).toHaveLength(2); + + expectInstruction(instructions[0], { + start: 16, + length: 12, + expectedData: [new Float32Array([1, 2, 3])], + }); + + expectInstruction(instructions[1], { + start: 48, + length: 36, + expectedData: [ + new Float32Array([7, 8, 9, 0]), + new Float32Array([10, 11, 12, 0]), + new Uint32Array([4]), + ], + }); + }); + + it('should return correct write instructions for arrays of structs', () => { + const Boid = d.struct({ + position: d.vec3f, + velocity: d.vec3f, + }); + + const struct = d.struct({ + boids: d.arrayOf(Boid, 3), + }); + + const data = [ + { + idx: 1, + value: { position: d.vec3f(1, 2, 3) }, + }, + ]; + + const instructions = getWriteInstructions(struct, { + boids: data, + }) as [WriteInstruction]; + + expect(instructions).toHaveLength(1); + + expectInstruction(instructions[0], { + start: 32, + length: 12, + expectedData: [new Float32Array([1, 2, 3])], + }); + }); + + it('should not accept invalid data', () => { + const struct = d.struct({ + a: d.u32, + b: d.vec3f, + c: d.struct({ d: d.u32 }), + }); + + // @ts-expect-error + getWriteInstructions(struct, { a: 3, b: 4, c: 5 }); + }); + + it('should not merge instructions if there is a gap', () => { + const array = d.arrayOf(d.vec3f, 1024); + + const data = Array.from({ length: 1024 }) + .map((_, i) => i) + .filter((i) => i % 2 === 0) + .map((i) => ({ idx: i, value: d.vec3f(1, 2, 3) })); + + const instructions = getWriteInstructions(array, data); + + expect(instructions).toHaveLength(512); + + for (let i = 0; i < 512; i++) { + if (instructions[i] === undefined) { + throw new Error('Instruction is undefined'); + } + expectInstruction(instructions[i] as WriteInstruction, { + start: i * 2 * 16, + length: 12, + expectedData: new Float32Array([1, 2, 3]), + }); + } + }); +});