Skip to content

Commit

Permalink
Support alpha-blended depth for lower-artifact frame extrapolation on…
Browse files Browse the repository at this point in the history
… Vision Pro

- Added a new "highQualityDepth" flag, which is on by default; when writing depth but you'd prefer a few more fps rather than smooth depth, set this to false
- When writing high quality depth, blend splat depth using alpha, so that mostly-transparent splats don't contribute significantly to depth. This is achieved by storing color/depth in tile memory and blending explicitely, then dumping to color/depth during a final pass.
- This new three-stage pipeline still has a performance penalty somewhere around 10% -- even if we don't care about depth. If I optimize the multi-stage path for the no-depth case (in particular getting rid of the unused depth blending), I can get that down to 1% on vertex-bound scenes or 5% on fragment-bound; but I'd like to avoid that regression. So I add a special single-stage shader path duplicating the old flow to keep that no-depth path nicely optimized.
- I suspect the multi-stage path might be useful in the future for other purposes which operate on pixels after accumulating all fragments, so high-quality depth may not be the only reason to use the multi-stage path eventually; hence, no-depth situations are technically supported in the multi-stage path, even though there's no reason to use that right now.
  • Loading branch information
scier committed Aug 5, 2024
1 parent 1eadbec commit ad56c89
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 111 deletions.
87 changes: 87 additions & 0 deletions MetalSplatter/Resources/MultiStageRenderPath.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "SplatProcessing.h"

typedef struct
{
half4 color [[raster_order_group(0)]];
float depth [[raster_order_group(0)]];
} FragmentValues;

typedef struct
{
FragmentValues values [[imageblock_data]];
} FragmentStore;

typedef struct
{
half4 color [[color(0)]];
float depth [[depth(any)]];
} FragmentOut;

kernel void initializeFragmentStore(imageblock<FragmentValues, imageblock_layout_explicit> blockData,
ushort2 localThreadID [[thread_position_in_threadgroup]]) {
threadgroup_imageblock FragmentValues *values = blockData.data(localThreadID);
values->color = { 0, 0, 0, 0 };
values->depth = 0;
}

vertex FragmentIn multiStageSplatVertexShader(uint vertexID [[vertex_id]],
uint instanceID [[instance_id]],
ushort amplificationID [[amplification_id]],
constant Splat* splatArray [[ buffer(BufferIndexSplat) ]],
constant UniformsArray & uniformsArray [[ buffer(BufferIndexUniforms) ]]) {
Uniforms uniforms = uniformsArray.uniforms[min(int(amplificationID), kMaxViewCount)];

uint splatID = instanceID * uniforms.indexedSplatCount + (vertexID / 4);
if (splatID >= uniforms.splatCount) {
FragmentIn out;
out.position = float4(1, 1, 0, 1);
return out;
}

Splat splat = splatArray[splatID];

return splatVertex(splat, uniforms, vertexID % 4);
}

fragment FragmentStore multiStageSplatFragmentShader(FragmentIn in [[stage_in]],
FragmentValues previousFragmentValues [[imageblock_data]]) {
FragmentStore out;

half alpha = splatFragmentAlpha(in.relativePosition, in.color.a);
half4 colorWithPremultipliedAlpha = half4(in.color.rgb * alpha, alpha);

half oneMinusAlpha = 1 - alpha;

half4 previousColor = previousFragmentValues.color;
out.values.color = previousColor * oneMinusAlpha + colorWithPremultipliedAlpha;

float previousDepth = previousFragmentValues.depth;
float depth = in.position.z;
out.values.depth = previousDepth * oneMinusAlpha + depth * alpha;

return out;
}

/// Generate a single triangle covering the entire screen
vertex FragmentIn postprocessVertexShader(uint vertexID [[vertex_id]]) {
FragmentIn out;

float4 position;
position.x = (vertexID == 2) ? 3.0 : -1.0;
position.y = (vertexID == 0) ? -3.0 : 1.0;
position.zw = 1.0;

out.position = position;
return out;
}

fragment FragmentOut postprocessFragmentShader(FragmentValues fragmentValues [[imageblock_data]]) {
FragmentOut out;
out.depth = (fragmentValues.color.a == 0) ? 0 : fragmentValues.depth / fragmentValues.color.a;
out.color = fragmentValues.color;
return out;
}

fragment half4 postprocessFragmentShaderNoDepth(FragmentValues fragmentValues [[imageblock_data]]) {
return fragmentValues.color;
}
49 changes: 49 additions & 0 deletions MetalSplatter/Resources/ShaderCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

constant const int kMaxViewCount = 2;
constant static const half kBoundsRadius = 3;
constant static const half kBoundsRadiusSquared = kBoundsRadius*kBoundsRadius;

enum BufferIndex: int32_t
{
BufferIndexUniforms = 0,
BufferIndexSplat = 1,
};

typedef struct
{
matrix_float4x4 projectionMatrix;
matrix_float4x4 viewMatrix;
uint2 screenSize;

/*
The first N splats are represented as as 2N primitives and 4N vertex indices. The remained are represented
as instanced of these first N. This allows us to limit the size of the indexed array (and associated memory),
but also avoid the performance penalty of a very large number of instances.
*/
uint splatCount;
uint indexedSplatCount;
} Uniforms;

typedef struct
{
Uniforms uniforms[kMaxViewCount];
} UniformsArray;

typedef struct
{
packed_float3 position;
packed_half4 color;
packed_half3 covA;
packed_half3 covB;
} Splat;

typedef struct
{
float4 position [[position]];
half2 relativePosition; // Ranges from -kBoundsRadius to +kBoundsRadius
half4 color;
} FragmentIn;
25 changes: 25 additions & 0 deletions MetalSplatter/Resources/SingleStageRenderPath.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "SplatProcessing.h"

vertex FragmentIn singleStageSplatVertexShader(uint vertexID [[vertex_id]],
uint instanceID [[instance_id]],
ushort amplificationID [[amplification_id]],
constant Splat* splatArray [[ buffer(BufferIndexSplat) ]],
constant UniformsArray & uniformsArray [[ buffer(BufferIndexUniforms) ]]) {
Uniforms uniforms = uniformsArray.uniforms[min(int(amplificationID), kMaxViewCount)];

uint splatID = instanceID * uniforms.indexedSplatCount + (vertexID / 4);
if (splatID >= uniforms.splatCount) {
FragmentIn out;
out.position = float4(1, 1, 0, 1);
return out;
}

Splat splat = splatArray[splatID];

return splatVertex(splat, uniforms, vertexID % 4);
}

fragment half4 singleStageSplatFragmentShader(FragmentIn in [[stage_in]]) {
half alpha = splatFragmentAlpha(in.relativePosition, in.color.a);
return half4(alpha * in.color.rgb, alpha);
}
16 changes: 16 additions & 0 deletions MetalSplatter/Resources/SplatProcessing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#import "ShaderCommon.h"

float3 calcCovariance2D(float3 viewPos,
packed_half3 cov3Da,
packed_half3 cov3Db,
float4x4 viewMatrix,
float4x4 projectionMatrix,
uint2 screenSize);

void decomposeCovariance(float3 cov2D, thread float2 &v1, thread float2 &v2);

FragmentIn splatVertex(Splat splat,
Uniforms uniforms,
uint relativeVertexIndex);

half splatFragmentAlpha(half2 relativePosition, half splatAlpha);
Original file line number Diff line number Diff line change
@@ -1,52 +1,4 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

constant const int kMaxViewCount = 2;
constant static const half kBoundsRadius = 3;
constant static const half kBoundsRadiusSquared = kBoundsRadius*kBoundsRadius;

enum BufferIndex: int32_t
{
BufferIndexUniforms = 0,
BufferIndexSplat = 1,
};

typedef struct
{
matrix_float4x4 projectionMatrix;
matrix_float4x4 viewMatrix;
uint2 screenSize;

/*
The first N splats are represented as as 2N primitives and 4N vertex indices. The remained are represented
as instanced of these first N. This allows us to limit the size of the indexed array (and associated memory),
but also avoid the performance penalty of a very large number of instances.
*/
uint splatCount;
uint indexedSplatCount;
} Uniforms;

typedef struct
{
Uniforms uniforms[kMaxViewCount];
} UniformsArray;

typedef struct
{
packed_float3 position;
packed_half4 color;
packed_half3 covA;
packed_half3 covB;
} Splat;

typedef struct
{
float4 position [[position]];
half2 relativePosition; // Ranges from -kBoundsRadius to +kBoundsRadius
half4 color;
} ColorInOut;
#import "SplatProcessing.h"

float3 calcCovariance2D(float3 viewPos,
packed_half3 cov3Da,
Expand Down Expand Up @@ -121,22 +73,11 @@ void decomposeCovariance(float3 cov2D, thread float2 &v1, thread float2 &v2) {
v2 = eigenvector2 * sqrt(lambda2);
}

vertex ColorInOut splatVertexShader(uint vertexID [[vertex_id]],
uint instanceID [[instance_id]],
ushort amp_id [[amplification_id]],
constant Splat* splatArray [[ buffer(BufferIndexSplat) ]],
constant UniformsArray & uniformsArray [[ buffer(BufferIndexUniforms) ]]) {
ColorInOut out;
FragmentIn splatVertex(Splat splat,
Uniforms uniforms,
uint relativeVertexIndex) {
FragmentIn out;

Uniforms uniforms = uniformsArray.uniforms[min(int(amp_id), kMaxViewCount)];

uint splatID = instanceID * uniforms.indexedSplatCount + (vertexID / 4);
if (splatID >= uniforms.splatCount) {
out.position = float4(1, 1, 0, 1);
return out;
}

Splat splat = splatArray[splatID];
float4 viewPosition4 = uniforms.viewMatrix * float4(splat.position, 1);
float3 viewPosition3 = viewPosition4.xyz;

Expand All @@ -161,7 +102,7 @@ vertex ColorInOut splatVertexShader(uint vertexID [[vertex_id]],
}

const half2 relativeCoordinatesArray[] = { { -1, -1 }, { -1, 1 }, { 1, -1 }, { 1, 1 } };
half2 relativeCoordinates = relativeCoordinatesArray[vertexID % 4];
half2 relativeCoordinates = relativeCoordinatesArray[relativeVertexIndex];
half2 screenSizeFloat = half2(uniforms.screenSize.x, uniforms.screenSize.y);
half2 projectedScreenDelta =
(relativeCoordinates.x * half2(axis1) + relativeCoordinates.y * half2(axis2))
Expand All @@ -178,14 +119,7 @@ vertex ColorInOut splatVertexShader(uint vertexID [[vertex_id]],
return out;
}

fragment half4 splatFragmentShader(ColorInOut in [[stage_in]]) {
half2 v = in.relativePosition;
half negativeVSquared = -dot(v, v);
if (negativeVSquared < -kBoundsRadiusSquared) {
discard_fragment();
}

half alpha = exp(0.5 * negativeVSquared) * in.color.a;
return half4(alpha * in.color.rgb, alpha);
half splatFragmentAlpha(half2 relativePosition, half splatAlpha) {
half negativeMagnitudeSquared = -dot(relativePosition, relativePosition);
return (negativeMagnitudeSquared < -kBoundsRadiusSquared) ? 0 : exp(0.5 * negativeMagnitudeSquared) * splatAlpha;
}

Loading

0 comments on commit ad56c89

Please sign in to comment.