diff --git a/cuslines/cuda_c/generate_streamlines_cuda.cu b/cuslines/cuda_c/generate_streamlines_cuda.cu index 68c32d0..9b9462f 100644 --- a/cuslines/cuda_c/generate_streamlines_cuda.cu +++ b/cuslines/cuda_c/generate_streamlines_cuda.cu @@ -148,7 +148,7 @@ __device__ int get_direction_prob_d(curandStatePhilox4_32_10_t *st, dir.y*sphere_vertices[i].y+ dir.z*sphere_vertices[i].z; - if (FABS(dot) < cos_similarity) { + if (APPLY_ABS_IF_SYM(dot) < cos_similarity) { __pmf_data_sh[i] = 0.0; } } diff --git a/cuslines/cuda_c/globals.h b/cuslines/cuda_c/globals.h index 505f5cc..31017b6 100644 --- a/cuslines/cuda_c/globals.h +++ b/cuslines/cuda_c/globals.h @@ -67,6 +67,12 @@ #endif +#if FULL_BASIS == 1 + #define APPLY_ABS_IF_SYM(x) (x) +#else + #define APPLY_ABS_IF_SYM(x) FABS(x) +#endif + #define MIN(x,y) (((x)<(y))?(x):(y)) #define MAX(x,y) (((x)>(y))?(x):(y)) #define POW2(n) (1 << (n)) diff --git a/cuslines/cuda_c/tracking_helpers.cu b/cuslines/cuda_c/tracking_helpers.cu index b9d9e20..f5b2569 100644 --- a/cuslines/cuda_c/tracking_helpers.cu +++ b/cuslines/cuda_c/tracking_helpers.cu @@ -206,9 +206,10 @@ __device__ int peak_directions_d(const REAL_T *__restrict__ odf, int j = 0; for(; j < k; j++) { - const REAL_T cos = FABS(abc.x*dirs[j].x+ - abc.y*dirs[j].y+ - abc.z*dirs[j].z); + const REAL_T cos = APPLY_ABS_IF_SYM( + abc.x*dirs[j].x+ + abc.y*dirs[j].y+ + abc.z*dirs[j].z); if (cos > cos_similarity) { break; } diff --git a/cuslines/cuda_python/cu_direction_getters.py b/cuslines/cuda_python/cu_direction_getters.py index 466c0e0..54adb9e 100644 --- a/cuslines/cuda_python/cu_direction_getters.py +++ b/cuslines/cuda_python/cu_direction_getters.py @@ -82,6 +82,7 @@ def compile_program(self, gpu_tracker, debug: bool = False): "RNG_SEED": str(int(gpu_tracker.rng_seed)), "SAMPLM_NR": str(int(gpu_tracker.samplm_nr)), "NUM_EDGES": str(int(gpu_tracker.nedges)), + "FULL_BASIS": "1" if gpu_tracker.full_basis else "0", "EXCESS_ALLOC_FACT": str(int(EXCESS_ALLOC_FACT)), "MAX_SLINES_PER_SEED": str(int(MAX_SLINES_PER_SEED)), "MAX_SLINE_LEN": str(int(MAX_SLINE_LEN)), diff --git a/cuslines/cuda_python/cu_tractography.py b/cuslines/cuda_python/cu_tractography.py index ebee530..1f0afcb 100644 --- a/cuslines/cuda_python/cu_tractography.py +++ b/cuslines/cuda_python/cu_tractography.py @@ -40,6 +40,7 @@ def __init__( stop_threshold: float, sphere_vertices: np.ndarray, sphere_edges: np.ndarray, + full_basis: bool = False, max_angle: float = radians(60), step_size: float = 0.5, min_pts=0, @@ -70,6 +71,9 @@ def __init__( Vertices of the sphere used for direction sampling. sphere_edges : np.ndarray Edges of the sphere used for direction sampling. + full_basis : bool, optional + Whether to use full basis for spherical harmonics + default: False max_angle : float, optional Maximum angle (in radians) between steps default: radians(60) @@ -143,6 +147,7 @@ def __init__( self.rng_seed = int(rng_seed) self.rng_offset = int(rng_offset) self.chunk_size = int(chunk_size) + self.full_basis = bool(full_basis) avail = checkCudaErrors(runtime.cudaGetDeviceCount()) if self.ngpus > avail: