diff --git a/examples/hello-world/index.html b/examples/hello-world/index.html index 6a8c61ab..6f89ad8a 100644 --- a/examples/hello-world/index.html +++ b/examples/hello-world/index.html @@ -24,7 +24,6 @@ - - - - diff --git a/examples/stochastic/index.html b/examples/stochastic/index.html deleted file mode 100644 index df52da07..00000000 --- a/examples/stochastic/index.html +++ /dev/null @@ -1,79 +0,0 @@ - - - - - - - Spark • Stochastic Rendering - - - - - - - - - diff --git a/src/OldSparkRenderer.ts b/src/OldSparkRenderer.ts deleted file mode 100644 index 2a3224c6..00000000 --- a/src/OldSparkRenderer.ts +++ /dev/null @@ -1,1083 +0,0 @@ -import * as THREE from "three"; - -import { - OldSparkViewpoint, - type OldSparkViewpointOptions, -} from "./OldSparkViewpoint"; -import { OldSplatAccumulator } from "./OldSplatAccumulator"; -import { OldSplatGeometry } from "./OldSplatGeometry"; -import { PackedSplats } from "./PackedSplats"; -import { RgbaArray } from "./RgbaArray"; -import type { GeneratorMapping } from "./SplatAccumulator"; -import { SplatEdit } from "./SplatEdit"; -import { SplatGenerator, SplatModifier } from "./SplatGenerator"; -import { SplatMesh } from "./SplatMesh"; -import { - DEFAULT_SPLAT_ENCODING, - LN_SCALE_MAX, - LN_SCALE_MIN, - type SplatEncoding, -} from "./defines"; -import { - DynoVec3, - DynoVec4, - Gsplat, - TPackedSplats, - dynoBlock, - readPackedSplat, - transformGsplat, -} from "./dyno"; -import { getShaders } from "./shaders"; -import { - averagePositions, - averageQuaternions, - cloneClock, - withinCoorientDist, -} from "./utils"; - -// SparkRenderer aggregates splats from multiple generators into a single -// accumulated collection per frame. In normal operation we only need a -// maximum of 3 accumulators: One currently being viewed, one currently -// being sorted, and one more for generating the next frame. Accumulators -// must be "released" by each viewpoint using it, so in unusual cases -// such as slow render-outs, we may want to allow more than 3 so the -// pipeline can continue generating new frames, but we limit to a maximum -// of 5 to avoid excessive memory usage. -const MAX_ACCUMULATORS = 5; - -export type OldSparkRendererOptions = { - /** - * Pass in your THREE.WebGLRenderer instance so Spark can perform work - * outside the usual render loop. Should be created with antialias: false - * (default setting) as WebGL anti-aliasing doesn't improve Gaussian Splatting - * rendering and significantly reduces performance. - */ - renderer: THREE.WebGLRenderer; - /** - * Whether to use premultiplied alpha when accumulating splat RGB - * @default true - */ - premultipliedAlpha?: boolean; - /** - * Pass in a THREE.Clock to synchronize time-based effects across different - * systems. Alternatively, you can set the SparkRenderer properties time and - * deltaTime directly. (default: new THREE.Clock) - */ - clock?: THREE.Clock; - /** - * Controls whether to check and automatically update Gsplat collection after - * each frame render. - * @default true - */ - autoUpdate?: boolean; - /** - * Controls whether to update the Gsplats before or after rendering. For WebXR - * this must be false in order to complete rendering as soon as possible. - * @default false - */ - preUpdate?: boolean; - /** - * Distance threshold for SparkRenderer movement triggering a Gsplat update at - * the new origin. - * @default 1.0 - */ - originDistance?: number; - /** - * Maximum standard deviations from the center to render Gaussians. Values - * Math.sqrt(5)..Math.sqrt(8) produce good results and can be tweaked for - * performance. - * @default Math.sqrt(8) - */ - maxStdDev?: number; - /** - * Minimum pixel radius for splat rendering. - * @default 0.0 - */ - minPixelRadius?: number; - /** - * Maximum pixel radius for splat rendering. - * @default 512.0 - */ - maxPixelRadius?: number; - /** - * Minimum alpha value for splat rendering. - * @default 0.5 * (1.0 / 255.0) - */ - minAlpha?: number; - /** - * Enable 2D Gaussian splatting rendering ability. When this mode is enabled, - * any scale x/y/z component that is exactly 0 (minimum quantized value) results - * in the other two non-0 axis being interpreted as an oriented 2D Gaussian Splat, - * rather instead of the usual projected 3DGS Z-slice. When reading PLY files, - * scale values less than e^-30 will be interpreted as 0. - * @default false - */ - enable2DGS?: boolean; - /** - * Scalar value to add to 2D splat covariance diagonal, effectively blurring + - * enlarging splats. In scenes trained without the Gsplat anti-aliasing tweak - * this value was typically 0.3, but with anti-aliasing it is 0.0 - * @default 0.0 - */ - preBlurAmount?: number; - /** - * Scalar value to add to 2D splat covarianve diagonal, with opacity adjustment - * to correctly account for "blurring" when anti-aliasing. Typically 0.3 - * (equivalent to approx 0.5 pixel radius) in scenes trained with anti-aliasing. - */ - blurAmount?: number; - /** - * Depth-of-field distance to focal plane - */ - focalDistance?: number; - /** - * Full-width angle of aperture opening (in radians), 0.0 to disable - * @default 0.0 - */ - apertureAngle?: number; - /** - * Modulate Gaussian kernel falloff. 0 means "no falloff, flat shading", - * while 1 is the normal Gaussian kernel. - * @default 1.0 - */ - falloff?: number; - /** - * X/Y clipping boundary factor for Gsplat centers against view frustum. - * 1.0 clips any centers that are exactly out of bounds, while 1.4 clips - * centers that are 40% beyond the bounds. - * @default 1.4 - */ - clipXY?: number; - /** - * Parameter to adjust projected splat scale calculation to match other renderers, - * similar to the same parameter in the MKellogg 3DGS renderer. Higher values will - * tend to sharpen the splats. A value 2.0 can be used to match the behavior of - * the PlayCanvas renderer. - * @default 1.0 - */ - focalAdjustment?: number; - /** - * Configures the SparkViewpointOptions for the default SparkViewpoint - * associated with this SparkRenderer. Notable option: sortRadial (sort by - * radial distance or Z-depth) - */ - view?: OldSparkViewpointOptions; - /** - * Override the default splat encoding ranges for the PackedSplats. - * (default: undefined) - */ - splatEncoding?: SplatEncoding; -}; - -export class OldSparkRenderer extends THREE.Mesh { - renderer: THREE.WebGLRenderer; - premultipliedAlpha: boolean; - material: THREE.ShaderMaterial; - uniforms: ReturnType; - - autoUpdate: boolean; - preUpdate: boolean; - needsUpdate: boolean; - originDistance: number; - maxStdDev: number; - minPixelRadius: number; - maxPixelRadius: number; - minAlpha: number; - enable2DGS: boolean; - preBlurAmount: number; - blurAmount: number; - focalDistance: number; - apertureAngle: number; - falloff: number; - clipXY: number; - focalAdjustment: number; - splatEncoding: SplatEncoding; - - splatTexture: null | { - enable?: boolean; - texture?: THREE.Data3DTexture; - multiply?: THREE.Matrix2; - add?: THREE.Vector2; - near?: number; - far?: number; - mid?: number; - } = null; - - time?: number; - deltaTime?: number; - clock: THREE.Clock; - - // Latest Gsplat collection being displayed - active: OldSplatAccumulator; - // Free list of accumulators for reuse - private freeAccumulators: OldSplatAccumulator[]; - // Total number of accumulators currently allocated - private accumulatorCount: number; - // Default SparkViewpoint used for rendering to the canvas - defaultView: OldSparkViewpoint; - // List of SparkViewpoints with autoUpdate enabled - autoViewpoints: OldSparkViewpoint[] = []; - - // Dynos used to transform Gsplats to the accumulator coordinate system - private rotateToAccumulator = new DynoVec4({ value: new THREE.Quaternion() }); - private translateToAccumulator = new DynoVec3({ value: new THREE.Vector3() }); - private modifier: SplatModifier; - - // Last rendered frame number so we know when we're rendering a new frame - private lastFrame = -1; - // Last update timestamp to compute deltaTime - private lastUpdateTime: number | null = null; - // List of cameras used for the current viewpoint (for WebXR) - private defaultCameras: THREE.Matrix4[] = []; - private lastStochastic: boolean | null = null; - - // Should be set to the defaultView, but can be temporarily changed to another - // viewpoint using prepareViewpoint() for rendering from a different viewpoint. - viewpoint: OldSparkViewpoint; - - // Holds data needed to perform a scheduled Gsplat update. - private pendingUpdate = { - scene: null as THREE.Scene | null, - originToWorld: new THREE.Matrix4(), - timeoutId: -1, - }; - - // Internal SparkViewpoint used for environment map rendering. - private envViewpoint: OldSparkViewpoint | null = null; - - // Data and buffers used for environment map rendering - private static cubeRender: { - target: THREE.WebGLCubeRenderTarget; - camera: THREE.CubeCamera; - near: number; - far: number; - } | null = null; - private static pmrem: THREE.PMREMGenerator | null = null; - - static EMPTY_SPLAT_TEXTURE = new THREE.Data3DTexture(); - - constructor(options: OldSparkRendererOptions) { - const uniforms = OldSparkRenderer.makeUniforms(); - const shaders = getShaders(); - const premultipliedAlpha = options.premultipliedAlpha ?? true; - const material = new THREE.ShaderMaterial({ - glslVersion: THREE.GLSL3, - vertexShader: shaders.oldSplatVertex, - fragmentShader: shaders.oldSplatFragment, - uniforms, - premultipliedAlpha, - transparent: true, - depthTest: true, - depthWrite: false, - side: THREE.DoubleSide, - }); - - super(EMPTY_GEOMETRY, material); - // Disable frustum culling because we want to always draw them all - // and cull Gsplats individually in the shader - this.frustumCulled = false; - - this.renderer = options.renderer; - this.material = material; - this.uniforms = uniforms; - - // Create a Gsplat modifier that takes the output of any SplatGenerator - // and transforms them into the accumulator's coordinate system - const modifier = dynoBlock( - { gsplat: Gsplat }, - { gsplat: Gsplat }, - ({ gsplat }) => { - if (!gsplat) { - throw new Error("gsplat not defined"); - } - gsplat = transformGsplat(gsplat, { - rotate: this.rotateToAccumulator, - translate: this.translateToAccumulator, - }); - return { gsplat }; - }, - ); - this.modifier = new SplatModifier(modifier); - - this.premultipliedAlpha = premultipliedAlpha; - this.autoUpdate = options.autoUpdate ?? true; - this.preUpdate = options.preUpdate ?? false; - this.needsUpdate = false; - this.originDistance = options.originDistance ?? 1; - this.maxStdDev = options.maxStdDev ?? Math.sqrt(8.0); - this.minPixelRadius = options.minPixelRadius ?? 0.0; - this.maxPixelRadius = options.maxPixelRadius ?? 512.0; - this.minAlpha = options.minAlpha ?? 0.5 * (1.0 / 255.0); - this.enable2DGS = options.enable2DGS ?? false; - this.preBlurAmount = options.preBlurAmount ?? 0.0; - this.blurAmount = options.blurAmount ?? 0.3; - this.focalDistance = options.focalDistance ?? 0.0; - this.apertureAngle = options.apertureAngle ?? 0.0; - this.falloff = options.falloff ?? 1.0; - this.clipXY = options.clipXY ?? 1.4; - this.focalAdjustment = options.focalAdjustment ?? 1.0; - this.splatEncoding = options.splatEncoding ?? { ...DEFAULT_SPLAT_ENCODING }; - - this.active = new OldSplatAccumulator(); - this.active.refCount = 1; - this.accumulatorCount = 1; - this.freeAccumulators = []; - // Start with the minimum of 2 total accumulators - for (let count = 0; count < 1; ++count) { - this.freeAccumulators.push(new OldSplatAccumulator()); - this.accumulatorCount += 1; - } - - // Create a default SparkViewpoint that is used when we call render() - // on the scene and has the sorted Gsplat collection from that viewpoint. - this.defaultView = new OldSparkViewpoint({ - ...options.view, - autoUpdate: true, - spark: this, - }); - this.viewpoint = this.defaultView; - this.prepareViewpoint(this.viewpoint); - - this.clock = options.clock ? cloneClock(options.clock) : new THREE.Clock(); - } - - static makeUniforms() { - // Create uniforms used for Gsplat vertex and fragment shaders - const uniforms = { - // Size of render viewport in pixels - renderSize: { value: new THREE.Vector2() }, - // Near and far plane distances - near: { value: 0.1 }, - far: { value: 1000.0 }, - // Total number of Gsplats in packedSplats to render - numSplats: { value: 0 }, - // SplatAccumulator to view transformation quaternion - renderToViewQuat: { value: new THREE.Quaternion() }, - // SplatAccumulator to view transformation translation - renderToViewPos: { value: new THREE.Vector3() }, - // Maximum distance (in stddevs) from Gsplat center to render - maxStdDev: { value: 1.0 }, - // Minimum pixel radius for splat rendering - minPixelRadius: { value: 0.0 }, - // Maximum pixel radius for splat rendering - maxPixelRadius: { value: 512.0 }, - // Minimum alpha value for splat rendering - minAlpha: { value: 0.5 * (1.0 / 255.0) }, - // Enable stochastic splat rendering - stochastic: { value: false }, - // Enable interpreting 0-thickness Gsplats as 2DGS - enable2DGS: { value: false }, - // Add to projected 2D splat covariance diagonal (thickens and brightens) - preBlurAmount: { value: 0.0 }, - // Add to 2D splat covariance diagonal and adjust opacity (anti-aliasing) - blurAmount: { value: 0.3 }, - // Depth-of-field distance to focal plane - focalDistance: { value: 0.0 }, - // Full-width angle of aperture opening (in radians) - apertureAngle: { value: 0.0 }, - // Modulate Gaussian kernal falloff. 0 means "no falloff, flat shading", - // 1 is normal e^-x^2 falloff. - falloff: { value: 1.0 }, - // Clip Gsplats that are clipXY times beyond the +-1 frustum bounds - clipXY: { value: 1.4 }, - // Debug renderSize scale factor - focalAdjustment: { value: 1.0 }, - // Enable splat texture rendering - splatTexEnable: { value: false }, - // Splat texture to render - splatTexture: { type: "t", value: OldSparkRenderer.EMPTY_SPLAT_TEXTURE }, - // Splat texture UV transform (multiply) - splatTexMul: { value: new THREE.Matrix2() }, - // Splat texture UV transform (add) - splatTexAdd: { value: new THREE.Vector2() }, - // Splat texture near plane distance - splatTexNear: { value: 0.1 }, - // Splat texture far plane distance - splatTexFar: { value: 1000.0 }, - // Splat texture mid plane distance, or 0.0 to disable - splatTexMid: { value: 0.0 }, - // Gsplat collection to render - packedSplats: { type: "t", value: PackedSplats.getEmptyArray }, - // Splat encoding ranges - rgbMinMaxLnScaleMinMax: { value: new THREE.Vector4() }, - // Time in seconds for time-based effects - time: { value: 0 }, - // Delta time in seconds since last frame - deltaTime: { value: 0 }, - // Whether to encode Gsplat with linear RGB (for environment mapping) - encodeLinear: { value: false }, - // Debug flag that alternates each frame - debugFlag: { value: false }, - }; - return uniforms; - } - - private canAllocAccumulator(): boolean { - // Returns true if can allocate an accumulator immediately - return ( - this.freeAccumulators.length > 0 || - this.accumulatorCount < MAX_ACCUMULATORS - ); - } - - private maybeAllocAccumulator(): OldSplatAccumulator | null { - // Allocate an accumulator immediately if possible, else return null - let accumulator = this.freeAccumulators.pop(); - if (accumulator === undefined) { - if (this.accumulatorCount >= MAX_ACCUMULATORS) { - return null; - } - accumulator = new OldSplatAccumulator(); - this.accumulatorCount += 1; - } - accumulator.refCount = 1; - return accumulator; - } - - releaseAccumulator(accumulator: OldSplatAccumulator) { - // Decrement reference count and recycle if no longer in use - accumulator.refCount -= 1; - if (accumulator.refCount === 0) { - this.freeAccumulators.push(accumulator); - } - } - - newViewpoint(options: OldSparkViewpointOptions) { - // Create a new SparkViewpoint for this SparkRenderer. - // Note that every SparkRenderer has an initial spark.defaultView: SparkViewpoint - // from construction, which is used for the default canvas render loop. - // Calling this method allows you to create additional viewpoints, which can be - // updated automatically each frame (performing Gsplat sorting every time there - // is an update), or updated on-demand for controlled rendering for video render - // or similar applications. - return new OldSparkViewpoint({ ...options, spark: this }); - } - - onBeforeRender( - renderer: THREE.WebGLRenderer, - scene: THREE.Scene, - camera: THREE.Camera, - ) { - // throw new Error("onBeforeRender disabled in SparkRenderer"); - - // Called by Three.js before rendering this SparkRenderer. - // At this point we can't modify the geometry or material, all these must - // be set in the scene already before this is called. Update the uniforms - // to render the Gsplats from the current active viewpoint. - const time = this.time ?? this.clock.getElapsedTime(); - const deltaTime = time - (this.viewpoint.lastTime ?? time); - this.viewpoint.lastTime = time; - - const frame = renderer.info.render.frame; - const isNewFrame = frame !== this.lastFrame; - this.lastFrame = frame; - - const viewpoint = this.viewpoint; - if (viewpoint === this.defaultView) { - // When rendering is triggered on the default viewpoint, - // perform automatic updates. - if (isNewFrame) { - if (!renderer.xr.isPresenting) { - // Non-WebXR mode, just a single camera - this.defaultView.viewToWorld = camera.matrixWorld.clone(); - this.defaultCameras = [this.defaultView.viewToWorld]; - } else { - // In WebXR mode we are called multiple times, once for each eye, - // so use their average to compute the sort center. - const cameras = renderer.xr.getCamera().cameras; - this.defaultCameras = cameras.map((camera) => camera.matrixWorld); - this.defaultView.viewToWorld = - averageOriginToWorlds(this.defaultCameras) ?? new THREE.Matrix4(); - } - } - - if (this.autoUpdate) { - this.update({ scene, viewToWorld: this.defaultView.viewToWorld }); - } - } - - // Update uniforms for rendering - - if (isNewFrame) { - // Keep these uniforms the same for both eyes if in WebXR - if (this.material.premultipliedAlpha !== this.premultipliedAlpha) { - this.material.premultipliedAlpha = this.premultipliedAlpha; - this.material.needsUpdate = true; - } - this.uniforms.time.value = time; - this.uniforms.deltaTime.value = deltaTime; - // Alternating debug flag that can aid in visual debugging - this.uniforms.debugFlag.value = (performance.now() / 1000.0) % 2.0 < 1.0; - - if (viewpoint.display && viewpoint.stochastic) { - (this.geometry as OldSplatGeometry).instanceCount = - this.uniforms.numSplats.value; - } - } - - if (viewpoint.target) { - // Rendering to a texture target, so its dimensions - this.uniforms.renderSize.value.set( - viewpoint.target.width, - viewpoint.target.height, - ); - } else { - // Rendering to the canvas or WebXR - const renderSize = renderer.getDrawingBufferSize( - this.uniforms.renderSize.value, - ); - if (renderSize.x === 1 && renderSize.y === 1) { - // WebXR mode on Apple Vision Pro returns 1x1 when presenting. - // Use a different means to figure out the render size. - const baseLayer = renderer.xr.getSession()?.renderState.baseLayer; - if (baseLayer) { - renderSize.x = baseLayer.framebufferWidth; - renderSize.y = baseLayer.framebufferHeight; - } - } - } - - // Update uniforms from instance properties - const typedCamera = camera as - | THREE.PerspectiveCamera - | THREE.OrthographicCamera; - this.uniforms.near.value = typedCamera.near; - this.uniforms.far.value = typedCamera.far; - this.uniforms.encodeLinear.value = viewpoint.encodeLinear; - this.uniforms.maxStdDev.value = this.maxStdDev; - this.uniforms.minPixelRadius.value = this.minPixelRadius; - this.uniforms.maxPixelRadius.value = this.maxPixelRadius; - this.uniforms.minAlpha.value = this.minAlpha; - this.uniforms.stochastic.value = viewpoint.stochastic; - this.uniforms.enable2DGS.value = this.enable2DGS; - this.uniforms.preBlurAmount.value = this.preBlurAmount; - this.uniforms.blurAmount.value = this.blurAmount; - this.uniforms.focalDistance.value = this.focalDistance; - this.uniforms.apertureAngle.value = this.apertureAngle; - this.uniforms.falloff.value = this.falloff; - this.uniforms.clipXY.value = this.clipXY; - this.uniforms.focalAdjustment.value = this.focalAdjustment; - - if (this.lastStochastic !== !viewpoint.stochastic) { - this.lastStochastic = !viewpoint.stochastic; - this.material.transparent = !viewpoint.stochastic; - this.material.depthWrite = viewpoint.stochastic; - this.material.needsUpdate = true; - } - - if (this.splatTexture) { - const { enable, texture, multiply, add, near, far, mid } = - this.splatTexture; - if (enable && texture) { - this.uniforms.splatTexEnable.value = true; - this.uniforms.splatTexture.value = texture; - if (multiply) { - this.uniforms.splatTexMul.value.fromArray(multiply.elements); - } else { - this.uniforms.splatTexMul.value.set( - 0.5 / this.maxStdDev, - 0, - 0, - 0.5 / this.maxStdDev, - ); - } - this.uniforms.splatTexAdd.value.set(add?.x ?? 0.5, add?.y ?? 0.5); - this.uniforms.splatTexNear.value = near ?? this.uniforms.near.value; - this.uniforms.splatTexFar.value = far ?? this.uniforms.far.value; - this.uniforms.splatTexMid.value = mid ?? 0.0; - } else { - this.uniforms.splatTexEnable.value = false; - this.uniforms.splatTexture.value = OldSparkRenderer.EMPTY_SPLAT_TEXTURE; - } - } else { - this.uniforms.splatTexEnable.value = false; - this.uniforms.splatTexture.value = OldSparkRenderer.EMPTY_SPLAT_TEXTURE; - } - - // Calculate the transform from the accumulator to the current camera - const accumToWorld = - viewpoint.display?.accumulator.toWorld ?? new THREE.Matrix4(); - const worldToCamera = camera.matrixWorld.clone().invert(); - const originToCamera = accumToWorld.clone().premultiply(worldToCamera); - originToCamera.decompose( - this.uniforms.renderToViewPos.value, - this.uniforms.renderToViewQuat.value, - new THREE.Vector3(), - ); - } - - // Update the uniforms for the given viewpoint. - // Note that the client expects to be able to call render() at any point - // to update the canvas, so we must switch the viewpoint back to - // defaultView when we're finished. - prepareViewpoint(viewpoint?: OldSparkViewpoint) { - this.viewpoint = viewpoint ?? this.viewpoint; - - if (this.viewpoint.display) { - const { accumulator, geometry } = this.viewpoint.display; - this.uniforms.numSplats.value = accumulator.splats.numSplats; - this.uniforms.packedSplats.value = accumulator.splats.getTexture(); - this.uniforms.rgbMinMaxLnScaleMinMax.value.set( - accumulator.splats.splatEncoding?.rgbMin ?? 0.0, - accumulator.splats.splatEncoding?.rgbMax ?? 1.0, - accumulator.splats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, - accumulator.splats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, - ); - this.geometry = geometry; - this.material.transparent = !this.viewpoint.stochastic; - this.material.depthWrite = this.viewpoint.stochastic; - this.material.needsUpdate = true; - } else { - // No Gsplats to display for this viewpoint yet - this.uniforms.numSplats.value = 0; - this.uniforms.packedSplats.value = PackedSplats.getEmptyArray; - this.geometry = EMPTY_GEOMETRY; - } - } - - // If spark.autoUpdate is false then you must manually call - // spark.update({ scene }) to have the scene Gsplats be re-generated. - update({ - scene, - viewToWorld, - }: { scene: THREE.Scene; viewToWorld?: THREE.Matrix4 }) { - // Compute the transform for the SparkRenderer to use as origin - // for Gsplat generation and accumulation. - const originToWorld = this.matrixWorld; - - // Either do the update now, or in the next "tick" depending on preUpdate - if (this.preUpdate) { - this.updateInternal({ - scene, - originToWorld: originToWorld.clone(), - viewToWorld, - }); - } else { - // Pass the update parameters to be performed on the next tick - this.pendingUpdate.scene = scene; - this.pendingUpdate.originToWorld.copy(originToWorld); - - // Schedule a timeout if there isn't one already - if (this.pendingUpdate.timeoutId === -1) { - this.pendingUpdate.timeoutId = setTimeout(() => { - const { scene, originToWorld } = this.pendingUpdate; - this.pendingUpdate.scene = null; - this.pendingUpdate.timeoutId = -1; - const updated = this.updateInternal({ - scene: scene as THREE.Scene, - originToWorld, - viewToWorld, - }); - - if (updated) { - // Flush to encourage eager execution - const gl = this.renderer.getContext() as WebGL2RenderingContext; - gl.flush(); - } - }, 1); - } - } - } - - updateInternal({ - scene, - originToWorld, - viewToWorld, - }: { - scene: THREE.Scene; - originToWorld?: THREE.Matrix4; - viewToWorld?: THREE.Matrix4; - }): boolean { - if (!this.canAllocAccumulator()) { - // We don't have any available accumulators because of sorting - // back pressure, so don't update this time but try again next time. - // Signal update not attempted. - return false; - } - - // Figure out the frame of the SparkRenderer and current view - if (!originToWorld) { - originToWorld = this.active.toWorld; - } - viewToWorld = viewToWorld ?? originToWorld.clone(); - - const time = this.time ?? this.clock.getElapsedTime(); - const deltaTime = time - (this.lastUpdateTime ?? time); - this.lastUpdateTime = time; - - // Create a lookup from last active SplatGenerator to Gsplat mapping record - const activeMapping = this.active.mapping.reduce((map, record) => { - map.set(record.node, record); - return map; - }, new Map()); - - // Traverse visible scene to find all SplatGenerators and global SplatEdits - const { generators, visibleGenerators, globalEdits } = - this.compileScene(scene); - - // Let all SplatGenerators run their frameUpdate() method - for (const object of generators) { - object.frameUpdate?.({ - renderer: this.renderer, - object, - time, - deltaTime, - viewToWorld, - globalEdits, - }); - } - - const visibleGenHash = new Set(visibleGenerators.map((g) => g.uuid)); - - // Make sure we have new version numbers for any objects with either - // generator or numSplats that have changed since the last frame. - for (const object of generators) { - const current = activeMapping.get(object); - const isVisible = object.generator && visibleGenHash.has(object.uuid); - const numSplats = isVisible ? object.numSplats : 0; - if ( - this.needsUpdate || - object.generator !== current?.generator || - numSplats !== current?.count - ) { - object.updateVersion(); - } - } - - // Check if the origin is within the maximum allowed distance before - // we trigger an update. - const originUpdate = !withinCoorientDist({ - matrix1: originToWorld, - matrix2: this.active.toWorld, - maxDistance: this.originDistance, - }); - - // Check if we need any update at all - const needsUpdate = - this.needsUpdate || - originUpdate || - generators.length !== activeMapping.size || - generators.some((g) => g.version !== activeMapping.get(g)?.version); - this.needsUpdate = false; - - let accumulator: OldSplatAccumulator | null = null; - if (needsUpdate) { - // Need to update, so allocate an accumulator - accumulator = this.maybeAllocAccumulator(); - if (!accumulator) { - // This should never happen since we checked canAllocAccumulator() above - throw new Error("Unreachable"); - } - - // Compute whether our view frame has changed enough to warrant - // doing a Gsplat sort. Check both distance epsilon and - // minimum co-orientation (dot product of quaternions) - const originChanged = !withinCoorientDist({ - matrix1: originToWorld, - matrix2: accumulator.toWorld, - maxDistance: 0.00001, - minCoorient: 0.99999, - }); - - // Compute an ordering of the generators with the rough goal - // of keeping unchanging generators near the front to minimize - // the number of Gsplats that need to be regenerated. - const sorted = visibleGenerators - .map((g, gIndex): [number, number, SplatGenerator] => { - const lastGen = activeMapping.get(g); - // If no previous generator, sort by absolute version, which will - // tend to push frequently updated generators toward the end - return !lastGen - ? [Number.POSITIVE_INFINITY, g.version, g] - : // Sort by version deltas then by previous ordering in the mapping, - // attempting to keep unchanging generators near the front - // to improve our chances of avoiding a re-generation. - [g.version - lastGen.version, lastGen.base, g]; - }) - .sort((a, b) => { - // Sort by first then second element of the tuple - if (a[0] !== b[0]) { - return a[0] - b[0]; - } - return a[1] - b[1]; - }); - const genOrder = sorted.map(([_version, _seq, g]) => g); - - // Compute sequential layout of generated splats - const splatCounts = genOrder.map((g) => g.numSplats); - const { maxSplats, mapping } = - accumulator.splats.generateMapping(splatCounts); - const newGenerators = genOrder.map((node, gIndex) => { - const { base, count } = mapping[gIndex]; - return { - node, - generator: node.generator, - version: node.version, - base, - count, - }; - }); - - // Compute worldToAccumulator origin transform (no scale) - originToWorld - .clone() - .invert() - .decompose( - this.translateToAccumulator.value, - this.rotateToAccumulator.value, - new THREE.Vector3(), - ); - - // Generate the Gsplats according to the mapping that need updating - accumulator.ensureGenerate(maxSplats); - accumulator.splats.splatEncoding = { ...this.splatEncoding }; - const generated = accumulator.generateSplats({ - renderer: this.renderer, - modifier: this.modifier, - generators: newGenerators, - forceUpdate: originChanged, - originToWorld, - }); - - // Update splat version number - accumulator.splatsVersion = this.active.splatsVersion + 1; - // Increment the mapping version if the mapping isn't identical to before - const hasCorrespondence = accumulator.hasCorrespondence(this.active); - accumulator.mappingVersion = - this.active.mappingVersion + (hasCorrespondence ? 0 : 1); - - // Release the old accumulator and make the new one active - this.releaseAccumulator(this.active); - this.active = accumulator; - this.prepareViewpoint(); - } - - // Let the system breath before potentially triggering sorts - setTimeout(() => { - // Notify all auto-updating viewpoints that we updated the Gsplats - for (const view of this.autoViewpoints) { - view.autoPoll({ accumulator: accumulator ?? undefined }); - } - }, 1); - - // Signal update was performed - return true; - } - - private compileScene(scene: THREE.Scene): { - generators: SplatGenerator[]; - visibleGenerators: SplatGenerator[]; - globalEdits: SplatEdit[]; - } { - // Take a snapshot of the SplatGenerators and SplatEdits in the scene - // to be used to run an update. - const generators: SplatGenerator[] = []; - // Collect all SplatGenerators, even if not visible, because we want to - // be able to call their update functions every frame. - scene.traverse((node) => { - if (node instanceof SplatGenerator) { - generators.push(node); - } - }); - - const visibleGenerators: SplatGenerator[] = []; - scene.traverseVisible((node) => { - if (node instanceof SplatGenerator) { - visibleGenerators.push(node); - } - }); - - const globalEdits = new Set(); - scene.traverseVisible((node) => { - if (node instanceof SplatEdit) { - let ancestor = node.parent; - while (ancestor != null && !(ancestor instanceof SplatMesh)) { - ancestor = ancestor.parent; - } - if (ancestor == null) { - // Not part of a SplatMesh so it's a global edit - globalEdits.add(node); - } - } - }); - return { - generators, - visibleGenerators, - globalEdits: Array.from(globalEdits), - }; - } - - // Renders out the scene to an environment map that can be used for - // Image-based lighting or similar applications. First optionally updates Gsplats, - // sorts them with respect to the provided worldCenter, renders 6 cube faces, - // then pre-filters them using THREE.PMREMGenerator and returns a THREE.Texture - // that can assigned directly to a THREE.MeshStandardMaterial.envMap property. - async renderEnvMap({ - renderer, - scene, - worldCenter, - size = 256, - near = 0.1, - far = 1000, - hideObjects = [], - update = false, - }: { - renderer?: THREE.WebGLRenderer; - scene: THREE.Scene; - worldCenter: THREE.Vector3; - size?: number; - near?: number; - far?: number; - hideObjects?: THREE.Object3D[]; - update?: boolean; - }): Promise { - if (!this.envViewpoint) { - this.envViewpoint = this.newViewpoint({ sort360: true }); - } - if ( - !OldSparkRenderer.cubeRender || - OldSparkRenderer.cubeRender.target.width !== size || - OldSparkRenderer.cubeRender.near !== near || - OldSparkRenderer.cubeRender.far !== far - ) { - if (OldSparkRenderer.cubeRender) { - OldSparkRenderer.cubeRender.target.dispose(); - } - const target = new THREE.WebGLCubeRenderTarget(size, { - format: THREE.RGBAFormat, - generateMipmaps: true, - minFilter: THREE.LinearMipMapLinearFilter, - }); - const camera = new THREE.CubeCamera(near, far, target); - OldSparkRenderer.cubeRender = { target, camera, near, far }; - } - - if (!OldSparkRenderer.pmrem) { - OldSparkRenderer.pmrem = new THREE.PMREMGenerator( - renderer ?? this.renderer, - ); - } - - // Prepare the viewpoint, sorting Gsplats for this view origin. - const viewToWorld = new THREE.Matrix4().setPosition(worldCenter); - await this.envViewpoint?.prepare({ scene, viewToWorld, update }); - - const { target, camera } = OldSparkRenderer.cubeRender; - camera.position.copy(worldCenter); - - // Save the visibility state of objects we want to hide before render - const objectVisibility = new Map(); - for (const object of hideObjects) { - objectVisibility.set(object, object.visible); - object.visible = false; - } - - // Update the CubeCamera, which performs 6 cube face renders - this.prepareViewpoint(this.envViewpoint); - camera.update(renderer ?? this.renderer, scene); - - // Restore viewpoint to default and object visibility - this.prepareViewpoint(this.defaultView); - for (const [object, visible] of objectVisibility.entries()) { - object.visible = visible; - } - - // Pre-filter the cube map using THREE.PMREMGenerator - return OldSparkRenderer.pmrem?.fromCubemap(target.texture).texture; - } - - // Utility function to recursively set the envMap property for any - // THREE.MeshStandardMaterial within the subtree of root. - recurseSetEnvMap(root: THREE.Object3D, envMap: THREE.Texture) { - root.traverse((node) => { - if (node instanceof THREE.Mesh) { - if (Array.isArray(node.material)) { - for (const material of node.material) { - if (material instanceof THREE.MeshStandardMaterial) { - material.envMap = envMap; - } - } - } else { - if (node.material instanceof THREE.MeshStandardMaterial) { - node.material.envMap = envMap; - } - } - } - }); - } - - // Utility function that helps extract the Gsplat RGBA values from a - // SplatGenerator, including the result of any real-time RGBA SDF edits applied - // to a SplatMesh. This effectively "bakes" any computed RGBA values, which can - // now be used as a pipeline input via SplatMesh.splatRgba to inject these - // baked values into the Gsplat data. - getRgba({ - generator, - rgba, - }: { generator: SplatGenerator; rgba?: RgbaArray }): RgbaArray { - const mapping = this.active.mapping.find(({ node }) => node === generator); - if (!mapping) { - throw new Error("Generator not found"); - } - - rgba = rgba ?? new RgbaArray(); - rgba.fromPackedSplats({ - packedSplats: this.active.splats, - base: mapping.base, - count: mapping.count, - renderer: this.renderer, - }); - return rgba; - } - - // Utility function that builds on getRgba({ generator }) and additionally - // reads back the RGBA values to the CPU in a Uint8Array with packed RGBA - // in that byte order. - async readRgba({ - generator, - rgba, - }: { generator: SplatGenerator; rgba?: RgbaArray }): Promise { - rgba = this.getRgba({ generator, rgba }); - return rgba.read(); - } -} - -const EMPTY_GEOMETRY = new OldSplatGeometry(new Uint32Array(1), 0); - -const reorderSplats = dynoBlock( - { packedSplats: TPackedSplats, index: "int" }, - { gsplat: Gsplat }, - ({ packedSplats, index }) => { - if (!packedSplats || !index) { - throw new Error("Invalid input"); - } - const gsplat = readPackedSplat(packedSplats, index); - return { gsplat }; - }, -); - -function averageOriginToWorlds( - originToWorlds: THREE.Matrix4[], -): THREE.Matrix4 | null { - if (originToWorlds.length === 0) { - return null; - } - - const position = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(); - const scale = new THREE.Vector3(); - - const positions: THREE.Vector3[] = []; - const quaternions: THREE.Quaternion[] = []; - for (const matrix of originToWorlds) { - matrix.decompose(position, quaternion, scale); - positions.push(position); - quaternions.push(quaternion); - } - - return new THREE.Matrix4().compose( - averagePositions(positions), - averageQuaternions(quaternions), - new THREE.Vector3(1, 1, 1), - ); -} diff --git a/src/OldSparkViewpoint.ts b/src/OldSparkViewpoint.ts deleted file mode 100644 index 32c9db2a..00000000 --- a/src/OldSparkViewpoint.ts +++ /dev/null @@ -1,879 +0,0 @@ -import * as THREE from "three"; - -import type { OldSparkRenderer } from "./OldSparkRenderer"; -import type { OldSplatAccumulator } from "./OldSplatAccumulator"; -import { OldSplatGeometry } from "./OldSplatGeometry"; -import { withWorker } from "./OldSplatWorker"; -import { DynoPackedSplats } from "./PackedSplats"; -import { Readback } from "./Readback"; -import { - type DynoBlock, - DynoBool, - DynoFloat, - type DynoVal, - DynoVec3, - Gsplat, - add, - combine, - defineGsplat, - dyno, - dynoBlock, - dynoConst, - floatBitsToUint, - mul, - packHalf2x16, - readPackedSplat, - uintToRgba8, - unindent, - unindentLines, -} from "./dyno"; -import { FreeList, withinCoorientDist } from "./utils"; - -export type OldSparkViewpointOptions = { - /** - * Controls whether to auto-update its sort order whenever the SparkRenderer - * updates the Gsplats. If you expect to render/display from this viewpoint - * most frames, set this to true. - * @default false - */ - autoUpdate?: boolean; - /** - * Set a THREE.Camera for this viewpoint to follow. - * @default undefined - */ - camera?: THREE.Camera; - /** - * Set an explicit view-to-world transformation matrix for this viewpoint (equivalent - * to camera.matrixWorld), overrides any camera setting. - * @default undefined - */ - viewToWorld?: THREE.Matrix4; - /** - * Configure viewpoint with an off-screen render target. - * @default undefined - */ - target?: { - /** - * Width of the render target in pixels. - */ - width: number; - /** - * Height of the render target in pixels. - */ - height: number; - /** - * If you want to be able to render a scene that depends on this target's - * output (for example, a recursive viewport), set this to true to enable - * double buffering. - * @default false - */ - doubleBuffer?: boolean; - /** - * Super-sampling factor for the render target. Values 1-4 are supported. - * Note that re-sampling back down to .width x .height is done on the CPU - * with simple averaging only when calling readTarget(). - * @default 1 - */ - superXY?: number; - }; - /** - * Callback function that is called when the render target texture is updated. - * Receives the texture as a parameter. Use this to update a viewport with - * the latest viewpoint render each frame. - * @default undefined - */ - onTextureUpdated?: (texture: THREE.Texture) => void; - /** - * Whether to sort splats radially (geometric distance) from the viewpoint (true) - * or by Z-depth (false). Most scenes are trained with the Z-depth sort metric - * and will render more accurately at certain viewpoints. However, radial sorting - * is more stable under viewpoint rotations. - * @default true - */ - sortRadial?: boolean; - /** - * Distance threshold for re-sorting splats. If the viewpoint moves more than - * this distance, splats will be re-sorted. - * @default 0.01 units - */ - sortDistance?: number; - /** - * View direction dot product threshold for re-sorting splats. For - * sortRadial: true we use 0.99 while sortRadial: false uses 0.999 because it is - * more sensitive to view direction. - * @default 0.99 if sortRadial else 0.999 - */ - sortCoorient?: boolean; - /** - * Constant added to Z-depth to bias values into the positive range for - * sortRadial: false, but also used for culling Gsplats "well behind" - * the viewpoint origin - * @default 1.0 - */ - depthBias?: number; - /** - * Set this to true if rendering a 360 to disable "behind the viewpoint" - * culling during sorting. This is set automatically when rendering 360 envMaps - * using the SparkRenderer.renderEnvMap() utility function. - * @default false - */ - sort360?: boolean; - /* - * Set this to true to sort with float32 precision with two-pass sort. - * @default true - */ - sort32?: boolean; - /* - * Set this to true to enable sort-free stochastic splat rendering. - * @default false - */ - stochastic?: boolean; -}; - -// A SparkViewpoint is created from and tied to a SparkRenderer, and represents -// an independent viewpoint of all the scene Gsplats and their sort order. Making -// these viewpoints explicit allows us to have multiple, simultaneous viewpoint -// renders, for example for camera preview panes or overhead map views. -// -// When creating a SparkRenderer it automatically creates a default viewpoint -// .defaultView that is used in the normal render loop when drawing to the canvas, -// and is automatically updated whenever the camera moves. Additional viewpoints -// can be created and configured separately. - -export class OldSparkViewpoint { - spark: OldSparkRenderer; - autoUpdate: boolean; - camera?: THREE.Camera; - viewToWorld: THREE.Matrix4; - lastTime: number | null = null; - - target?: THREE.WebGLRenderTarget; - private back?: THREE.WebGLRenderTarget; - onTextureUpdated?: (texture: THREE.Texture) => void; - encodeLinear = false; - superXY = 1; - private superPixels?: Uint8Array; - private pixels?: Uint8Array; - - sortRadial: boolean; - sortDistance?: number; - sortCoorient?: boolean; - depthBias?: number; - sort360?: boolean; - sort32?: boolean; - stochastic: boolean; - - display: { - accumulator: OldSplatAccumulator; - viewToWorld: THREE.Matrix4; - geometry: OldSplatGeometry; - } | null = null; - - private sorting: { viewToWorld: THREE.Matrix4 } | null = null; - private pending: { - accumulator?: OldSplatAccumulator; - viewToWorld: THREE.Matrix4; - displayed: boolean; - } | null = null; - private sortingCheck = false; - - private readback16: Uint16Array = new Uint16Array(0); - private readback32: Uint32Array = new Uint32Array(0); - private orderingFreelist: FreeList; - - constructor(options: OldSparkViewpointOptions & { spark: OldSparkRenderer }) { - this.spark = options.spark; - this.camera = options.camera; - this.viewToWorld = options.viewToWorld ?? new THREE.Matrix4(); - - if (options.target) { - const { width, height, doubleBuffer } = options.target; - const superXY = Math.max(1, Math.min(4, options.target.superXY ?? 1)); - this.superXY = superXY; - if (width * superXY > 8192 || height * superXY > 8192) { - throw new Error("Target size too large"); - } - - this.target = new THREE.WebGLRenderTarget( - width * superXY, - height * superXY, - { - format: THREE.RGBAFormat, - type: THREE.UnsignedByteType, - colorSpace: THREE.SRGBColorSpace, - }, - ); - if (doubleBuffer) { - this.back = new THREE.WebGLRenderTarget( - width * superXY, - height * superXY, - { - format: THREE.RGBAFormat, - type: THREE.UnsignedByteType, - colorSpace: THREE.SRGBColorSpace, - }, - ); - } - this.encodeLinear = true; - } - this.onTextureUpdated = options.onTextureUpdated; - - this.sortRadial = options.sortRadial ?? true; - this.sortDistance = options.sortDistance; - this.sortCoorient = options.sortCoorient; - this.depthBias = options.depthBias; - this.sort360 = options.sort360; - this.sort32 = options.sort32; - this.stochastic = options.stochastic ?? false; - - this.orderingFreelist = new FreeList({ - allocate: (maxSplats) => - new Uint32Array(maxSplats) as Uint32Array, - valid: (ordering, maxSplats) => ordering.length === maxSplats, - }); - - this.autoUpdate = false; - this.setAutoUpdate(options.autoUpdate ?? false); - } - - // Call this when you are done with the SparkViewpoint and want to - // free up its resources (GPU targets, pixel buffers, etc.) - dispose() { - this.setAutoUpdate(false); - if (this.target) { - this.target.dispose(); - this.target = undefined; - } - if (this.back) { - this.back.dispose(); - this.back = undefined; - } - if (this.display) { - this.spark.releaseAccumulator(this.display.accumulator); - this.display.geometry.dispose(); - this.display = null; - } - if (this.pending?.accumulator) { - this.spark.releaseAccumulator(this.pending.accumulator); - this.pending = null; - } - } - - // Use this function to change whether this viewpoint will auto-update - // its sort order whenever the attached SparkRenderer updates the Gsplats. - // Turn this on or off depending on whether you expect to do renders from - // this viewpoint most frames. - setAutoUpdate(autoUpdate: boolean) { - if (!this.autoUpdate && autoUpdate) { - this.spark.autoViewpoints.push(this); - } else if (this.autoUpdate && !autoUpdate) { - this.spark.autoViewpoints = this.spark.autoViewpoints.filter( - (v) => v !== this, - ); - } - this.autoUpdate = autoUpdate; - } - - // See below async prepareRenderPixels() for explanation of parameters. - // Awaiting this method updates the Gsplats in the scene and performs a sort of the - // Gsplats from this viewpoint, preparing it for a subsequent this.renderTarget() - // call in the same tick. - async prepare({ - scene, - camera, - viewToWorld, - update, - forceOrigin, - }: { - scene: THREE.Scene; - camera?: THREE.Camera; - viewToWorld?: THREE.Matrix4; - update?: boolean; - forceOrigin?: boolean; - }) { - if (viewToWorld) { - this.viewToWorld = viewToWorld; - } else { - this.camera = camera ?? this.camera; - if (this.camera) { - this.camera.updateMatrixWorld(); - this.viewToWorld = this.camera.matrixWorld.clone(); - } - } - while (update ?? true) { - // Force an update, possibly with origin centered at this camera - // to yield the best quality output. - const originToWorld = forceOrigin - ? this.viewToWorld - : this.spark.matrixWorld; - const updated = this.spark.updateInternal({ scene, originToWorld }); - if (updated) { - break; - } - // A bit of a hack, but try again. We shouldn't be starved for long. - await new Promise((resolve) => setTimeout(resolve, 10)); - } - - const accumulator = this.spark.active; - // Hold reference to accumulator while sorting - accumulator.refCount += 1; - await this.sortUpdate({ accumulator, viewToWorld: this.viewToWorld }); - // Release accumulator reference - this.spark.releaseAccumulator(accumulator); - } - - // Render out the viewpoint to the view target RGBA buffer. - // Swaps buffers if doubleBuffer: true was set. - // Calls onTextureUpdated(texture) with the resulting texture. - renderTarget({ - scene, - camera, - }: { scene: THREE.Scene; camera?: THREE.Camera }) { - const target = this.back ?? this.target; - if (!target) { - throw new Error("Must initialize SparkViewpoint with target"); - } - - camera = camera ?? this.camera; - if (!camera) { - throw new Error("Must provide camera"); - } - if (camera instanceof THREE.PerspectiveCamera) { - const newCam = new THREE.PerspectiveCamera().copy(camera, false); - newCam.aspect = target.width / target.height; - newCam.updateProjectionMatrix(); - camera = newCam; - } - this.viewToWorld = camera.matrixWorld.clone(); - - const previousTarget = this.spark.renderer.getRenderTarget(); - try { - this.spark.renderer.setRenderTarget(target); - this.spark.prepareViewpoint(this); - - this.spark.renderer.render(scene, camera); - } finally { - this.spark.prepareViewpoint(this.spark.defaultView); - this.spark.renderer.setRenderTarget(previousTarget); - } - - if (target !== this.target) { - // Swap back buffer and target - [this.target, this.back] = [this.back, this.target]; - } - this.onTextureUpdated?.(target.texture); - } - - // Read back the previously rendered target image as a Uint8Array of packed - // RGBA values (in that order). If superXY was set greater than 1 then - // downsampling is performed in the target pixel array with simple averaging - // to derive the returned pixel values. Subsequent calls to this.readTarget() - // will reuse the same buffers to minimize memory allocations. - async readTarget(): Promise { - if (!this.target) { - throw new Error("Must initialize SparkViewpoint with target"); - } - const { width, height } = this.target; - const byteSize = width * height * 4; - if (!this.superPixels || this.superPixels.length < byteSize) { - this.superPixels = new Uint8Array(byteSize); - } - await this.spark.renderer.readRenderTargetPixelsAsync( - this.target, - 0, - 0, - width, - height, - this.superPixels, - ); - - const { superXY } = this; - if (superXY === 1) { - return this.superPixels; - } - - const subWidth = width / superXY; - const subHeight = height / superXY; - const subSize = subWidth * subHeight * 4; - if (!this.pixels || this.pixels.length < subSize) { - this.pixels = new Uint8Array(subSize); - } - - const { superPixels, pixels } = this; - const super2 = superXY * superXY; - for (let y = 0; y < subHeight; y++) { - const row = y * subWidth; - for (let x = 0; x < subWidth; x++) { - const superCol = x * superXY; - let r = 0; - let g = 0; - let b = 0; - let a = 0; - for (let sy = 0; sy < superXY; sy++) { - const superRow = (y * superXY + sy) * this.target.width; - for (let sx = 0; sx < superXY; sx++) { - const superIndex = (superRow + superCol + sx) * 4; - r += superPixels[superIndex]; - g += superPixels[superIndex + 1]; - b += superPixels[superIndex + 2]; - a += superPixels[superIndex + 3]; - } - } - const pixelIndex = (row + x) * 4; - pixels[pixelIndex] = r / super2; - pixels[pixelIndex + 1] = g / super2; - pixels[pixelIndex + 2] = b / super2; - pixels[pixelIndex + 3] = a / super2; - } - } - return pixels; - } - - // Render out a viewpoint as a Uint8Array of RGBA values for the provided scene - // and any camera/viewToWorld viewpoint overrides. By default update is true, - // which triggers its SparkRenderer to check and potentially update the Gsplats. - // Setting update to false disables this and sorts the Gsplats as they are. - // Setting forceOrigin (default: false) to true forces the view update to - // recalculate the splats with this view origin, potentially altering any - // view-dependent effects. If you expect view-dependent effects to play a role - // in the rendering quality, enable this. - // - // Underneath, prepareRenderPixels() simply calls await this.prepare(...), - // this.renderTarget(...), and finally returns the result this.readTarget(), - // a Promise to a Uint8Array with RGBA values for all the pixels (potentially - // downsampled if the superXY parameter was used). These steps can also be called - // manually, for example if you need to alter the scene before and after - // this.renderTarget(...) to hide UI elements from being rendered. - async prepareRenderPixels({ - scene, - camera, - viewToWorld, - update, - forceOrigin, - }: { - scene: THREE.Scene; - camera?: THREE.Camera; - viewToWorld?: THREE.Matrix4; - update?: boolean; - forceOrigin?: boolean; - }) { - await this.prepare({ scene, camera, viewToWorld, update, forceOrigin }); - this.renderTarget({ scene, camera }); - return this.readTarget(); - } - - // This is called automatically by SparkRenderer, there is no need to call it! - // The method cannot be private because then SparkRenderer would - // not be able to call it. - autoPoll({ accumulator }: { accumulator?: OldSplatAccumulator }) { - if (this.camera) { - this.camera.updateMatrixWorld(); - this.viewToWorld = this.camera.matrixWorld.clone(); - } - - let needsSort = false; - let displayed = false; - - if (!this.display) { - // Need to do first sort - needsSort = true; - } else if (accumulator) { - needsSort = true; - const { mappingVersion } = this.display.accumulator; - if (accumulator.mappingVersion === mappingVersion) { - // Splat mapping has not changed, so reuse the existing sorted - // geometry to show updates faster. We will still fire off - // a re-sort if necessary. First release old accumulator. - accumulator.refCount += 1; - this.spark.releaseAccumulator(this.display.accumulator); - this.display.accumulator = accumulator; - this.display.viewToWorld.copy(this.viewToWorld); - displayed = true; - - if (this.spark.viewpoint === this) { - this.spark.prepareViewpoint(this); - } - } - } - - const latestView = this.sorting?.viewToWorld ?? this.display?.viewToWorld; - if ( - latestView && - !withinCoorientDist({ - matrix1: this.viewToWorld, - matrix2: latestView, - // By default update sort each 1 cm - maxDistance: this.sortDistance ?? 0.01, - // By default for radial sort, update for intermittent movement so that - // we bring back splats culled by being behind the camera. - // For depth sort, small rotations can change sort order a lot, so - // update sort for even small rotations. - minCoorient: (this.sortCoorient ?? this.sortRadial) ? 0.99 : 0.999, - }) - ) { - needsSort = true; - } - - if (!needsSort) { - // Stop here, no sort necessary - return; - } - - if (accumulator) { - // Hold a reference to the accumulator for sorting - accumulator.refCount += 1; - } - - if (this.pending?.accumulator) { - this.spark.releaseAccumulator(this.pending.accumulator); - } - this.pending = { accumulator, viewToWorld: this.viewToWorld, displayed }; - - // Don't await this, just trigger the sort if necessary - this.driveSort(); - } - - private async driveSort() { - while (true) { - if (this.sorting || !this.pending) { - return; // Sort already in process or nothing to sort - } - - const { viewToWorld, displayed } = this.pending; - let accumulator = this.pending.accumulator; - if (!accumulator) { - // Hold a reference to the accumulator while sorting - accumulator = this.display?.accumulator ?? this.spark.active; - accumulator.refCount += 1; - } - this.pending = null; - if (!accumulator) { - throw new Error("No accumulator to sort"); - } - - this.sorting = { viewToWorld }; - await this.sortUpdate({ accumulator, viewToWorld, displayed }); - this.sorting = null; - - // Release the reference to the accumulator - this.spark.releaseAccumulator(accumulator); - - // Continue in loop with any queued sort - } - } - - private async sortUpdate({ - accumulator, - viewToWorld, - displayed = false, - }: { - accumulator?: OldSplatAccumulator; - viewToWorld: THREE.Matrix4; - displayed?: boolean; - }) { - if (this.sortingCheck) { - throw new Error("Only one sort at a time"); - } - this.sortingCheck = true; - - accumulator = accumulator ?? this.spark.active; - const { numSplats, maxSplats } = accumulator.splats; - let activeSplats = 0; - let ordering = this.orderingFreelist.alloc(maxSplats); - - if (this.stochastic) { - activeSplats = numSplats; - // Render all splats in order since the Z-buffer - // will handle ordering. - for (let i = 0; i < numSplats; ++i) { - ordering[i] = i; - } - } else if (numSplats > 0) { - const { - reader, - doubleSortReader, - sort32Reader, - dynoSortRadial, - dynoOrigin, - dynoDirection, - dynoDepthBias, - dynoSort360, - dynoSplats, - } = OldSparkViewpoint.makeSorter(); - const sort32 = this.sort32 ?? false; - let readback: Uint16Array | Uint32Array; - if (sort32) { - this.readback32 = reader.ensureBuffer(maxSplats, this.readback32); - readback = this.readback32; - } else { - const halfMaxSplats = Math.ceil(maxSplats / 2); - this.readback16 = reader.ensureBuffer(halfMaxSplats, this.readback16); - readback = this.readback16; - } - - const worldToOrigin = accumulator.toWorld.clone().invert(); - const viewToOrigin = viewToWorld.clone().premultiply(worldToOrigin); - - dynoSortRadial.value = this.sort360 ? true : this.sortRadial; - dynoOrigin.value.set(0, 0, 0).applyMatrix4(viewToOrigin); - dynoDirection.value - .set(0, 0, -1) - .applyMatrix4(viewToOrigin) - .sub(dynoOrigin.value) - .normalize(); - dynoDepthBias.value = this.depthBias ?? 1.0; - dynoSort360.value = this.sort360 ?? false; - dynoSplats.packedSplats = accumulator.splats; - - const sortReader = sort32 ? sort32Reader : doubleSortReader; - const count = sort32 ? numSplats : Math.ceil(numSplats / 2); - await reader.renderReadback({ - renderer: this.spark.renderer, - reader: sortReader, - count, - readback, - }); - - const result = (await withWorker(async (worker) => { - const rpcName = sort32 ? "sort32Splats" : "sortDoubleSplats"; - return worker.call(rpcName, { - maxSplats, - numSplats, - readback, - ordering, - }); - })) as { - readback: Uint16Array | Uint32Array; - ordering: Uint32Array; - activeSplats: number; - }; - if (sort32) { - this.readback32 = result.readback as Uint32Array; - } else { - this.readback16 = result.readback as Uint16Array; - } - ordering = result.ordering; - activeSplats = result.activeSplats; - } - - this.updateDisplay({ - accumulator, - viewToWorld, - ordering, - activeSplats, - displayed, - }); - this.sortingCheck = false; - } - - private updateDisplay({ - accumulator, - viewToWorld, - ordering, - activeSplats, - displayed = false, - }: { - accumulator: OldSplatAccumulator; - viewToWorld: THREE.Matrix4; - ordering: Uint32Array; - activeSplats: number; - displayed?: boolean; - }) { - if (!this.display) { - // Hold a reference to the accumulator while part of display - accumulator.refCount += 1; - this.display = { - accumulator, - viewToWorld, - geometry: new OldSplatGeometry(ordering, activeSplats), - }; - } else { - if (!displayed && accumulator !== this.display.accumulator) { - // Hold a reference to the new accumulator being displayed - accumulator.refCount += 1; - // Release the reference to the previously displayed accumulator - this.spark.releaseAccumulator(this.display.accumulator); - this.display.accumulator = accumulator; - } - - this.display.viewToWorld = viewToWorld; - - const oldOrdering = this.display.geometry.ordering; - if (oldOrdering.length === ordering.length) { - this.display.geometry.update(ordering, activeSplats); - } else { - this.display.geometry.dispose(); - // console.log("*** alloc SplatGeometry", ordering.length); - this.display.geometry = new OldSplatGeometry(ordering, activeSplats); - } - this.orderingFreelist.free(oldOrdering); - } - if (this.spark.viewpoint === this) { - this.spark.prepareViewpoint(this); - } - } - - // If you need an empty THREE.Texture to use to initialize a uniform that is - // updated via onTextureUpdated(texture), this static texture can be handy. - static EMPTY_TEXTURE = new THREE.Texture(); - - private static dynos: { - dynoSortRadial: DynoBool; - dynoOrigin: DynoVec3; - dynoDirection: DynoVec3; - dynoDepthBias: DynoFloat; - dynoSort360: DynoBool; - dynoSplats: DynoPackedSplats; - reader: Readback; - doubleSortReader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; - sort32Reader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; - } | null = null; - - private static makeSorter() { - if (!OldSparkViewpoint.dynos) { - const dynoSortRadial = new DynoBool({ value: true }); - const dynoOrigin = new DynoVec3({ value: new THREE.Vector3() }); - const dynoDirection = new DynoVec3({ value: new THREE.Vector3() }); - const dynoDepthBias = new DynoFloat({ value: 1.0 }); - const dynoSort360 = new DynoBool({ value: false }); - const dynoSplats = new DynoPackedSplats(); - - const reader = new Readback(); - const doubleSortReader = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - if (!index) { - throw new Error("No index"); - } - const sortParams = { - sortRadial: dynoSortRadial, - sortOrigin: dynoOrigin, - sortDirection: dynoDirection, - sortDepthBias: dynoDepthBias, - sort360: dynoSort360, - }; - const index2 = mul(index, dynoConst("int", 2)); - - const gsplat0 = readPackedSplat(dynoSplats, index2); - const metric0 = computeSortMetric({ gsplat: gsplat0, ...sortParams }); - - const gsplat1 = readPackedSplat( - dynoSplats, - add(index2, dynoConst("int", 1)), - ); - const metric1 = computeSortMetric({ gsplat: gsplat1, ...sortParams }); - - const combined = combine({ - vectorType: "vec2", - x: metric0, - y: metric1, - }); - const rgba8 = uintToRgba8(packHalf2x16(combined)); - return { rgba8 }; - }, - ); - - const sort32Reader = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - if (!index) { - throw new Error("No index"); - } - const sortParams = { - sortRadial: dynoSortRadial, - sortOrigin: dynoOrigin, - sortDirection: dynoDirection, - sortDepthBias: dynoDepthBias, - sort360: dynoSort360, - }; - - const gsplat = readPackedSplat(dynoSplats, index); - const metric = computeSortMetric({ gsplat, ...sortParams }); - const rgba8 = uintToRgba8(floatBitsToUint(metric)); - return { rgba8 }; - }, - ); - - OldSparkViewpoint.dynos = { - dynoSortRadial, - dynoOrigin, - dynoDirection, - dynoDepthBias, - dynoSort360, - dynoSplats, - reader, - doubleSortReader, - sort32Reader, - }; - } - return OldSparkViewpoint.dynos; - } -} - -const defineComputeSortMetric = unindent(` - float computeSort(Gsplat gsplat, bool sortRadial, vec3 sortOrigin, vec3 sortDirection, float sortDepthBias, bool sort360) { - if (!isGsplatActive(gsplat.flags)) { - return INFINITY; - } - - vec3 center = gsplat.center - sortOrigin; - float biasedDepth = dot(center, sortDirection) + sortDepthBias; - if (!sort360 && (biasedDepth <= 0.0)) { - return INFINITY; - } - - return sortRadial ? length(center) : biasedDepth; - } -`); - -function computeSortMetric({ - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, -}: { - gsplat: DynoVal; - sortRadial: DynoVal<"bool">; - sortOrigin: DynoVal<"vec3">; - sortDirection: DynoVal<"vec3">; - sortDepthBias: DynoVal<"float">; - sort360: DynoVal<"bool">; -}) { - return dyno({ - inTypes: { - gsplat: Gsplat, - sortRadial: "bool", - sortOrigin: "vec3", - sortDirection: "vec3", - sortDepthBias: "float", - sort360: "bool", - }, - outTypes: { metric: "float" }, - globals: () => [defineGsplat, defineComputeSortMetric], - inputs: { - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, - }, - statements: ({ inputs, outputs }) => { - const { - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, - } = inputs; - return unindentLines(` - ${outputs.metric} = computeSort(${gsplat}, ${sortRadial}, ${sortOrigin}, ${sortDirection}, ${sortDepthBias}, ${sort360}); - `); - }, - }).outputs.metric; -} diff --git a/src/OldSplatAccumulator.ts b/src/OldSplatAccumulator.ts deleted file mode 100644 index e2baa091..00000000 --- a/src/OldSplatAccumulator.ts +++ /dev/null @@ -1,107 +0,0 @@ -import * as THREE from "three"; - -import { PackedSplats } from "./PackedSplats"; -import type { GeneratorMapping } from "./SplatAccumulator"; -import type { SplatGenerator, SplatModifier } from "./SplatGenerator"; - -// SplatAccumulator helps manage the generation of splats from multiple -// SplatGenerators, keeping track of the splat mapping, coordinate system, -// and reference count. - -export class OldSplatAccumulator { - splats = new PackedSplats(); - // The transform from Accumulator coordinate system to world coordinates. - toWorld = new THREE.Matrix4(); - // An array of all Gsplat mappings that were used for generation - mapping: GeneratorMapping[] = []; - // Number of SparkViewpoints (or other) that reference this accumulator, used - // to figure out when it can be recycled for use - refCount = 0; - - // Incremented every time the splats are updated/generated. - splatsVersion = -1; - // Incremented every time the splat mapping/layout is updated. - // Splat sort order can be reused between equivalent mapping versions. - mappingVersion = -1; - - ensureGenerate(maxSplats: number) { - if (this.splats.ensureGenerate(maxSplats)) { - // If we had to resize our PackedSplats then clear all previous mappings - this.mapping = []; - } - } - - // Generate all Gsplats from an array of generators - generateSplats({ - renderer, - modifier, - generators, - forceUpdate, - originToWorld, - }: { - renderer: THREE.WebGLRenderer; - modifier: SplatModifier; - generators: GeneratorMapping[]; - forceUpdate?: boolean; - originToWorld: THREE.Matrix4; - }) { - // Create a lookup from last SplatGenerator - const mapping = this.mapping.reduce((map, record) => { - map.set(record.node, record); - return map; - }, new Map()); - - // Run generators that are different from existing mapping - let updated = 0; - let numSplats = 0; - for (const { node, generator, version, base, count } of generators) { - const current = mapping.get(node); - if ( - forceUpdate || - generator !== current?.generator || - version !== current?.version || - base !== current?.base || - count !== current?.count - ) { - // Something is different from before so we should generate these Gsplats - if (generator && count > 0) { - const modGenerator = modifier.apply(generator); - try { - this.splats.generate({ - generator: modGenerator, - base, - count, - renderer, - }); - } catch (error) { - node.generator = undefined; - node.generatorError = error; - } - updated += 1; - } - } - numSplats = Math.max(numSplats, base + count); - } - - this.splats.numSplats = numSplats; - this.toWorld.copy(originToWorld); - this.mapping = generators; - return updated !== 0; - } - - // Check if this accumulator has exactly the same generator mapping as - // the previous one. If so, we can reuse the Gsplat sort order. - hasCorrespondence(other: OldSplatAccumulator) { - if (this.mapping.length !== other.mapping.length) { - return false; - } - return this.mapping.every(({ node, base, count }, i) => { - const { - node: otherNode, - base: otherBase, - count: otherCount, - } = other.mapping[i]; - return node === otherNode && base === otherBase && count === otherCount; - }); - } -} diff --git a/src/OldSplatGeometry.ts b/src/OldSplatGeometry.ts deleted file mode 100644 index 23ba6f92..00000000 --- a/src/OldSplatGeometry.ts +++ /dev/null @@ -1,44 +0,0 @@ -import * as THREE from "three"; - -// SplatGeometry is an internal class used by SparkRenderer to render a collection -// of Gsplats in a single draw call by extending THREE.InstancedBufferGeometry. -// Each Gsplat is drawn as two triangles, with the order of the Gsplats determined -// by the instance attribute "ordering". - -export class OldSplatGeometry extends THREE.InstancedBufferGeometry { - ordering: Uint32Array; - attribute: THREE.InstancedBufferAttribute; - - constructor(ordering: Uint32Array, activeSplats: number) { - super(); - - this.ordering = ordering; - - this.setAttribute("position", new THREE.BufferAttribute(QUAD_VERTICES, 3)); - this.setIndex(new THREE.BufferAttribute(QUAD_INDICES, 1)); - - // Hack to work around Three.js - // @ts-ignore - this._maxInstanceCount = ordering.length; - this.instanceCount = activeSplats; - - this.attribute = new THREE.InstancedBufferAttribute(ordering, 1, false, 1); - this.attribute.setUsage(THREE.DynamicDrawUsage); - this.setAttribute("splatIndex", this.attribute); - } - - update(ordering: Uint32Array, activeSplats: number) { - this.ordering = ordering; - this.attribute.array = ordering; - this.instanceCount = activeSplats; - this.attribute.addUpdateRange(0, activeSplats); - this.attribute.needsUpdate = true; - } -} - -// Each instance draws to triangles covering a quad over coords (-1,-1,0)..(1,1,0) -const QUAD_VERTICES = new Float32Array([ - -1, -1, 0, 1, -1, 0, 1, 1, 0, -1, 1, 0, -]); - -const QUAD_INDICES = new Uint16Array([0, 1, 2, 0, 2, 3]); diff --git a/src/OldSplatWorker.ts b/src/OldSplatWorker.ts deleted file mode 100644 index c164198d..00000000 --- a/src/OldSplatWorker.ts +++ /dev/null @@ -1,128 +0,0 @@ -import BundledWorker from "./oldWorker?worker&inline"; -import { getTransferable } from "./utils.js"; - -// SplatWorker is an internal class that manages a WebWorker for executing -// longer running CPU tasks such as Gsplat file decoding and sorting. -// Although a SplatWorker can be created and used directly, the utility -// function withWorker() is recommended to allocate from a managed -// pool of SplatWorkers. - -export class OldSplatWorker { - worker: Worker; - messages: Record< - number, - { resolve: (value: unknown) => void; reject: (reason?: unknown) => void } - > = {}; - messageIdNext = 0; - - constructor() { - // this.worker = new Worker(new URL("./worker", import.meta.url), { type: "module" }); - this.worker = new BundledWorker(); - this.worker.onmessage = (event) => this.onMessage(event); - } - - makeMessageId(): number { - return ++this.messageIdNext; - } - - makeMessagePromiseId(): { id: number; promise: Promise } { - const id = this.makeMessageId(); - const promise = new Promise((resolve, reject) => { - this.messages[id] = { resolve, reject }; - }); - return { id, promise }; - } - - onMessage(event: MessageEvent) { - // console.log("SplatWorker.onMessage:", event); - const { id, result, error } = event.data; - // console.log(`SplatWorker.onMessage(${id}):`, result, error); - const handler = this.messages[id]; - if (handler) { - delete this.messages[id]; - if (error) { - handler.reject(error); - } else { - handler.resolve(result); - } - } - } - - // Invoke an RPC on the worker with the given name and arguments. - // The normal usage of a worker is to run one activity at a time, - // but this function allows for concurrent calls, tagging each request - // with a unique message Id and awaiting a response to that same Id. - // The method will automatically transfer any ArrayBuffers in the - // arguments to the worker. If you'd like to transfer a copy of a - // buffer then you must clone it before passing to this function. - async call(name: string, args: unknown): Promise { - const { id, promise } = this.makeMessagePromiseId(); - // console.log(`SplatWorker.call(${name}):`, args); - this.worker.postMessage( - { name, args, id }, - { transfer: getTransferable(args) }, - ); - return promise; - } -} - -let maxWorkers = 4; - -let numWorkers = 0; -const freeWorkers: OldSplatWorker[] = []; -const workerQueue: ((worker: OldSplatWorker) => void)[] = []; - -// Set the maximum number of workers to allocate for the pool. (default: 4) -export function setWorkerPool(count = 4) { - maxWorkers = count; -} - -// Allocate a worker from the pool. If none are available and we are below the -// maximum, create a new one. Otherwise, add the request to a queue and wait -// for it to be fulfilled. -export async function allocWorker(): Promise { - const worker = freeWorkers.shift(); - if (worker) { - return worker; - } - - if (numWorkers < maxWorkers) { - const worker = new OldSplatWorker(); - numWorkers += 1; - return worker; - } - - return new Promise((resolve) => { - workerQueue.push(resolve); - }); -} - -// Return a worker to the pool. Pass the worker to any pending waiter. -export function freeWorker(worker: OldSplatWorker) { - if (numWorkers > maxWorkers) { - // Worker no longer needed - numWorkers -= 1; - return; - } - - const waiter = workerQueue.shift(); - if (waiter) { - waiter(worker); - return; - } - - freeWorkers.push(worker); -} - -// Allocate a worker from the pool and invoke the callback with the worker. -// When the callback completes, the worker will be returned to the pool. -export async function withWorker( - callback: (worker: OldSplatWorker) => Promise, -): Promise { - const worker = await allocWorker(); - try { - return await callback(worker); - } finally { - freeWorker(worker); - } -} diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index 5f2484f4..b5866099 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -1,7 +1,6 @@ import { unzipSync } from "fflate"; import { FileLoader, Loader, type LoadingManager } from "three"; import { ExtSplats, type ExtSplatsOptions } from "./ExtSplats"; -import { withWorker } from "./OldSplatWorker"; import { PackedSplats, type PackedSplatsOptions } from "./PackedSplats"; import { SplatMesh } from "./SplatMesh"; import { workerPool } from "./SplatWorker"; @@ -634,129 +633,6 @@ export function tryPcSogsZip( } } -export async function unpackSplats({ - input, - extraFiles, - fileType, - pathOrUrl, - splatEncoding, -}: { - input: Uint8Array | ArrayBuffer; - extraFiles?: Record; - fileType?: SplatFileType; - pathOrUrl?: string; - splatEncoding?: SplatEncoding; -}): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra?: Record; -}> { - const fileBytes = - input instanceof ArrayBuffer ? new Uint8Array(input) : input; - let splatFileType = fileType; - if (!fileType) { - splatFileType = getSplatFileType(fileBytes); - if (!splatFileType && pathOrUrl) { - splatFileType = getSplatFileTypeFromPath(pathOrUrl); - } - } - - switch (splatFileType) { - case SplatFileType.PLY: { - const ply = new PlyReader({ fileBytes }); - await ply.parseHeader(); - const numSplats = ply.numSplats; - const maxSplats = getTextureSize(numSplats).maxSplats; - const args = { - fileBytes, - packedArray: new Uint32Array(maxSplats * 4), - splatEncoding, - }; - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "unpackPly", - args, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.SPZ: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodeSpz", - { - fileBytes, - splatEncoding, - }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.SPLAT: { - return await withWorker(async (worker) => { - const { packedArray, numSplats } = (await worker.call( - "decodeAntiSplat", - { - fileBytes, - splatEncoding, - }, - )) as { packedArray: Uint32Array; numSplats: number }; - return { packedArray, numSplats }; - }); - } - case SplatFileType.KSPLAT: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodeKsplat", - { fileBytes, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.PCSOGS: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodePcSogs", - { fileBytes, extraFiles, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.PCSOGSZIP: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodePcSogsZip", - { fileBytes, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - default: { - throw new Error(`Unknown splat file type: ${splatFileType}`); - } - } -} - export class SplatData { numSplats: number; maxSplats: number; @@ -894,18 +770,6 @@ export class SplatData { } } -export async function transcodeSpz( - input: TranscodeSpzInput, -): Promise<{ input: TranscodeSpzInput; fileBytes: Uint8Array }> { - return await withWorker(async (worker) => { - const result = (await worker.call("transcodeSpz", input)) as { - input: TranscodeSpzInput; - fileBytes: Uint8Array; - }; - return result; - }); -} - export type FileInput = { fileBytes: Uint8Array; fileType?: SplatFileType; diff --git a/src/SplatMesh.ts b/src/SplatMesh.ts index 5a033fab..ff96b630 100644 --- a/src/SplatMesh.ts +++ b/src/SplatMesh.ts @@ -7,7 +7,6 @@ import init_wasm, { raycast_packed_buffer, } from "spark-rs"; import { ExtSplats } from "./ExtSplats"; -import { OldSparkRenderer } from "./OldSparkRenderer"; import { PackedSplats } from "./PackedSplats"; import { type RgbaArray, TRgbaArray } from "./RgbaArray"; import { SparkRenderer } from "./SparkRenderer"; @@ -1279,39 +1278,3 @@ export const emptyLodIndices = (() => { texture.needsUpdate = true; return texture; })(); - -const EMPTY_GEOMETRY = new THREE.BufferGeometry(); -const EMPTY_MATERIAL = new THREE.ShaderMaterial(); - -// Creates an empty mesh to hook into Three.js rendering. -// This is used to detect if a SparkRenderer is present in the scene. -// If not, one will be injected automatically. -function createRendererDetectionMesh(): THREE.Mesh { - const mesh = new THREE.Mesh(EMPTY_GEOMETRY, EMPTY_MATERIAL); - mesh.frustumCulled = false; - mesh.onBeforeRender = function (renderer, scene) { - if (!scene.isScene) { - // The SplatMesh is part of render call that doesn't have a Scene at its root - // Don't auto-inject a renderer. - this.removeFromParent(); - return; - } - - // Check if the scene has a SparkRenderer instance - let hasSparkRenderer = false; - scene.traverse((c) => { - if (c instanceof SparkRenderer || c instanceof OldSparkRenderer) { - hasSparkRenderer = true; - } - }); - - if (!hasSparkRenderer) { - // No spark renderer present in the scene, inject one. - scene.add(new SparkRenderer({ renderer })); - } - - // Remove mesh to stop checking - this.removeFromParent(); - }; - return mesh; -} diff --git a/src/index.ts b/src/index.ts index ce9b00a3..4a68545e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,12 +1,3 @@ -export { - OldSparkRenderer, - type OldSparkRendererOptions, -} from "./OldSparkRenderer"; -export { - OldSparkViewpoint, - type OldSparkViewpointOptions, -} from "./OldSparkViewpoint"; - export { SparkRenderer, type SparkRendererOptions, @@ -19,7 +10,6 @@ export { RgbaArray, readRgbaArray } from "./RgbaArray"; export { SplatLoader, - unpackSplats, getSplatFileType, isPcSogs, } from "./SplatLoader"; @@ -36,7 +26,6 @@ export { type GsplatModifier, SplatTransformer, } from "./SplatGenerator"; -export { OldSplatAccumulator } from "./OldSplatAccumulator"; export { Readback, type Rgba8Readback, type ReadbackBuffer } from "./Readback"; export { diff --git a/src/oldWorker.ts b/src/oldWorker.ts deleted file mode 100644 index fd169bfb..00000000 --- a/src/oldWorker.ts +++ /dev/null @@ -1,697 +0,0 @@ -import init_wasm, { sort_splats, sort32_splats } from "spark-worker-rs"; -import type { PcSogsJson, TranscodeSpzInput } from "./SplatLoader"; -import { unpackAntiSplat } from "./antisplat"; -import { type SplatEncoding, WASM_SPLAT_SORT } from "./defines"; -import { unpackKsplat } from "./ksplat"; -import { unpackPcSogs, unpackPcSogsZip } from "./pcsogs"; -import { PlyReader } from "./ply"; -import { SpzReader, transcodeSpz } from "./spz"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - getTransferable, - setPackedSplat, - setPackedSplatCenter, - setPackedSplatOpacity, - setPackedSplatQuat, - setPackedSplatRgb, - setPackedSplatScales, - toHalf, -} from "./utils"; - -// WebWorker for Spark's background CPU tasks, such as Gsplat file decoding -// and sorting. - -async function onMessage(event: MessageEvent) { - // Unpack RPC function name, arguments, and ID from the main thread. - const { name, args, id }: { name: string; args: unknown; id: number } = - event.data; - // console.log(`worker.onMessage(${id}, ${name}):`, args); - - // Initialize return result/error, to be filled out below. - let result = undefined; - let error = undefined; - - try { - switch (name) { - case "unpackPly": { - const { packedArray, fileBytes, splatEncoding } = args as { - packedArray: Uint32Array; - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackPly({ - packedArray, - fileBytes, - splatEncoding, - }); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodeSpz": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackSpz(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodeAntiSplat": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = unpackAntiSplat(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - }; - break; - } - case "decodeKsplat": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = unpackKsplat(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodePcSogs": { - const { fileBytes, extraFiles, splatEncoding } = args as { - fileBytes: Uint8Array; - extraFiles: Record; - splatEncoding: SplatEncoding; - }; - const json = JSON.parse( - new TextDecoder().decode(fileBytes), - ) as PcSogsJson; - const decoded = await unpackPcSogs(json, extraFiles, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodePcSogsZip": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackPcSogsZip(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "sortSplats": { - // Sort maxSplats splats using readback data, which encodes one uint32 per - // Gsplats, with the low bytes encoding a float16 distance sort metric. - const { maxSplats, totalSplats, readback, ordering } = args as { - maxSplats: number; - totalSplats: number; - readback: Uint8Array[]; - ordering: Uint32Array; - }; - // Sort totalSplats splats each with 4 bytes of readback, and outputs Uint32Array ordering of splat indices - result = { - id, - readback, - ...sortSplats({ totalSplats, readback, ordering }), - }; - break; - } - case "sortDoubleSplats": { - // Sort numSplats splats using the readback distance metric, which encodes - // one float16 per splat (no unused high bytes like for sortSplats). - const { numSplats, readback, ordering } = args as { - numSplats: number; - readback: Uint16Array; - ordering: Uint32Array; - }; - if (WASM_SPLAT_SORT) { - result = { - id, - readback, - ordering, - activeSplats: sort_splats(numSplats, readback, ordering), - }; - } else { - result = { - id, - readback, - ...sortDoubleSplats({ numSplats, readback, ordering }), - }; - } - break; - } - case "sort32Splats": { - const { maxSplats, numSplats, readback, ordering } = args as { - maxSplats: number; - numSplats: number; - readback: Uint32Array; - ordering: Uint32Array; - }; - // Benchmark sort - // benchmarkSort(numSplats, readback, ordering); - if (WASM_SPLAT_SORT) { - result = { - id, - readback, - ordering, - activeSplats: sort32_splats(numSplats, readback, ordering), - }; - } else { - result = { - id, - readback, - ...sort32Splats({ maxSplats, numSplats, readback, ordering }), - }; - } - break; - } - case "transcodeSpz": { - const input = args as TranscodeSpzInput; - const spzBytes = await transcodeSpz(input); - result = { - id, - fileBytes: spzBytes, - input, - }; - break; - } - default: { - throw new Error(`Unknown name: ${name}`); - } - } - } catch (e) { - error = e; - console.error(error); - } - - // Send the result or error back to the main thread, making sure to transfer any ArrayBuffers - self.postMessage( - { id, result, error }, - { transfer: getTransferable(result) }, - ); -} - -function benchmarkSort( - numSplats: number, - readback32: Uint32Array, - ordering: Uint32Array, -) { - if (numSplats > 0) { - console.log("Running sort benchmark"); - const readbackF32 = new Float32Array(readback32.buffer); - const readback16 = new Uint16Array(readback32.length); - for (let i = 0; i < numSplats; ++i) { - readback16[i] = toHalf(readbackF32[i]); - } - - const WARMUP = 10; - for (let i = 0; i < WARMUP; ++i) { - const activeSplats = sort_splats(numSplats, readback16, ordering); - const activeSplats32 = sort32_splats(numSplats, readback32, ordering); - const results = sortDoubleSplats({ - numSplats, - readback: readback16, - ordering, - }); - const results32 = sort32Splats({ - maxSplats: numSplats, - numSplats, - readback: readback32, - ordering, - }); - } - - const TIMING_SAMPLES = 1000; - let start: number; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const activeSplats = sort_splats(numSplats, readback16, ordering); - } - const wasmTime = (performance.now() - start) / TIMING_SAMPLES; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const results = sortDoubleSplats({ - numSplats, - readback: readback16, - ordering, - }); - } - const jsTime = (performance.now() - start) / TIMING_SAMPLES; - - console.log( - `JS: ${jsTime} ms, WASM: ${wasmTime} ms, numSplats: ${numSplats}`, - ); - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const activeSplats32 = sort32_splats(numSplats, readback32, ordering); - } - const wasm32Time = (performance.now() - start) / TIMING_SAMPLES; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const results = sort32Splats({ - maxSplats: numSplats, - numSplats, - readback: readback32, - ordering, - }); - } - const js32Time = (performance.now() - start) / TIMING_SAMPLES; - - console.log( - `JS32: ${js32Time} ms, WASM32: ${wasm32Time} ms, numSplats: ${numSplats}`, - ); - } -} - -async function unpackPly({ - packedArray, - fileBytes, - splatEncoding, -}: { - packedArray: Uint32Array; - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; -}): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const ply = new PlyReader({ fileBytes }); - await ply.parseHeader(); - const numSplats = ply.numSplats; - - const extra: Record = {}; - - ply.parseSplats( - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - setPackedSplat( - packedArray, - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - }, - (index, sh1, sh2, sh3) => { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); - } - }, - ); - - return { packedArray, numSplats, extra }; -} - -async function unpackSpz( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const spz = new SpzReader({ fileBytes }); - await spz.parseHeader(); - const numSplats = spz.numSplats; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - let extraCallbacks = {}; - if (spz.flagLod) { - const childCounts = new Uint16Array(numSplats); - const childStarts = new Uint32Array(numSplats); - extra.childCounts = childCounts; - extra.childStarts = childStarts; - extraCallbacks = { - childCounts: (index: number, count: number) => { - childCounts[index] = count; - }, - childStarts: (index: number, start: number) => { - childStarts[index] = start; - }, - }; - } - - await spz.parseSplats( - (index, x, y, z) => { - setPackedSplatCenter(packedArray, index, x, y, z); - }, - (index, alpha) => { - setPackedSplatOpacity(packedArray, index, alpha); - }, - (index, r, g, b) => { - setPackedSplatRgb(packedArray, index, r, g, b, splatEncoding); - }, - (index, scaleX, scaleY, scaleZ) => { - setPackedSplatScales( - packedArray, - index, - scaleX, - scaleY, - scaleZ, - splatEncoding, - ); - }, - (index, quatX, quatY, quatZ, quatW) => { - setPackedSplatQuat(packedArray, index, quatX, quatY, quatZ, quatW); - }, - (index, sh1, sh2, sh3) => { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); - } - }, - extraCallbacks, - ); - return { packedArray, numSplats, extra }; -} - -// Array of buckets for sorting float16 distances with range [0, DEPTH_INFINITY]. -const DEPTH_INFINITY_F16 = 0x7c00; -const DEPTH_SIZE_16 = DEPTH_INFINITY_F16 + 1; -let depthArray16: Uint32Array | null = null; - -function sortSplats({ - totalSplats, - readback, - ordering, -}: { totalSplats: number; readback: Uint8Array[]; ordering: Uint32Array }): { - activeSplats: number; - ordering: Uint32Array; -} { - // Sort totalSplats Gsplats, each with 4 bytes of readback, and outputs Uint32Array - // of indices from most distant to nearest. Each 4 bytes encode a float16 distance - // and unused high bytes. - if (!depthArray16) { - depthArray16 = new Uint32Array(DEPTH_SIZE_16); - } - depthArray16.fill(0); - - const readbackUint32 = readback.map((layer) => new Uint32Array(layer.buffer)); - const layerSize = readbackUint32[0].length; - const numLayers = Math.ceil(totalSplats / layerSize); - - let layerBase = 0; - for (let layer = 0; layer < numLayers; ++layer) { - const readbackLayer = readbackUint32[layer]; - const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); - for (let i = 0; i < layerSplats; ++i) { - const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY_F16) { - depthArray16[pri] += 1; - } - } - layerBase += layerSplats; - } - - let activeSplats = 0; - for (let j = 0; j < DEPTH_SIZE_16; ++j) { - const nextIndex = activeSplats + depthArray16[j]; - depthArray16[j] = activeSplats; - activeSplats = nextIndex; - } - - layerBase = 0; - for (let layer = 0; layer < numLayers; ++layer) { - const readbackLayer = readbackUint32[layer]; - const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); - for (let i = 0; i < layerSplats; ++i) { - const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY_F16) { - ordering[depthArray16[pri]] = layerBase + i; - depthArray16[pri] += 1; - } - } - layerBase += layerSplats; - } - if (depthArray16[DEPTH_SIZE_16 - 1] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray16[DEPTH_SIZE_16 - 1]}`, - ); - } - - return { activeSplats, ordering }; -} - -// Sort numSplats splats, each with 2 bytes of float16 readback for distance metric, -// using one bucket sort pass, outputting Uint32Array of indices. -function sortDoubleSplats({ - numSplats, - readback, - ordering, -}: { numSplats: number; readback: Uint16Array; ordering: Uint32Array }): { - activeSplats: number; - ordering: Uint32Array; -} { - // Ensure depthArray is allocated and zeroed out for our buckets. - if (!depthArray16) { - depthArray16 = new Uint32Array(DEPTH_SIZE_16); - } - depthArray16.fill(0); - - // Count the number of splats in each bucket (cull Gsplats at infinity). - for (let i = 0; i < numSplats; ++i) { - const pri = readback[i]; - if (pri < DEPTH_INFINITY_F16) { - depthArray16[pri] += 1; - } - } - - // Compute the beginning index of each bucket in the output array and the - // total number of active (non-infinity) splats, going in reverse order - // because we want most distant Gsplats to be first in the output array. - let activeSplats = 0; - for (let j = DEPTH_INFINITY_F16 - 1; j >= 0; --j) { - const nextIndex = activeSplats + depthArray16[j]; - depthArray16[j] = activeSplats; - activeSplats = nextIndex; - } - - // Write out the sorted indices into the output array according - // bucket order. - for (let i = 0; i < numSplats; ++i) { - const pri = readback[i]; - if (pri < DEPTH_INFINITY_F16) { - ordering[depthArray16[pri]] = i; - depthArray16[pri] += 1; - } - } - // Sanity check that the end of the closest bucket is the same as - // our total count of active splats (not at infinity). - if (depthArray16[0] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray16[0]}`, - ); - } - - return { activeSplats, ordering }; -} - -const DEPTH_INFINITY_F32 = 0x7f800000; -let bucket16lo: Uint32Array | null = null; -let bucket16hi: Uint32Array | null = null; -let scratchSplats: Uint32Array | null = null; - -// two-pass radix sort (base 65536) of 32-bit keys in readback, -// but placing largest values first. -function sort32Splats({ - maxSplats, - numSplats, - readback, // Uint32Array of bit‑patterns - ordering, // Uint32Array to fill with sorted indices -}: { - maxSplats: number; - numSplats: number; - readback: Uint32Array; - ordering: Uint32Array; -}): { activeSplats: number; ordering: Uint32Array } { - const BASE = 1 << 16; // 65536 - - // allocate once - if (!bucket16lo) { - bucket16lo = new Uint32Array(BASE); - } - if (!bucket16hi) { - bucket16hi = new Uint32Array(BASE); - } - if (!scratchSplats || scratchSplats.length < maxSplats) { - scratchSplats = new Uint32Array(maxSplats); - } - - // tally low and high buckets - bucket16lo.fill(0); - bucket16hi.fill(0); - for (let i = 0; i < numSplats; ++i) { - const key = readback[i]; - if (key < DEPTH_INFINITY_F32) { - const inv = ~key >>> 0; - bucket16lo[inv & 0xffff] += 1; - bucket16hi[inv >>> 16] += 1; - } - } - - // - // ——— Pass #1: bucket by inv(lo 16 bits) ——— - // - // exclusive prefix‑sum → starting offsets - let total = 0; - for (let b = 0; b < BASE; ++b) { - const c = bucket16lo[b]; - bucket16lo[b] = total; - total += c; - } - const activeSplats = total; - - // scatter into scratch by low bits of inv - for (let i = 0; i < numSplats; ++i) { - const key = readback[i]; - if (key < DEPTH_INFINITY_F32) { - const inv = ~key >>> 0; - scratchSplats[bucket16lo[inv & 0xffff]++] = i; - } - } - - // - // ——— Pass #2: bucket by inv(hi 16 bits) ——— - // - // exclusive prefix‑sum again - let sum = 0; - for (let b = 0; b < BASE; ++b) { - const c = bucket16hi[b]; - bucket16hi[b] = sum; - sum += c; - } - - // scatter into final ordering by high bits of inv - for (let k = 0; k < activeSplats; ++k) { - const idx = scratchSplats[k]; - const inv = ~readback[idx] >>> 0; - ordering[bucket16hi[inv >>> 16]++] = idx; - } - - // sanity‑check: the last bucket should have eaten all entries - if (bucket16hi[BASE - 1] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${bucket16hi[BASE - 1]}`, - ); - } - - return { activeSplats, ordering }; -} - -// Buffer to queue any messages received while initializing, for example -// early messages to unpack a Gsplat file while still initializing the WASM code. -const messageBuffer: MessageEvent[] = []; - -function bufferMessage(event: MessageEvent) { - messageBuffer.push(event); -} - -async function initialize() { - // Hold any messages received while initializing - self.addEventListener("message", bufferMessage); - - await init_wasm(); - - self.removeEventListener("message", bufferMessage); - self.addEventListener("message", onMessage); - - // Process any buffered messages - for (const event of messageBuffer) { - onMessage(event); - } - messageBuffer.length = 0; -} - -initialize().catch(console.error); diff --git a/src/shaders.ts b/src/shaders.ts index ec598bec..89d05dea 100644 --- a/src/shaders.ts +++ b/src/shaders.ts @@ -4,8 +4,6 @@ import computeUvec4Template from "./shaders/computeUvec4.glsl"; import computeUvec4Vec4Template from "./shaders/computeUvec4_Vec4.glsl"; import computeUvec4x2Vec4Template from "./shaders/computeUvec4x2_Vec4.glsl"; import computeVec4Template from "./shaders/computeVec4.glsl"; -import oldSplatFragment from "./shaders/oldSplatFragment.glsl"; -import oldSplatVertex from "./shaders/oldSplatVertex.glsl"; import splatDefines from "./shaders/splatDefines.glsl"; import splatFragment from "./shaders/splatFragment.glsl"; import splatVertex from "./shaders/splatVertex.glsl"; @@ -17,8 +15,6 @@ export function getShaders(): Record { // @ts-ignore THREE.ShaderChunk.splatDefines = splatDefines; shaders = { - oldSplatVertex, - oldSplatFragment, splatVertex, splatFragment, computeVec4Template, diff --git a/src/shaders/oldSplatFragment.glsl b/src/shaders/oldSplatFragment.glsl deleted file mode 100644 index 358ca57c..00000000 --- a/src/shaders/oldSplatFragment.glsl +++ /dev/null @@ -1,96 +0,0 @@ - -precision highp float; -precision highp int; - -#include - -uniform float near; -uniform float far; -uniform bool encodeLinear; -uniform float time; -uniform bool debugFlag; -uniform float maxStdDev; -uniform float minAlpha; -uniform bool stochastic; -uniform bool disableFalloff; -uniform float falloff; - -uniform bool splatTexEnable; -uniform sampler3D splatTexture; -uniform mat2 splatTexMul; -uniform vec2 splatTexAdd; -uniform float splatTexNear; -uniform float splatTexFar; -uniform float splatTexMid; - -out vec4 fragColor; - -in vec4 vRgba; -in vec2 vSplatUv; -in vec3 vNdc; -flat in uint vSplatIndex; - -void main() { - vec4 rgba = vRgba; - - float z = dot(vSplatUv, vSplatUv); - if (!splatTexEnable) { - if (z > (maxStdDev * maxStdDev)) { - discard; - } - } else { - vec2 uv = splatTexMul * vSplatUv + splatTexAdd; - float ndcZ = vNdc.z; - float depth = (2.0 * near * far) / (far + near - ndcZ * (far - near)); - float clampedFar = max(splatTexFar, splatTexNear); - float clampedDepth = clamp(depth, splatTexNear, clampedFar); - float logDepth = log2(clampedDepth + 1.0); - float logNear = log2(splatTexNear + 1.0); - float logFar = log2(clampedFar + 1.0); - - float texZ; - if (splatTexMid > 0.0) { - float clampedMid = clamp(splatTexMid, splatTexNear, clampedFar); - float logMid = log2(clampedMid + 1.0); - texZ = (clampedDepth <= clampedMid) ? - (0.5 * ((logDepth - logNear) / (logMid - logNear))) : - (0.5 * ((logDepth - logMid) / (logFar - logMid)) + 0.5); - } else { - texZ = (logDepth - logNear) / (logFar - logNear); - } - - vec4 modulate = texture(splatTexture, vec3(uv, 1.0 - texZ)); - rgba *= modulate; - } - - rgba.a *= mix(1.0, exp(-0.5 * z), falloff); - - if (rgba.a < minAlpha) { - discard; - } - if (encodeLinear) { - rgba.rgb = srgbToLinear(rgba.rgb); - } - - if (stochastic) { - const bool STEADY = false; - uint uTime = STEADY ? 0u : floatBitsToUint(time); - uvec2 coord = uvec2(gl_FragCoord.xy); - uint state = uTime + 0x9e3779b9u * coord.x + 0x85ebca6bu * coord.y + 0xc2b2ae35u * uint(vSplatIndex); - state = state * 747796405u + 2891336453u; - uint hash = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - hash = (hash >> 22u) ^ hash; - float rand = float(hash) / 4294967296.0; - if (rand < rgba.a) { - fragColor = vec4(rgba.rgb, 1.0); - } else { - discard; - } - } else { - #ifdef PREMULTIPLIED_ALPHA - fragColor = vec4(rgba.rgb * rgba.a, rgba.a); - #else - fragColor = rgba; - #endif - } -} diff --git a/src/shaders/oldSplatVertex.glsl b/src/shaders/oldSplatVertex.glsl deleted file mode 100644 index ad90d698..00000000 --- a/src/shaders/oldSplatVertex.glsl +++ /dev/null @@ -1,218 +0,0 @@ - -precision highp float; -precision highp int; -precision highp usampler2DArray; - -#include - -attribute uint splatIndex; - -out vec4 vRgba; -out vec2 vSplatUv; -out vec3 vNdc; -flat out uint vSplatIndex; - -uniform vec2 renderSize; -uniform uint numSplats; -uniform vec4 renderToViewQuat; -uniform vec3 renderToViewPos; -uniform float maxStdDev; -uniform float minPixelRadius; -uniform float maxPixelRadius; -uniform float time; -uniform float deltaTime; -uniform bool debugFlag; -uniform float minAlpha; -uniform bool stochastic; -uniform bool enable2DGS; -uniform float blurAmount; -uniform float preBlurAmount; -uniform float focalDistance; -uniform float apertureAngle; -uniform float clipXY; -uniform float focalAdjustment; - -uniform usampler2DArray packedSplats; -uniform vec4 rgbMinMaxLnScaleMinMax; - -void main() { - // Default to outside the frustum so it's discarded if we return early - gl_Position = vec4(0.0, 0.0, 2.0, 1.0); - - if (uint(gl_InstanceID) >= numSplats) { - return; - } - - ivec3 texCoord; - if (stochastic) { - texCoord = ivec3( - uint(gl_InstanceID) & SPLAT_TEX_WIDTH_MASK, - (uint(gl_InstanceID) >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK, - (uint(gl_InstanceID) >> SPLAT_TEX_LAYER_BITS) - ); - } else { - if (splatIndex == 0xffffffffu) { - // Special value reserved for "no splat" - return; - } - texCoord = ivec3( - splatIndex & SPLAT_TEX_WIDTH_MASK, - (splatIndex >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK, - splatIndex >> SPLAT_TEX_LAYER_BITS - ); - } - uvec4 packed = texelFetch(packedSplats, texCoord, 0); - - vec3 center, scales; - vec4 quaternion, rgba; - unpackSplatEncoding(packed, center, scales, quaternion, rgba, rgbMinMaxLnScaleMinMax); - - if (rgba.a < minAlpha) { - return; - } - bvec3 zeroScales = equal(scales, vec3(0.0)); - if (all(zeroScales)) { - return; - } - - // Compute the view space center of the splat - vec3 viewCenter = quatVec(renderToViewQuat, center) + renderToViewPos; - - // Discard splats behind the camera - if (viewCenter.z >= 0.0) { - return; - } - - // Compute the clip space center of the splat - vec4 clipCenter = projectionMatrix * vec4(viewCenter, 1.0); - - // Discard splats outside near/far planes - if (abs(clipCenter.z) >= clipCenter.w) { - return; - } - - // Discard splats more than clipXY times outside the XY frustum - float clip = clipXY * clipCenter.w; - if (abs(clipCenter.x) > clip || abs(clipCenter.y) > clip) { - return; - } - - // Record the splat index for entropy - vSplatIndex = splatIndex; - - // Compute view space quaternion of splat - vec4 viewQuaternion = quatQuat(renderToViewQuat, quaternion); - - if (enable2DGS && any(zeroScales)) { - vRgba = rgba; - vSplatUv = position.xy * maxStdDev; - - vec3 offset; - if (zeroScales.z) { - offset = vec3(vSplatUv.xy * scales.xy, 0.0); - } else if (zeroScales.y) { - offset = vec3(vSplatUv.x * scales.x, 0.0, vSplatUv.y * scales.z); - } else { - offset = vec3(0.0, vSplatUv.xy * scales.yz); - } - - vec3 viewPos = viewCenter + quatVec(viewQuaternion, offset); - gl_Position = projectionMatrix * vec4(viewPos, 1.0); - vNdc = gl_Position.xyz / gl_Position.w; - return; - } - - // Compute NDC center of the splat - vec3 ndcCenter = clipCenter.xyz / clipCenter.w; - - // Compute the 3D covariance matrix of the splat - mat3 RS = scaleQuaternionToMatrix(scales, viewQuaternion); - mat3 cov3D = RS * transpose(RS); - - // Compute the Jacobian of the splat's projection at its center - vec2 scaledRenderSize = renderSize * focalAdjustment; - vec2 focal = 0.5 * scaledRenderSize * vec2(projectionMatrix[0][0], projectionMatrix[1][1]); - - mat3 J; - if(isOrthographic) { - J = mat3( - focal.x, 0.0, 0.0, - 0.0, focal.y, 0.0, - 0.0, 0.0, 0.0 - ); - } else { - float invZ = 1.0 / viewCenter.z; - vec2 J1 = focal * invZ; - vec2 J2 = -(J1 * viewCenter.xy) * invZ; - J = mat3( - J1.x, 0.0, J2.x, - 0.0, J1.y, J2.y, - 0.0, 0.0, 0.0 - ); - } - - // Compute the 2D covariance by projecting the 3D covariance - // and picking out the XY plane components. - // Keeping below because we may need it in the future - // for skinning deformations. - // mat3 W = transpose(mat3(viewMatrix)); - // mat3 T = W * J; - // mat3 cov2D = transpose(T) * cov3D * T; - mat3 cov2D = transpose(J) * cov3D * J; - float a = cov2D[0][0]; - float d = cov2D[1][1]; - float b = cov2D[0][1]; - - // Optionally pre-blur the splat to match non-antialias optimized splats - a += preBlurAmount; - d += preBlurAmount; - - float fullBlurAmount = blurAmount; - if ((focalDistance > 0.0) && (apertureAngle > 0.0)) { - float focusRadius = maxPixelRadius; - if (viewCenter.z < 0.0) { - float focusBlur = abs((-viewCenter.z - focalDistance) / viewCenter.z); - float apertureRadius = focal.x * tan(0.5 * apertureAngle); - focusRadius = focusBlur * apertureRadius; - } - fullBlurAmount = clamp(sqr(focusRadius), blurAmount, sqr(maxPixelRadius)); - } - - // Do convolution with a 0.5-pixel Gaussian for anti-aliasing: sqrt(0.3) ~= 0.5 - float detOrig = a * d - b * b; - a += fullBlurAmount; - d += fullBlurAmount; - float det = a * d - b * b; - - // Compute anti-aliasing intensity scaling factor - float blurAdjust = sqrt(max(0.0, detOrig / det)); - rgba.a *= blurAdjust; - if (rgba.a < minAlpha) { - return; - } - - // Compute the eigenvalue and eigenvectors of the 2D covariance matrix - float eigenAvg = 0.5 * (a + d); - float eigenDelta = sqrt(max(0.0, eigenAvg * eigenAvg - det)); - float eigen1 = eigenAvg + eigenDelta; - float eigen2 = eigenAvg - eigenDelta; - - vec2 eigenVec1 = normalize(vec2((abs(b) < 0.001) ? 1.0 : b, eigen1 - a)); - vec2 eigenVec2 = vec2(eigenVec1.y, -eigenVec1.x); - - float scale1 = min(maxPixelRadius, maxStdDev * sqrt(eigen1)); - float scale2 = min(maxPixelRadius, maxStdDev * sqrt(eigen2)); - if (scale1 < minPixelRadius && scale2 < minPixelRadius) { - return; - } - - // Compute the NDC coordinates for the ellipsoid's diagonal axes. - vec2 pixelOffset = position.x * eigenVec1 * scale1 + position.y * eigenVec2 * scale2; - vec2 ndcOffset = (2.0 / scaledRenderSize) * pixelOffset; - vec3 ndc = vec3(ndcCenter.xy + ndcOffset, ndcCenter.z); - - vRgba = rgba; - vSplatUv = position.xy * maxStdDev; - vNdc = ndc; - gl_Position = vec4(ndc.xy * clipCenter.w, clipCenter.zw); -}