Skip to content

Commit e9a291b

Browse files
GS Cleanup + depth renderer + more tests (#17544)
- removed `viewDirectionFactor` uniform - more explicit `eyeToSplatLocalSpace` name - more consistent Y scaling: keep data unchanged and change scaling.y at loading time - add options to `updateData` to preserve back compat - update spz v3 scene accordingly - removed `invy` matrix usage in shaders - depth renderer : https://playground.babylonjs.com/#V80DRL#12 - `loadFileAsync` now loads more format as it uses ImportMeshAsync internally - 4 new tests
1 parent 63b8d6a commit e9a291b

File tree

21 files changed

+219
-100
lines changed

21 files changed

+219
-100
lines changed

packages/dev/core/src/Materials/GaussianSplatting/gaussianSplattingMaterial.ts

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ export class GaussianSplattingMaterial extends PushMaterial {
148148
"focal",
149149
"eyePosition",
150150
"kernelSize",
151-
"viewDirectionFactor",
152151
"alpha",
152+
"depthValues",
153153
];
154154
private _sourceMesh: GaussianSplattingMesh | null = null;
155155
/**
@@ -328,7 +328,6 @@ export class GaussianSplattingMaterial extends PushMaterial {
328328
}
329329

330330
effect.setFloat2("focal", focal, focal);
331-
effect.setVector3("viewDirectionFactor", gsMesh.viewDirectionFactor);
332331
effect.setFloat("kernelSize", gsMaterial && gsMaterial.kernelSize ? gsMaterial.kernelSize : GaussianSplattingMaterial.KernelSize);
333332
effect.setFloat("alpha", gsMaterial.alpha);
334333
scene.bindEyePosition(effect, "eyePosition", true);
@@ -398,6 +397,84 @@ export class GaussianSplattingMaterial extends PushMaterial {
398397
this._afterBind(mesh, this._activeEffect, subMesh);
399398
}
400399

400+
protected static _BindEffectUniforms(gsMesh: GaussianSplattingMesh, gsMaterial: GaussianSplattingMaterial, shaderMaterial: ShaderMaterial, scene: Scene): void {
401+
const engine = scene.getEngine();
402+
const effect = shaderMaterial.getEffect()!;
403+
404+
gsMesh.getMeshUniformBuffer().bindToEffect(effect, "Mesh");
405+
shaderMaterial.bindView(effect);
406+
shaderMaterial.bindViewProjection(effect);
407+
408+
const renderWidth = engine.getRenderWidth();
409+
const renderHeight = engine.getRenderHeight();
410+
effect.setFloat2("invViewport", 1 / renderWidth, 1 / renderHeight);
411+
412+
const projection = scene.getProjectionMatrix();
413+
const t = projection.m[5];
414+
const focal = (renderWidth * t) / 2.0;
415+
416+
effect.setFloat2("focal", focal, focal);
417+
effect.setFloat("kernelSize", gsMaterial && gsMaterial.kernelSize ? gsMaterial.kernelSize : GaussianSplattingMaterial.KernelSize);
418+
effect.setFloat("alpha", gsMaterial.alpha);
419+
420+
let minZ: number, maxZ: number;
421+
422+
const camera = scene.activeCamera;
423+
if (!camera) {
424+
return;
425+
}
426+
const cameraIsOrtho = camera.mode === Camera.ORTHOGRAPHIC_CAMERA;
427+
if (cameraIsOrtho) {
428+
minZ = !engine.useReverseDepthBuffer && engine.isNDCHalfZRange ? 0 : 1;
429+
maxZ = engine.useReverseDepthBuffer && engine.isNDCHalfZRange ? 0 : 1;
430+
} else {
431+
minZ = engine.useReverseDepthBuffer && engine.isNDCHalfZRange ? camera.minZ : engine.isNDCHalfZRange ? 0 : camera.minZ;
432+
maxZ = engine.useReverseDepthBuffer && engine.isNDCHalfZRange ? 0 : camera.maxZ;
433+
}
434+
435+
effect.setFloat2("depthValues", minZ, minZ + maxZ);
436+
437+
if (gsMesh.covariancesATexture) {
438+
const textureSize = gsMesh.covariancesATexture.getSize();
439+
effect.setFloat2("dataTextureSize", textureSize.width, textureSize.height);
440+
441+
effect.setTexture("covariancesATexture", gsMesh.covariancesATexture);
442+
effect.setTexture("covariancesBTexture", gsMesh.covariancesBTexture);
443+
effect.setTexture("centersTexture", gsMesh.centersTexture);
444+
effect.setTexture("colorsTexture", gsMesh.colorsTexture);
445+
}
446+
}
447+
448+
/**
449+
* Create a depth rendering material for a Gaussian Splatting mesh
450+
* @param scene scene it belongs to
451+
* @param shaderLanguage GLSL or WGSL
452+
* @returns depth rendering shader material
453+
*/
454+
public makeDepthRenderingMaterial(scene: Scene, shaderLanguage: ShaderLanguage): ShaderMaterial {
455+
const shaderMaterial = new ShaderMaterial(
456+
"gaussianSplattingDepthRender",
457+
scene,
458+
{
459+
vertex: "gaussianSplattingDepth",
460+
fragment: "gaussianSplattingDepth",
461+
},
462+
{
463+
attributes: GaussianSplattingMaterial._Attribs,
464+
uniforms: GaussianSplattingMaterial._Uniforms,
465+
samplers: GaussianSplattingMaterial._Samplers,
466+
uniformBuffers: GaussianSplattingMaterial._UniformBuffers,
467+
shaderLanguage: shaderLanguage,
468+
defines: ["#define DEPTH_RENDER"],
469+
}
470+
);
471+
shaderMaterial.onBindObservable.add((mesh: AbstractMesh) => {
472+
const gsMaterial = mesh.material as GaussianSplattingMaterial;
473+
const gsMesh = mesh as GaussianSplattingMesh;
474+
GaussianSplattingMaterial._BindEffectUniforms(gsMesh, gsMaterial, shaderMaterial, scene);
475+
});
476+
return shaderMaterial;
477+
}
401478
protected static _MakeGaussianSplattingShadowDepthWrapper(scene: Scene, shaderLanguage: ShaderLanguage): ShadowDepthWrapper {
402479
const shaderMaterial = new ShaderMaterial(
403480
"gaussianSplattingDepth",
@@ -420,34 +497,10 @@ export class GaussianSplattingMaterial extends PushMaterial {
420497
});
421498

422499
shaderMaterial.onBindObservable.add((mesh: AbstractMesh) => {
423-
const effect = shaderMaterial.getEffect()!;
424500
const gsMaterial = mesh.material as GaussianSplattingMaterial;
425501
const gsMesh = mesh as GaussianSplattingMesh;
426502

427-
mesh.getMeshUniformBuffer().bindToEffect(effect, "Mesh");
428-
shaderMaterial.bindView(effect);
429-
shaderMaterial.bindViewProjection(effect);
430-
431-
const shadowmapWidth = scene.getEngine().getRenderWidth();
432-
const shadowmapHeight = scene.getEngine().getRenderHeight();
433-
effect.setFloat2("invViewport", 1 / shadowmapWidth, 1 / shadowmapHeight);
434-
435-
const projection = scene.getProjectionMatrix();
436-
const t = projection.m[5];
437-
const focal = (shadowmapWidth * t) / 2.0;
438-
439-
effect.setFloat2("focal", focal, focal);
440-
effect.setFloat("kernelSize", gsMaterial && gsMaterial.kernelSize ? gsMaterial.kernelSize : GaussianSplattingMaterial.KernelSize);
441-
442-
if (gsMesh.covariancesATexture) {
443-
const textureSize = gsMesh.covariancesATexture.getSize();
444-
effect.setFloat2("dataTextureSize", textureSize.width, textureSize.height);
445-
446-
effect.setTexture("covariancesATexture", gsMesh.covariancesATexture);
447-
effect.setTexture("covariancesBTexture", gsMesh.covariancesBTexture);
448-
effect.setTexture("centersTexture", gsMesh.centersTexture);
449-
effect.setTexture("colorsTexture", gsMesh.colorsTexture);
450-
}
503+
GaussianSplattingMaterial._BindEffectUniforms(gsMesh, gsMaterial, shaderMaterial, scene);
451504
});
452505

453506
return shadowDepthWrapper;

packages/dev/core/src/Materials/Node/Blocks/GaussianSplatting/gaussianSplattingBlock.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
133133
state._emitUniformFromString("invViewport", NodeMaterialBlockConnectionPointTypes.Vector2);
134134
state._emitUniformFromString("kernelSize", NodeMaterialBlockConnectionPointTypes.Float);
135135
state._emitUniformFromString("eyePosition", NodeMaterialBlockConnectionPointTypes.Vector3);
136-
state._emitUniformFromString("viewDirectionFactor", NodeMaterialBlockConnectionPointTypes.Vector3);
137136
state.attributes.push(VertexBuffer.PositionKind);
138137
state.attributes.push("splatIndex0");
139138
state.attributes.push("splatIndex1");
@@ -167,16 +166,14 @@ export class GaussianSplattingBlock extends NodeMaterialBlock {
167166
if (state.shaderLanguage === ShaderLanguage.WGSL) {
168167
state.compilationString += `let worldRot: mat3x3f = mat3x3f(${world.associatedVariableName}[0].xyz, ${world.associatedVariableName}[1].xyz, ${world.associatedVariableName}[2].xyz);`;
169168
state.compilationString += `let normWorldRot: mat3x3f = inverseMat3(worldRot);`;
170-
state.compilationString += `var dir: vec3f = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - uniforms.eyePosition));\n`;
171-
state.compilationString += `dir *= uniforms.viewDirectionFactor;\n`;
169+
state.compilationString += `var eyeToSplatLocalSpace: vec3f = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - uniforms.eyePosition));\n`;
172170
} else {
173171
state.compilationString += `mat3 worldRot = mat3(${world.associatedVariableName});`;
174172
state.compilationString += `mat3 normWorldRot = inverseMat3(worldRot);`;
175-
state.compilationString += `vec3 dir = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - eyePosition));\n`;
176-
state.compilationString += `dir *= viewDirectionFactor;\n`;
173+
state.compilationString += `vec3 eyeToSplatLocalSpace = normalize(normWorldRot * (${splatPosition.associatedVariableName}.xyz - eyePosition));\n`;
177174
}
178175

179-
state.compilationString += `${state._declareOutput(sh)} = computeSH(splat, dir);\n`;
176+
state.compilationString += `${state._declareOutput(sh)} = computeSH(splat, eyeToSplatLocalSpace);\n`;
180177
state.compilationString += `#else\n`;
181178
state.compilationString += `${state._declareOutput(sh)} = vec3${addF}(0.,0.,0.);\n`;
182179
state.compilationString += `#endif;\n`;

packages/dev/core/src/Meshes/GaussianSplatting/gaussianSplattingMesh.ts

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import { Logger } from "core/Misc/logger";
1111
import { GaussianSplattingMaterial } from "core/Materials/GaussianSplatting/gaussianSplattingMaterial";
1212
import { RawTexture } from "core/Materials/Textures/rawTexture";
1313
import { Constants } from "core/Engines/constants";
14-
import { Tools } from "core/Misc/tools";
1514
import "core/Meshes/thinInstanceMesh";
1615
import type { ThinEngine } from "core/Engines/thinEngine";
1716
import { ToHalfFloat } from "core/Misc/textureTools";
@@ -20,6 +19,7 @@ import { Scalar } from "core/Maths/math.scalar";
2019
import { runCoroutineSync, runCoroutineAsync, createYieldingScheduler, type Coroutine } from "core/Misc/coroutine";
2120
import { EngineStore } from "core/Engines/engineStore";
2221
import type { Camera } from "core/Cameras/camera";
22+
import { ImportMeshAsync } from "core/Loading/sceneLoader";
2323

2424
interface IDelayedTextureUpdate {
2525
covA: Uint16Array;
@@ -28,6 +28,9 @@ interface IDelayedTextureUpdate {
2828
centers: Float32Array;
2929
sh?: Uint8Array[];
3030
}
31+
interface IUpdateOptions {
32+
flipY?: boolean;
33+
}
3134

3235
// @internal
3336
const UnpackUnorm = (value: number, bits: number) => {
@@ -84,11 +87,6 @@ interface ICompressedPLYChunk {
8487
maxColor: Vector3;
8588
}
8689

87-
interface IPLYConversionBuffers {
88-
buffer: ArrayBuffer;
89-
sh?: [];
90-
}
91-
9290
/**
9391
* To support multiple camera rendering, rendered mesh is separated from the GaussianSplattingMesh itself.
9492
* The GS mesh serves as a proxy and a different mesh is rendered for each camera. This hot switch is done
@@ -325,15 +323,15 @@ export class GaussianSplattingMesh extends Mesh {
325323
// batch size between 2 yield calls during the PLY to splat conversion.
326324
private static _PlyConversionBatchSize = 32768;
327325
private _shDegree = 0;
328-
private _viewDirectionFactor = new Vector3(1, 1, -1);
329326

330327
private static readonly _BatchSize = 16; // 16 splats per instance
331328
private _cameraViewInfos = new Map<number, ICameraViewInfo>();
332329
/**
333330
* View direction factor used to compute the SH view direction in the shader.
331+
* @deprecated Not used anymore for SH rendering
334332
*/
335333
public get viewDirectionFactor() {
336-
return this._viewDirectionFactor;
334+
return Vector3.OneReadOnly;
337335
}
338336

339337
/**
@@ -420,7 +418,7 @@ export class GaussianSplattingMesh extends Mesh {
420418
*/
421419
public override set material(value: Material) {
422420
this._material = value;
423-
this._material.backFaceCulling = true;
421+
this._material.backFaceCulling = false;
424422
this._material.cullBackFaces = false;
425423
value.resetDrawCache();
426424
}
@@ -1320,15 +1318,14 @@ export class GaussianSplattingMesh extends Mesh {
13201318
}
13211319

13221320
/**
1323-
* Loads a .splat Gaussian or .ply Splatting file asynchronously
1321+
* Loads a Gaussian or Splatting file asynchronously
13241322
* @param url path to the splat file to load
1323+
* @param scene optional scene it belongs to
13251324
* @returns a promise that resolves when the operation is complete
13261325
* @deprecated Please use SceneLoader.ImportMeshAsync instead
13271326
*/
1328-
public async loadFileAsync(url: string): Promise<void> {
1329-
const plyBuffer = await Tools.LoadFileAsync(url, true);
1330-
const splatsData: IPLYConversionBuffers = await (GaussianSplattingMesh.ConvertPLYWithSHToSplatAsync(plyBuffer) as any);
1331-
await this.updateDataAsync(splatsData.buffer, splatsData.sh);
1327+
public async loadFileAsync(url: string, scene?: Scene): Promise<void> {
1328+
await ImportMeshAsync(url, (scene || EngineStore.LastCreatedScene)!, { pluginOptions: { splat: { gaussianSplattingMesh: this } } });
13321329
}
13331330

13341331
/**
@@ -1470,15 +1467,16 @@ export class GaussianSplattingMesh extends Mesh {
14701467
covB: Uint16Array,
14711468
colorArray: Uint8Array,
14721469
minimum: Vector3,
1473-
maximum: Vector3
1470+
maximum: Vector3,
1471+
options: IUpdateOptions
14741472
): void {
14751473
const matrixRotation = TmpVectors.Matrix[0];
14761474
const matrixScale = TmpVectors.Matrix[1];
14771475
const quaternion = TmpVectors.Quaternion[0];
14781476
const covBSItemSize = this._useRGBACovariants ? 4 : 2;
14791477

14801478
const x = fBuffer[8 * index + 0];
1481-
const y = -fBuffer[8 * index + 1];
1479+
const y = fBuffer[8 * index + 1] * (options.flipY ? -1 : 1);
14821480
const z = fBuffer[8 * index + 2];
14831481

14841482
this._splatPositions![4 * index + 0] = x;
@@ -1582,7 +1580,7 @@ export class GaussianSplattingMesh extends Mesh {
15821580
}
15831581
}
15841582

1585-
private *_updateData(data: ArrayBuffer, isAsync: boolean, sh?: Uint8Array[]): Coroutine<void> {
1583+
private *_updateData(data: ArrayBuffer, isAsync: boolean, sh?: Uint8Array[], options: IUpdateOptions = { flipY: false }): Coroutine<void> {
15861584
// if a covariance texture is present, then it's not a creation but an update
15871585
if (!this._covariancesATexture) {
15881586
this._readyToDisplay = false;
@@ -1630,7 +1628,7 @@ export class GaussianSplattingMesh extends Mesh {
16301628
const updateLine = partIndex * lineCountUpdate;
16311629
const splatIndexBase = updateLine * textureSize.x;
16321630
for (let i = 0; i < textureLengthPerUpdate; i++) {
1633-
this._makeSplat(splatIndexBase + i, fBuffer, uBuffer, covA, covB, colorArray, minimum, maximum);
1631+
this._makeSplat(splatIndexBase + i, fBuffer, uBuffer, covA, covB, colorArray, minimum, maximum, options);
16341632
}
16351633
this._updateSubTextures(this._splatPositions, covA, covB, colorArray, updateLine, Math.min(lineCountUpdate, textureSize.y - updateLine));
16361634
// Update the binfo
@@ -1648,7 +1646,7 @@ export class GaussianSplattingMesh extends Mesh {
16481646
} else {
16491647
const paddedVertexCount = (vertexCount + 15) & ~0xf;
16501648
for (let i = 0; i < vertexCount; i++) {
1651-
this._makeSplat(i, fBuffer, uBuffer, covA, covB, colorArray, minimum, maximum);
1649+
this._makeSplat(i, fBuffer, uBuffer, covA, covB, colorArray, minimum, maximum, options);
16521650
if (isAsync && i % GaussianSplattingMesh._SplatBatchSize === 0) {
16531651
yield;
16541652
}
@@ -1662,6 +1660,7 @@ export class GaussianSplattingMesh extends Mesh {
16621660
// Update the binfo
16631661
this.getBoundingInfo().reConstruct(minimum, maximum, this.getWorldMatrix());
16641662
this.setEnabled(true);
1663+
this._sortIsDirty = true;
16651664
}
16661665
this._postToWorker(true);
16671666
}
@@ -1681,9 +1680,10 @@ export class GaussianSplattingMesh extends Mesh {
16811680
* Update data from GS (position, orientation, color, scaling)
16821681
* @param data array that contain all the datas
16831682
* @param sh optional array of uint8 array for SH data
1683+
* @param options optional informations on how to treat data
16841684
*/
1685-
public updateData(data: ArrayBuffer, sh?: Uint8Array[]): void {
1686-
runCoroutineSync(this._updateData(data, false, sh));
1685+
public updateData(data: ArrayBuffer, sh?: Uint8Array[], options: IUpdateOptions = { flipY: true }): void {
1686+
runCoroutineSync(this._updateData(data, false, sh, options));
16871687
}
16881688

16891689
/**

packages/dev/core/src/Rendering/depthRenderer.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { BindBonesParameters, BindMorphTargetParameters, PrepareDefinesAndAttrib
2121
import { ShaderLanguage } from "core/Materials/shaderLanguage";
2222
import { EffectFallbacks } from "core/Materials/effectFallbacks";
2323
import type { IEffectCreationOptions } from "core/Materials/effect";
24+
import type { GaussianSplattingMaterial } from "../Materials/GaussianSplatting/gaussianSplattingMaterial";
2425

2526
/**
2627
* This represents a depth renderer in Babylon.
@@ -245,7 +246,15 @@ export class DepthRenderer {
245246
if (this.isReady(subMesh, hardwareInstancedRendering) && camera) {
246247
subMesh._renderId = scene.getRenderId();
247248

248-
const renderingMaterial = effectiveMesh._internalAbstractMeshDataInfo._materialForRenderPass?.[engine.currentRenderPassId];
249+
let renderingMaterial = effectiveMesh._internalAbstractMeshDataInfo._materialForRenderPass?.[engine.currentRenderPassId];
250+
if (renderingMaterial === undefined && effectiveMesh.getClassName() === "GaussianSplattingMesh") {
251+
const gsMaterial = effectiveMesh.material! as GaussianSplattingMaterial;
252+
renderingMaterial = gsMaterial.makeDepthRenderingMaterial(this._scene, this._shaderLanguage);
253+
this.setMaterialForRendering(effectiveMesh, renderingMaterial);
254+
if (!renderingMaterial.isReady()) {
255+
return;
256+
}
257+
}
249258

250259
let drawWrapper = subMesh._getDrawWrapper();
251260
if (!drawWrapper && renderingMaterial) {

packages/dev/core/src/Shaders/ShadersInclude/gaussianSplatting.fx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ vec4 gaussianSplatting(vec2 meshPos, vec3 worldPos, vec2 scale, vec3 covA, vec3
244244
);
245245
}
246246

247-
mat3 invy = mat3(1, 0, 0, 0, -1, 0, 0, 0, 1);
248-
249-
mat3 T = invy * transpose(mat3(modelView)) * J;
247+
mat3 T = transpose(mat3(modelView)) * J;
250248
mat3 cov2d = transpose(T) * Vrk * T;
251249

252250
#if COMPENSATION

packages/dev/core/src/Shaders/gaussianSplatting.vertex.fx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ uniform vec2 dataTextureSize;
1616
uniform vec2 focal;
1717
uniform float kernelSize;
1818
uniform vec3 eyePosition;
19-
uniform vec3 viewDirectionFactor;
2019
uniform float alpha;
2120

2221
uniform sampler2D covariancesATexture;
@@ -55,9 +54,8 @@ void main () {
5554
mat3 worldRot = mat3(world);
5655
mat3 normWorldRot = inverseMat3(worldRot);
5756

58-
vec3 dir = normalize(normWorldRot * (worldPos.xyz - eyePosition));
59-
dir *= viewDirectionFactor;
60-
vColor.xyz = splat.color.xyz + computeSH(splat, dir);
57+
vec3 eyeToSplatLocalSpace = normalize(normWorldRot * (worldPos.xyz - eyePosition));
58+
vColor.xyz = splat.color.xyz + computeSH(splat, eyeToSplatLocalSpace);
6159
#endif
6260
vColor.w *= alpha;
6361

packages/dev/core/src/Shaders/gaussianSplattingDepth.fragment.fx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@ precision highp float;
33
varying vec2 vPosition;
44
varying vec4 vColor;
55

6+
#ifdef DEPTH_RENDER
7+
varying float vDepthMetric;
8+
#endif
9+
610
void main(void) {
711
float A = -dot(vPosition, vPosition);
8-
912
#if defined(SM_SOFTTRANSPARENTSHADOW) && SM_SOFTTRANSPARENTSHADOW == 1
1013
float alpha = exp(A) * vColor.a;
1114
if (A < -4.) discard;
1215
#else
13-
if (A < -1.) discard;
16+
if (A < -vColor.a) discard;
17+
#endif
18+
#ifdef DEPTH_RENDER
19+
gl_FragColor = vec4(vDepthMetric, 0.0, 0.0, 1.0);
1420
#endif
1521
}

0 commit comments

Comments
 (0)