代码分析说明:
import { assert, makeSample, SampleInit } from '../../components/SampleLayout';
import spriteWGSL from './sprite.wgsl';
import updateSpritesWGSL from './updateSprites.wgsl';
const init: SampleInit = async ({ canvas, pageState, gui }) => {
const adapter = await navigator.gpu.requestAdapter();
assert(adapter, 'requestAdapter returned null');
// 判断是否支持时间戳查询
const hasTimestampQuery = adapter.features.has('timestamp-query');
// 如果支持启用该特性
const device = await adapter.requestDevice({
requiredFeatures: hasTimestampQuery ? ['timestamp-query'] : [],
});
const perfDisplayContainer = document.createElement('div');
perfDisplayContainer.style.color = 'white';
perfDisplayContainer.style.backdropFilter = 'blur(10px)';
perfDisplayContainer.style.position = 'absolute';
perfDisplayContainer.style.bottom = '10px';
perfDisplayContainer.style.left = '10px';
perfDisplayContainer.style.textAlign = 'left';
const perfDisplay = document.createElement('pre');
perfDisplay.style.margin = '.5em';
perfDisplayContainer.appendChild(perfDisplay);
if (canvas.parentNode) {
canvas.parentNode.appendChild(perfDisplayContainer);
} else {
console.error('canvas.parentNode is null');
}
if (!pageState.active) return;
const context = canvas.getContext('webgpu') as GPUCanvasContext;
const devicePixelRatio = window.devicePixelRatio;
canvas.width = canvas.clientWidth * devicePixelRatio;
canvas.height = canvas.clientHeight * devicePixelRatio;
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
context.configure({
device,
format: presentationFormat,
alphaMode: 'premultiplied',
});
const spriteShaderModule = device.createShaderModule({ code: spriteWGSL });
const renderPipeline = device.createRenderPipeline({
layout: 'auto',
vertex: {
module: spriteShaderModule,
entryPoint: 'vert_main',
buffers: [
{
// instanced particles buffer
arrayStride: 4 * 4,
// 实例化方式渲染
stepMode: 'instance',
attributes: [
{
// instance position
shaderLocation: 0,
offset: 0,
format: 'float32x2',
},
{
// instance velocity
shaderLocation: 1,
offset: 2 * 4,
format: 'float32x2',
},
],
},
{
// vertex buffer
arrayStride: 2 * 4,
stepMode: 'vertex',
attributes: [
{
// vertex positions
shaderLocation: 2,
offset: 0,
format: 'float32x2',
},
],
},
],
},
fragment: {
module: spriteShaderModule,
entryPoint: 'frag_main',
targets: [
{
format: presentationFormat,
},
],
},
primitive: {
topology: 'triangle-list',
},
});
// 创建一个计算管线
const computePipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module: device.createShaderModule({
code: updateSpritesWGSL,
}),
entryPoint: 'main',
},
});
const renderPassDescriptor: GPURenderPassDescriptor = {
colorAttachments: [
{
view: undefined as GPUTextureView, // Assigned later
clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 },
loadOp: 'clear' as const,
storeOp: 'store' as const,
},
],
};
const computePassDescriptor: GPUComputePassDescriptor = {};
/** Storage for timestamp query results */
let querySet: GPUQuerySet | undefined = undefined;
/** Timestamps are resolved into this buffer */
let resolveBuffer: GPUBuffer | undefined = undefined;
/** Pool of spare buffers for MAP_READing the timestamps back to CPU. A buffer
* is taken from the pool (if available) when a readback is needed, and placed
* back into the pool once the readback is done and it's unmapped. */
const spareResultBuffers = [];
// 查询时间戳
if (hasTimestampQuery) {
querySet = device.createQuerySet({
type: 'timestamp',
count: 4,
});
// 创建一个查询结果的buffer
resolveBuffer = device.createBuffer({
size: 4 * BigInt64Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.QUERY_RESOLVE | GPUBufferUsage.COPY_SRC,
});
computePassDescriptor.timestampWrites = {
querySet,
beginningOfPassWriteIndex: 0,
endOfPassWriteIndex: 1,
};
renderPassDescriptor.timestampWrites = {
querySet,
beginningOfPassWriteIndex: 2,
endOfPassWriteIndex: 3,
};
}
// prettier-ignore
// 顶点数据
const vertexBufferData = new Float32Array([
-0.01, -0.02, 0.01,
-0.02, 0.0, 0.02,
]);
// 创建精灵顶点buffer
const spriteVertexBuffer = device.createBuffer({
size: vertexBufferData.byteLength,
usage: GPUBufferUsage.VERTEX,
mappedAtCreation: true,
});
new Float32Array(spriteVertexBuffer.getMappedRange()).set(vertexBufferData);
spriteVertexBuffer.unmap();
const simParams = {
deltaT: 0.04,
rule1Distance: 0.1,
rule2Distance: 0.025,
rule3Distance: 0.025,
rule1Scale: 0.02,
rule2Scale: 0.05,
rule3Scale: 0.005,
};
const simParamBufferSize = 7 * Float32Array.BYTES_PER_ELEMENT;
// 创建一套参数buffer
const simParamBuffer = device.createBuffer({
size: simParamBufferSize,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
// 数据写入参数
function updateSimParams() {
device.queue.writeBuffer(
simParamBuffer,
0,
new Float32Array([
simParams.deltaT,
simParams.rule1Distance,
simParams.rule2Distance,
simParams.rule3Distance,
simParams.rule1Scale,
simParams.rule2Scale,
simParams.rule3Scale,
])
);
}
updateSimParams();
Object.keys(simParams).forEach((k) => {
const key = k as keyof typeof simParams;
if (gui === undefined) {
console.error('GUI not initialized');
} else {
gui.add(simParams, key).onFinishChange(updateSimParams);
}
});
// 精灵粒子数量
const numParticles = 1500;
const initialParticleData = new Float32Array(numParticles * 4);
for (let i = 0; i < numParticles; ++i) {
initialParticleData[4 * i + 0] = 2 * (Math.random() - 0.5);
initialParticleData[4 * i + 1] = 2 * (Math.random() - 0.5);
initialParticleData[4 * i + 2] = 2 * (Math.random() - 0.5) * 0.1;
initialParticleData[4 * i + 3] = 2 * (Math.random() - 0.5) * 0.1;
}
const particleBuffers: GPUBuffer[] = new Array(2);
const particleBindGroups: GPUBindGroup[] = new Array(2);
for (let i = 0; i < 2; ++i) {
// 创建两个buffer, 读 与 写
particleBuffers[i] = device.createBuffer({
size: initialParticleData.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.STORAGE,
mappedAtCreation: true,
});
new Float32Array(particleBuffers[i].getMappedRange()).set(
initialParticleData
);
particleBuffers[i].unmap();
}
for (let i = 0; i < 2; ++i) {
particleBindGroups[i] = device.createBindGroup({
layout: computePipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: simParamBuffer,
},
},
{
binding: 1,
resource: {
buffer: particleBuffers[i],
offset: 0,
size: initialParticleData.byteLength,
},
},
{
binding: 2,
resource: {
buffer: particleBuffers[(i + 1) % 2],
offset: 0,
size: initialParticleData.byteLength,
},
},
],
});
}
let t = 0;
let computePassDurationSum = 0;
let renderPassDurationSum = 0;
let timerSamples = 0;
function frame() {
// Sample is no longer the active page.
if (!pageState.active) return;
renderPassDescriptor.colorAttachments[0].view = context
.getCurrentTexture()
.createView();
const commandEncoder = device.createCommandEncoder();
{
const passEncoder = commandEncoder.beginComputePass(
computePassDescriptor
);
// 计算管线
passEncoder.setPipeline(computePipeline);
passEncoder.setBindGroup(0, particleBindGroups[t % 2]);
// 派发计算组每组64个
passEncoder.dispatchWorkgroups(Math.ceil(numParticles / 64));
passEncoder.end();
}
{
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
// 设置渲染管线
passEncoder.setPipeline(renderPipeline);
passEncoder.setVertexBuffer(0, particleBuffers[(t + 1) % 2]);
passEncoder.setVertexBuffer(1, spriteVertexBuffer);
// 绘制三个顶点, numParticles个实例
passEncoder.draw(3, numParticles, 0, 0);
passEncoder.end();
}
let resultBuffer: GPUBuffer | undefined = undefined;
if (hasTimestampQuery) {
resultBuffer =
spareResultBuffers.pop() ||
device.createBuffer({
size: 4 * BigInt64Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});
commandEncoder.resolveQuerySet(querySet, 0, 4, resolveBuffer, 0);
commandEncoder.copyBufferToBuffer(
resolveBuffer,
0,
resultBuffer,
0,
resultBuffer.size
);
}
device.queue.submit([commandEncoder.finish()]);
// 时间戳相关数据获取
if (hasTimestampQuery) {
resultBuffer.mapAsync(GPUMapMode.READ).then(() => {
const times = new BigInt64Array(resultBuffer.getMappedRange());
const computePassDuration = Number(times[1] - times[0]);
const renderPassDuration = Number(times[3] - times[2]);
// In some cases the timestamps may wrap around and produce a negative
// number as the GPU resets it's timings. These can safely be ignored.
if (computePassDuration > 0 && renderPassDuration > 0) {
computePassDurationSum += computePassDuration;
renderPassDurationSum += renderPassDuration;
timerSamples++;
}
resultBuffer.unmap();
// Periodically update the text for the timer stats
const kNumTimerSamplesPerUpdate = 100;
if (timerSamples >= kNumTimerSamplesPerUpdate) {
const avgComputeMicroseconds = Math.round(
computePassDurationSum / timerSamples / 1000
);
const avgRenderMicroseconds = Math.round(
renderPassDurationSum / timerSamples / 1000
);
perfDisplay.textContent = `\
avg compute pass duration: ${avgComputeMicroseconds}µs
avg render pass duration: ${avgRenderMicroseconds}µs
spare readback buffers: ${spareResultBuffers.length}`;
computePassDurationSum = 0;
renderPassDurationSum = 0;
timerSamples = 0;
}
spareResultBuffers.push(resultBuffer);
});
}
++t;
requestAnimationFrame(frame);
}
requestAnimationFrame(frame);
};
计算管线shader
struct Particle {
pos : vec2<f32>,
vel : vec2<f32>,
}
struct SimParams {
deltaT : f32,
rule1Distance : f32,
rule2Distance : f32,
rule3Distance : f32,
rule1Scale : f32,
rule2Scale : f32,
rule3Scale : f32,
}
struct Particles {
particles : array<Particle>,
}
@binding(0) @group(0) var<uniform> params : SimParams;
@binding(1) @group(0) var<storage, read> particlesA : Particles;
@binding(2) @group(0) var<storage, read_write> particlesB : Particles;
// https://github.com/austinEng/Project6-Vulkan-Flocking/blob/master/data/shaders/computeparticles/particle.comp
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) GlobalInvocationID : vec3<u32>) {
var index = GlobalInvocationID.x;
var vPos = particlesA.particles[index].pos;
var vVel = particlesA.particles[index].vel;
var cMass = vec2(0.0);
var cVel = vec2(0.0);
var colVel = vec2(0.0);
var cMassCount = 0u;
var cVelCount = 0u;
var pos : vec2<f32>;
var vel : vec2<f32>;
for (var i = 0u; i < arrayLength(&particlesA.particles); i++) {
if (i == index) {
continue;
}
pos = particlesA.particles[i].pos.xy;
vel = particlesA.particles[i].vel.xy;
if (distance(pos, vPos) < params.rule1Distance) {
cMass += pos;
cMassCount++;
}
if (distance(pos, vPos) < params.rule2Distance) {
colVel -= pos - vPos;
}
if (distance(pos, vPos) < params.rule3Distance) {
cVel += vel;
cVelCount++;
}
}
if (cMassCount > 0) {
cMass = (cMass / vec2(f32(cMassCount))) - vPos;
}
if (cVelCount > 0) {
cVel /= f32(cVelCount);
}
vVel += (cMass * params.rule1Scale) + (colVel * params.rule2Scale) + (cVel * params.rule3Scale);
// clamp velocity for a more pleasing simulation
vVel = normalize(vVel) * clamp(length(vVel), 0.0, 0.1);
// kinematic update
vPos = vPos + (vVel * params.deltaT);
// Wrap around boundary
if (vPos.x < -1.0) {
vPos.x = 1.0;
}
if (vPos.x > 1.0) {
vPos.x = -1.0;
}
if (vPos.y < -1.0) {
vPos.y = 1.0;
}
if (vPos.y > 1.0) {
vPos.y = -1.0;
}
// Write back 回写
particlesB.particles[index].pos = vPos;
particlesB.particles[index].vel = vVel;
}
顶点片元着色器
struct VertexOutput {
@builtin(position) position : vec4<f32>,
@location(4) color : vec4<f32>,
}
@vertex
fn vert_main(
@location(0) a_particlePos : vec2<f32>,
@location(1) a_particleVel : vec2<f32>,
@location(2) a_pos : vec2<f32>
) -> VertexOutput {
let angle = -atan2(a_particleVel.x, a_particleVel.y);
let pos = vec2(
(a_pos.x * cos(angle)) - (a_pos.y * sin(angle)),
(a_pos.x * sin(angle)) + (a_pos.y * cos(angle))
);
var output : VertexOutput;
output.position = vec4(pos + a_particlePos, 0.0, 1.0);
output.color = vec4(
1.0 - sin(angle + 1.0) - a_particleVel.y,
pos.x * 100.0 - a_particleVel.y + 0.1,
a_particleVel.x + cos(angle + 0.5),
1.0);
return output;
}
@fragment
fn frag_main(@location(4) color : vec4<f32>) -> @location(0) vec4<f32> {
return color;
}
总结步骤:
计算管线与渲染管线的同步
- 计算管线计算的传入的数据与计算的数据两个buffer 交替
- 渲染管线与计算管线公用buffer,既可以把计算管线的数据直接使用,具体如
passEncoder.setVertexBuffer(0, particleBuffers[(t + 1) % 2]);