diff --git a/cuslines/__init__.py b/cuslines/__init__.py index 2f9f9f7..792fc8f 100644 --- a/cuslines/__init__.py +++ b/cuslines/__init__.py @@ -30,7 +30,6 @@ def _detect_backend(): pass return None - BACKEND = _detect_backend() if BACKEND == "metal": @@ -38,7 +37,7 @@ def _detect_backend(): MetalBootDirectionGetter as BootDirectionGetter, ) from cuslines.metal import ( - MetalGPUTracker as GPUTracker, + MetalGPUTracker as Tracker, ) from cuslines.metal import ( MetalProbDirectionGetter as ProbDirectionGetter, @@ -49,7 +48,7 @@ def _detect_backend(): elif BACKEND == "cuda": from cuslines.cuda_python import ( BootDirectionGetter, - GPUTracker, + GPUTracker as Tracker, ProbDirectionGetter, PttDirectionGetter, ) @@ -64,18 +63,24 @@ def _detect_backend(): WebGPUPttDirectionGetter as PttDirectionGetter, ) from cuslines.webgpu import ( - WebGPUTracker as GPUTracker, + WebGPUTracker as Tracker, ) else: - raise ImportError( - "No GPU backend available. Install either:\n" - " - CUDA: pip install 'cuslines[cu13]' (NVIDIA GPU)\n" - " - Metal: pip install 'cuslines[metal]' (Apple Silicon)\n" - " - WebGPU: pip install 'cuslines[webgpu]' (cross-platform)" + from cuslines.numba import ( + CPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.numba import ( + CPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.numba import ( + CPUPttDirectionGetter as PttDirectionGetter, + ) + from cuslines.numba import ( + CPUTracker as Tracker, ) __all__ = [ - "GPUTracker", + "Tracker", "ProbDirectionGetter", "PttDirectionGetter", "BootDirectionGetter", diff --git a/cuslines/cuda_c/boot.cu b/cuslines/cuda_c/boot.cu index 978d158..9089bfa 100644 --- a/cuslines/cuda_c/boot.cu +++ b/cuslines/cuda_c/boot.cu @@ -582,8 +582,7 @@ __device__ int get_direction_boot_d( const int ndir = peak_directions_d(__h_sh, dirs, sphere_vertices, - sphere_edges, - reinterpret_cast(__r_sh)); // reuse __r_sh as shInd in func which is large enough + sphere_edges); if (NATTEMPTS == 1) { // init=True... return ndir; // and dirs; } else { // init=False... diff --git a/cuslines/cuda_c/generate_streamlines_cuda.cu b/cuslines/cuda_c/generate_streamlines_cuda.cu index 68c32d0..ab56c8b 100644 --- a/cuslines/cuda_c/generate_streamlines_cuda.cu +++ b/cuslines/cuda_c/generate_streamlines_cuda.cu @@ -64,8 +64,8 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32; const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1))); - extern __shared__ REAL_T __sh[]; - REAL_T *__pmf_data_sh = __sh + tidy*N32DIMT; + __shared__ REAL_T pmf_data_sh[BDIM_Y][DIMT]; + REAL_T* __pmf_data_sh = pmf_data_sh[tidy]; // pmf = self.pmf_gen.get_pmf_c(&point[0], pmf) __syncwarp(WMASK); @@ -94,13 +94,11 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, __syncwarp(WMASK); if (IS_START) { - int *__shInd = reinterpret_cast(__sh + BDIM_Y*N32DIMT) + tidy*N32DIMT; return peak_directions_d(__pmf_data_sh, dirs, sphere_vertices, - sphere_edges, - __shInd); + sphere_edges); } else { REAL_T __tmp; #ifdef DEBUG @@ -148,7 +146,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, dir.y*sphere_vertices[i].y+ dir.z*sphere_vertices[i].z; - if (FABS(dot) < cos_similarity) { + if (APPLY_ABS_IF_SYM(dot) < cos_similarity) { __pmf_data_sh[i] = 0.0; } } diff --git a/cuslines/cuda_c/globals.h b/cuslines/cuda_c/globals.h index 505f5cc..6080690 100644 --- a/cuslines/cuda_c/globals.h +++ b/cuslines/cuda_c/globals.h @@ -67,6 +67,12 @@ #endif +#if SPHERE_SYMM == 0 + #define APPLY_ABS_IF_SYM(x) (x) +#else + #define APPLY_ABS_IF_SYM(x) FABS(x) +#endif + #define MIN(x,y) (((x)<(y))?(x):(y)) #define MAX(x,y) (((x)>(y))?(x):(y)) #define POW2(n) (1 << (n)) diff --git a/cuslines/cuda_c/ptt.cu b/cuslines/cuda_c/ptt.cu index fd770df..b2fc9d5 100644 --- a/cuslines/cuda_c/ptt.cu +++ b/cuslines/cuda_c/ptt.cu @@ -31,12 +31,14 @@ __device__ float interp4_d(const float3 pos, const float* frame, }; const int odf_idx = static_cast(tex3D(*sphere_vertices_lut, uvw.z, uvw.y, uvw.x)); - const int grid_col = odf_idx & WIDTH_MASK; - const int grid_row = odf_idx >> LOG2_WIDTH; - - const float x_query = (float)(grid_col * DIMX) + pos.x; - const float y_query = (float)(grid_row * DIMY) + pos.y; - return tex3D(*pmf, x_query, y_query, pos.z); + const int tx = odf_idx & WIDTH_MASK; + const int ty = (odf_idx >> LOG2_X) & HEIGHT_MASK; + const int tz = (odf_idx >> (LOG2_X + LOG2_Y)); + + const float x_query = (float)(tx * DIMX) + pos.x; + const float y_query = (float)(ty * DIMY) + pos.y; + const float z_query = (float)(tz * DIMZ) + pos.z; + return tex3D(*pmf, x_query, y_query, z_query); } __device__ void prepare_propagator_d(float k1, float k2, float arclength, @@ -231,6 +233,8 @@ __device__ int get_direction_ptt_d( } } + __syncwarp(WMASK); + const float first_val = interp4_d( pos, __frame_sh, pmf, sphere_vertices_lut); @@ -287,17 +291,14 @@ __device__ int get_direction_ptt_d( // Move vert to face for (int ii = tidx; ii < DISC_FACE_CNT; ii+=BDIM_X) { - bool all_verts_valid = 1; for (int jj = 0; jj < 3; jj++) { float vert_val = __vert_pdf_sh[DISC_FACE[ii*3 + jj]]; if (vert_val == 0) { - all_verts_valid = IS_INIT; // On init, even go with faces that are not fully supported + __face_cdf_sh[ii] = 0; + break; } __face_cdf_sh[ii] += vert_val; } - if (!all_verts_valid) { - __face_cdf_sh[ii] = 0; - } } __syncwarp(WMASK); @@ -456,15 +457,16 @@ __device__ bool init_frame_ptt_d( } } else { if (tidx == 0) { - for (int ii = 0; ii < 9; ii++) { + for (int ii = 0; ii < 6; ii++) { __frame[ii] = -__frame[ii]; } } __syncwarp(WMASK); } if (tidx == 0) { - for (int ii = 0; ii < 9; ii++) { + for (int ii = 0; ii < 6; ii++) { __frame[9+ii] = -__frame[ii]; // save flipped frame for second run + // note that we only flip the tangent and normal, not the binormal } } __syncwarp(WMASK); diff --git a/cuslines/cuda_c/ptt_init.cu b/cuslines/cuda_c/ptt_init.cu index 9d45a9e..9bcf316 100644 --- a/cuslines/cuda_c/ptt_init.cu +++ b/cuslines/cuda_c/ptt_init.cu @@ -27,33 +27,33 @@ __global__ void getNumStreamlinesPtt_k( const int nseed, curandStatePhilox4_32_10_t st; curand_init(RNG_SEED, gid, 0, &st); - extern __shared__ REAL_T __sh[]; - REAL_T *__pmf_data_sh = __sh + tidy*N32DIMT; + __shared__ REAL_T pmf_data_sh[BDIM_Y][DIMT]; + REAL_T* __pmf_data_sh = pmf_data_sh[tidy]; REAL3_T point = seeds[slid]; #pragma unroll for (int i = tidx; i < DIMT; i += BDIM_X) { - const int grid_col = i & WIDTH_MASK; - const int grid_row = i >> LOG2_WIDTH; - - const REAL_T x_query = (REAL_T)(grid_col * DIMX) + point.x; - const REAL_T y_query = (REAL_T)(grid_row * DIMY) + point.y; - __pmf_data_sh[i] = tex3D(*pmf, x_query, y_query, point.z); + const int tx = i & WIDTH_MASK; + const int ty = (i >> LOG2_X) & HEIGHT_MASK; + const int tz = (i >> (LOG2_X + LOG2_Y)); + + const REAL_T x_query = (REAL_T)(tx * DIMX) + point.x; + const REAL_T y_query = (REAL_T)(ty * DIMY) + point.y; + const REAL_T z_query = (REAL_T)(tz * DIMZ) + point.z; + __pmf_data_sh[i] = tex3D(*pmf, x_query, y_query, z_query); if (__pmf_data_sh[i] < PMF_THRESHOLD_P) { __pmf_data_sh[i] = 0.0; } } __syncwarp(WMASK); - int *__shInd = reinterpret_cast(__sh + BDIM_Y*N32DIMT) + tidy*N32DIMT; int ndir = peak_directions_d< BDIM_X, BDIM_Y>(__pmf_data_sh, __shDir, sphere_vertices, - sphere_edges, - __shInd); + sphere_edges); if (tidx == 0) { slineOutOff[slid] = ndir; diff --git a/cuslines/cuda_c/tracking_helpers.cu b/cuslines/cuda_c/tracking_helpers.cu index b9d9e20..0de438b 100644 --- a/cuslines/cuda_c/tracking_helpers.cu +++ b/cuslines/cuda_c/tracking_helpers.cu @@ -79,17 +79,18 @@ template cos_similarity) { break; } @@ -232,7 +234,7 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, __syncwarp(WMASK); */ #else - const int indMax = max_d(__shInd[tidy], -1); + const int indMax = max_d(SAMPLM_NR, __shInd, -1); if (indMax != -1) { __ret = MAKE_REAL3(sphere_vertices[indMax][0], sphere_vertices[indMax][1], diff --git a/cuslines/cuda_python/cu_direction_getters.py b/cuslines/cuda_python/cu_direction_getters.py index 466c0e0..f921537 100644 --- a/cuslines/cuda_python/cu_direction_getters.py +++ b/cuslines/cuda_python/cu_direction_getters.py @@ -82,6 +82,7 @@ def compile_program(self, gpu_tracker, debug: bool = False): "RNG_SEED": str(int(gpu_tracker.rng_seed)), "SAMPLM_NR": str(int(gpu_tracker.samplm_nr)), "NUM_EDGES": str(int(gpu_tracker.nedges)), + "SPHERE_SYMM": "1" if gpu_tracker.sphere_symm else "0", "EXCESS_ALLOC_FACT": str(int(EXCESS_ALLOC_FACT)), "MAX_SLINES_PER_SEED": str(int(MAX_SLINES_PER_SEED)), "MAX_SLINE_LEN": str(int(MAX_SLINE_LEN)), @@ -92,8 +93,10 @@ def compile_program(self, gpu_tracker, debug: bool = False): } self.set_macros(gpu_tracker) optional_macros = [ - "log2_width", - "width_mask", + "LOG2_X", + "LOG2_Y", + "WIDTH_MASK", + "HEIGHT_MASK", "probe_step_size", "max_curvature", "probe_quality", @@ -368,11 +371,7 @@ def __init__(self): def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): ker = self.module.get_kernel(self.getnum_kernel_name) - shared_memory = ( - REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt - + np.int32().nbytes * BLOCK_Y * sp.gpu_tracker.n32dimt - ) - config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory) + config = LaunchConfig(block=block, grid=grid, shmem_size=0) if isinstance(sp.gpu_tracker.dataf_d[n], runtime.cudaTextureObject_t): dataf_d_n = sp.gpu_tracker.dataf_d[n].getPtr() @@ -394,8 +393,7 @@ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp): def generateStreamlines(self, n, nseeds_gpu, block, grid, sp): ker = self.module.get_kernel(self.genstreamlines_kernel_name) - shared_memory = REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt - config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory) + config = LaunchConfig(block=block, grid=grid, shmem_size=0) launch( sp.gpu_tracker.streams[n], @@ -456,8 +454,10 @@ def __init__( self.target_short_step = target_short_step def set_macros(self, gpu_tracker): - self.macros["LOG2_WIDTH"] = str(int(self.log2_width)) - self.macros["WIDTH_MASK"] = str(int(self.width_mask)) + self.macros["LOG2_X"] = str(int(self.log2_x)) + self.macros["LOG2_Y"] = str(int(self.log2_y)) + self.macros["WIDTH_MASK"] = str(int(self.nx - 1)) + self.macros["HEIGHT_MASK"] = str(int(self.ny - 1)) self.macros["PROBE_QUALITY"] = str(float(self.probe_quality)) self.macros["PROBE_STEP_SIZE"] = str(float(self.probe_length)) self.macros["STEP_FRAC"] = str( @@ -528,7 +528,7 @@ def deallocate_on_gpu(self, n): hard_error=False, ) - def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices): + def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices, sphere_symm): dimx, dimy, dimz, dimt = dataf.shape dataf = dataf.clip(min=0) @@ -542,16 +542,36 @@ def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices): # This rearrangement is for cuda texture memory # In particular, for texture memory, we want each dimension - # to be less than 65,535, so we tile t across x and y - # additionally, we then make the tiles in the x dim + # to be less than 2k, so we tile t across x, y, and z, + # additionally, we then make the tiles in the dims are # a power of 2 to ensure it is fast to calculate indices # into the tiles - ideal_tiles_per_row = math.ceil(math.sqrt(dimt)) - self.log2_width = math.ceil(math.log2(ideal_tiles_per_row)) - tiles_per_row = 1 << self.log2_width - self.width_mask = tiles_per_row - 1 - tiles_per_col = math.ceil(dimt / tiles_per_row) - total_slots = tiles_per_row * tiles_per_col + + def calculate_tight_fit(dimt): + nx, ny, nz = 1, 1, 1 + + # Keep 64x the smallest dimension until we have enough slots + while nx * ny * nz < dimt: + if nx <= ny and nx <= nz: + nx *= 2 + elif ny <= nx and ny <= nz: + ny *= 2 + else: + nz *= 2 + + return nx, ny, nz + + nx, ny, nz = calculate_tight_fit(dimt) + + # these must be saved to be passed as + # macros to the CUDA code + self.log2_x = int(math.log2(nx)) + self.log2_y = int(math.log2(ny)) + self.nx = nx + self.ny = ny + + total_slots = nx * ny * nz + if dimt < total_slots: padding = np.zeros((dimx, dimy, dimz, total_slots - dimt), dtype=np.float32) data_f_rearranged = np.concatenate([dataf, padding], axis=3) @@ -559,22 +579,22 @@ def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices): data_f_rearranged = dataf total_memory_usage_gb = ( - (tiles_per_row * dimx) * (tiles_per_col * dimy) * dimz * 4 / 1e9 + (nx * dimx) * (ny * dimy) * (nz * dimz) * 4 / 1e9 ) logger.info( ( f"For PTT, we will allocate a 3D texture of size " - f"{tiles_per_row * dimx}x{tiles_per_col * dimy}x{dimz} " + f"{nx * dimx}x{ny * dimy}x{nz * dimz} " "to store the ODFs on the GPU. This will be in 4 byte floats and use " f"{total_memory_usage_gb:.2f} GB of GPU memory. " "If this is too near your total GPU memory, it will error" ) ) data_f_rearranged = data_f_rearranged.reshape( - dimx, dimy, dimz, tiles_per_col, tiles_per_row + dimx, dimy, dimz, nz, ny, nx ) - data_f_rearranged = data_f_rearranged.transpose(2, 3, 1, 4, 0).reshape( - dimz, tiles_per_col * dimy, tiles_per_row * dimx + data_f_rearranged = data_f_rearranged.transpose(3, 2, 4, 1, 5, 0).reshape( + nz * dimz, ny * dimy, nx * dimx ) data_f_rearranged = np.ascontiguousarray(data_f_rearranged, dtype=np.float32) @@ -584,8 +604,18 @@ def prepare_data(self, dataf, stop_map, stop_threshold, sphere_vertices): grid_x, grid_y, grid_z = np.meshgrid(coords, coords, coords, indexing="ij") grid_points = np.stack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=1) - tree = KDTree(sphere_vertices) - _, closest_indices = tree.query(grid_points) + if sphere_symm: + symmetric_vertices = np.concatenate([sphere_vertices, -sphere_vertices], axis=0) + num_orig = len(sphere_vertices) + index_map = np.concatenate([np.arange(num_orig), np.arange(num_orig)], axis=0) + + tree = KDTree(symmetric_vertices) + _, closest_augmented_indices = tree.query(grid_points) + closest_indices = index_map[closest_augmented_indices] + else: + tree = KDTree(sphere_vertices) + _, closest_indices = tree.query(grid_points) + lut = closest_indices.reshape( (self.odf_lut_res, self.odf_lut_res, self.odf_lut_res) ) diff --git a/cuslines/cuda_python/cu_tractography.py b/cuslines/cuda_python/cu_tractography.py index ebee530..8f7275a 100644 --- a/cuslines/cuda_python/cu_tractography.py +++ b/cuslines/cuda_python/cu_tractography.py @@ -10,6 +10,8 @@ from tqdm import tqdm from trx.trx_file_memmap import TrxFile +from cuslines.generic_tracker import GenericTracker + from cuslines.cuda_python.cu_direction_getters import ( BootDirectionGetter, GPUDirectionGetter, @@ -31,7 +33,7 @@ # Remove small/long streamlines on gpu -class GPUTracker: +class GPUTracker(GenericTracker): def __init__( self, dg: GPUDirectionGetter, @@ -40,6 +42,7 @@ def __init__( stop_threshold: float, sphere_vertices: np.ndarray, sphere_edges: np.ndarray, + sphere_symm: bool = False, max_angle: float = radians(60), step_size: float = 0.5, min_pts=0, @@ -70,6 +73,9 @@ def __init__( Vertices of the sphere used for direction sampling. sphere_edges : np.ndarray Edges of the sphere used for direction sampling. + sphere_symm : bool, optional + Whether to assume sphere vertices are antipodally symmetric + default: False max_angle : float, optional Maximum angle (in radians) between steps default: radians(60) @@ -117,6 +123,7 @@ def __init__( stop_map, stop_threshold, sphere_vertices, + sphere_symm, ) else: self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) @@ -129,6 +136,12 @@ def __init__( if isinstance(dg, BootDirectionGetter): self.samplm_nr = int(dg.sampling_matrix.shape[0]) else: + if len(self.sphere_vertices) != self.dimt: + raise ValueError( + f"Number of sphere vertices ({len(self.sphere_vertices)}) must" + f" match last dimension of dataf ({self.dimt}), " + "because dataf should be an ODF when using prob or ptt tracking" + ) self.samplm_nr = self.dimt self.n32dimt = ((self.dimt + 31) // 32) * 32 @@ -143,6 +156,7 @@ def __init__( self.rng_seed = int(rng_seed) self.rng_offset = int(rng_offset) self.chunk_size = int(chunk_size) + self.sphere_symm = bool(sphere_symm) avail = checkCudaErrors(runtime.cudaGetDeviceCount()) if self.ngpus > avail: @@ -284,96 +298,3 @@ def __exit__(self, exc_type, exc, tb): runtime.cudaStreamDestroy(self.streams[n]), hard_error=False ) return False - - def _divide_chunks(self, seeds): - global_chunk_sz = self.chunk_size * self.ngpus - nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz - return global_chunk_sz, nchunks - - def generate_sft(self, seeds, ref_img): - global_chunk_sz, nchunks = self._divide_chunks(seeds) - buffer_size = 0 - generators = [] - - with tqdm(total=seeds.shape[0]) as pbar: - for idx in range(nchunks): - self.seed_propagator.propagate( - seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] - ) - buffer_size += self.seed_propagator.get_buffer_size() - generators.append(self.seed_propagator.as_generator()) - pbar.update( - seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0] - ) - array_sequence = ArraySequence( - (item for gen in generators for item in gen), buffer_size - ) - return StatefulTractogram(array_sequence, ref_img, Space.VOX) - - def generate_trx(self, seeds, ref_img): - global_chunk_sz, nchunks = self._divide_chunks(seeds) - - # Will resize by a factor of 2 if these are exceeded - sl_len_guess = 100 - sl_per_seed_guess = 2 - n_sls_guess = sl_per_seed_guess * seeds.shape[0] - - # trx files use memory mapping - trx_reference = TrxFile(reference=ref_img) - trx_reference.streamlines._data = trx_reference.streamlines._data.astype( - np.float32 - ) - trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype( - np.uint64 - ) - - trx_file = TrxFile( - nb_streamlines=n_sls_guess, - nb_vertices=n_sls_guess * sl_len_guess, - init_as=trx_reference, - ) - offsets_idx = 0 - sls_data_idx = 0 - - with tqdm(total=seeds.shape[0]) as pbar: - for idx in range(int(nchunks)): - self.seed_propagator.propagate( - seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] - ) - tractogram = Tractogram( - self.seed_propagator.as_array_sequence(), - affine_to_rasmm=ref_img.affine, - ) - tractogram.to_world() - sls = tractogram.streamlines - - new_offsets_idx = offsets_idx + len(sls._offsets) - new_sls_data_idx = sls_data_idx + len(sls._data) - - if ( - new_offsets_idx > trx_file.header["NB_STREAMLINES"] - or new_sls_data_idx > trx_file.header["NB_VERTICES"] - ): - logger.info("TRX resizing...") - trx_file.resize( - nb_streamlines=new_offsets_idx * 2, - nb_vertices=new_sls_data_idx * 2, - ) - - # TRX uses memmaps here - trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data - trx_file.streamlines._offsets[offsets_idx:new_offsets_idx] = ( - sls_data_idx + sls._offsets - ) - trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = ( - sls._lengths - ) - - offsets_idx = new_offsets_idx - sls_data_idx = new_sls_data_idx - pbar.update( - seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0] - ) - trx_file.resize() - - return trx_file diff --git a/cuslines/generic_tracker.py b/cuslines/generic_tracker.py new file mode 100644 index 0000000..c753f6a --- /dev/null +++ b/cuslines/generic_tracker.py @@ -0,0 +1,116 @@ +import logging +import numpy as np +from tqdm import tqdm +from trx.trx_file_memmap import TrxFile +from dipy.io.stateful_tractogram import Space, StatefulTractogram +from nibabel.streamlines.array_sequence import ArraySequence +from nibabel.streamlines.tractogram import Tractogram + +logger = logging.getLogger("GPUStreamlines") + + +class GenericTracker: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def _ngpus(self): + if hasattr(self, "ngpus"): + return self.ngpus + else: + return 1 + + def _divide_chunks(self, seeds): + global_chunk_sz = self.chunk_size * self._ngpus() + nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz + return global_chunk_sz, nchunks + + def generate_sft(self, seeds, ref_img): + global_chunk_sz, nchunks = self._divide_chunks(seeds) + buffer_size = 0 + generators = [] + + with tqdm(total=seeds.shape[0]) as pbar: + for idx in range(nchunks): + self.seed_propagator.propagate( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] + ) + buffer_size += self.seed_propagator.get_buffer_size() + generators.append(self.seed_propagator.as_generator()) + pbar.update( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0] + ) + array_sequence = ArraySequence( + (item for gen in generators for item in gen), buffer_size + ) + return StatefulTractogram(array_sequence, ref_img, Space.VOX) + + def generate_trx(self, seeds, ref_img): + global_chunk_sz, nchunks = self._divide_chunks(seeds) + + # Will resize by a factor of 2 if these are exceeded + sl_len_guess = 100 + sl_per_seed_guess = 2 + n_sls_guess = sl_per_seed_guess * seeds.shape[0] + + # trx files use memory mapping + trx_reference = TrxFile(reference=ref_img) + trx_reference.streamlines._data = trx_reference.streamlines._data.astype( + np.float32 + ) + trx_reference.streamlines._offsets = trx_reference.streamlines._offsets.astype( + np.uint64 + ) + + trx_file = TrxFile( + nb_streamlines=n_sls_guess, + nb_vertices=n_sls_guess * sl_len_guess, + init_as=trx_reference, + ) + offsets_idx = 0 + sls_data_idx = 0 + + with tqdm(total=seeds.shape[0]) as pbar: + for idx in range(int(nchunks)): + self.seed_propagator.propagate( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz] + ) + tractogram = Tractogram( + self.seed_propagator.as_array_sequence(), + affine_to_rasmm=ref_img.affine, + ) + tractogram.to_world() + sls = tractogram.streamlines + + new_offsets_idx = offsets_idx + len(sls._offsets) + new_sls_data_idx = sls_data_idx + len(sls._data) + + if ( + new_offsets_idx > trx_file.header["NB_STREAMLINES"] + or new_sls_data_idx > trx_file.header["NB_VERTICES"] + ): + logger.info("TRX resizing...") + trx_file.resize( + nb_streamlines=new_offsets_idx * 2, + nb_vertices=new_sls_data_idx * 2, + ) + + # TRX uses memmaps here + trx_file.streamlines._data[sls_data_idx:new_sls_data_idx] = sls._data + trx_file.streamlines._offsets[offsets_idx:new_offsets_idx] = ( + sls_data_idx + sls._offsets + ) + trx_file.streamlines._lengths[offsets_idx:new_offsets_idx] = ( + sls._lengths + ) + + offsets_idx = new_offsets_idx + sls_data_idx = new_sls_data_idx + pbar.update( + seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0] + ) + trx_file.resize() + + return trx_file diff --git a/cuslines/metal/mt_direction_getters.py b/cuslines/metal/mt_direction_getters.py index d6ed0ff..d639a48 100644 --- a/cuslines/metal/mt_direction_getters.py +++ b/cuslines/metal/mt_direction_getters.py @@ -42,11 +42,11 @@ def getNumStreamlines(self, nseeds_gpu, block, grid, sp): def generateStreamlines(self, nseeds_gpu, block, grid, sp): pass - def setup_device(self, device): + def setup_device(self, device, sphere_symm=False): """Called once when GPUTracker allocates resources.""" pass - def compile_program(self, device): + def compile_program(self, device, sphere_symm=False): import Metal import re @@ -96,7 +96,9 @@ def compile_program(self, device): # Prepend compile-time constants enable = 1 if self.angular_weight > 0 else 0 + sphere_symm_define = 1 if sphere_symm else 0 defines = ( + f"#define SPHERE_SYMM {sphere_symm_define}\n" f"#define ENABLE_ANGULAR_WEIGHT {enable}\n" f"#define ANGULAR_WEIGHT {self.angular_weight:.2f}f\n" ) @@ -153,8 +155,8 @@ def __init__(self): def _shader_files(self): return [] - def setup_device(self, device): - self.compile_program(device) + def setup_device(self, device, sphere_symm=False): + self.compile_program(device, sphere_symm) self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeProb_k") @@ -249,8 +251,8 @@ class MetalPttDirectionGetter(MetalProbDirectionGetter): def _shader_files(self): return ["ptt.metal"] - def setup_device(self, device): - self.compile_program(device) + def setup_device(self, device, sphere_symm=False): + self.compile_program(device, sphere_symm) # PTT reuses Prob's getNum kernel for initial direction finding self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") # PTT has its own gen kernel with parallel transport frame tracking @@ -335,10 +337,10 @@ def from_dipy_csa(cls, gtab, sphere, sh_order_max=6, full_basis=False, def _shader_files(self): return ["boot.metal"] - def setup_device(self, device): + def setup_device(self, device, sphere_symm=False): from cuslines.metal.mt_tractography import _make_shared_buffer - self.compile_program(device) + self.compile_program(device, sphere_symm) self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesBoot_k") self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeBoot_k") diff --git a/cuslines/metal/mt_tractography.py b/cuslines/metal/mt_tractography.py index 15bc698..88a3337 100644 --- a/cuslines/metal/mt_tractography.py +++ b/cuslines/metal/mt_tractography.py @@ -73,6 +73,7 @@ def __init__( stop_threshold: float, sphere_vertices: np.ndarray, sphere_edges: np.ndarray, + sphere_symm: bool = False, max_angle: float = radians(60), step_size: float = 0.5, min_pts=0, @@ -106,6 +107,7 @@ def __init__( self.n32dimt = ((self.dimt + 31) // 32) * 32 self.dg = dg + self.sphere_symm = bool(sphere_symm) self.max_angle = np.float32(max_angle) self.tc_threshold = np.float32(stop_threshold) self.step_size = np.float32(step_size) @@ -155,7 +157,7 @@ def _allocate(self): self.sphere_vertices_buf = _make_shared_buffer(self.device, self.sphere_vertices) self.sphere_edges_buf = _make_shared_buffer(self.device, self.sphere_edges) - self.dg.setup_device(self.device) + self.dg.setup_device(self.device, self.sphere_symm) self._allocated = True def __exit__(self, exc_type, exc, tb): diff --git a/cuslines/metal_shaders/generate_streamlines_metal.metal b/cuslines/metal_shaders/generate_streamlines_metal.metal index 4a0a681..0b84258 100644 --- a/cuslines/metal_shaders/generate_streamlines_metal.metal +++ b/cuslines/metal_shaders/generate_streamlines_metal.metal @@ -127,7 +127,7 @@ inline int get_direction_prob(thread PhiloxState& st, for (int i = int(tidx); i < dimt; i += THR_X_SL) { float3 sv = load_f3(sphere_vertices, uint(i)); const float dot = dir.x * sv.x + dir.y * sv.y + dir.z * sv.z; - if (FABS(dot) < cos_similarity) { + if (APPLY_ABS_IF_SYM(dot) < cos_similarity) { pmf_data_sh[i] = 0.0f; } } diff --git a/cuslines/metal_shaders/globals.h b/cuslines/metal_shaders/globals.h index c6eb014..6e33183 100644 --- a/cuslines/metal_shaders/globals.h +++ b/cuslines/metal_shaders/globals.h @@ -35,6 +35,12 @@ using namespace metal; #define MAX_SLINES_PER_SEED (10) +#if SPHERE_SYMM == 0 + #define APPLY_ABS_IF_SYM(x) (x) +#else + #define APPLY_ABS_IF_SYM(x) FABS(x) +#endif + #define MIN(x,y) (((x)<(y))?(x):(y)) #define MAX(x,y) (((x)>(y))?(x):(y)) #define POW2(n) (1 << (n)) diff --git a/cuslines/metal_shaders/tracking_helpers.metal b/cuslines/metal_shaders/tracking_helpers.metal index 8ef2148..8df5a83 100644 --- a/cuslines/metal_shaders/tracking_helpers.metal +++ b/cuslines/metal_shaders/tracking_helpers.metal @@ -200,7 +200,7 @@ inline int peak_directions(const threadgroup float* odf, int j = 0; for (; j < k; j++) { - const float cs = FABS(abc.x * dirs[j].x + + const float cs = APPLY_ABS_IF_SYM(abc.x * dirs[j].x + abc.y * dirs[j].y + abc.z * dirs[j].z); if (cs > cos_similarity) { diff --git a/cuslines/numba/__init__.py b/cuslines/numba/__init__.py new file mode 100644 index 0000000..cd99a96 --- /dev/null +++ b/cuslines/numba/__init__.py @@ -0,0 +1,13 @@ +from .nu_tractography import ( + CPUBootDirectionGetter, + CPUProbDirectionGetter, + CPUPttDirectionGetter, + CPUTracker, +) + +__all__ = [ + "CPUTracker", + "CPUProbDirectionGetter", + "CPUPttDirectionGetter", + "CPUBootDirectionGetter", +] diff --git a/cuslines/numba/nu_globals.py b/cuslines/numba/nu_globals.py new file mode 100644 index 0000000..ccdb10f --- /dev/null +++ b/cuslines/numba/nu_globals.py @@ -0,0 +1,19 @@ +import importlib.util +from pathlib import Path +import numpy as np + +# Import _globals.py directly (bypasses cuslines.cuda_python.__init__ +# which would trigger CUDA imports). +_globals_path = Path(__file__).resolve().parent.parent / "cuda_python" / "_globals.py" +_spec = importlib.util.spec_from_file_location("_globals", str(_globals_path)) +_globals_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_globals_mod) + +MAX_SLINE_LEN = _globals_mod.MAX_SLINE_LEN +PMF_THRESHOLD_P = _globals_mod.PMF_THRESHOLD_P +if _globals_mod.REAL_SIZE == 4: + REAL_DTYPE = np.float32 +elif _globals_mod.REAL_SIZE == 8: + REAL_DTYPE = np.float64 +else: + raise NotImplementedError(f"Unsupported REAL_SIZE={_globals_mod.REAL_SIZE}") diff --git a/cuslines/numba/nu_tractography.py b/cuslines/numba/nu_tractography.py new file mode 100644 index 0000000..a8770f6 --- /dev/null +++ b/cuslines/numba/nu_tractography.py @@ -0,0 +1,243 @@ +import math +from math import radians + +import numpy as np +from dipy.io.stateful_tractogram import Space, StatefulTractogram +from nibabel.streamlines.array_sequence import ArraySequence, MEGABYTE +from tqdm import tqdm + +from cuslines.generic_tracker import GenericTracker +from cuslines.numba_njit.num_streamlines_numba import getNumStreamlinesProb_generator +from cuslines.numba_njit.generate_streamlines_numba import genStreamlinesMergeProb_generator +from cuslines.numba.nu_globals import MAX_SLINE_LEN, REAL_DTYPE + + +class CPUProbDirectionGetter: + pass + +class CPUPttDirectionGetter: + def __init__(self): + raise NotImplementedError( + "Only CPU detected. Only ProbDirectionGetter implemented on CPU. \n" + "Either switch to ProbDirectionGetter or use a backend. Install either:\n" + " - CUDA: pip install 'cuslines[cu13]' (NVIDIA GPU)\n" + " - Metal: pip install 'cuslines[metal]' (Apple Silicon)\n" + " - WebGPU: pip install 'cuslines[webgpu]' (cross-platform)") + +class CPUBootDirectionGetter: + def __init__(self): + raise NotImplementedError( + "Only CPU detected. Only ProbDirectionGetter implemented on CPU. \n" + "Either switch to ProbDirectionGetter or use a backend. Install either:\n" + " - CUDA: pip install 'cuslines[cu13]' (NVIDIA GPU)\n" + " - Metal: pip install 'cuslines[metal]' (Apple Silicon)\n" + " - WebGPU: pip install 'cuslines[webgpu]' (cross-platform)") + +class SeedBatchPropagator: + def __init__(self, cpu_tracker, minlen: int = 0, maxlen: float = np.inf): + self.cpu_tracker = cpu_tracker + self.minlen = minlen + self.maxlen = maxlen + + self.nSlines = 0 + self.slines = None + self.sline_lens = None + + t = self.cpu_tracker + self.getNumStreamlinesProb = getNumStreamlinesProb_generator( + t.dimx, + t.dimy, + t.dimz, + t.dimt, + t.relative_peak_thresh, + t.min_separation_angle, + t.nedges, + t.sphere_symm, + ) + + self.genStreamlinesMergeProb = genStreamlinesMergeProb_generator( + t.dimx, + t.dimy, + t.dimz, + t.dimt, + t.sphere_symm, + t.step_size, + t.max_angle, + t.tc_threshold, + ) + + def _get_num_streamlines(self, seeds): + t = self.cpu_tracker + nseed = len(seeds) + + shDir0 = np.zeros((nseed * t.dimt, 3), dtype=REAL_DTYPE) + slineOutOff = np.zeros(nseed + 1, dtype=np.int32) + + self.getNumStreamlinesProb( + seeds, + t.dataf, + t.sphere_vertices, + t.sphere_edges, + shDir0, + slineOutOff, + ) + + __pval = slineOutOff[0] + slineOutOff[0] = 0 + for jj in range(1, nseed + 1): + __cval = slineOutOff[jj] + slineOutOff[jj] = slineOutOff[jj - 1] + __pval + __pval = __cval + + return shDir0, slineOutOff + + def _generate_streamlines(self, seeds, shDir0, slineOutOff): + t = self.cpu_tracker + nSlines = int(slineOutOff[-1]) + + slineSeed = np.full(nSlines, -1, dtype=np.int32) + slineLen = np.zeros(nSlines, dtype=np.int32) + sline = np.zeros((nSlines * MAX_SLINE_LEN * 2, 3), dtype=REAL_DTYPE) + + self.genStreamlinesMergeProb( + seeds, + t.dataf, + t.metric_map, + t.sphere_vertices, + t.sphere_edges, + slineOutOff, + shDir0, + slineSeed, + slineLen, + sline, + ) + return nSlines, slineLen, sline + + def propagate(self, seeds: np.ndarray): + """ + Run full two-phase tracking for `seeds` (float32[N, 3]). + Results stored in self.slines, self.sline_lens, self.nSlines. + """ + seeds = np.ascontiguousarray(seeds, dtype=REAL_DTYPE) + + shDir0, slineOutOff = self._get_num_streamlines(seeds) + nSlines, slineLen, sline = self._generate_streamlines(seeds, shDir0, slineOutOff) + + self.nSlines = nSlines + self.sline_lens = slineLen + self.slines = sline + + def get_buffer_size(self) -> int: + """Return estimated buffer size in MB (mirrors GPU version).""" + if self.sline_lens is None: + return 0 + total_pts = sum( + l for l in self.sline_lens[:self.nSlines] + if self.minlen <= l <= self.maxlen + ) + return math.ceil(total_pts * 3 * REAL_DTYPE(0).itemsize / MEGABYTE) + + def as_generator(self): + def _yield_slines(): + slines = self.slines + sline_lens = self.sline_lens + step = MAX_SLINE_LEN * 2 # points allocated per streamline + + for i in range(self.nSlines): + npts = int(sline_lens[i]) + if npts < self.minlen or npts > self.maxlen: + continue + yield np.asarray(slines[i * step : i * step + npts], dtype=REAL_DTYPE) + return _yield_slines() + + def as_array_sequence(self) -> ArraySequence: + return ArraySequence(self.as_generator(), self.get_buffer_size()) + + +class CPUTracker(GenericTracker): + """ + CPU probabilistic tractography tracker. + + Parameters + ---------- + dg : DirectionGetter + Direction getter to use. + Can only be CPUProbDirectionGetter. + Maintained to match API with other backends. + dataf : np.ndarray, shape (dimx, dimy, dimz, dimt) + ODF volume. + stop_map : np.ndarray, shape (dimx, dimy, dimz) + Stopping metric (e.g. GFA or FA). + stop_threshold : float + Voxels with stop_map <= stop_threshold are endpoints. + sphere_vertices : np.ndarray, shape (dimt, 3) + Unit sphere vertices. + sphere_edges : np.ndarray, shape (num_edges, 2) + Sphere adjacency list (int32). + max_angle : float + Maximum turning angle in radians. Default: radians(60). + step_size : float + Step size in voxels. Default: 0.5. + min_pts : int + Minimum streamline length (points) to keep. Default: 0. + max_pts : float + Maximum streamline length (points) to keep. Default: inf. + relative_peak_thresh : float + Relative peak threshold for direction selection. Default: 0.25. + min_separation_angle : float + Minimum separation angle (radians) between peaks. Default: radians(45). + ngpus : int, optional + Ignored. Maintained to match API with other backends. + default: 1 + rng_seed : int, optional + Seed for random number generator + default: 0 + rng_offset : int, optional + Ignored. Maintained to match API with other backends. + default: 0 + chunk_size : int + Seeds per propagate() call in generate_sft(). Default: 100000. + """ + + def __init__( + self, + dg: object, + dataf: np.ndarray, + stop_map: np.ndarray, + stop_threshold: float, + sphere_vertices: np.ndarray, + sphere_edges: np.ndarray, + sphere_symm: bool = False, + max_angle: float = radians(60), + step_size: float = 0.5, + min_pts: int = 0, + max_pts: float = np.inf, + relative_peak_thresh: float = 0.25, + min_separation_angle: float = radians(45), + ngpus: int = 1, + rng_seed: int = 0, + rng_offset: int = 0, + chunk_size: int = 100000, + ): + self.dataf = np.ascontiguousarray(dataf, dtype=REAL_DTYPE) + self.metric_map = np.ascontiguousarray(stop_map, dtype=REAL_DTYPE) + self.sphere_vertices = np.ascontiguousarray(sphere_vertices, dtype=REAL_DTYPE) + self.sphere_edges = np.ascontiguousarray(sphere_edges, dtype=np.int32) + + self.sphere_symm = sphere_symm + self.dimx, self.dimy, self.dimz, self.dimt = dataf.shape + + self.max_angle = float(max_angle) + self.tc_threshold = float(stop_threshold) + self.step_size = float(step_size) + self.relative_peak_thresh = float(relative_peak_thresh) + self.min_separation_angle = float(min_separation_angle) + self.chunk_size = int(chunk_size) + self.nedges = int(sphere_edges.shape[0]) + + if rng_seed != 0: + np.random.seed(rng_seed) + + self.seed_propagator = SeedBatchPropagator( + cpu_tracker=self, minlen=min_pts, maxlen=max_pts + ) diff --git a/cuslines/numba_njit/generate_streamlines_numba.py b/cuslines/numba_njit/generate_streamlines_numba.py new file mode 100644 index 0000000..b703723 --- /dev/null +++ b/cuslines/numba_njit/generate_streamlines_numba.py @@ -0,0 +1,284 @@ +import math +import numpy as np +import random +from numba import njit, prange +from cuslines.numba_njit.tracking_helpers import trilinear_interp_generator +from cuslines.numba.nu_globals import MAX_SLINE_LEN, PMF_THRESHOLD_P + + +def genStreamlinesMergeProb_generator(DIMX, DIMY, DIMZ, DIMT, SPHERE_SYMM, STEP_SIZE, MAX_ANGLE, TC_THRESHOLD): + trilinear_interp = trilinear_interp_generator(DIMX, DIMY, DIMZ, DIMT) + + @njit + def check_point(point, metric_map): + px, py, pz = point[0], point[1], point[2] + + if px < -0.5 or px + 0.5 >= DIMX: + return 0 + if py < -0.5 or py + 0.5 >= DIMY: + return 0 + if pz < -0.5 or pz + 0.5 >= DIMZ: + return 0 + + flx = math.floor(px) + fly = math.floor(py) + flz = math.floor(pz) + + wx1 = px - flx; wx0 = 1.0 - wx1 + wy1 = py - fly; wy0 = 1.0 - wy1 + wz1 = pz - flz; wz0 = 1.0 - wz1 + + ix0 = max(0, int(flx)); ix1 = min(DIMX - 1, ix0 + 1) + iy0 = max(0, int(fly)); iy1 = min(DIMY - 1, iy0 + 1) + iz0 = max(0, int(flz)); iz1 = min(DIMZ - 1, iz0 + 1) + + val = (wx0 * wy0 * wz0 * metric_map[ix0, iy0, iz0] + + wx0 * wy0 * wz1 * metric_map[ix0, iy0, iz1] + + wx0 * wy1 * wz0 * metric_map[ix0, iy1, iz0] + + wx0 * wy1 * wz1 * metric_map[ix0, iy1, iz1] + + wx1 * wy0 * wz0 * metric_map[ix1, iy0, iz0] + + wx1 * wy0 * wz1 * metric_map[ix1, iy0, iz1] + + wx1 * wy1 * wz0 * metric_map[ix1, iy1, iz0] + + wx1 * wy1 * wz1 * metric_map[ix1, iy1, iz1]) + + return 2 if val > TC_THRESHOLD else 3 + + + @njit + def get_direction_prob_step( + pmf_volume, # float32[DIMX, DIMY, DIMZ, DIMT] + direction, # float32[3] current propagation direction + point, # float32[3] current position + sphere_vertices, # float32[DIMT, 3] + new_dir, # float32[3] output: chosen next direction (written in-place) + pmf_scratch, # float32[DIMT] scratch buffer + ): + # 1 interpolate PMF + rv = trilinear_interp(pmf_volume, point, pmf_scratch) + if rv != 0: + return 0 + + # 2 absolute PMF threshold + max_pmf = -np.inf + for i in range(DIMT): + if pmf_scratch[i] > max_pmf: + max_pmf = pmf_scratch[i] + abs_thresh = PMF_THRESHOLD_P * max_pmf + for i in range(DIMT): + if pmf_scratch[i] < abs_thresh: + pmf_scratch[i] = 0.0 + + # 3 angle filtering + cos_sim = math.cos(MAX_ANGLE) + for i in range(DIMT): + dot = (direction[0] * sphere_vertices[i, 0] + + direction[1] * sphere_vertices[i, 1] + + direction[2] * sphere_vertices[i, 2]) + if SPHERE_SYMM: + if dot < 0.0: + dot = -dot + if dot < cos_sim: + pmf_scratch[i] = 0.0 + + # 4 cumulative sum (prefix_sum_sh_d) + for i in range(1, DIMT): + pmf_scratch[i] += pmf_scratch[i - 1] + + last_cdf = pmf_scratch[DIMT - 1] + if last_cdf == 0.0: + return 0 + + u = random.random() * last_cdf + + # Binary search for first CDF entry >= u (mirrors CUDA bisect + ballot) + low = 0 + high = DIMT - 1 + while low < high: + mid = (low + high) // 2 + if pmf_scratch[mid] <= u: + low = mid + 1 + else: + high = mid + ind_prob = low + + # 5 flip vertex to match current hemisphere + if SPHERE_SYMM: + dot = (direction[0] * sphere_vertices[ind_prob, 0] + + direction[1] * sphere_vertices[ind_prob, 1] + + direction[2] * sphere_vertices[ind_prob, 2]) + if dot > 0.0: + new_dir[0] = sphere_vertices[ind_prob, 0] + new_dir[1] = sphere_vertices[ind_prob, 1] + new_dir[2] = sphere_vertices[ind_prob, 2] + else: + new_dir[0] = -sphere_vertices[ind_prob, 0] + new_dir[1] = -sphere_vertices[ind_prob, 1] + new_dir[2] = -sphere_vertices[ind_prob, 2] + else: + new_dir[0] = sphere_vertices[ind_prob, 0] + new_dir[1] = sphere_vertices[ind_prob, 1] + new_dir[2] = sphere_vertices[ind_prob, 2] + + return 1 + + + @njit + def tracker( + seed, # float32[3] starting position + first_step, # float32[3] initial direction + pmf_volume, # float32[DIMX, DIMY, DIMZ, DIMT] + metric_map, # float32[DIMX, DIMY, DIMZ] + sphere_vertices, # float32[DIMT, 3] + sphere_edges, # int32[NUM_EDGES, 2] (unused in PROB model, kept for API symmetry) + streamline, # float32[MAX_SLINE_LEN*2, 3] output, written in-place + pmf_scratch, # float32[DIMT] scratch buffer + ): + point = np.empty(3, dtype=np.float32) + direction = np.empty(3, dtype=np.float32) + new_dir = np.empty(3, dtype=np.float32) + + point[0] = seed[0]; point[1] = seed[1]; point[2] = seed[2] + direction[0] = first_step[0] + direction[1] = first_step[1] + direction[2] = first_step[2] + + streamline[0, 0] = point[0] + streamline[0, 1] = point[1] + streamline[0, 2] = point[2] + + tissue_class = 2 + + i = 1 + while i < MAX_SLINE_LEN: + ndir = get_direction_prob_step( + pmf_volume, + direction, + point, + sphere_vertices, + new_dir, + pmf_scratch, + ) + + if ndir == 0: + break + + direction[0] = new_dir[0] + direction[1] = new_dir[1] + direction[2] = new_dir[2] + + # voxel_size == (1,1,1) → step = direction * STEP_SIZE + point[0] += direction[0] * STEP_SIZE + point[1] += direction[1] * STEP_SIZE + point[2] += direction[2] * STEP_SIZE + + streamline[i, 0] = point[0] + streamline[i, 1] = point[1] + streamline[i, 2] = point[2] + + tissue_class = check_point(point, metric_map) + + if (tissue_class == 0 or + tissue_class == 1 or + tissue_class == 3): + break + + i += 1 + + return i, tissue_class + + + @njit(parallel=True) + def genStreamlinesMergeProb( + seeds, # float32[nseed, 3] + pmf_volume, # float32[DIMX, DIMY, DIMZ, DIMT] + metric_map, # float32[DIMX, DIMY, DIMZ] + sphere_vertices, # float32[DIMT, 3] + sphere_edges, # int32[NUM_EDGES, 2] + slineOutOff, # int32[nseed+1] prefix-sum offsets from getNumStreamlinesProb + shDir0, # float32[nseed*DIMT, 3] peak directions from getNumStreamlinesProb + slineSeed, # int32[total_slines] output: seed index per streamline + slineLen, # int32[total_slines] output: length of each streamline + sline, # float32[total_slines * MAX_SLINE_LEN*2, 3] output: streamline points + ): + nseed = seeds.shape[0] + + for slid in prange(nseed): + # Number of peak directions for this seed (from the prefix-sum offsets) + ndir = slineOutOff[slid + 1] - slineOutOff[slid] + slineOff = slineOutOff[slid] # index of first output streamline for this seed + + # Per-worker scratch (equivalent to CUDA shared memory / registers) + pmf_scratch = np.empty(DIMT, dtype=np.float32) + + seed = seeds[slid] # float32[3] view read-only + + for i in range(ndir): + # first_step is the i-th peak direction for this seed + dir_idx = slid * DIMT + i + first_step = shDir0[dir_idx] # float32[3] view + + # Offset into the flat streamline output array + sline_start = slineOff * MAX_SLINE_LEN * 2 # in points + + # Record which seed produced this streamline + slineSeed[slineOff] = slid + + # ---------------------------------------------------------- + # Backward pass: start from seed, direction = -first_step + # ---------------------------------------------------------- + neg_first = np.empty(3, dtype=np.float32) + neg_first[0] = -first_step[0] + neg_first[1] = -first_step[1] + neg_first[2] = -first_step[2] + + curr_sline = sline[sline_start : sline_start + MAX_SLINE_LEN * 2] + + stepsB, _ = tracker( + seed, + neg_first, + pmf_volume, + metric_map, + sphere_vertices, + sphere_edges, + curr_sline, + pmf_scratch, + ) + + # Reverse the backward segment in-place (mirrors the CUDA loop) + lo = 0 + hi = stepsB - 1 + while lo < hi: + tmp0 = curr_sline[lo, 0]; tmp1 = curr_sline[lo, 1]; tmp2 = curr_sline[lo, 2] + curr_sline[lo, 0] = curr_sline[hi, 0] + curr_sline[lo, 1] = curr_sline[hi, 1] + curr_sline[lo, 2] = curr_sline[hi, 2] + curr_sline[hi, 0] = tmp0 + curr_sline[hi, 1] = tmp1 + curr_sline[hi, 2] = tmp2 + lo += 1 + hi -= 1 + + # ---------------------------------------------------------- + # Forward pass: append at the junction (currSline + stepsB-1) + # The backward segment ends at index stepsB-1, which becomes + # index 0 of the forward segment (the seed point is shared). + # ---------------------------------------------------------- + fwd_sline = curr_sline[stepsB - 1 :] # view starting at junction + + stepsF, _ = tracker( + seed, + first_step, + pmf_volume, + metric_map, + sphere_vertices, + sphere_edges, + fwd_sline, + pmf_scratch, + ) + + # Total length: backward points + forward points, junction counted once + slineLen[slineOff] = stepsB - 1 + stepsF + + slineOff += 1 + + return genStreamlinesMergeProb diff --git a/cuslines/numba_njit/num_streamlines_numba.py b/cuslines/numba_njit/num_streamlines_numba.py new file mode 100644 index 0000000..82bf5a2 --- /dev/null +++ b/cuslines/numba_njit/num_streamlines_numba.py @@ -0,0 +1,170 @@ +import math +import numpy as np +from numba import njit, prange +from cuslines.numba_njit.tracking_helpers import trilinear_interp_generator +from cuslines.numba.nu_globals import PMF_THRESHOLD_P + + +def getNumStreamlinesProb_generator(DIMX, DIMY, DIMZ, DIMT, RELATIVE_PEAK_THRESH, MIN_SEPARATION_ANGLE, NUM_EDGES, SPHERE_SYMM): + trilinear_interp = trilinear_interp_generator(DIMX, DIMY, DIMZ, DIMT) + + @njit + def peak_directions(odf, sphere_vertices, sphere_edges, dirs_out): + # --- shInd buffer (equivalent of __shared__ int __shInd[DIMT]) --- + shInd = np.zeros(DIMT, dtype=np.int32) + + # --- odf_min = max(0, min(odf)) (mirrors min_d then MAX(0, odf_min)) --- + odf_min = np.inf + for i in range(DIMT): + if odf[i] < odf_min: + odf_min = odf[i] + if odf_min < 0.0: + odf_min = 0.0 + + # --- local_maxima: mark via edges --- + # For each edge (u,v): the smaller-valued vertex is marked -1 (not a max), + # the larger is marked >=1 (candidate). atomicExch/atomicOr become plain + # assignments; races were "benign" in CUDA and are non-existent serially. + for e in range(NUM_EDGES): + u = sphere_edges[e, 0] + v = sphere_edges[e, 1] + u_val = odf[u] + v_val = odf[v] + if u_val < v_val: + shInd[u] = -1 + if shInd[v] != -1: # preserve -1 (atomicOr with 1 ≡ set bit 0) + shInd[v] = shInd[v] | 1 + elif v_val < u_val: + shInd[v] = -1 + if shInd[u] != -1: + shInd[u] = shInd[u] | 1 + + masked_max = -np.inf + for i in range(DIMT): + if shInd[i] > 0: + val = odf[i] - odf_min + if val > masked_max: + masked_max = val + comp_thres = RELATIVE_PEAK_THRESH * masked_max + + # --- compact: keep indices where shInd[i]>0 AND (odf[i]-odf_min) >= compThres --- + n = 0 + for i in range(DIMT): + if shInd[i] > 0 and (odf[i] - odf_min) >= comp_thres: + shInd[n] = i + n += 1 + + if n == 0: + return 0 + + # --- sort compacted indices by descending odf value --- + # CUDA used warp_sort (bitonic); insertion sort is correct and fine for + # small n (bounded by DIMT, typically < 10). + for i in range(1, n): + key_i = odf[shInd[i]] + idx_i = shInd[i] + j = i - 1 + while j >= 0 and odf[shInd[j]] < key_i: + shInd[j + 1] = shInd[j] + j -= 1 + shInd[j + 1] = idx_i + + # --- remove_similar_vertices --- + # Keep a direction only if it is at least MIN_SEPARATION_ANGLE away from + # all already-kept directions (mirrors the single-threaded tidx==0 block). + cos_sim = math.cos(MIN_SEPARATION_ANGLE) + + dirs_out[0, 0] = sphere_vertices[shInd[0], 0] + dirs_out[0, 1] = sphere_vertices[shInd[0], 1] + dirs_out[0, 2] = sphere_vertices[shInd[0], 2] + k = 1 + + for i in range(1, n): + ax = sphere_vertices[shInd[i], 0] + ay = sphere_vertices[shInd[i], 1] + az = sphere_vertices[shInd[i], 2] + + too_close = False + for j in range(k): + dot = ax * dirs_out[j, 0] + ay * dirs_out[j, 1] + az * dirs_out[j, 2] + if SPHERE_SYMM: + if dot < 0.0: + dot = -dot + if dot > cos_sim: + too_close = True + break + + if not too_close: + dirs_out[k, 0] = ax + dirs_out[k, 1] = ay + dirs_out[k, 2] = az + k += 1 + + return k + + @njit + def get_direction_prob_start( + pmf_volume, # 4-D float array [nz, ny, nx, DIMT] + point, # float[3] fractional voxel coords + sphere_vertices, # float[DIMT, 3] + sphere_edges, # int[E, 2] + dirs_out, # float[DIMT, 3] output buffer (pre-allocated) + pmf_scratch, # float[DIMT] scratch buffer (pre-allocated) + ): + # Step 1 interpolate PMF + rv = trilinear_interp(pmf_volume, point, pmf_scratch) # must be defined + if rv != 0: + return 0 + + # Step 2 threshold + max_pmf = -np.inf + for i in range(DIMT): + if pmf_scratch[i] > max_pmf: + max_pmf = pmf_scratch[i] + abs_thresh = PMF_THRESHOLD_P * max_pmf + + for i in range(DIMT): + if pmf_scratch[i] < abs_thresh: + pmf_scratch[i] = 0.0 + + # Step 3 peak directions (IS_START branch; no angle filtering) + ndir = peak_directions(pmf_scratch, sphere_vertices, sphere_edges, dirs_out) # must be defined + + return ndir + + @njit(parallel=True) + def getNumStreamlinesProb( + seeds, # float[nseed, 3] + pmf_volume, # float[nz, ny, nx, DIMT] + sphere_vertices, # float[DIMT, 3] + sphere_edges, # int[E, 2] + shDir0, # float[nseed * DIMT, 3] output: peak dirs per seed + slineOutOff, # int[nseed + 1] output: prefix-sum offsets + ): + nseed = seeds.shape[0] + + for slid in prange(nseed): + # --- scratch buffers (thread-local on CPU) --------------------- + pmf_scratch = np.empty(DIMT, dtype=np.float32) + dirs_out = np.empty((DIMT, 3), dtype=np.float32) + + # --- get peak directions at this seed -------------------------- + # shDir0 slice for this seed + dir_base = slid * DIMT + ndir = get_direction_prob_start( + pmf_volume, + seeds[slid], # point = seed (fractional voxel coords) + sphere_vertices, + sphere_edges, + dirs_out, + pmf_scratch, + ) + + # --- write outputs -------------------------------------------- + slineOutOff[slid] = ndir + for d in range(ndir): + shDir0[dir_base + d, 0] = dirs_out[d, 0] + shDir0[dir_base + d, 1] = dirs_out[d, 1] + shDir0[dir_base + d, 2] = dirs_out[d, 2] + + return getNumStreamlinesProb diff --git a/cuslines/numba_njit/tracking_helpers.py b/cuslines/numba_njit/tracking_helpers.py new file mode 100644 index 0000000..e21b0b3 --- /dev/null +++ b/cuslines/numba_njit/tracking_helpers.py @@ -0,0 +1,62 @@ +import math +import numpy as np +from numba import njit + + +def trilinear_interp_generator(DIMX, DIMY, DIMZ, DIMT): + @njit + def trilinear_interp(pmf_volume, point, out_pmf): + px = point[0] + py = point[1] + pz = point[2] + + # Bounds check (matches: point.x < -0.5 || point.x+0.5 >= DIMX ...) + if px < -0.5 or px + 0.5 >= DIMX: + return -1 + if py < -0.5 or py + 0.5 >= DIMY: + return -1 + if pz < -0.5 or pz + 0.5 >= DIMZ: + return -1 + + # Floor coordinates + flx = math.floor(px) + fly = math.floor(py) + flz = math.floor(pz) + + # Weights along each axis + wx1 = px - flx # high-side weight + wx0 = 1.0 - wx1 # low-side weight + + wy1 = py - fly + wy0 = 1.0 - wy1 + + wz1 = pz - flz + wz0 = 1.0 - wz1 + + # Clamped voxel indices (matches MAX(0,fl) / MIN(DIM-1, ...)) + ix0 = max(0, int(flx)) + ix1 = min(DIMX - 1, ix0 + 1) + + iy0 = max(0, int(fly)) + iy1 = min(DIMY - 1, iy0 + 1) + + iz0 = max(0, int(flz)) + iz1 = min(DIMZ - 1, iz0 + 1) + + # Accumulate the 8-corner trilinear blend for every ODF direction + # Matches the triple loop: for i in {0,1}: for j in {0,1}: for k in {0,1}: + # out[t] += wgh[0][i]*wgh[1][j]*wgh[2][k] * dataf[ix,iy,iz,t] + for t in range(DIMT): + out_pmf[t] = ( + wx0 * wy0 * wz0 * pmf_volume[ix0, iy0, iz0, t] + + wx0 * wy0 * wz1 * pmf_volume[ix0, iy0, iz1, t] + + wx0 * wy1 * wz0 * pmf_volume[ix0, iy1, iz0, t] + + wx0 * wy1 * wz1 * pmf_volume[ix0, iy1, iz1, t] + + wx1 * wy0 * wz0 * pmf_volume[ix1, iy0, iz0, t] + + wx1 * wy0 * wz1 * pmf_volume[ix1, iy0, iz1, t] + + wx1 * wy1 * wz0 * pmf_volume[ix1, iy1, iz0, t] + + wx1 * wy1 * wz1 * pmf_volume[ix1, iy1, iz1, t] + ) + + return 0 + return trilinear_interp diff --git a/cuslines/webgpu/wg_direction_getters.py b/cuslines/webgpu/wg_direction_getters.py index 367c20e..da1b48a 100644 --- a/cuslines/webgpu/wg_direction_getters.py +++ b/cuslines/webgpu/wg_direction_getters.py @@ -36,11 +36,11 @@ def getNumStreamlines(self, nseeds_gpu, block, grid, sp): def generateStreamlines(self, nseeds_gpu, block, grid, sp): pass - def setup_device(self, device, has_subgroups=True): + def setup_device(self, device, has_subgroups=True, sphere_symm=False): """Called once when WebGPUTracker allocates resources.""" pass - def compile_program(self, device, has_subgroups=True): + def compile_program(self, device, has_subgroups=True, sphere_symm=False): start_time = time() logger.info("Compiling WebGPU/WGSL shaders...") @@ -85,6 +85,10 @@ def compile_program(self, device, has_subgroups=True): full_source = "\n".join(source_parts) + # Prepend compile-time constants + sphere_symm_val = 1 if sphere_symm else 0 + full_source = f"const SPHERE_SYMM: u32 = {sphere_symm_val}u;\n" + full_source + shader_module = device.create_shader_module(code=full_source) self.shader_module = shader_module logger.info("WGSL shaders compiled in %.2f seconds", time() - start_time) @@ -130,8 +134,8 @@ def __init__(self): def _shader_files(self): return [] - def setup_device(self, device, has_subgroups=True): - self.compile_program(device, has_subgroups) + def setup_device(self, device, has_subgroups=True, sphere_symm=False): + self.compile_program(device, has_subgroups, sphere_symm) self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeProb_k") @@ -241,8 +245,8 @@ class WebGPUPttDirectionGetter(WebGPUProbDirectionGetter): def _shader_files(self): return ["disc.wgsl", "ptt.wgsl"] - def setup_device(self, device, has_subgroups=True): - self.compile_program(device, has_subgroups) + def setup_device(self, device, has_subgroups=True, sphere_symm=False): + self.compile_program(device, has_subgroups, sphere_symm) # PTT reuses Prob's getNum kernel for initial direction finding self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesProb_k") # PTT has its own gen kernel @@ -331,8 +335,8 @@ def _kernel_files(self): # boot.wgsl is self-contained (has its own buffer bindings, params, entry points) return [] - def setup_device(self, device, has_subgroups=True): - self.compile_program(device, has_subgroups) + def setup_device(self, device, has_subgroups=True, sphere_symm=False): + self.compile_program(device, has_subgroups, sphere_symm) self.getnum_pipeline = self._make_pipeline(device, "getNumStreamlinesBoot_k") self.gen_pipeline = self._make_pipeline(device, "genStreamlinesMergeBoot_k") diff --git a/cuslines/webgpu/wg_tractography.py b/cuslines/webgpu/wg_tractography.py index 35a29fc..e8793f7 100644 --- a/cuslines/webgpu/wg_tractography.py +++ b/cuslines/webgpu/wg_tractography.py @@ -35,6 +35,7 @@ def __init__( stop_threshold: float, sphere_vertices: np.ndarray, sphere_edges: np.ndarray, + sphere_symm: bool = False, max_angle: float = radians(60), step_size: float = 0.5, min_pts=0, @@ -63,6 +64,7 @@ def __init__( self.n32dimt = ((self.dimt + 31) // 32) * 32 self.dg = dg + self.sphere_symm = bool(sphere_symm) self.max_angle = np.float32(max_angle) self.tc_threshold = np.float32(stop_threshold) self.step_size = np.float32(step_size) @@ -176,7 +178,7 @@ def _allocate(self): self.device, self.sphere_edges.ravel(), label="sphere_edges" ) - self.dg.setup_device(self.device, self.has_subgroups) + self.dg.setup_device(self.device, self.has_subgroups, self.sphere_symm) except Exception: # Clean up any partially allocated buffers self.dataf_buf = None diff --git a/cuslines/wgsl_shaders/generate_streamlines.wgsl b/cuslines/wgsl_shaders/generate_streamlines.wgsl index 0cbf9fa..4a0af23 100644 --- a/cuslines/wgsl_shaders/generate_streamlines.wgsl +++ b/cuslines/wgsl_shaders/generate_streamlines.wgsl @@ -136,7 +136,7 @@ fn get_direction_prob( for (var i = i32(tidx); i < dimt; i += i32(THR_X_SL)) { let sv = load_sphere_verts_f3(u32(i)); let dot_val = dir.x * sv.x + dir.y * sv.y + dir.z * sv.z; - if (abs(dot_val) < cos_similarity) { + if (select(dot_val, abs(dot_val), SPHERE_SYMM == 1u) < cos_similarity) { wg_sh_mem[sh_offset + u32(i)] = 0.0; } } diff --git a/cuslines/wgsl_shaders/tracking_helpers.wgsl b/cuslines/wgsl_shaders/tracking_helpers.wgsl index 693733c..4d7eb31 100644 --- a/cuslines/wgsl_shaders/tracking_helpers.wgsl +++ b/cuslines/wgsl_shaders/tracking_helpers.wgsl @@ -237,7 +237,8 @@ fn peak_directions_fn( let dx = wg_dirs_sh[d_base]; let dy = wg_dirs_sh[d_base + 1u]; let dz = wg_dirs_sh[d_base + 2u]; - let cs = abs(abc.x * dx + abc.y * dy + abc.z * dz); + let dot_val = abc.x * dx + abc.y * dy + abc.z * dz; + let cs = select(dot_val, abs(dot_val), SPHERE_SYMM == 1u); if (cs > cos_similarity) { break; } diff --git a/pyproject.toml b/pyproject.toml index 1406cf7..52a3242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ readme = "README.md" requires-python = ">=3.7" dependencies = [ "numpy", + "numba", "nibabel", "tqdm", "dipy", diff --git a/run_gpu_streamlines.py b/run_gpu_streamlines.py index 90600a3..7776763 100644 --- a/run_gpu_streamlines.py +++ b/run_gpu_streamlines.py @@ -36,7 +36,8 @@ import nibabel as nib import numpy as np from dipy.core.gradients import gradient_table, unique_bvals_magnitude -from dipy.data import default_sphere, get_fnames, read_stanford_pve_maps, small_sphere +from dipy.data import get_sphere, get_fnames, read_stanford_pve_maps, small_sphere, HemiSphere + from dipy.direction import ( BootDirectionGetter as cpu_BootDirectionGetter, ) @@ -60,7 +61,7 @@ from cuslines import ( BACKEND, BootDirectionGetter, - GPUTracker, + Tracker, ProbDirectionGetter, PttDirectionGetter, ) @@ -101,8 +102,8 @@ def get_img(ep2_seq): "--device", type=str, default="gpu", - choices=["cpu", "gpu", "metal", "webgpu"], - help="Whether to use cpu, gpu (auto-detect), metal, or webgpu", + choices=["cpu", "gpu", "metal", "webgpu", "numba"], + help="Whether to use cpu, gpu (auto-detect), metal, webgpu, or numba", ) parser.add_argument( "--sphere", @@ -199,7 +200,7 @@ def get_img(ep2_seq): WebGPUPttDirectionGetter as PttDirectionGetter, ) from cuslines.webgpu import ( - WebGPUTracker as GPUTracker, + WebGPUTracker as Tracker, ) except ImportError: raise RuntimeError( @@ -216,6 +217,24 @@ def get_img(ep2_seq): args.device = "gpu" # use the GPU code path elif args.device == "gpu": print("Using %s backend" % BACKEND) +elif args.device == "numba": + from cuslines.numba import ( + CPUBootDirectionGetter as BootDirectionGetter, + ) + from cuslines.numba import ( + CPUProbDirectionGetter as ProbDirectionGetter, + ) + from cuslines.numba import ( + CPUPttDirectionGetter as PttDirectionGetter, + ) + from cuslines.numba import ( + CPUTracker as Tracker, + ) + print(( + "WARNING: in this script, numba backend only runs probabilistic " + "tractography on csd, ignoring dg and model")) + args.dg = "prob" + args.model = "csd" if args.device == "cpu" and args.write_method != "trk": print("WARNING: only trk write method is implemented for cpu testing.") @@ -298,7 +317,8 @@ def get_img(ep2_seq): if args.sphere == "small": sphere = small_sphere else: - sphere = default_sphere + sphere = get_sphere("repulsion724") + if args.model == "opdt": if args.device == "cpu": model = OpdtModel( @@ -358,16 +378,17 @@ def get_img(ep2_seq): if args.cache_dir != "": np.save(csd_odf_cache_file, data) + if args.dg == "ptt": if args.device == "cpu": - dg = cpu_PTTDirectionGetter() + dg = cpu_PTTDirectionGetter else: # Set FOD to 0 outside mask for probing data[FA < args.fa_threshold, :] = 0 dg = PttDirectionGetter() elif args.dg == "prob": if args.device == "cpu": - dg = cpu_ProbDirectionGetter() + dg = cpu_ProbDirectionGetter else: dg = ProbDirectionGetter() else: @@ -394,21 +415,22 @@ def get_img(ep2_seq): min_separation_angle=args.min_separation_angle, ) - ts = time.time() - streamline_generator = LocalTracking( - dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=args.step_size - ) - sft = StatefulTractogram(streamline_generator, img, Space.VOX) - n_sls = len(sft.streamlines) - te = time.time() + ts = time.time() + streamline_generator = LocalTracking( + dg, tissue_classifier, seed_mask, affine=np.eye(4), step_size=args.step_size + ) + sft = StatefulTractogram(streamline_generator, img, Space.VOX) + n_sls = len(sft.streamlines) + te = time.time() else: - with GPUTracker( + with Tracker( dg, data, FA, args.fa_threshold, sphere.vertices, sphere.edges, + sphere_symm=isinstance(sphere, HemiSphere), max_angle=args.max_angle * np.pi / 180, step_size=args.step_size, relative_peak_thresh=args.relative_peak_threshold,