diff --git a/AFQ/_fixes.py b/AFQ/_fixes.py index da8dc081..b6b7b2f4 100644 --- a/AFQ/_fixes.py +++ b/AFQ/_fixes.py @@ -223,7 +223,7 @@ def tensor_odf(evals, evecs, sphere, num_batches=100): def gaussian_weights( - bundle, assignment_idxs=None, n_points=100, return_mahalnobis=False, stat=np.mean + bundle, assignment_idxs=None, n_points=100, return_mahalanobis=False, stat=np.mean ): """ Calculate weights for each streamline/node in a bundle, based on a @@ -270,7 +270,7 @@ def gaussian_weights( n_sls, n_nodes, _ = sls.shape - if n_sls < 15: # Cov^-1 unstable under this amount + def _weighting_failed(): weights = np.ones((n_sls, n_nodes)) logger.warning( ( @@ -278,46 +278,72 @@ def gaussian_weights( "weighting everything evenly" ) ) - if return_mahalnobis: + if return_mahalanobis: return np.full((n_sls, n_nodes), np.nan) else: return weights / np.sum(weights, 0) + + if n_sls < 15: # Cov^-1 unstable under this amount + return _weighting_failed() else: weights = np.zeros((n_sls, n_nodes)) if assignment_idxs is None: - working_groups = np.tile(np.arange(n_nodes), (n_sls, 1)) + mu = stat(sls, axis=0) + diff = sls - mu + + cov = np.einsum("snj,snk->njk", diff, diff) / n_sls + cov = 0.5 * (cov + cov.transpose(0, 2, 1)) + + eigvals, eigvecs = np.linalg.eigh(cov) + eigvals = np.clip(eigvals, 0, None) + + max_ev = eigvals.max(axis=1, keepdims=True) + tol = np.finfo(eigvals.dtype).eps * np.maximum(max_ev, 1e-6) * cov.shape[-1] + inv_eigvals = np.where( + eigvals > tol, 1.0 / np.where(eigvals > tol, eigvals, 1.0), 0.0 + ) + + inv_cov = np.einsum("nij,nj,nkj->nik", eigvecs, inv_eigvals, eigvecs) + + m = np.einsum("snj,njk,snk->sn", diff, inv_cov, diff) + np.clip(m, 0, None, out=m) + weights = np.sqrt(m) + + degenerate = max_ev.ravel() <= 0 + if np.any(degenerate): + return _weighting_failed() else: working_groups = np.asarray(assignment_idxs) - flat_coords = sls.reshape(-1, 3) - flat_groups = working_groups.reshape(-1) - unique_ids = np.unique(flat_groups) + flat_coords = sls.reshape(-1, 3) + flat_groups = working_groups.reshape(-1) + unique_ids = np.unique(flat_groups) - for gid in unique_ids: - mask = flat_groups == gid - group_data = flat_coords[mask] + for gid in unique_ids: + mask = flat_groups == gid + group_data = flat_coords[mask] - if len(group_data) < 15: - continue + if len(group_data) < 15: + continue - mu = stat(group_data, axis=0) - diff = group_data - mu + mu = stat(group_data, axis=0) + diff = group_data - mu - cov = np.cov(group_data.T, ddof=0) + cov = np.cov(group_data.T, ddof=0) - # Ensure positive semi-definite - if np.any(np.linalg.eigvals(cov) < 0): - eigenvalues, eigenvectors = np.linalg.eigh((cov + cov.T) / 2) - eigenvalues[eigenvalues < 0] = 0 - cov = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T + # Ensure positive semi-definite + if np.any(np.linalg.eigvals(cov) < 0): + eigenvalues, eigenvectors = np.linalg.eigh((cov + cov.T) / 2) + eigenvalues[eigenvalues < 0] = 0 + cov = eigenvectors @ np.diag(eigenvalues) @ eigenvectors.T - if np.any(cov > 0): - weights.ravel()[mask] = np.sqrt( - np.einsum("ij,jk,ik->i", diff, pinvh(cov), diff) - ) + if np.any(cov > 0): + weights.ravel()[mask] = np.sqrt( + np.einsum("ij,jk,ik->i", diff, pinvh(cov), diff) + ) - if return_mahalnobis: + if return_mahalanobis: return weights with np.errstate(divide="ignore"): @@ -325,7 +351,14 @@ def gaussian_weights( w_inv[np.isinf(w_inv)] = 0 denom = np.sum(w_inv, axis=0) - return np.divide(w_inv, denom, out=np.zeros_like(w_inv), where=denom != 0) + w = np.divide(w_inv, denom, out=np.zeros_like(w_inv), where=denom != 0) + col_sums = w.sum(axis=0) + if np.max(np.abs(col_sums - 1)) > 1e-3: + return _weighting_failed() + else: + final_sums = w.sum(axis=0, keepdims=True) + np.divide(w, final_sums, out=w, where=final_sums != 0) + return w def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150): diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 67f05fbd..f80ab6c4 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -105,7 +105,7 @@ def clean_by_orientation_mahalanobis( m_dist = gaussian_weights( fgarray_dists, assignment_idxs=assignment_idxs, - return_mahalnobis=True, + return_mahalanobis=True, n_points=None, stat=np.mean, ) @@ -253,7 +253,7 @@ def clean_bundle( while rounds_elapsed < clean_rounds: # This calculates the Mahalanobis for each streamline/node: m_dist = gaussian_weights( - fgarray, return_mahalnobis=True, n_points=None, stat=stat + fgarray, return_mahalanobis=True, n_points=None, stat=stat ) logger.debug(f"Shape of fgarray: {np.asarray(fgarray).shape}") logger.debug(f"Shape of m_dist: {m_dist.shape}") diff --git a/AFQ/recognition/clustering.py b/AFQ/recognition/clustering.py index 335f68b4..0832a9b6 100644 --- a/AFQ/recognition/clustering.py +++ b/AFQ/recognition/clustering.py @@ -19,13 +19,13 @@ import AFQ.utils.streamlines as aus -@njit(parallel=True) +@njit(parallel=True, fastmath=True, cache=True) def _compute_mean_euclidean_matrix(group_n, group_m): len_n = group_n.shape[0] len_m = group_m.shape[0] num_points = group_n.shape[1] - dist_matrix = np.empty((len_n, len_m), dtype=np.float64) + dist_matrix = np.empty((len_n, len_m), dtype=np.float32) for i in prange(len_n): for j in range(len_m): @@ -144,7 +144,13 @@ def spectral_atlas_label( def subcluster_by_atlas( - sub_trk, mapping, dwi_ref, cluster_indices, atlas_data=None, n_points=20 + sub_trk, + mapping, + dwi_ref, + cluster_indices, + atlas_data=None, + n_points=20, + batch_size=int(5e4), ): """ Use an existing atlas to label a new set of streamlines, and return the @@ -165,6 +171,8 @@ def subcluster_by_atlas( See `afd.read_org800_templates` as a reference. n_points : int, optional Number of points to resample streamlines to for labeling. Default is 20. + batch_size : int, optional + Number of streamlines to process in a batch. Default is 50,000. """ if atlas_data is None: @@ -177,13 +185,33 @@ def subcluster_by_atlas( atlas_fgarray = np.array(abu.resample_tg(moved_atlas_sft.streamlines, n_points)) sub_trk.to_rasmm() - sub_fgarray = np.array(abu.resample_tg(sub_trk.streamlines, n_points)) - - cluster_idxs, _ = spectral_atlas_label( - sub_fgarray, - atlas_fgarray, - atlas_data=atlas_data, - cluster_indices=cluster_indices, - ) - - return cluster_idxs + n_sub = len(sub_trk.streamlines) + + if n_sub <= batch_size: + sub_fgarray = np.asarray( + abu.resample_tg(sub_trk.streamlines, n_points), dtype=np.float32 + ) + cluster_idxs, _ = spectral_atlas_label( + sub_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + return cluster_idxs + + all_idxs = np.empty(n_sub, dtype=np.int64) + for start in range(0, n_sub, batch_size): + end = min(start + batch_size, n_sub) + batch_sls = sub_trk.streamlines[start:end] + batch_fgarray = np.asarray( + abu.resample_tg(batch_sls, n_points), dtype=np.float32 + ) + batch_idxs, _ = spectral_atlas_label( + batch_fgarray, + atlas_fgarray, + atlas_data=atlas_data, + cluster_indices=cluster_indices, + ) + all_idxs[start:end] = batch_idxs + + return all_idxs diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index c7326d99..70dbc88f 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -11,6 +11,8 @@ from dipy.segment.featurespeed import ResampleFeature from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric from scipy.ndimage import distance_transform_edt +from tqdm import tqdm +from trx.io import load as load_trx import AFQ.recognition.cleaning as abc import AFQ.recognition.curvature as abv @@ -19,9 +21,14 @@ import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import apply_to_roi_dict from AFQ.recognition.clustering import subcluster_by_atlas +from AFQ.recognition.preprocess import PreprocPlan +from AFQ.recognition.utils import resample_tg from AFQ.utils.streamlines import move_streamlines -criteria_order_pre_other_bundles = [ +# Criteria that are purely per-streamline and safe to run on a chunk +# without needing to see the rest of the tractogram. These run in the +# chunk-local phase. +criteria_order_chunk_local = [ "length", "endpoint_dists", "cross_midline", @@ -32,9 +39,12 @@ "include", "exclude", "curvature", - "recobundles", ] +# RecoBundles needs the whole candidate pool for a bundle, so it runs +# in the global phase even though it's nominally a "pre-other-bundles" +# criterion. +criteria_order_pre_other_bundles = criteria_order_chunk_local + ["recobundles"] criteria_order_post_other_bundles = ["orient_mahal", "isolation_forest", "qb_thresh"] @@ -56,35 +66,35 @@ logger = logging.getLogger("AFQ") -def prob_map(b_sls, bundle_def, preproc_imap, prob_threshold, img, **kwargs): +def prob_map(b_sls, bundle_def, preproc_plan, prob_threshold, img, **kwargs): b_sls.initiate_selection("Prob. Map") fiber_probabilities = dts.values_from_volume( bundle_def["prob_map"].get_fdata(), - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + preproc_plan.fgarray[b_sls.selected_fiber_idxs], img.affine, ) fiber_probabilities = np.mean(fiber_probabilities, -1) b_sls.select(fiber_probabilities > prob_threshold, "Prob. Map") -def cross_midline(b_sls, bundle_def, preproc_imap, **kwargs): +def cross_midline(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("Cross Mid.") - accepted = preproc_imap["crosses"][b_sls.selected_fiber_idxs] + accepted = preproc_plan.crosses[b_sls.selected_fiber_idxs] if not bundle_def["cross_midline"]: accepted = np.invert(accepted) b_sls.select(accepted, "Cross Mid.") -def start(b_sls, bundle_def, preproc_imap, **kwargs): +def start(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("Startpoint") exact_endpoints = bundle_def.get("exact_endpoints", False) if exact_endpoints: tol = 0 else: - tol = preproc_imap["dist_to_atlas"] + tol = kwargs["dist_to_atlas"] accept_idx = abr.clean_by_endpoints( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["start"], 0, tol=tol, @@ -92,7 +102,7 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["start"], -1, tol=tol, @@ -100,7 +110,7 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) special_idx_to_flip = abu.manual_orient_sls( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] + preproc_plan.fgarray[b_sls.selected_fiber_idxs][special_idx] ) accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) @@ -109,16 +119,16 @@ def start(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.select(accept_idx, "Startpoint") -def end(b_sls, bundle_def, preproc_imap, **kwargs): +def end(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("endpoint") exact_endpoints = bundle_def.get("exact_endpoints", False) if exact_endpoints: tol = 0 else: - tol = preproc_imap["dist_to_atlas"] + tol = kwargs["dist_to_atlas"] accept_idx = abr.clean_by_endpoints( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["end"], -1, tol=tol, @@ -126,7 +136,7 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): ) if not b_sls.oriented_yet: accepted_idx_flipped = abr.clean_by_endpoints( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + preproc_plan.fgarray[b_sls.selected_fiber_idxs], bundle_def["end"], 0, tol=tol, @@ -134,7 +144,7 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): new_accept_idx = np.logical_or(accepted_idx_flipped, accept_idx) special_idx = np.logical_and(accept_idx, accepted_idx_flipped) special_idx_to_flip = abu.manual_orient_sls( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs][special_idx] + preproc_plan.fgarray[b_sls.selected_fiber_idxs][special_idx] ) accepted_idx_flipped[special_idx] = special_idx_to_flip b_sls.reorient(accepted_idx_flipped) @@ -142,31 +152,29 @@ def end(b_sls, bundle_def, preproc_imap, **kwargs): b_sls.select(accept_idx, "endpoint") -def length(b_sls, bundle_def, preproc_imap, **kwargs): +def length(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("length") min_len = bundle_def["length"].get("min_len", 0) max_len = bundle_def["length"].get("max_len", np.inf) - # No need to use b_sls.selected_fiber_idxs - # because this is first step - sl_lens = preproc_imap["lengths"] + sl_lens = preproc_plan.lengths accept_idx = (sl_lens >= min_len) & (sl_lens <= max_len) b_sls.select(accept_idx, "length") -def endpoint_dists(b_sls, bundle_def, preproc_imap, **kwargs): +def endpoint_dists(b_sls, bundle_def, preproc_plan, **kwargs): b_sls.initiate_selection("endpoint_dists") min_dist = bundle_def["endpoint_dists"].get("min_dist", 0) max_dist = bundle_def["endpoint_dists"].get("max_dist", np.inf) - sl_endpoint_dists = preproc_imap["endpoint_dists"][b_sls.selected_fiber_idxs] + sl_endpoint_dists = preproc_plan.endpoint_dists[b_sls.selected_fiber_idxs] accept_idx = (sl_endpoint_dists >= min_dist) & (sl_endpoint_dists <= max_dist) b_sls.select(accept_idx, "endpoint_dists") -def primary_axis(b_sls, bundle_def, img, **kwargs): +def primary_axis(b_sls, bundle_def, **kwargs): b_sls.initiate_selection("orientation") accept_idx = abc.clean_by_orientation( b_sls.get_selected_sls(), @@ -176,18 +184,16 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): b_sls.select(accept_idx, "orientation") -def include(b_sls, bundle_def, preproc_imap, **kwargs): +def include(b_sls, bundle_def, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1 and not b_sls.oriented_yet if "inc_addtol" in bundle_def: include_roi_tols = [] for inc_tol in bundle_def["inc_addtol"]: - include_roi_tols.append( - (inc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"]) ** 2 - ) + include_roi_tols.append((inc_tol / kwargs["vox_dim"] + kwargs["tol"]) ** 2) else: - include_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["include"]) + include_roi_tols = [kwargs["tol"] ** 2] * len(bundle_def["include"]) inc_results = abr.check_sls_with_inclusion( b_sls.get_selected_sls(), bundle_def["include"], include_roi_tols @@ -205,10 +211,7 @@ def include(b_sls, bundle_def, preproc_imap, **kwargs): roi_closest[:, sl_idx] = sl_closest roi_dists[:, sl_idx] = sl_dists if len(sl_closest) > 1: - # Only accept SLs that, when cut, are meaningful if (len(sl_closest) < 2) or abs(sl_closest[0] - sl_closest[-1]) > 1: - # Flip sl if it is close to second ROI - # before its close to the first ROI if flip_using_include: to_flip[sl_idx] = sl_closest[0] > sl_closest[-1] if to_flip[sl_idx]: @@ -226,10 +229,6 @@ def include(b_sls, bundle_def, preproc_imap, **kwargs): def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): - """ - Filters streamlines by how well they match - a curve in orientation and shape but not scale - """ accept_idx = b_sls.initiate_selection("curvature") if "sft" in bundle_def["curvature"]: ref_sl = bundle_def["curvature"]["sft"] @@ -253,16 +252,14 @@ def curvature(b_sls, bundle_def, mapping, img, save_intermediates, **kwargs): b_sls.select(accept_idx, "curvature", cut=cut) -def exclude(b_sls, bundle_def, preproc_imap, **kwargs): +def exclude(b_sls, bundle_def, **kwargs): accept_idx = b_sls.initiate_selection("exclude") if "exc_addtol" in bundle_def: exclude_roi_tols = [] for exc_tol in bundle_def["exc_addtol"]: - exclude_roi_tols.append( - (exc_tol / preproc_imap["vox_dim"] + preproc_imap["tol"]) ** 2 - ) + exclude_roi_tols.append((exc_tol / kwargs["vox_dim"] + kwargs["tol"]) ** 2) else: - exclude_roi_tols = [preproc_imap["tol"] ** 2] * len(bundle_def["exclude"]) + exclude_roi_tols = [kwargs["tol"] ** 2] * len(bundle_def["exclude"]) for sl_idx, sl in enumerate(b_sls.get_selected_sls()): if abr.check_sl_with_exclusion(sl, bundle_def["exclude"], exclude_roi_tols): accept_idx[sl_idx] = 1 @@ -320,7 +317,7 @@ def qb_thresh(b_sls, bundle_def, clip_edges, **kwargs): def clean_by_other_bundle( - b_sls, bundle_def, img, preproc_imap, other_bundle_name, other_bundle_sls, **kwargs + b_sls, bundle_def, img, other_bundle_name, other_bundle_sls, **kwargs ): cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 @@ -352,7 +349,7 @@ def clean_by_other_bundle( consideration = bundle_def[other_bundle_name].get("consideration", 10.0) if isinstance(consideration, (int, float)): consideration = float(consideration) - consideration = consideration / preproc_imap["vox_dim"] + consideration = consideration / kwargs["vox_dim"] cleaned_idx_core = abo.clean_relative_to_other_core( bundle_def[other_bundle_name]["core"].lower(), @@ -396,19 +393,11 @@ def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs): b_sls.select(cleaned_idx, "Mahalanobis", cut=cut) -def run_bundle_rec_plan( - bundle_dict, - streamlines, - mapping, - img, - reg_template, - preproc_imap, - bundle_name, - recognized_bundles_dict, - **segmentation_params, -): - # Warp ROIs - logger.info(f"Preparing ROIs for {bundle_name}") +def _prepare_bundle_def(bundle_dict, bundle_name, mapping, img): + """ + Warp ROIs and apply distance-transform conversion + """ + tqdm.write(f"Preparing ROIs for {bundle_name}") start_time = time() bundle_def = dict(bundle_dict.get_b_info(bundle_name)) bundle_def.update( @@ -436,33 +425,11 @@ def check_space(roi): apply_to_recobundles=False, apply_to_prob_map=False, ) - logger.info(f"Time to prep ROIs: {time() - start_time}s") + tqdm.write(f"Time to prep ROIs: {time() - start_time}s") + return bundle_def - if isinstance(streamlines, abu.SlsBeingRecognized): - # This only occurs when your inside a subbundle, - # in which case we want to keep the same SlsBeingRecognized object so that - # we can keep track of the same streamlines and their orientations - b_sls = streamlines - else: - b_sls = abu.SlsBeingRecognized( - streamlines, - logger, - segmentation_params["save_intermediates"], - bundle_name, - img, - len(bundle_def.get("include", [])), - ) - - inputs = {} - inputs["b_sls"] = b_sls - inputs["preproc_imap"] = preproc_imap - inputs["bundle_def"] = bundle_def - inputs["mapping"] = mapping - inputs["img"] = img - inputs["reg_template"] = reg_template - for key, value in segmentation_params.items(): - inputs[key] = value +def _validate_criteria(bundle_def, bundle_name, bundle_dict, recognized_bundles_dict): for potential_criterion in bundle_def.keys(): if ( (potential_criterion not in criteria_order_post_other_bundles) @@ -472,29 +439,97 @@ def check_space(roi): ): if potential_criterion in bundle_dict.bundle_names: raise ValueError( - ( - f"Bundle {potential_criterion} is being used as a criterion in " - f"the definition of bundle {bundle_name}, however this bundle " - "was not found." - " This could because of insufficient streamlines" - ) + f"Bundle {potential_criterion} is being used as a criterion in " + f"the definition of bundle {bundle_name}, however this bundle " + "was not found. This could be because of insufficient streamlines" ) else: raise ValueError( - ( - "Invalid criterion in bundle definition:\n" - f"{potential_criterion} in bundle {bundle_name}.\n" - "Valid criteria are:\n" - f"{criteria_order_pre_other_bundles}\n" - f"{criteria_order_post_other_bundles}\n" - f"{recognized_bundles_dict.keys()}\n" - f"{valid_noncriterion}\n" - ) + "Invalid criterion in bundle definition:\n" + f"{potential_criterion} in bundle {bundle_name}.\n" + "Valid criteria are:\n" + f"{criteria_order_pre_other_bundles}\n" + f"{criteria_order_post_other_bundles}\n" + f"{recognized_bundles_dict.keys()}\n" + f"{valid_noncriterion}\n" ) - for criterion in criteria_order_pre_other_bundles: + +def _run_chunk_local( + bundle_def, + chunk_streamlines, + bundle_name, + img, + preproc_plan, + save_intermediates, + vox_dim, + tol, + dist_to_atlas, + **segmentation_params, +): + b_sls = abu.SlsBeingRecognized( + chunk_streamlines, + save_intermediates, + bundle_name, + img, + len(bundle_def.get("include", [])), + ) + + inputs = { + "b_sls": b_sls, + "preproc_plan": preproc_plan, + "bundle_def": bundle_def, + "img": img, + "save_intermediates": save_intermediates, + "vox_dim": vox_dim, + "tol": tol, + "dist_to_atlas": dist_to_atlas, + } + inputs.update(segmentation_params) + + for criterion in criteria_order_chunk_local: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) + + return b_sls + + +def _run_global_phase( + bundle_def, + bundle_name, + b_sls, + fgarray_for_candidates, + candidate_global_idx, + mapping, + img, + reg_template, + preproc_scalars, + recognized_bundles_dict, + vox_dim, + tol, + dist_to_atlas, + is_subbundle=False, + **segmentation_params, +): + if not b_sls: + return + + inputs = { + "b_sls": b_sls, + "preproc_plan": preproc_scalars, + "bundle_def": bundle_def, + "mapping": mapping, + "img": img, + "reg_template": reg_template, + "vox_dim": vox_dim, + "tol": tol, + "dist_to_atlas": dist_to_atlas, + } + inputs.update(segmentation_params) + + if "recobundles" in bundle_def: + recobundles(**inputs) + if b_sls: for o_bundle_name in recognized_bundles_dict.keys(): if o_bundle_name in bundle_def.keys(): @@ -505,9 +540,11 @@ def check_space(roi): o_bundle_name ].get_selected_sls(flip=True), ) + for criterion in criteria_order_post_other_bundles: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) + if b_sls: if "mahal" in bundle_def or ( "isolation_forest" not in bundle_def @@ -516,16 +553,20 @@ def check_space(roi): ): mahalanobis(**inputs) - # If you don't cross the midline, we remove streamliens - # entirely on the wrong side of the midline here after filtering - if b_sls and "cross_midline" in bundle_def and not bundle_def["cross_midline"]: + # Wrong-side-of-midline cleanup. fgarray_for_candidates is in + # candidate-local order; b_sls.selected_fiber_idxs is in global + # order. searchsorted translates between them. + if ( + b_sls + and not is_subbundle + and "cross_midline" in bundle_def + and not bundle_def["cross_midline"] + and fgarray_for_candidates is not None + and candidate_global_idx is not None + ): + pos = np.searchsorted(candidate_global_idx, b_sls.selected_fiber_idxs) b_sls.initiate_selection("Wrong side of mid.") - avg_side = np.sign( - np.mean( - preproc_imap["fgarray"][b_sls.selected_fiber_idxs, :, 0], - axis=1, - ) - ) + avg_side = np.sign(np.mean(fgarray_for_candidates[pos, :, 0], axis=1)) majority_side = np.sign(np.sum(avg_side)) b_sls.select(avg_side == majority_side, "Wrong side of mid.") @@ -534,49 +575,208 @@ def check_space(roi): "pyAFQ was unable to consistently orient streamlines " f"in bundle {bundle_name} using the provided ROIs. " "This can be fixed by including at least 2 " - "waypoint ROIs, or by using " - "endpoint ROIs." + "waypoint ROIs, or by using endpoint ROIs." ) - if b_sls: - if "ORG_spectral_subbundles" in bundle_def: - subdict = bundle_def["ORG_spectral_subbundles"] - b_sls.initiate_selection( - ( - f"ORG spectral clustering, {len(subdict.bundle_names)} " - "subbundles being recognized" - ) + if not b_sls: + return + + if "ORG_spectral_subbundles" in bundle_def: + if is_subbundle: + raise ValueError("Nested ORG_spectral_subbundles are not supported.") + subdict = bundle_def["ORG_spectral_subbundles"] + b_sls.initiate_selection( + f"ORG spectral clustering, {len(subdict.bundle_names)} " + "subbundles being recognized" + ) + + sub_sft = StatefulTractogram( + b_sls.get_selected_sls(flip=True), img, Space.RASMM + ) + cluster_labels = subcluster_by_atlas( + sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 + ) + + for sub_b_name in subdict.bundle_names: + c_ids = subdict._dict[sub_b_name]["cluster_IDs"] + n_roi = len(subdict._dict[sub_b_name].get("include", [])) + cluster_b_sls = b_sls.copy(sub_b_name, n_roi) + selected = np.zeros(len(b_sls), dtype=bool) + for c_id in c_ids: + selected = np.logical_or(selected, cluster_labels == c_id) + cluster_b_sls.select(selected, f"Clusters {c_ids}") + + sub_bundle_def = _prepare_bundle_def(subdict, sub_b_name, mapping, img) + _validate_criteria( + sub_bundle_def, sub_b_name, subdict, recognized_bundles_dict + ) + _run_global_phase( + sub_bundle_def, + sub_b_name, + cluster_b_sls, + None, + None, + mapping, + img, + reg_template, + preproc_scalars, + recognized_bundles_dict, + vox_dim, + tol, + dist_to_atlas, + is_subbundle=True, + **segmentation_params, + ) + else: + b_sls.bundle_def = bundle_def + recognized_bundles_dict[bundle_name] = b_sls + + +def recognize_bundles( + tg, + bundle_dict, + mapping, + img, + reg_template, + chunk_size, + dist_to_waypoint, + dist_to_atlas, + save_intermediates, + **segmentation_params, +): + if isinstance(tg, str): + tg_path = tg + tg = load_trx(tg_path, img) + else: + tg_path = None + + n_streamlines = len(tg) + recognized_bundles_dict = {} + + tqdm.write( + f"Recognizing bundles over {n_streamlines} streamlines " + f"in chunks of {chunk_size}" + ) + + tol, dist_to_atlas, vox_dim = abu.tolerance_mm_to_vox( + img, dist_to_waypoint, dist_to_atlas + ) + preproc_scalars = { + "vox_dim": vox_dim, + "tol": tol, + "dist_to_atlas": dist_to_atlas, + } + + bundle_defs = {} + survivor_dicts = {} + for bundle_name in bundle_dict.bundle_names: + bd = _prepare_bundle_def(bundle_dict, bundle_name, mapping, img) + bundle_defs[bundle_name] = bd + survivor_dicts[bundle_name] = [] + + total_chunks = (n_streamlines + chunk_size - 1) // chunk_size + for chunk_start in tqdm( + range(0, n_streamlines, chunk_size), + total=total_chunks, + desc="Batched Portion of Recognition", + ): + chunk_end = min(chunk_start + chunk_size, n_streamlines) + tqdm.write( + f"Processing chunk {chunk_start}:{chunk_end} of {n_streamlines} " + f"({(chunk_end / n_streamlines) * 100:.2f}%)" + ) + + if tg_path is not None and tg is None: + tg = load_trx(tg_path, img) + chunk_streamlines = tg.streamlines[chunk_start:chunk_end].copy() + if tg_path is not None: + tg.close() + tg = None + + chunk_preproc = PreprocPlan(chunk_streamlines) + + for bundle_name in bundle_dict.bundle_names: + tqdm.write(f"Running chunk-local phase for bundle {bundle_name}") + chunk_b_sls = _run_chunk_local( + bundle_defs[bundle_name], + chunk_streamlines, + bundle_name, + img, + chunk_preproc, + save_intermediates, + mapping=mapping, + reg_template=reg_template, + vox_dim=vox_dim, + tol=tol, + dist_to_atlas=dist_to_atlas, + **segmentation_params, ) + survivor_dicts[bundle_name].append(chunk_b_sls.export_selected(chunk_start)) + del chunk_b_sls + del chunk_preproc, chunk_streamlines + + if tg_path is not None: + tg = load_trx(tg_path, img) + + for bundle_name in bundle_dict.bundle_names: + tqdm.write(f"Running global phase for bundle {bundle_name}") + bundle_def = bundle_defs[bundle_name] + + merged = abu.SlsBeingRecognized.from_selected( + survivor_dicts[bundle_name], + tg.streamlines, + save_intermediates, + bundle_name, + img, + len(bundle_def.get("include", [])), + ) + survivor_dicts[bundle_name] = None # free per-chunk dicts - sub_sft = StatefulTractogram( - b_sls.get_selected_sls(flip=True), img, Space.RASMM + if merged is None: + tqdm.write( + f"Bundle {bundle_name}: 0 candidates after chunk-local filtering" ) - cluster_labels = subcluster_by_atlas( - sub_sft, mapping, img, subdict.all_cluster_IDs, n_points=40 + continue + + _validate_criteria( + bundle_def, bundle_name, bundle_dict, recognized_bundles_dict + ) + + tqdm.write( + f"Bundle {bundle_name}: {len(merged)} candidates after " + "chunk-local filtering" + ) + + need_fgarray = "cross_midline" in bundle_def and not bundle_def["cross_midline"] + if need_fgarray: + candidate_global_idx = np.array(merged.selected_fiber_idxs, dtype=np.int64) + cand_streamlines = [tg.streamlines[int(i)] for i in candidate_global_idx] + start_time = time() + fgarray_for_candidates = np.asarray( + resample_tg(cand_streamlines, 20), dtype=np.float32 ) - clusters_being_recognized = [] - for sub_b_name in subdict.bundle_names: - c_ids = subdict._dict[sub_b_name]["cluster_IDs"] - n_roi = len(subdict._dict[sub_b_name].get("include", [])) - cluster_b_sls = b_sls.copy(sub_b_name, n_roi) - selected = np.zeros(len(b_sls), dtype=bool) - for c_id in c_ids: - selected = np.logical_or(selected, cluster_labels == c_id) - cluster_b_sls.select(selected, f"Clusters {c_ids}") - clusters_being_recognized.append(cluster_b_sls) - - for ii, sub_b_name in enumerate(subdict.bundle_names): - run_bundle_rec_plan( - bundle_def["ORG_spectral_subbundles"], - clusters_being_recognized[ii], - mapping, - img, - reg_template, - preproc_imap, - sub_b_name, - recognized_bundles_dict, - **segmentation_params, - ) + tqdm.write(f"Resampling took {time() - start_time:.2f} seconds") + del cand_streamlines else: - b_sls.bundle_def = bundle_def - recognized_bundles_dict[bundle_name] = b_sls + candidate_global_idx = None + fgarray_for_candidates = None + + _run_global_phase( + bundle_def, + bundle_name, + merged, + fgarray_for_candidates, + candidate_global_idx, + mapping, + img, + reg_template, + preproc_scalars, + recognized_bundles_dict, + vox_dim, + tol, + dist_to_atlas, + save_intermediates=save_intermediates, + **segmentation_params, + ) + + return recognized_bundles_dict, n_streamlines diff --git a/AFQ/recognition/preprocess.py b/AFQ/recognition/preprocess.py index 36fe16a5..d6f0a505 100644 --- a/AFQ/recognition/preprocess.py +++ b/AFQ/recognition/preprocess.py @@ -1,79 +1,30 @@ -import logging -from time import time +from functools import cached_property -import immlib import numpy as np import AFQ.recognition.utils as abu -logger = logging.getLogger("AFQ") +class PreprocPlan: + def __init__(self, tg): + self.tg = tg -@immlib.calc("tol", "dist_to_atlas", "vox_dim") -def tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas): - return abu.tolerance_mm_to_vox(img, dist_to_waypoint, input_dist_to_atlas) + @cached_property + def fgarray(self): + return np.asarray(abu.resample_tg(self.tg, 20), dtype=np.float32) + @cached_property + def crosses(self): + return np.logical_and( + np.any(self.fgarray[:, :, 0] > 0, axis=1), + np.any(self.fgarray[:, :, 0] < 0, axis=1), + ) -@immlib.calc("fgarray") -def fgarray(tg): - """ - Streamlines resampled to 20 points. - """ - logger.info("Resampling Streamlines...") - start_time = time() - fg_array = np.array(abu.resample_tg(tg, 20)) - logger.info((f"Streamlines Resampled (time: {time() - start_time}s)")) - return fg_array + @cached_property + def lengths(self): + segments = np.diff(self.fgarray, axis=1) + return np.sum(np.sqrt(np.sum(segments**2, axis=2)), axis=1) - -@immlib.calc("crosses") -def crosses(fgarray): - """ - Classify the streamlines by whether they cross the midline. - Creates a crosses attribute which is an array of booleans. Each boolean - corresponds to a streamline, and is whether or not that streamline - crosses the midline. - """ - return np.logical_and( - np.any(fgarray[:, :, 0] > 0, axis=1), - np.any(fgarray[:, :, 0] < 0, axis=1), - ) - - -@immlib.calc("lengths") -def lengths(fgarray): - """ - Calculate the lengths of the streamlines. - Using resampled fgarray biases lengths to be lower. However, - this is not meant to be a precise selection requirement, and - is more meant for efficiency. - """ - segments = np.diff(fgarray, axis=1) - segment_lengths = np.sqrt(np.sum(segments**2, axis=2)) - return np.sum(segment_lengths, axis=1) - - -@immlib.calc("endpoint_dists") -def endpoint_dists(fgarray): - """ - Calculate the distances between the endpoints of the streamlines. - """ - return np.linalg.norm(fgarray[:, 0, :] - fgarray[:, -1, :], axis=1) - - -# Things that can be calculated for multiple bundles at once -# (i.e., for a whole tractogram) go here -def get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas): - preproc_plan = immlib.plan( - tolerance_mm_to_vox=tolerance_mm_to_vox, - fgarray=fgarray, - crosses=crosses, - lengths=lengths, - endpoint_dists=endpoint_dists, - ) - return preproc_plan( - img=img, - tg=tg, - dist_to_waypoint=dist_to_waypoint, - input_dist_to_atlas=dist_to_atlas, - ) + @cached_property + def endpoint_dists(self): + return np.linalg.norm(self.fgarray[:, 0, :] - self.fgarray[:, -1, :], axis=1) diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index d6575dc9..6da23822 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -10,8 +10,7 @@ import AFQ.recognition.sparse_decisions as ars import AFQ.recognition.utils as abu from AFQ.api.bundle_dict import BundleDict -from AFQ.recognition.criteria import run_bundle_rec_plan -from AFQ.recognition.preprocess import get_preproc_plan +from AFQ.recognition.criteria import recognize_bundles from AFQ.utils.path import write_json logger = logging.getLogger("AFQ") @@ -36,13 +35,14 @@ def recognize( dist_to_atlas=4, save_intermediates=None, cleaning_params=None, + chunk_size=int(1e6), ): """ Segment streamlines into bundles. Parameters ---------- - tg : StatefulTractogram, or TrxFile + tg : StatefulTractogram, or path to a TRXfile Tractogram to segment. img : str, nib.Nifti1Image Image for reference. @@ -114,6 +114,12 @@ def recognize( override the default parameters of that method. However, this can be overridden by setting the cleaning parameters in the bundle_dict. Default: {}. + chunk_size : int, optional + Number of streamlines to preprocess at a time. The full + tractogram is processed in chunks of this size to keep peak + memory bounded. Per-chunk surviving candidates are merged + before the global per-bundle filtering steps run. + Default: 1e6. References ---------- @@ -166,36 +172,28 @@ def recognize( tg.to_rasmm() - n_streamlines = len(tg) - recognized_bundles_dict = {} - fiber_groups = {} meta = {} - preproc_imap = get_preproc_plan(img, tg, dist_to_waypoint, dist_to_atlas) - - logger.info("Assigning Streamlines to Bundles") - for bundle_name in bundle_dict.bundle_names: - logger.info(f"Finding Streamlines for {bundle_name}") - run_bundle_rec_plan( - bundle_dict, - tg.streamlines, - mapping, - img, - reg_template, - preproc_imap, - bundle_name, - recognized_bundles_dict, - clip_edges=clip_edges, - rb_recognize_params=rb_recognize_params, - prob_threshold=prob_threshold, - refine_reco=refine_reco, - rng=rng, - return_idx=return_idx, - filter_by_endpoints=filter_by_endpoints, - save_intermediates=save_intermediates, - cleaning_params=cleaning_params, - ) + recognized_bundles_dict, n_streamlines = recognize_bundles( + tg, + bundle_dict, + mapping, + img, + reg_template, + chunk_size=chunk_size, + dist_to_waypoint=dist_to_waypoint, + dist_to_atlas=dist_to_atlas, + save_intermediates=save_intermediates, + clip_edges=clip_edges, + rb_recognize_params=rb_recognize_params, + prob_threshold=prob_threshold, + refine_reco=refine_reco, + rng=rng, + return_idx=return_idx, + filter_by_endpoints=filter_by_endpoints, + cleaning_params=cleaning_params, + ) if save_intermediates is not None: os.makedirs(save_intermediates, exist_ok=True) diff --git a/AFQ/recognition/utils.py b/AFQ/recognition/utils.py index f081c16e..cfc0e0c4 100644 --- a/AFQ/recognition/utils.py +++ b/AFQ/recognition/utils.py @@ -1,5 +1,4 @@ import copy -import logging import os.path as op from time import time @@ -9,9 +8,7 @@ from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.streamline import save_tractogram from dipy.tracking.distances import bundles_distances_mdf - -logger = logging.getLogger("AFQ") - +from tqdm import tqdm axes_dict = { "L/R": 0, @@ -137,11 +134,10 @@ def resample_tg(tg, n_points): class SlsBeingRecognized: - def __init__(self, sls, logger, save_intermediates, b_name, ref, n_roi): + def __init__(self, sls, save_intermediates, b_name, ref, n_roi): self.oriented_yet = False self.selected_fiber_idxs = np.arange(len(sls), dtype=np.uint32) self.sls_flipped = np.zeros(len(sls), dtype=np.bool_) - self.logger = logger self.start_time = -1 self.save_intermediates = save_intermediates self.b_name = b_name @@ -151,7 +147,7 @@ def __init__(self, sls, logger, save_intermediates, b_name, ref, n_roi): def initiate_selection(self, clean_name): self.start_time = time() - self.logger.info(f"Filtering by {clean_name}") + tqdm.write(f"Filtering by {clean_name}") return np.zeros(len(self.selected_fiber_idxs), dtype=np.bool_) def select(self, idx, clean_name, cut=False): @@ -162,7 +158,7 @@ def select(self, idx, clean_name, cut=False): if hasattr(self, "roi_dists"): self.roi_dists = self.roi_dists[idx] time_taken = time() - self.start_time - self.logger.info( + tqdm.write( f"After filtering by {clean_name} (time: {time_taken}s), " f"{len(self)} streamlines remain." ) @@ -234,3 +230,77 @@ def copy(self, new_name, n_roi): new_copy.roi_dists = self.roi_dists.copy() return new_copy + + def export_selected(self, chunk_offset): + return { + "global_idx": ( + self.selected_fiber_idxs.astype(np.int64) + int(chunk_offset) + ).copy(), + "sls_flipped": self.sls_flipped.copy(), + "oriented_yet": self.oriented_yet, + "roi_closest": ( + self.roi_closest.copy() if hasattr(self, "roi_closest") else None + ), + "roi_dists": ( + self.roi_dists.copy() if hasattr(self, "roi_dists") else None + ), + } + + @classmethod + def from_selected( + cls, + survivor_dicts, + full_streamlines, + save_intermediates, + b_name, + ref, + n_roi, + ): + non_empty = [d for d in survivor_dicts if d["global_idx"].size > 0] + if not non_empty: + return None + + global_idx = np.concatenate([d["global_idx"] for d in non_empty]) + sls_flipped = np.concatenate([d["sls_flipped"] for d in non_empty]) + oriented_yet = any(d["oriented_yet"] for d in non_empty) + + if global_idx.size > 1 and not np.all(np.diff(global_idx) > 0): + order = np.argsort(global_idx, kind="stable") + global_idx = global_idx[order] + sls_flipped = sls_flipped[order] + else: + order = None + + has_roi = [d["roi_closest"] is not None for d in non_empty] + if any(has_roi): + if not all(has_roi): + raise RuntimeError( + "Inconsistent roi_closest across chunks for bundle " + f"{b_name}: some chunks have it, some don't. This is a" + " bug in chunked recognition." + ) + roi_closest = np.concatenate([d["roi_closest"] for d in non_empty], axis=0) + roi_dists = np.concatenate([d["roi_dists"] for d in non_empty], axis=0) + if order is not None: + roi_closest = roi_closest[order] + roi_dists = roi_dists[order] + else: + roi_closest = None + roi_dists = None + + inst = cls( + full_streamlines, + save_intermediates, + b_name, + ref, + n_roi, + ) + + inst.selected_fiber_idxs = global_idx.astype(np.uint32, copy=False) + inst.sls_flipped = sls_flipped + inst.oriented_yet = oriented_yet + if roi_closest is not None: + inst.roi_closest = roi_closest + inst.roi_dists = roi_dists + + return inst diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index d40db1d7..3830479f 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -63,7 +63,10 @@ def segment( is_trx = False elif streamlines.endswith(".trx"): is_trx = True - tg = load_trx(streamlines, data_imap["dwi"]) + if segmentation_params["nb_streamlines"] or segmentation_params["nb_points"]: + tg = load_trx(streamlines, data_imap["dwi"]) + else: + tg = streamlines elif streamlines.endswith(".tck.gz"): # uncompress tck.gz to a temporary tck: temp_tck = op.join(mkdtemp(), op.split(streamlines.replace(".gz", ""))[1]) @@ -74,13 +77,6 @@ def segment( # initialize stateful tractogram from tck file: tg = load_tractogram(temp_tck, data_imap["dwi"], bbox_valid_check=False) is_trx = False - if len(tg.streamlines) == 0: - raise ValueError( - f"There are no streamlines in {streamlines}." - " This is likely due to errors in defining the " - " tractography parameters or the" - " seed/PVE masks." - ) if not is_trx: indices_to_remove, _ = tg.remove_invalid_streamlines() diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 8f6a565a..ecaf5cbf 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -128,8 +128,11 @@ def streamlines( if len(sft) == 0: raise ValueError( - "No streamlines were generated. " - "Please check your tracking parameters and input data." + "No streamlines were generated." + " This is likely due to errors in defining the " + " tractography parameters or the" + " seed/PVE masks." + " Please check your tracking parameters and input data." ) return sft, _meta_from_tracking_params( diff --git a/AFQ/tests/test_fixes.py b/AFQ/tests/test_fixes.py index 9501b297..0d5e13df 100644 --- a/AFQ/tests/test_fixes.py +++ b/AFQ/tests/test_fixes.py @@ -10,7 +10,6 @@ import AFQ.data.fetch as afd from AFQ._fixes import gaussian_weights, gwi_odf -from AFQ._fixes import gaussian_weights as gaussian_weights_fast from AFQ.utils.testing import make_dki_data @@ -35,43 +34,58 @@ def test_GQI_fix(): def test_gaussian_weights(): file_dict = afd.read_stanford_hardi_tractography() streamlines = file_dict["tractography_subsampled"] - assert not np.any(np.isnan(gaussian_weights(streamlines[76:92]))) + + weights = gaussian_weights(streamlines) + assert not np.any(np.isnan(weights)) + + # test consistency + assignment_idxs = np.tile(np.arange(100), (len(streamlines), 1)) + assignment_method_weights = gaussian_weights( + streamlines, assignment_idxs=assignment_idxs + ) + + assert np.allclose( + weights, assignment_method_weights[: len(weights)], rtol=1e-6, atol=1e-6 + ) + + assert np.allclose(np.sum(weights), 100) + assert np.allclose(np.sum(assignment_method_weights), 100) def test_mahal_fix(): sls = [ - [[8.0, 53, 39], [8, 50, 39], [8, 45, 39], [30, 41, 61], [28, 61, 38]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [30, 41, 62], [20, 44, 34]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [50, 67, 88], [10, 10, 20]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [35, 43, 65], [25, 55, 35]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [40, 50, 70], [15, 15, 25]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [45, 54, 75], [12, 22, 32]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [32, 48, 68], [28, 58, 40]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [38, 52, 72], [18, 38, 28]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [34, 44, 64], [21, 41, 31]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [36, 46, 66], [23, 53, 33]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [37, 47, 67], [24, 54, 34]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [39, 49, 69], [19, 39, 29]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [33, 53, 73], [22, 42, 32]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [31, 51, 71], [26, 56, 36]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [29, 59, 79], [27, 57, 37]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [28, 58, 78], [17, 47, 27]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [27, 57, 77], [16, 36, 26]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [26, 56, 76], [14, 24, 34]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [25, 55, 75], [13, 23, 33]], - [[8, 53, 39], [8, 50, 39], [8, 45, 39], [24, 54, 74], [11, 21, 31]], + [[30, 41, 61], [28, 61, 38]], + [[30, 41, 62], [20, 44, 34]], + [[50, 67, 88], [10, 10, 20]], + [[35, 43, 65], [25, 55, 35]], + [[40, 50, 70], [15, 15, 25]], + [[45, 54, 75], [12, 22, 32]], + [[32, 48, 68], [28, 58, 40]], + [[38, 52, 72], [18, 38, 28]], + [[34, 44, 64], [21, 41, 31]], + [[36, 46, 66], [23, 53, 33]], + [[37, 47, 67], [24, 54, 34]], + [[39, 49, 69], [19, 39, 29]], + [[33, 53, 73], [22, 42, 32]], + [[31, 51, 71], [26, 56, 36]], + [[29, 59, 79], [27, 57, 37]], + [[28, 58, 78], [17, 47, 27]], + [[27, 57, 77], [16, 36, 26]], + [[26, 56, 76], [14, 24, 34]], + [[25, 55, 75], [13, 23, 33]], + [[24, 54, 74], [11, 21, 31]], ] sls_array = np.asarray(sls).astype(float) results = np.asarray( [ - [0.0, 0.0, 0.0, 1.718654, 1.550252], - [0.0, 0.0, 0.0, 2.202227, 0.7881], - [0.0, 0.0, 0.0, 3.415999, 2.689814], + [1.718654, 1.550252], + [2.202227, 0.7881], + [3.415999, 2.689814], ] ) npt.assert_array_almost_equal( - gaussian_weights_fast( - sls_array, n_points=None, return_mahalnobis=True, stat=np.mean + gaussian_weights( + sls_array, n_points=None, return_mahalanobis=True, stat=np.mean )[:3], results, )