diff --git a/.codeocean/datasets.json b/.codeocean/datasets.json index 272299e..abe9ce5 100644 --- a/.codeocean/datasets.json +++ b/.codeocean/datasets.json @@ -4,6 +4,10 @@ { "id": "68ef27d7-9d95-40ce-9e40-7de93dccf5f8", "mount": "LCNE-patchseq-ephys" + }, + { + "id": "c5d5e922-1863-4869-99ff-fdce3e10edc7", + "mount": "fonts" } ] } \ No newline at end of file diff --git a/assets/fonts/Helvetica.ttf b/assets/fonts/Helvetica.ttf new file mode 100644 index 0000000..c2c74cb Binary files /dev/null and b/assets/fonts/Helvetica.ttf differ diff --git a/code/run b/code/run new file mode 100755 index 0000000..713757c --- /dev/null +++ b/code/run @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -ex + +# This is the master script for the capsule. When you click "Reproducible Run", the code in this file will execute. +python -m LCNE_patchseq_analysis.figures.main_pca_tau diff --git a/environment/Dockerfile b/environment/Dockerfile index e342d3b..88fa063 100644 --- a/environment/Dockerfile +++ b/environment/Dockerfile @@ -15,8 +15,8 @@ ARG GIT_ASKPASS ARG GIT_ACCESS_TOKEN COPY git-askpass / -RUN pip install --no-cache-dir "git+https://github.com/AllenNeuralDynamics/LCNE-patchseq-analysis.git@main#egg=LCNE-patchseq-analysis[panel]" - +RUN pip install -U --no-cache-dir \ + "git+https://github.com/AllenNeuralDynamics/LCNE-patchseq-analysis.git@b8c709294e970e35beda633757255b00d579f832#egg=LCNE-patchseq-analysis" COPY postInstall / RUN /postInstall diff --git a/environment/postInstall b/environment/postInstall index ccf4bef..810e5ef 100755 --- a/environment/postInstall +++ b/environment/postInstall @@ -1,29 +1,17 @@ #!/usr/bin/env bash set -e -# check if code-server is installed, and then install extensions into specified directory -if code-server --disable-telemetry --version; then - if [ ! -d "/.vscode/extensions" ] - then - echo "Directory /.vscode/extensions DOES NOT exists." - mkdir -p /.vscode/extensions/ - fi - - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension ms-python.python - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension ms-toolsai.jupyter - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension njpwerner.autodocstring - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension KevinRose.vsc-python-indent - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension mhutchie.git-graph - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension zhuangtongfa.material-theme - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension ms-python.black-formatter - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension randomfractalsinc.vscode-data-preview - # code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension ryanluker.vscode-coverage-gutters - - curl -L -o copilot_1.161.zip https://github.com/user-attachments/files/16859733/copilot_1.161.zip - unzip copilot_1.161.zip - code-server --disable-telemetry --extensions-dir=/.vscode/extensions --install-extension GitHub.copilot-1.161.0.vsix - rm copilot_1.161.zip GitHub.copilot-1.161.0.vsix - - else - echo "code-server not found" - fi \ No newline at end of file +HOME_DIR="${HOME:-/root}" +BASHRC="${HOME_DIR}/.bashrc" + +cat >> "$BASHRC" <<'EOF' + +# LCNE editable install (runs once on first login, src/ only available at runtime) +if [ ! -f "$HOME/.lcne_editable_setup_done" ]; then + echo "Installing local package in editable mode from /root/capsule..." + pip install -e /root/capsule + echo "Editable install complete" + + touch "$HOME/.lcne_editable_setup_done" +fi +EOF diff --git a/pyproject.toml b/pyproject.toml index 9c40e61..fb6f720 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,9 @@ dependencies = [ 'requests', 'trimesh', 'statsmodels', - 'matplotlib_venn' + 'matplotlib_venn', + 'scipy', + 'scikit-learn', ] [project.optional-dependencies] diff --git a/src/LCNE_patchseq_analysis/efel/core.py b/src/LCNE_patchseq_analysis/efel/core.py index dc1a739..51b40c0 100644 --- a/src/LCNE_patchseq_analysis/efel/core.py +++ b/src/LCNE_patchseq_analysis/efel/core.py @@ -159,9 +159,7 @@ def extract_spike_waveforms( DataFrame containing spike waveforms """ peak_times = ( - features_dict["df_features_per_spike"] - .reset_index() - .set_index("sweep_number")["peak_time"] + features_dict["df_features_per_spike"].reset_index().set_index("sweep_number")["peak_time"] ) # Time can be determined by the sampling rate @@ -176,9 +174,9 @@ def extract_spike_waveforms( v = raw_trace["V"] for peak_time in peak_times_this_sweep: - idx = np.where( - (t >= peak_time + spike_window[0]) & (t < peak_time + spike_window[1]) - )[0] + idx = np.where((t >= peak_time + spike_window[0]) & (t < peak_time + spike_window[1]))[ + 0 + ] v_this = v[idx] vs.append(v_this) @@ -312,9 +310,7 @@ def extract_features_using_efel( # Append stimulus to raw_traces (doing here because eFEL cannot handle it) for raw_trace in raw_traces: raw_trace["stimulus"] = raw.get_stimulus(raw_trace["sweep_number"][0]) - df_peri_stimulus_raw_traces = extract_peri_stimulus_raw_traces( - raw_traces, features_dict - ) + df_peri_stimulus_raw_traces = extract_peri_stimulus_raw_traces(raw_traces, features_dict) # -- Enrich df_sweeps -- df_sweeps = raw.df_sweeps.copy() @@ -323,9 +319,9 @@ def extract_features_using_efel( "spike_count": "efel_num_spikes", "first_spike_AP_width": "efel_first_spike_AP_width", } - _df_to_df_sweeps = features_dict["df_features_per_sweep"][ - list(col_to_df_sweeps.keys()) - ].rename(columns=col_to_df_sweeps) + _df_to_df_sweeps = features_dict["df_features_per_sweep"][list(col_to_df_sweeps.keys())].rename( + columns=col_to_df_sweeps + ) df_sweeps = df_sweeps.merge(_df_to_df_sweeps, on="sweep_number", how="left") # Add metadata to features_dict @@ -377,9 +373,7 @@ def extract_efel_one( except Exception as e: import traceback - error_message = ( - f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" - ) + error_message = f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" logger.error(error_message) return error_message diff --git a/src/LCNE_patchseq_analysis/efel/pipeline.py b/src/LCNE_patchseq_analysis/efel/pipeline.py index 1677df0..3e8d9d3 100644 --- a/src/LCNE_patchseq_analysis/efel/pipeline.py +++ b/src/LCNE_patchseq_analysis/efel/pipeline.py @@ -15,9 +15,7 @@ logger = logging.getLogger(__name__) -def extract_efel_features_in_parallel( - skip_existing: bool = True, skip_errors: bool = True -): +def extract_efel_features_in_parallel(skip_existing: bool = True, skip_errors: bool = True): """Extract eFEL features in parallel.""" def get_roi_ids(): @@ -25,9 +23,7 @@ def get_roi_ids(): return df_meta["ephys_roi_id_tab_master"] def check_existing(ephys_roi_id): - return os.path.exists( - f"{RESULTS_DIRECTORY}/features/{int(ephys_roi_id)}_efel.h5" - ) + return os.path.exists(f"{RESULTS_DIRECTORY}/features/{int(ephys_roi_id)}_efel.h5") return run_parallel_processing( process_func=extract_efel_one, @@ -39,15 +35,11 @@ def check_existing(ephys_roi_id): ) -def generate_sweep_plots_in_parallel( - skip_existing: bool = True, skip_errors: bool = True -): +def generate_sweep_plots_in_parallel(skip_existing: bool = True, skip_errors: bool = True): """Generate sweep plots in parallel.""" def check_existing(ephys_roi_id): - return os.path.exists( - f"{RESULTS_DIRECTORY}/plots/{int(ephys_roi_id)}/all_success" - ) + return os.path.exists(f"{RESULTS_DIRECTORY}/plots/{int(ephys_roi_id)}/all_success") return run_parallel_processing( process_func=generate_sweep_plots_one, @@ -58,9 +50,7 @@ def check_existing(ephys_roi_id): ) -def extract_cell_level_stats_in_parallel( - skip_errors: bool = True, if_generate_plots: bool = True -): +def extract_cell_level_stats_in_parallel(skip_errors: bool = True, if_generate_plots: bool = True): """Extract cell-level statistics from all available eFEL features files in parallel.""" # ---- Extract cell-level stats ---- @@ -111,9 +101,9 @@ def extract_cell_level_stats_in_parallel( ) # ---- Merge into Brian's spreadsheet ---- - df_ephys_metadata = load_ephys_metadata( - if_from_s3=False, combine_roi_ids=True - ).rename(columns={"ephys_roi_id_tab_master": "ephys_roi_id"}) + df_ephys_metadata = load_ephys_metadata(if_from_s3=False, combine_roi_ids=True).rename( + columns={"ephys_roi_id_tab_master": "ephys_roi_id"} + ) df_merged = df_ephys_metadata.merge(df_cell_stats, on="ephys_roi_id", how="left") # ---- Post-processing ---- @@ -127,9 +117,7 @@ def extract_cell_level_stats_in_parallel( # Remove "first_spike_" in all column names df_merged.columns = [col.replace("first_spike_", "") for col in df_merged.columns] # Add EFEL_prefix to all columns that has @ in its name - df_merged.columns = [ - f"efel_{col}" if "@" in col else col for col in df_merged.columns - ] + df_merged.columns = [f"efel_{col}" if "@" in col else col for col in df_merged.columns] # ---- Save the summary table to disk ---- save_path = f"{RESULTS_DIRECTORY}/cell_stats/cell_level_stats.csv" @@ -148,9 +136,7 @@ def extract_cell_level_stats_in_parallel( save_path = f"{RESULTS_DIRECTORY}/cell_stats/cell_level_last_spike_waveforms.pkl" df_cell_representative_last_spike_waveforms.to_pickle(save_path) - logger.info( - f"Successfully extracted cell-level stats for {len(df_cell_stats)} cells!" - ) + logger.info(f"Successfully extracted cell-level stats for {len(df_cell_stats)} cells!") logger.info(f"Summary table saved to {save_path}") return df_merged diff --git a/src/LCNE_patchseq_analysis/efel/plot.py b/src/LCNE_patchseq_analysis/efel/plot.py index 2105e62..9ee7b25 100644 --- a/src/LCNE_patchseq_analysis/efel/plot.py +++ b/src/LCNE_patchseq_analysis/efel/plot.py @@ -285,9 +285,7 @@ def plot_overlaid_spikes( + df_spike_feature["peak_voltage"].loc[i] ) / 2 - AP_duration_half_width = df_spike_feature["AP_duration_half_width"].loc[ - i - ] + AP_duration_half_width = df_spike_feature["AP_duration_half_width"].loc[i] ax_v.plot(half_rise_time, half_voltage, "mo", ms=10) ax_v.plot( [half_rise_time, half_rise_time + AP_duration_half_width], @@ -307,9 +305,7 @@ def plot_overlaid_spikes( _t_after_begin = np.where(t >= t_begin)[0] if len(_t_after_begin) > 0: # Sometimes t_begin is None begin_ind = _t_after_begin[0] - ax_phase.plot( - v[begin_ind], dvdt[begin_ind], "go", ms=10, label="AP_begin" - ) + ax_phase.plot(v[begin_ind], dvdt[begin_ind], "go", ms=10, label="AP_begin") ax_phase.axhline( efel_settings["DerivativeThreshold"], color="g", @@ -325,9 +321,7 @@ def plot_overlaid_spikes( ax_phase.plot(xx, yy, "g--", label="AP_phaseslope") # Phase plot: AP_peak_upstroke - ax_phase.axhline( - peak_upstroke, color="c", linestyle="--", label="AP_peak_upstroke" - ) + ax_phase.axhline(peak_upstroke, color="c", linestyle="--", label="AP_peak_upstroke") # Phase plot: AP_peak_downstroke ax_phase.axhline( @@ -348,9 +342,7 @@ def plot_overlaid_spikes( ax_phase.set_xlabel("Voltage (mV)") ax_phase.set_ylabel("dv/dt (mV/ms)") ax_phase.set_title("Phase Plots") - ax_phase.legend( - loc="best", fontsize=12, title="1st spike features", title_fontsize=13 - ) + ax_phase.legend(loc="best", fontsize=12, title="1st spike features", title_fontsize=13) ax_phase.grid(True) fig.tight_layout() @@ -374,13 +366,9 @@ def plot_sweep_summary(features_dict: Dict[str, Any], save_dir: str) -> None: has_spikes = df_sweep_feature["spike_count"] > 0 df_spike_feature = ( - features_dict["df_features_per_spike"].loc[sweep_number] - if has_spikes - else None - ) - df_sweep_meta = features_dict["df_sweeps"].query( - "sweep_number == @sweep_number" + features_dict["df_features_per_spike"].loc[sweep_number] if has_spikes else None ) + df_sweep_meta = features_dict["df_sweeps"].query("sweep_number == @sweep_number") sweep_this = ( features_dict["df_peri_stimulus_raw_traces"] .query("sweep_number == @sweep_number") @@ -388,9 +376,7 @@ def plot_sweep_summary(features_dict: Dict[str, Any], save_dir: str) -> None: ) # Plot raw sweep - fig_sweep = plot_sweep_raw( - sweep_this, df_sweep_meta, df_sweep_feature, df_spike_feature - ) + fig_sweep = plot_sweep_raw(sweep_this, df_sweep_meta, df_sweep_feature, df_spike_feature) fig_sweep.savefig( f"{save_dir}/{ephys_roi_id}/{ephys_roi_id}_sweep_{sweep_number}.png", dpi=400, @@ -399,9 +385,7 @@ def plot_sweep_summary(features_dict: Dict[str, Any], save_dir: str) -> None: # Plot spikes if present if has_spikes: - spike_this = features_dict["df_spike_waveforms"].query( - "sweep_number == @sweep_number" - ) + spike_this = features_dict["df_spike_waveforms"].query("sweep_number == @sweep_number") fig_spikes = plot_overlaid_spikes( spike_this, sweep_this, @@ -431,9 +415,7 @@ def generate_sweep_plots_one(ephys_roi_id: str): except Exception as e: import traceback - error_message = ( - f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" - ) + error_message = f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" logger.error(error_message) return error_message @@ -505,12 +487,8 @@ def plot_cell_summary( # noqa: C901 # -- Spikes -- color_used = [] - df_second_spike_waveforms = features_dict.get( - "df_second_spike_waveforms", pd.DataFrame() - ) - df_last_spike_waveforms = features_dict.get( - "df_last_spike_waveforms", pd.DataFrame() - ) + df_second_spike_waveforms = features_dict.get("df_second_spike_waveforms", pd.DataFrame()) + df_last_spike_waveforms = features_dict.get("df_last_spike_waveforms", pd.DataFrame()) for label, settings in spikes_to_show.items(): sweep_number = settings["sweep_number"] # noqa: F841 @@ -538,9 +516,7 @@ def plot_cell_summary( # noqa: C901 ) width_values = [] if "AP_duration_half_width" in df_spike_feature.columns: - width_values.append( - df_spike_feature["AP_duration_half_width"].values[0] - ) + width_values.append(df_spike_feature["AP_duration_half_width"].values[0]) df_second_spike_feature = features_dict["df_features_per_spike"].query( "sweep_number == @sweep_number and spike_idx == 1" @@ -549,9 +525,7 @@ def plot_cell_summary( # noqa: C901 not df_second_spike_feature.empty and "AP_duration_half_width" in df_second_spike_feature.columns ): - width_values.append( - df_second_spike_feature["AP_duration_half_width"].values[0] - ) + width_values.append(df_second_spike_feature["AP_duration_half_width"].values[0]) df_last_spike_feature = pd.DataFrame() if not df_last_spike_waveforms.empty: @@ -559,21 +533,17 @@ def plot_cell_summary( # noqa: C901 "sweep_number == @sweep_number" ) if len(df_last_spikes_raw) > 0: - last_spike_idx = df_last_spikes_raw.index.get_level_values( - "spike_idx" - )[0] - df_last_spike_feature = features_dict[ - "df_features_per_spike" - ].loc[(sweep_number, last_spike_idx)] + last_spike_idx = df_last_spikes_raw.index.get_level_values("spike_idx")[0] + df_last_spike_feature = features_dict["df_features_per_spike"].loc[ + (sweep_number, last_spike_idx) + ] if isinstance(df_last_spike_feature, pd.Series): df_last_spike_feature = df_last_spike_feature.to_frame().T if ( not df_last_spike_feature.empty and "AP_duration_half_width" in df_last_spike_feature.columns ): - width_values.append( - df_last_spike_feature["AP_duration_half_width"].values[0] - ) + width_values.append(df_last_spike_feature["AP_duration_half_width"].values[0]) if width_values: width_text = ", ".join([f"{value:.2f}" for value in width_values]) @@ -613,9 +583,7 @@ def plot_cell_summary( # noqa: C901 ) if not df_last_spike_waveforms.empty: - df_last_spikes_raw = df_last_spike_waveforms.query( - "sweep_number == @sweep_number" - ) + df_last_spikes_raw = df_last_spike_waveforms.query("sweep_number == @sweep_number") if len(df_last_spikes_raw) > 0: v_last_spike = df_last_spikes_raw.values[0] t_last_spike = df_last_spikes_raw.columns.values @@ -666,12 +634,8 @@ def plot_cell_summary( # noqa: C901 fig.tight_layout() fig.savefig("./tmp.png") - fig.savefig( - f"{RESULTS_DIRECTORY}/cell_stats/{ephys_roi_id}_cell_summary.svg", dpi=500 - ) - fig.savefig( - f"{RESULTS_DIRECTORY}/cell_stats/{ephys_roi_id}_cell_summary.png", dpi=500 - ) + fig.savefig(f"{RESULTS_DIRECTORY}/cell_stats/{ephys_roi_id}_cell_summary.svg", dpi=500) + fig.savefig(f"{RESULTS_DIRECTORY}/cell_stats/{ephys_roi_id}_cell_summary.png", dpi=500) plt.close(fig) return fig diff --git a/src/LCNE_patchseq_analysis/efel/population.py b/src/LCNE_patchseq_analysis/efel/population.py index afb1b0a..5565e64 100644 --- a/src/LCNE_patchseq_analysis/efel/population.py +++ b/src/LCNE_patchseq_analysis/efel/population.py @@ -47,8 +47,7 @@ def _get_min_or_aver(df_this, aggregate_method): if aggregate_method == "aver": # All SubThreshold sweeps return df.query( - "stimulus_code.str.contains('SubThresh')" - "and stimulus_name == 'Long Square'" + "stimulus_code.str.contains('SubThresh')" "and stimulus_name == 'Long Square'" ) elif isinstance(aggregate_method, int): return df.query( @@ -102,12 +101,8 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr ) df_raw_spikes = features_dict["df_spike_waveforms"] - df_second_spike_waveforms = features_dict.get( - "df_second_spike_waveforms", pd.DataFrame() - ) - df_last_spike_waveforms = features_dict.get( - "df_last_spike_waveforms", pd.DataFrame() - ) + df_second_spike_waveforms = features_dict.get("df_second_spike_waveforms", pd.DataFrame()) + df_last_spike_waveforms = features_dict.get("df_last_spike_waveforms", pd.DataFrame()) cell_stats_dict = {} cell_representative_spike_waveforms = [] @@ -116,12 +111,10 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr cell_representative_last_spike_waveforms = [] second_spike_features = [ - feature.replace("first_spike_", "second_spike_") - for feature in EXTRACT_SPIKE_FEATURES + feature.replace("first_spike_", "second_spike_") for feature in EXTRACT_SPIKE_FEATURES ] last_spike_features = [ - feature.replace("first_spike_", "last_spike_") - for feature in EXTRACT_SPIKE_FEATURES + feature.replace("first_spike_", "last_spike_") for feature in EXTRACT_SPIKE_FEATURES ] spike_features_to_extract = ( EXTRACT_SPIKE_FEATURES + second_spike_features + last_spike_features @@ -142,47 +135,35 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr mean_values = df_sweep[features_to_extract].mean() # Create a dictionary with feature names and their mean values feature_this = { - f"{feature} @ {key}": value - for feature, value in mean_values.items() + f"{feature} @ {key}": value for feature, value in mean_values.items() } cell_stats_dict.update(feature_this) # --- Extract cell-representative spike waveforms --- # Note: averages across all spikes in the selected sweeps. - df_spikes = df_raw_spikes.query( - "sweep_number in @df_sweep.sweep_number.values" - ) + df_spikes = df_raw_spikes.query("sweep_number in @df_sweep.sweep_number.values") if not df_spikes.empty: averaged_spike_waveforms = pd.DataFrame(df_spikes.mean()).T averaged_spike_waveforms.index = pd.MultiIndex.from_tuples( [(ephys_roi_id, key)], names=["ephys_roi_id", "extract_from"], ) - cell_representative_spike_waveforms.append( - averaged_spike_waveforms - ) + cell_representative_spike_waveforms.append(averaged_spike_waveforms) df_first_spikes = df_raw_spikes.query( "sweep_number in @df_sweep.sweep_number.values and spike_idx == 0" ) if not df_first_spikes.empty: - averaged_first_spike_waveforms = pd.DataFrame( - df_first_spikes.mean() - ).T - averaged_first_spike_waveforms.index = ( - pd.MultiIndex.from_tuples( - [(ephys_roi_id, key)], - names=["ephys_roi_id", "extract_from"], - ) + averaged_first_spike_waveforms = pd.DataFrame(df_first_spikes.mean()).T + averaged_first_spike_waveforms.index = pd.MultiIndex.from_tuples( + [(ephys_roi_id, key)], + names=["ephys_roi_id", "extract_from"], ) cell_representative_first_spike_waveforms.append( averaged_first_spike_waveforms ) - if ( - feature_type is EXTRACT_SPIKE_FROMS - and not df_second_spike_waveforms.empty - ): + if feature_type is EXTRACT_SPIKE_FROMS and not df_second_spike_waveforms.empty: df_second_spikes = df_second_spike_waveforms.query( "sweep_number in @df_sweep.sweep_number.values" ) @@ -190,32 +171,23 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr averaged_second_spike_waveforms = pd.DataFrame( df_second_spikes.mean() ).T - averaged_second_spike_waveforms.index = ( - pd.MultiIndex.from_tuples( - [(ephys_roi_id, key)], - names=["ephys_roi_id", "extract_from"], - ) + averaged_second_spike_waveforms.index = pd.MultiIndex.from_tuples( + [(ephys_roi_id, key)], + names=["ephys_roi_id", "extract_from"], ) cell_representative_second_spike_waveforms.append( averaged_second_spike_waveforms ) - if ( - feature_type is EXTRACT_SPIKE_FROMS - and not df_last_spike_waveforms.empty - ): + if feature_type is EXTRACT_SPIKE_FROMS and not df_last_spike_waveforms.empty: df_last_spikes = df_last_spike_waveforms.query( "sweep_number in @df_sweep.sweep_number.values" ) if not df_last_spikes.empty: - averaged_last_spike_waveforms = pd.DataFrame( - df_last_spikes.mean() - ).T - averaged_last_spike_waveforms.index = ( - pd.MultiIndex.from_tuples( - [(ephys_roi_id, key)], - names=["ephys_roi_id", "extract_from"], - ) + averaged_last_spike_waveforms = pd.DataFrame(df_last_spikes.mean()).T + averaged_last_spike_waveforms.index = pd.MultiIndex.from_tuples( + [(ephys_roi_id, key)], + names=["ephys_roi_id", "extract_from"], ) cell_representative_last_spike_waveforms.append( averaged_last_spike_waveforms @@ -225,9 +197,7 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr cell_stats_dict, index=pd.Index([ephys_roi_id], name="ephys_roi_id") ) if cell_representative_spike_waveforms: - df_cell_representative_spike_waveforms = pd.concat( - cell_representative_spike_waveforms - ) + df_cell_representative_spike_waveforms = pd.concat(cell_representative_spike_waveforms) else: df_cell_representative_spike_waveforms = pd.DataFrame() @@ -258,9 +228,7 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr if not if_generate_plots: return "Success", { "df_cell_stats": df_cell_stats, - "df_cell_representative_spike_waveforms": ( - df_cell_representative_spike_waveforms - ), + "df_cell_representative_spike_waveforms": (df_cell_representative_spike_waveforms), "df_cell_representative_first_spike_waveforms": ( df_cell_representative_first_spike_waveforms ), @@ -310,20 +278,14 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr sweeps_to_show=to_plot["sweeps"], spikes_to_show=to_plot["spikes"], info_text=info_text, - region_color=REGION_COLOR_MAPPER.get( - df_this["injection region"].lower(), "black" - ), + region_color=REGION_COLOR_MAPPER.get(df_this["injection region"].lower(), "black"), ) - logger.info( - f"Successfully generated cell-level summary plots for {ephys_roi_id}!" - ) + logger.info(f"Successfully generated cell-level summary plots for {ephys_roi_id}!") return "Success", { "df_cell_stats": df_cell_stats, - "df_cell_representative_spike_waveforms": ( - df_cell_representative_spike_waveforms - ), + "df_cell_representative_spike_waveforms": (df_cell_representative_spike_waveforms), "df_cell_representative_first_spike_waveforms": ( df_cell_representative_first_spike_waveforms ), @@ -337,15 +299,11 @@ def extract_cell_level_stats_one(ephys_roi_id: str, if_generate_plots: bool = Tr except Exception as e: import traceback - error_message = ( - f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" - ) + error_message = f"Error processing {ephys_roi_id}: {str(e)}\n{traceback.format_exc()}" logger.error(error_message) return error_message if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - status, cell_stats = extract_cell_level_stats_one( - "1428138882", if_generate_plots=True - ) + status, cell_stats = extract_cell_level_stats_one("1428138882", if_generate_plots=True) diff --git a/src/LCNE_patchseq_analysis/figures/__init__.py b/src/LCNE_patchseq_analysis/figures/__init__.py index 6846c7a..796a3db 100644 --- a/src/LCNE_patchseq_analysis/figures/__init__.py +++ b/src/LCNE_patchseq_analysis/figures/__init__.py @@ -1,9 +1,24 @@ """Init figures package""" +from pathlib import Path + import matplotlib as mpl +from matplotlib import font_manager + +REPO_ROOT = Path(__file__).resolve().parents[3] +FONT_CANDIDATES = [ + Path("/data/fonts/Helvetica.ttf"), + REPO_ROOT / "assets" / "fonts" / "Helvetica.ttf", +] +for font_path in FONT_CANDIDATES: + if font_path.exists(): + font_manager.fontManager.addfont(str(font_path)) + break + +DEFAULT_FONT_FAMILY = "Helvetica" mpl.rcParams["svg.fonttype"] = "none" -mpl.rcParams["font.family"] = "Helvetica" +mpl.rcParams["font.family"] = DEFAULT_FONT_FAMILY GLOBAL_FILTER = ( "(`jem-status_reporter` == 'Positive') & " @@ -84,7 +99,7 @@ def _region_sort_key(region): return sorted(region, key=_region_sort_key) -def set_plot_style(base_size: int = 11, font_family: str = "Helvetica"): +def set_plot_style(base_size: int = 11, font_family: str = DEFAULT_FONT_FAMILY): # Seaborn first (it may overwrite some rcParams) # sns.set_theme(context="paper", style="white", font_scale=1.0) mpl.rcParams.update( diff --git a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.ipynb b/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.ipynb deleted file mode 100644 index 1a6d3ea..0000000 --- a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.ipynb +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b678dfe6b25bef43a99567b623ca98418c927d18e5d935cca231ec5168aaa3bc -size 1166835 diff --git a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.ipynb b/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.ipynb deleted file mode 100644 index 1cec896..0000000 --- a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.ipynb +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ed0daaff92439816a1e0d074f710ae58b147023521ee34f6eaef960f679df545 -size 202883 diff --git a/src/LCNE_patchseq_analysis/figures/cached/__init__.py b/src/LCNE_patchseq_analysis/figures/cached/__init__.py new file mode 100644 index 0000000..1f0f66c --- /dev/null +++ b/src/LCNE_patchseq_analysis/figures/cached/__init__.py @@ -0,0 +1 @@ +"""Cached figure scripts moved from figures/ root.""" diff --git a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.py b/src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_eFEL.py similarity index 92% rename from src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.py rename to src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_eFEL.py index a499bbd..06d21df 100644 --- a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_eFEL.py +++ b/src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_eFEL.py @@ -7,7 +7,9 @@ from LCNE_patchseq_analysis import RESULTS_DIRECTORY from LCNE_patchseq_analysis.data_util.metadata import load_ephys_metadata from LCNE_patchseq_analysis.figures import GLOBAL_FILTER, set_plot_style -from LCNE_patchseq_analysis.figures.fig_3c import _generate_multi_feature_scatter_plots +from LCNE_patchseq_analysis.figures.cached.fig_3c import ( + _generate_multi_feature_scatter_plots, +) from LCNE_patchseq_analysis.figures.util import save_figure from LCNE_patchseq_analysis.population_analysis.anova import anova_features @@ -65,9 +67,7 @@ def main(if_save_figure: bool = True): # 2) Compute asymmetry columns from all available rise/fall eFEL pairs. add_efel_asymmetry_columns(df_meta_filtered) - asymmetry_cols = sorted( - [col for col in df_meta_filtered.columns if "_asymmetry @" in col] - ) + asymmetry_cols = sorted([col for col in df_meta_filtered.columns if "_asymmetry @" in col]) asymmetry_features = [{col: format_asymmetry_name(col)} for col in asymmetry_cols] # 3) Run ANCOVA to test projection effects while controlling for y. @@ -90,9 +90,7 @@ def main(if_save_figure: bool = True): n_cols=4, ) if if_save_figure: - output_dir = os.path.join( - RESULTS_DIRECTORY, "figures", "asymmetry_eFEL_summary" - ) + output_dir = os.path.join(RESULTS_DIRECTORY, "figures", "asymmetry_eFEL_summary") save_figure( fig, output_dir=output_dir, diff --git a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.py b/src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_waveform.py similarity index 90% rename from src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.py rename to src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_waveform.py index e6490c7..bbe1609 100644 --- a/src/LCNE_patchseq_analysis/figures/asymmetry_index_from_waveform.py +++ b/src/LCNE_patchseq_analysis/figures/cached/asymmetry_index_from_waveform.py @@ -1,9 +1,9 @@ import logging import os +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler @@ -15,10 +15,14 @@ GLOBAL_FILTER, set_plot_style, ) +from LCNE_patchseq_analysis.figures.cached.fig_3c import ( + _generate_multi_feature_scatter_plots, +) from LCNE_patchseq_analysis.figures.util import generate_violin_plot, save_figure -from LCNE_patchseq_analysis.figures.fig_3c import _generate_multi_feature_scatter_plots -from LCNE_patchseq_analysis.pipeline_util.s3 import get_public_representative_spikes -from LCNE_patchseq_analysis.pipeline_util.s3 import load_mesh_from_s3 +from LCNE_patchseq_analysis.pipeline_util.s3 import ( + get_public_representative_spikes, + load_mesh_from_s3, +) from LCNE_patchseq_analysis.population_analysis.anova import anova_features logger = logging.getLogger(__name__) @@ -38,9 +42,7 @@ def compute_spike_metrics(times_ms: np.ndarray, voltages: np.ndarray) -> dict: peak_deriv_idx = int(np.argmax(dv)) post_peak_slice = dv[peak_deriv_idx + 1 :] trough_deriv_idx = ( - int(peak_deriv_idx + 1 + np.argmin(post_peak_slice)) - if post_peak_slice.size > 0 - else None + int(peak_deriv_idx + 1 + np.argmin(post_peak_slice)) if post_peak_slice.size > 0 else None ) fall_time_ms = ( @@ -52,9 +54,7 @@ def compute_spike_metrics(times_ms: np.ndarray, voltages: np.ndarray) -> dict: peak_idx = int(np.argmax(voltages)) post_peak_voltage = voltages[peak_idx + 1 :] trough_idx = ( - int(peak_idx + 1 + np.argmin(post_peak_voltage)) - if post_peak_voltage.size > 0 - else None + int(peak_idx + 1 + np.argmin(post_peak_voltage)) if post_peak_voltage.size > 0 else None ) kick_idx = None @@ -77,16 +77,10 @@ def compute_spike_metrics(times_ms: np.ndarray, voltages: np.ndarray) -> dict: else: kick_idx = int(local_maxima[-1]) - rise_time_ms = ( - times_ms[peak_deriv_idx] - times_ms[kick_idx] - if kick_idx is not None - else np.nan - ) + rise_time_ms = times_ms[peak_deriv_idx] - times_ms[kick_idx] if kick_idx is not None else np.nan time_asymmetry = ( - rise_time_ms / fall_time_ms - if np.isfinite(fall_time_ms) and fall_time_ms != 0 - else np.nan + rise_time_ms / fall_time_ms if np.isfinite(fall_time_ms) and fall_time_ms != 0 else np.nan ) return { @@ -105,9 +99,7 @@ def compute_spike_metrics(times_ms: np.ndarray, voltages: np.ndarray) -> dict: def compute_metrics_for_table(df_spikes: pd.DataFrame, spike_type: str): """Compute timing metrics and keep waveform vectors for PCA.""" df_spikes = df_spikes.copy() - df_spikes.index = df_spikes.index.set_levels( - df_spikes.index.levels[0].astype(str), level=0 - ) + df_spikes.index = df_spikes.index.set_levels(df_spikes.index.levels[0].astype(str), level=0) times_ms = df_spikes.columns.to_numpy(dtype=float) records = [] waveforms = [] @@ -163,14 +155,10 @@ def add_pca_variants(df_metrics: pd.DataFrame, ephys_cols: list[str]) -> pd.Data idx = df_group.index waveforms = np.vstack(df_group["waveform_vector"].to_numpy()) derivatives = np.vstack(df_group["dvdt_vector"].to_numpy()) - derived = df_group[["rise_time_ms", "fall_time_ms", "time_asymmetry"]].to_numpy( - dtype=float - ) + derived = df_group[["rise_time_ms", "fall_time_ms", "time_asymmetry"]].to_numpy(dtype=float) if ephys_cols: ephys_values = ( - df_group[ephys_cols] - .apply(pd.to_numeric, errors="coerce") - .to_numpy(dtype=float) + df_group[ephys_cols].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=float) ) else: ephys_values = None @@ -178,9 +166,7 @@ def add_pca_variants(df_metrics: pd.DataFrame, ephys_cols: list[str]) -> pd.Data df_metrics.loc[idx, "pc1_dvdt"] = compute_pc1(derivatives) derived_ephys = ( - np.column_stack([derived, ephys_values]) - if ephys_values is not None - else derived + np.column_stack([derived, ephys_values]) if ephys_values is not None else derived ) df_metrics.loc[idx, "pc1_derived_ephys"] = compute_pc1(derived_ephys) @@ -286,9 +272,7 @@ def sanitize_filename(text: str) -> str: def plot_summary_by_region(df_metrics: pd.DataFrame, output_dir: str) -> None: """Summarize time asymmetry and PC1 by injection region.""" os.makedirs(output_dir, exist_ok=True) - for (spike_type, extract_from), df_group in df_metrics.groupby( - ["spike_type", "extract_from"] - ): + for (spike_type, extract_from), df_group in df_metrics.groupby(["spike_type", "extract_from"]): df_group = df_group.dropna(subset=["injection region"]) if df_group.empty: continue @@ -340,9 +324,7 @@ def format_waveform_feature_name(col_name: str) -> str: def plot_waveform_asymmetry_grid(df_metrics: pd.DataFrame, output_dir: str) -> None: """Generate a multi-panel scatter grid for waveform time asymmetry.""" - df_base = df_metrics[["ephys_roi_id", "injection region", "y"]].drop_duplicates( - "ephys_roi_id" - ) + df_base = df_metrics[["ephys_roi_id", "injection region", "y"]].drop_duplicates("ephys_roi_id") df_features = df_metrics[ ["ephys_roi_id", "extract_from", "spike_type", "time_asymmetry"] ].copy() @@ -463,9 +445,7 @@ def plot_pc_on_mesh(df_metrics: pd.DataFrame, output_dir: str) -> None: def run_anova_by_group(df_metrics: pd.DataFrame) -> pd.DataFrame: """Run ANCOVA for time asymmetry and PC1 across groups.""" results = [] - for (spike_type, extract_from), df_group in df_metrics.groupby( - ["spike_type", "extract_from"] - ): + for (spike_type, extract_from), df_group in df_metrics.groupby(["spike_type", "extract_from"]): features = ["time_asymmetry", "pc1_all"] df_anova = anova_features( df_group, @@ -503,18 +483,14 @@ def main(): metrics_all.append(df_metrics) # 2) Render a subset of diagnostic examples for manual inspection. - example_pool = df_metrics.dropna( - subset=["kick_idx", "peak_deriv_idx", "trough_deriv_idx"] - ) + example_pool = df_metrics.dropna(subset=["kick_idx", "peak_deriv_idx", "trough_deriv_idx"]) if example_pool.empty: example_rows = df_metrics.head(0) else: example_rows = example_pool.sample( n=min(example_count, len(example_pool)), random_state=42 ) - output_dir = os.path.join( - RESULTS_DIRECTORY, "figures", "asymmetry_waveform_examples" - ) + output_dir = os.path.join(RESULTS_DIRECTORY, "figures", "asymmetry_waveform_examples") for _, row in example_rows.iterrows(): plot_example_spike( df_spikes, @@ -551,9 +527,7 @@ def main(): # 6) Generate summary plots, mesh projections, and a large asymmetry grid. plot_summary_by_region( df_metrics_all, - output_dir=os.path.join( - RESULTS_DIRECTORY, "figures", "asymmetry_waveform_summary" - ), + output_dir=os.path.join(RESULTS_DIRECTORY, "figures", "asymmetry_waveform_summary"), ) plot_pc_on_mesh( df_metrics_all, @@ -561,20 +535,14 @@ def main(): ) plot_waveform_asymmetry_grid( df_metrics_all, - output_dir=os.path.join( - RESULTS_DIRECTORY, "figures", "asymmetry_waveform_summary" - ), + output_dir=os.path.join(RESULTS_DIRECTORY, "figures", "asymmetry_waveform_summary"), ) # 7) Save metrics and ANCOVA results. output_dir = os.path.join(RESULTS_DIRECTORY, "analysis") os.makedirs(output_dir, exist_ok=True) - df_metrics_all.to_csv( - os.path.join(output_dir, "waveform_asymmetry_metrics.csv"), index=False - ) - df_anova.to_csv( - os.path.join(output_dir, "waveform_asymmetry_anova.csv"), index=False - ) + df_metrics_all.to_csv(os.path.join(output_dir, "waveform_asymmetry_metrics.csv"), index=False) + df_anova.to_csv(os.path.join(output_dir, "waveform_asymmetry_anova.csv"), index=False) logger.info("Saved metrics to %s", output_dir) diff --git a/src/LCNE_patchseq_analysis/figures/fig_3a.py b/src/LCNE_patchseq_analysis/figures/cached/fig_3a.py old mode 100755 new mode 100644 similarity index 92% rename from src/LCNE_patchseq_analysis/figures/fig_3a.py rename to src/LCNE_patchseq_analysis/figures/cached/fig_3a.py index 5df4bc7..974b9dd --- a/src/LCNE_patchseq_analysis/figures/fig_3a.py +++ b/src/LCNE_patchseq_analysis/figures/cached/fig_3a.py @@ -4,7 +4,11 @@ from LCNE_patchseq_analysis import REGION_COLOR_MAPPER from LCNE_patchseq_analysis.data_util.metadata import load_ephys_metadata -from LCNE_patchseq_analysis.figures.util import generate_ccf_plot, generate_violin_plot, save_figure +from LCNE_patchseq_analysis.figures.util import ( + generate_ccf_plot, + generate_violin_plot, + save_figure, +) # Configure logging logger = logging.getLogger() @@ -27,7 +31,12 @@ def figure_3a_ccf_sagittal( """ fig, ax = generate_ccf_plot( - df_meta, filter_query, view="sagittal", ax=ax, show_marginal_x=True, show_marginal_y=True + df_meta, + filter_query, + view="sagittal", + ax=ax, + show_marginal_x=True, + show_marginal_y=True, ) if if_save_figure: diff --git a/src/LCNE_patchseq_analysis/figures/fig_3b.py b/src/LCNE_patchseq_analysis/figures/cached/fig_3b.py similarity index 100% rename from src/LCNE_patchseq_analysis/figures/fig_3b.py rename to src/LCNE_patchseq_analysis/figures/cached/fig_3b.py diff --git a/src/LCNE_patchseq_analysis/figures/fig_3c.py b/src/LCNE_patchseq_analysis/figures/cached/fig_3c.py similarity index 93% rename from src/LCNE_patchseq_analysis/figures/fig_3c.py rename to src/LCNE_patchseq_analysis/figures/cached/fig_3c.py index c757816..21aeb4a 100644 --- a/src/LCNE_patchseq_analysis/figures/fig_3c.py +++ b/src/LCNE_patchseq_analysis/figures/cached/fig_3c.py @@ -159,7 +159,10 @@ def _generate_multi_feature_scatter_plots( # noqa: C901 def figure_3c_tau_comparison( - df_meta: pd.DataFrame, filter_query: str | None = None, if_save_figure: bool = True, ax=None + df_meta: pd.DataFrame, + filter_query: str | None = None, + if_save_figure: bool = True, + ax=None, ): """ Generate and save violin plot for ipfx_tau grouped by injection region (Figure 3B). @@ -205,7 +208,10 @@ def figure_3c_tau_comparison( def figure_3c_latency_comparison( - df_meta: pd.DataFrame, filter_query: str | None = None, if_save_figure: bool = True, ax=None + df_meta: pd.DataFrame, + filter_query: str | None = None, + if_save_figure: bool = True, + ax=None, ): """ Generate and save violin plot for ipfx_latency grouped by injection region (Figure 3B). @@ -240,13 +246,21 @@ def figure_3c_latency_comparison( ax.set_ylabel("Latency to first spike\nat rheobase (s)") if if_save_figure: - save_figure(fig, filename="fig_3c_violinplot_ipfx_latency", dpi=300, formats=("png", "svg")) + save_figure( + fig, + filename="fig_3c_violinplot_ipfx_latency", + dpi=300, + formats=("png", "svg"), + ) print("Figure saved as fig_3c_violinplot_ipfx_latency.png/.svg") return fig, ax def sup_figure_3c_all_ipfx_features( - df_meta: pd.DataFrame, filter_query: str | None = None, if_save_figure: bool = True, ax=None + df_meta: pd.DataFrame, + filter_query: str | None = None, + if_save_figure: bool = True, + ax=None, ): """ Generate and save scatter plots for all ipfx features vs anatomical y coordinate. @@ -280,7 +294,10 @@ def sup_figure_3c_all_ipfx_features( def sup_figure_3b_all_gene_features( - df_meta: pd.DataFrame, filter_query: str | None = None, if_save_figure: bool = True, ax=None + df_meta: pd.DataFrame, + filter_query: str | None = None, + if_save_figure: bool = True, + ax=None, ): """ Generate and save scatter plots for all gene features vs anatomical y coordinate. @@ -313,7 +330,10 @@ def sup_figure_3b_all_gene_features( def sup_figure_3d_morphology( - df_meta: pd.DataFrame, filter_query: str | None = None, if_save_figure: bool = True, ax=None + df_meta: pd.DataFrame, + filter_query: str | None = None, + if_save_figure: bool = True, + ax=None, ): """ Generate and save scatter plots for all morphology features vs anatomical y coordinate. diff --git a/src/LCNE_patchseq_analysis/figures/main_figure.py b/src/LCNE_patchseq_analysis/figures/cached/main_figure.py similarity index 88% rename from src/LCNE_patchseq_analysis/figures/main_figure.py rename to src/LCNE_patchseq_analysis/figures/cached/main_figure.py index 48fa6ec..241b9f9 100644 --- a/src/LCNE_patchseq_analysis/figures/main_figure.py +++ b/src/LCNE_patchseq_analysis/figures/cached/main_figure.py @@ -1,15 +1,15 @@ import matplotlib.pyplot as plt from LCNE_patchseq_analysis.figures import set_plot_style -from LCNE_patchseq_analysis.figures.fig_3a import ( +from LCNE_patchseq_analysis.figures.cached.fig_3a import ( figure_3a_ccf_sagittal, sup_figure_3a_ccf_coronal, ) -from LCNE_patchseq_analysis.figures.fig_3b import ( +from LCNE_patchseq_analysis.figures.cached.fig_3b import ( figure_3b_imputed_MERFISH, figure_3b_imputed_scRNAseq, ) -from LCNE_patchseq_analysis.figures.fig_3c import figure_3c_tau_comparison +from LCNE_patchseq_analysis.figures.cached.fig_3c import figure_3c_tau_comparison from LCNE_patchseq_analysis.figures.util import save_figure set_plot_style(base_size=12, font_family="Helvetica") @@ -67,14 +67,17 @@ def generate_main_figure( if if_save_figure: save_figure( - fig, filename="main_figure", dpi=300, formats=("png", "svg"), bbox_inches="tight" + fig, + filename="main_figure", + dpi=300, + formats=("png", "svg"), + bbox_inches="tight", ) print("Figure saved as main_figure.png/.svg") return fig if __name__ == "__main__": - from LCNE_patchseq_analysis.data_util.metadata import load_ephys_metadata df_meta = load_ephys_metadata(if_from_s3=True, if_with_seq=True) diff --git a/src/LCNE_patchseq_analysis/figures/sup_figures.py b/src/LCNE_patchseq_analysis/figures/cached/sup_figures.py similarity index 80% rename from src/LCNE_patchseq_analysis/figures/sup_figures.py rename to src/LCNE_patchseq_analysis/figures/cached/sup_figures.py index 10adc7f..4dd3c96 100644 --- a/src/LCNE_patchseq_analysis/figures/sup_figures.py +++ b/src/LCNE_patchseq_analysis/figures/cached/sup_figures.py @@ -1,7 +1,7 @@ """Aggregated access to supplemental figure generators.""" -from LCNE_patchseq_analysis.figures.fig_3a import sup_figure_3a_ccf_coronal # noqa: F401 -from LCNE_patchseq_analysis.figures.fig_3c import ( # noqa: F401 +from LCNE_patchseq_analysis.figures.cached.fig_3a import sup_figure_3a_ccf_coronal # noqa: F401 +from LCNE_patchseq_analysis.figures.cached.fig_3c import ( # noqa: F401 sup_figure_3b_all_gene_features, sup_figure_3c_all_ipfx_features, sup_figure_3d_morphology, diff --git a/src/LCNE_patchseq_analysis/figures/main_pca_tau.py b/src/LCNE_patchseq_analysis/figures/main_pca_tau.py new file mode 100644 index 0000000..5c999fa --- /dev/null +++ b/src/LCNE_patchseq_analysis/figures/main_pca_tau.py @@ -0,0 +1,639 @@ +"""Spike waveform PCA analysis and figure generation. + +Reproduces the "Spike Analysis" panel from LCNE-patchseq-viz as static +matplotlib figures, using the same default settings. + +Interactive version: +https://hanhou-patchseq.hf.space/patchseq_panel_viz?tab=1 +""" + +import logging + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from scipy.stats import mannwhitneyu, ttest_ind +from sklearn.decomposition import PCA + +from LCNE_patchseq_analysis import REGION_COLOR_MAPPER +from LCNE_patchseq_analysis.data_util.mesh import plot_mesh +from LCNE_patchseq_analysis.figures import set_plot_style, sort_region +from LCNE_patchseq_analysis.figures.util import save_figure +from LCNE_patchseq_analysis.pipeline_util.s3 import ( + get_public_representative_spikes, + load_mesh_from_s3, +) +from LCNE_patchseq_analysis.population_analysis.spikes import ( + extract_representative_spikes, +) + +logger = logging.getLogger(__name__) + +# --- Default parameters matching the viz app widget defaults --- +DEFAULT_SPIKE_TYPE = "average" +DEFAULT_EXTRACT_FROM = "long_square_rheo, min" +DEFAULT_SPIKE_RANGE = (-3, 6) # ms +DEFAULT_NORMALIZE_WINDOW_V = (-2, 4) # ms +DEFAULT_NORMALIZE_WINDOW_DVDT = (-2, 0) # ms + +SPINAL_REGIONS = ["C5", "Spinal cord"] +CORTEX_REGIONS = ["Cortex", "PL", "PL, MOs"] +CB_REGIONS = ["Crus 1", "Cerebellum", "CB"] + + +def spike_pca_analysis( + df_meta: pd.DataFrame, + df_spikes: pd.DataFrame | None = None, + spike_type: str = DEFAULT_SPIKE_TYPE, + extract_from: str = DEFAULT_EXTRACT_FROM, + spike_range: tuple = DEFAULT_SPIKE_RANGE, + normalize_window_v: tuple = DEFAULT_NORMALIZE_WINDOW_V, + normalize_window_dvdt: tuple = DEFAULT_NORMALIZE_WINDOW_DVDT, + filtered_df_meta: pd.DataFrame | None = None, +) -> dict: + """Run PCA on spike waveforms. + + Parameters + ---------- + df_meta : pd.DataFrame + Metadata with columns including 'ephys_roi_id', 'injection region', + 'X (A --> P)', 'Y (D --> V)', and optionally ipfx_tau columns. + df_spikes : pd.DataFrame, optional + Pre-loaded spike waveforms. If None, loads from S3 using spike_type. + spike_type : str + Which spike to use: "average", "first", "second", or "last". + extract_from : str + Stimulus source to extract spikes from. + spike_range : tuple + (start, end) in ms — time window of spike waveform to analyse. + normalize_window_v : tuple + (start, end) in ms — window for V min-max normalization. + normalize_window_dvdt : tuple + (start, end) in ms — window for dV/dt min-max normalization. + filtered_df_meta : pd.DataFrame, optional + If provided, restrict analysis to these cells. + + Returns + ------- + dict with keys: df_v_proj, df_v_norm, reducer, tau_col + """ + if df_spikes is None: + df_spikes = get_public_representative_spikes(spike_type) + + # Extract normalised, peak-aligned waveforms + df_v_norm, _df_dvdt_norm = extract_representative_spikes( + df_spikes=df_spikes, + extract_from=extract_from, + if_normalize_v=True, + normalize_window_v=normalize_window_v, + if_normalize_dvdt=True, + normalize_window_dvdt=normalize_window_dvdt, + if_smooth_dvdt=False, + if_align_dvdt_peaks=True, + filtered_df_meta=filtered_df_meta if filtered_df_meta is not None else df_meta, + ) + + # Filter to spike_range (cast columns to numeric for static type checkers) + time_cols = np.asarray( + pd.to_numeric(df_v_norm.columns, errors="coerce"), dtype=float + ) + in_window = ( + np.isfinite(time_cols) + & (time_cols >= spike_range[0]) + & (time_cols <= spike_range[1]) + ) + df_v_norm = df_v_norm.loc[:, in_window] + + # PCA + v = df_v_norm.values + reducer = PCA() + v_proj = reducer.fit_transform(v) + n_components = 5 + columns = [f"PCA{i}" for i in range(1, n_components + 1)] + df_v_proj = pd.DataFrame(v_proj[:, :n_components], index=df_v_norm.index) + df_v_proj.columns = columns + + # Merge metadata + tau_col = [c for c in df_meta.columns if "ipfx_tau" in c][0] + x_col, y_col = ("X (A --> P)", "Y (D --> V)") + if x_col not in df_meta.columns: + x_col, y_col = ("x", "y") + + merge_cols = ["ephys_roi_id", "injection region", x_col, y_col, tau_col] + df_v_proj = df_v_proj.merge(df_meta[merge_cols], on="ephys_roi_id", how="left") + + if x_col == "x": + df_v_proj = df_v_proj.rename(columns={"x": "X (A --> P)", "y": "Y (D --> V)"}) + + # ipfx_tau is in seconds; convert to milliseconds for plotting. + tau_ms_col = "membrane_time_constant_ms" + tau_vals = np.asarray( + pd.to_numeric(df_v_proj[tau_col], errors="coerce"), dtype=float + ) + df_v_proj[tau_ms_col] = tau_vals * 1000.0 + + return { + "df_v_proj": df_v_proj, + "df_v_norm": df_v_norm, + "reducer": reducer, + "tau_col": tau_ms_col, + } + + +# --------------------------------------------------------------------------- +# Private plotting helpers +# --------------------------------------------------------------------------- + + +def _plot_pca_scatter(ax, df_v_proj, marker_size=50): + """PC1 vs PC2 scatter coloured by projection target.""" + for region in df_v_proj["injection region"].unique(): + sub = df_v_proj.query("`injection region` == @region") + color = REGION_COLOR_MAPPER.get(region, "gray") + ax.scatter( + sub["PCA1"], + sub["PCA2"], + c=color, + s=marker_size, + alpha=0.8, + label=f"{region} (n={len(sub)})", + edgecolors="none", + ) + + ax.set_xlabel("PC1") + ax.set_ylabel("PC2") + ax.set_aspect("equal") + ax.legend(fontsize=6, loc="best", framealpha=0.5) + sns.despine(ax=ax, trim=True) + + +def _plot_waveform_overlay(ax, df_v_norm, df_v_proj): + """Overlay normalized waveforms for the three projection targets.""" + x = np.asarray(pd.to_numeric(df_v_norm.columns, errors="coerce"), dtype=float) + id_col = "ephys_roi_id" if "ephys_roi_id" in df_v_proj.columns else None + + for label, region_set, color in [ + ("Spinal cord", SPINAL_REGIONS, REGION_COLOR_MAPPER["Spinal cord"]), + ("Cortex", CORTEX_REGIONS, REGION_COLOR_MAPPER["Cortex"]), + ("Cerebellum", CB_REGIONS, REGION_COLOR_MAPPER["Cerebellum"]), + ]: + mask = df_v_proj["injection region"].isin(region_set) + ids = ( + df_v_proj.loc[mask, id_col].to_numpy() + if id_col + else df_v_proj.index.to_numpy() + ) + traces = df_v_norm.loc[df_v_norm.index.isin(ids)] + y = traces.to_numpy(dtype=float) + + for trace in y: + ax.plot(x, trace, color=color, alpha=0.2, linewidth=1) + + if len(y) > 0: + ax.plot( + x, + np.nanmean(y, axis=0), + color=color, + linewidth=3, + label=f"{label} (n={len(y)})", + ) + + ax.set_xlabel("Time to peak (ms)") + ax.set_ylabel("") + ax.set_yticks([]) + ax.legend(fontsize=6, loc="best", framealpha=0.6) + sns.despine(ax=ax, trim=True) + + +def _plot_group_hist(ax, groups, bins=18, alpha=1.0): + """Overlayed histograms for each projection group. + + Parameters + ---------- + groups : list of (label, data_array, color) tuples + """ + all_values = np.concatenate([data for _, data, _ in groups]) + shared_bins = np.linspace(all_values.min(), all_values.max(), bins + 1) + bin_width = shared_bins[1] - shared_bins[0] + centers = (shared_bins[:-1] + shared_bins[1:]) / 2 + + n_groups = len(groups) + bar_width = bin_width * 0.3 + offsets = (np.arange(n_groups) - (n_groups - 1) / 2) * (bar_width * 1.15) + + for i, (label, data, color) in enumerate(groups): + counts, _ = np.histogram(data, bins=shared_bins) + ax.bar( + centers + offsets[i], + counts, + width=bar_width, + alpha=alpha, + color=color, + edgecolor="none", + label=f"{label} (n={len(data)})", + align="center", + ) + + ax.set_ylabel("Count") + ax.legend(loc="best", framealpha=0.6) + sns.despine(ax=ax, trim=True) + + +def _plot_group_cdf(ax, groups): + """Separate cumulative distribution plot for each projection group.""" + for label, data, color in groups: + sorted_data = np.sort(data) + cdf = np.arange(1, len(sorted_data) + 1) / len(sorted_data) + ax.step( + sorted_data, + cdf, + where="post", + color=color, + linewidth=1.6, + alpha=1.0, + label=f"{label} (n={len(data)})", + ) + + ax.set_ylim(0, 1) + ax.set_ylabel("Cumulative fraction") + ax.legend(loc="best", framealpha=0.6) + sns.despine(ax=ax, trim=True) + + +def _fmt_pval(p): + if pd.isna(p): + return "n/a" + stars = "***" if p < 1e-3 else "**" if p < 1e-2 else "*" if p < 5e-2 else "" + p_txt = f"{p:.2e}" if p < 1e-3 else f"{p:.3f}" + return f"{p_txt}{stars}" + + +def _plot_pairwise_stats_table(ax, groups, title): + """Pairwise rank-sum and t-test p-values table.""" + rows = [] + for i in range(len(groups)): + for j in range(i + 1, len(groups)): + label_i, data_i, _ = groups[i] + label_j, data_j, _ = groups[j] + + if len(data_i) < 2 or len(data_j) < 2: + mw_p = np.nan + tt_p = np.nan + else: + _, mw_p = mannwhitneyu(data_i, data_j, alternative="two-sided") + _, tt_p = ttest_ind(data_i, data_j, equal_var=False) + + rows.append([f"{label_i} vs {label_j}", _fmt_pval(mw_p), _fmt_pval(tt_p)]) + + ax.axis("off") + table = ax.table( + cellText=rows, + colLabels=["Pair", "Ranksum p", "t-test p"], + loc="center", + cellLoc="left", + colLoc="left", + colWidths=[0.55, 0.19, 0.19], + ) + table.auto_set_font_size(False) + table.set_fontsize(7) + table.scale(1.0, 1.15) + ax.set_title(title) + + +def _plot_violin_strip(ax, groups, marker_size=15, alpha=0.5, seed=42): + """Violin with jittered points for each projection group.""" + rng = np.random.default_rng(seed) + for idx, (_, data, color) in enumerate(groups): + vp = ax.violinplot( + dataset=[data], + positions=[idx], + widths=0.7, + showmeans=False, + showmedians=False, + showextrema=False, + ) + for body in vp["bodies"]: + body.set_facecolor(color) + body.set_edgecolor("black") + body.set_alpha(0.25) + + median = np.median(data) + ax.hlines(median, idx - 0.2, idx + 0.2, color="black", linewidth=2, zorder=4) + + jitter = rng.uniform(-0.15, 0.15, len(data)) + ax.scatter( + np.full(len(data), idx) + jitter, + data, + s=marker_size, + c=color, + alpha=alpha, + edgecolors="black", + linewidths=0.3, + zorder=3, + ) + + ax.set_xticks(range(len(groups))) + ax.set_xticklabels([f"{lbl}\n(n={len(d)})" for lbl, d, _ in groups]) + plt.setp(ax.get_xticklabels(), rotation=45, ha="right") + ax.set_xlim(-0.6, len(groups) - 0.4) + sns.despine(ax=ax, trim=True) + + +def _plot_spatial_map( + ax, + df_v_proj, + color_col, + cmap=None, + label=None, + marker_size=50, + use_projection_edgecolor=False, + reserve_colorbar_space=False, +): + """Scatter on LC mesh (sagittal view). + + - Continuous mode: color_col is numeric -> use colormap + colorbar. + - Categorical mode: color_col == 'injection region' -> use REGION_COLOR_MAPPER + legend. + """ + mesh = load_mesh_from_s3() + plot_mesh(ax, mesh, direction="sagittal", meshcol="lightgray") + + if color_col == "injection region": + regions = sort_region(df_v_proj["injection region"].dropna().unique()) + for region in regions: + sub = df_v_proj[df_v_proj["injection region"] == region] + color = REGION_COLOR_MAPPER.get(region, "gray") + ax.scatter( + sub["X (A --> P)"], + sub["Y (D --> V)"], + c=color, + s=marker_size, + edgecolors="black", + linewidths=1, + alpha=0.7, + label=f"{region} (n={len(sub)})", + ) + ax.legend(fontsize=6, loc="best", framealpha=0.6) + if reserve_colorbar_space: + dummy = plt.cm.ScalarMappable(cmap="Greys") + cbar = plt.colorbar(dummy, ax=ax, shrink=0.7) + cbar.set_ticks([]) + cbar.ax.set_ylabel("") + else: + vals = pd.to_numeric(df_v_proj[color_col], errors="coerce") + edgecolors = ( + [REGION_COLOR_MAPPER.get(r, "black") for r in df_v_proj["injection region"]] + if use_projection_edgecolor + else "black" + ) + sc = ax.scatter( + df_v_proj["X (A --> P)"], + df_v_proj["Y (D --> V)"], + c=vals, + cmap=cmap, + s=marker_size, + edgecolors=edgecolors, + linewidths=1, + alpha=0.7, + ) + plt.colorbar(sc, ax=ax, shrink=0.7, label=label) + + ax.set_xlabel("Anterior-posterior (μm)") + ax.set_ylabel("Dorsal-ventral (μm)") + ax.locator_params(axis="x", nbins=4) + ax.set_aspect("equal") + sns.despine(ax=ax, trim=True) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def figure_spike_pca( + df_meta: pd.DataFrame, + df_spikes: pd.DataFrame | None = None, + spike_type: str = DEFAULT_SPIKE_TYPE, + extract_from: str = DEFAULT_EXTRACT_FROM, + spike_range: tuple = DEFAULT_SPIKE_RANGE, + normalize_window_v: tuple = DEFAULT_NORMALIZE_WINDOW_V, + normalize_window_dvdt: tuple = DEFAULT_NORMALIZE_WINDOW_DVDT, + filtered_df_meta: pd.DataFrame | None = None, + use_projection_edgecolor: bool = False, + if_save_figure: bool = True, + figsize: tuple = (15, 20), +): + """Generate the spike PCA figure (4 rows, 13 panels). + + Row 1: PCA scatter | normalized waveform overlays + Row 2: PC1 violin | PC1 histogram | PC1 CDF | PC1 pairwise stats + Row 3: membrane time constant violin | membrane time constant histogram | + membrane time constant CDF | membrane time constant pairwise stats + Row 4: PC1 spatial | membrane time constant spatial | projection target spatial + + Parameters + ---------- + df_meta : pd.DataFrame + Metadata (from load_ephys_metadata). + df_spikes : pd.DataFrame, optional + Pre-loaded spike waveforms. Loaded from S3 if None. + spike_type, extract_from, spike_range, normalize_window_v, + normalize_window_dvdt : + Analysis parameters (see spike_pca_analysis). + filtered_df_meta : pd.DataFrame, optional + Cell-level filter. + use_projection_edgecolor : bool + Whether to use projection-target color on marker edges in CCF plots. + if_save_figure : bool + Whether to save figure to results/figures/. + figsize : tuple + Figure size in inches. + + Returns + ------- + (fig, axes_dict, results) + """ + set_plot_style(base_size=12) + + results = spike_pca_analysis( + df_meta=df_meta, + df_spikes=df_spikes, + spike_type=spike_type, + extract_from=extract_from, + spike_range=spike_range, + normalize_window_v=normalize_window_v, + normalize_window_dvdt=normalize_window_dvdt, + filtered_df_meta=filtered_df_meta, + ) + df_v_proj = results["df_v_proj"] + df_v_norm = results["df_v_norm"] + tau_col = results["tau_col"] + + # --- Build figure --- + fig = plt.figure(figsize=figsize) + gs = fig.add_gridspec(4, 1, height_ratios=[1.2, 1, 1, 1.5], hspace=0.5) + gs_row1 = gs[0].subgridspec(1, 3, width_ratios=[0.7, 1, 1], wspace=0.3) + # Narrower middle rows with side spacers. + gs_row2 = gs[1].subgridspec(1, 4, width_ratios=[0.4, 0.55, 0.55, 0.7], wspace=0.35) + gs_row3 = gs[2].subgridspec(1, 4, width_ratios=[0.4, 0.55, 0.55, 0.7], wspace=0.35) + gs_row4 = gs[3].subgridspec(1, 3, width_ratios=[1, 1, 1], wspace=0.3) + axes = [ + fig.add_subplot(gs_row1[0, 0]), + fig.add_subplot(gs_row1[0, 1]), + fig.add_subplot(gs_row2[0, 0]), + fig.add_subplot(gs_row2[0, 1]), + fig.add_subplot(gs_row2[0, 2]), + fig.add_subplot(gs_row2[0, 3]), + fig.add_subplot(gs_row3[0, 0]), + fig.add_subplot(gs_row3[0, 1]), + fig.add_subplot(gs_row3[0, 2]), + fig.add_subplot(gs_row3[0, 3]), + fig.add_subplot(gs_row4[0, 0]), + fig.add_subplot(gs_row4[0, 1]), + fig.add_subplot(gs_row4[0, 2]), + ] + axes_dict = {} + + # Panel 1: overlaid normalized waveforms by projection target + ax = axes[0] + axes_dict["waveform_overlay"] = ax + _plot_waveform_overlay(ax, df_v_norm, df_v_proj) + ax.set_title("Normalized spike waveforms") + + # Panel 2: PCA scatter + ax = axes[1] + axes_dict["pca_scatter"] = ax + _plot_pca_scatter(ax, df_v_proj) + ax.set_title("PC1 vs PC2 from normalized spike waveforms") + + pc1_groups = _build_projection_groups(df_v_proj, "PCA1") + tau_groups = _build_projection_groups(df_v_proj, tau_col) + + # Panel 3: PC1 violin by projection target + ax = axes[2] + axes_dict["pca1_violin"] = ax + _plot_violin_strip(ax, pc1_groups) + ax.set_title("PC1") + ax.set_ylabel("PC1") + + # Panel 4: PC1 histogram by projection target + ax = axes[3] + axes_dict["pca1_hist"] = ax + _plot_group_hist(ax, pc1_groups) + ax.set_title("PC1") + ax.set_xlabel("PC1") + + # Panel 5: PC1 cumulative distribution by projection target + ax = axes[4] + axes_dict["pca1_cdf"] = ax + _plot_group_cdf(ax, pc1_groups) + ax.set_title("PC1") + ax.set_xlabel("PC1") + + # Panel 6: PC1 pairwise stats + ax = axes[5] + axes_dict["pca1_stats"] = ax + _plot_pairwise_stats_table(ax, pc1_groups, "PC1 pairwise stats") + + # Panel 7: membrane time constant violin by projection target + ax = axes[6] + axes_dict["tau_violin"] = ax + _plot_violin_strip(ax, tau_groups) + ax.set_title("Membrane time constant") + ax.set_ylabel("Membrane time constant (ms)") + + # Panel 8: membrane time constant histogram by projection target + ax = axes[7] + axes_dict["tau_hist"] = ax + _plot_group_hist(ax, tau_groups) + ax.set_title("Membrane time constant") + ax.set_xlabel("Membrane time constant (ms)") + + # Panel 9: membrane time constant cumulative distribution by projection target + ax = axes[8] + axes_dict["tau_cdf"] = ax + _plot_group_cdf(ax, tau_groups) + ax.set_title("Membrane time constant") + ax.set_xlabel("Membrane time constant (ms)") + + # Panel 10: membrane time constant pairwise stats + ax = axes[9] + axes_dict["tau_stats"] = ax + _plot_pairwise_stats_table(ax, tau_groups, "Membrane time constant pairwise stats") + + # Panel 11: PC1 in X/Y space with LC mesh + ax = axes[10] + axes_dict["pca1_spatial"] = ax + _plot_spatial_map( + ax, + df_v_proj, + "PCA1", + "RdBu_r", + "PC1", + use_projection_edgecolor=use_projection_edgecolor, + ) + ax.set_title("PC1 in CCF space") + + # Panel 12: membrane time constant in X/Y space with LC mesh + ax = axes[11] + axes_dict["tau_spatial"] = ax + _plot_spatial_map( + ax, + df_v_proj, + tau_col, + "inferno", + "Membrane time constant (ms)", + use_projection_edgecolor=use_projection_edgecolor, + ) + ax.set_title("Membrane time constant (ms) in CCF space") + + # Panel 13: projection target in X/Y space with LC mesh + ax = axes[12] + axes_dict["projection_spatial"] = ax + _plot_spatial_map(ax, df_v_proj, "injection region", reserve_colorbar_space=True) + ax.set_title("Projection target in CCF space") + + fig.tight_layout() + + if if_save_figure: + save_figure( + fig, filename="main_pca_tau", formats=("png", "svg"), bbox_inches="tight" + ) + + return fig, axes_dict, results + + +def _build_projection_groups(df_v_proj, value_col): + """Build (label, values, color) groups for spinal cord, cortex, and CB.""" + groups = [] + for grp_label, region_set, color in [ + ("Spinal cord", SPINAL_REGIONS, REGION_COLOR_MAPPER["Spinal cord"]), + ("Cortex", CORTEX_REGIONS, REGION_COLOR_MAPPER["Cortex"]), + ("Cerebellum", CB_REGIONS, REGION_COLOR_MAPPER["Cerebellum"]), + ]: + mask = df_v_proj["injection region"].isin(region_set) + vals_series = pd.Series( + pd.to_numeric(df_v_proj.loc[mask, value_col], errors="coerce") + ) + vals = vals_series.dropna().to_numpy() + groups.append((grp_label, vals, color)) + return groups + + +if __name__ == "__main__": + from LCNE_patchseq_analysis.data_util.metadata import load_ephys_metadata + from LCNE_patchseq_analysis.figures import GLOBAL_FILTER + + logging.basicConfig(level=logging.INFO) + logger.info("Loading metadata...") + df_meta = load_ephys_metadata(if_from_s3=True, if_with_seq=True) + df_meta = df_meta.query(GLOBAL_FILTER) + logger.info(f"Loaded metadata with shape: {df_meta.shape}") + + logger.info("Loading spike waveforms...") + df_spikes = get_public_representative_spikes("average") + + logger.info("Generating spike PCA figure...") + fig, axes_dict, results = figure_spike_pca( + df_meta, df_spikes, filtered_df_meta=df_meta + ) diff --git a/src/LCNE_patchseq_analysis/figures/util.py b/src/LCNE_patchseq_analysis/figures/util.py index f38c7cb..3bd0678 100644 --- a/src/LCNE_patchseq_analysis/figures/util.py +++ b/src/LCNE_patchseq_analysis/figures/util.py @@ -386,7 +386,13 @@ def generate_scatter_plot( # Plot confidence band ax.fill_between( - x_vals, ci_lower, ci_upper, color="lightgray", alpha=0.3, zorder=3, label="95% CI" + x_vals, + ci_lower, + ci_upper, + color="lightgray", + alpha=0.3, + zorder=3, + label="95% CI", ) ax.plot( @@ -458,7 +464,12 @@ def generate_ccf_plot( # NoQA: C901 view = (view or "").strip().lower() if view == "sagittal": - x_key, y_key, mesh_direction, x_label = "x", "y", "sagittal", "Anterior-posterior (μm)" + x_key, y_key, mesh_direction, x_label = ( + "x", + "y", + "sagittal", + "Anterior-posterior (μm)", + ) elif view == "coronal": x_key, y_key, mesh_direction, x_label = "z", "y", "coronal", "Left-right (μm)" else: @@ -518,7 +529,7 @@ def generate_ccf_plot( # NoQA: C901 for region in sorted_regions: color_key = region if region in REGION_COLOR_MAPPER else region.lower() color = REGION_COLOR_MAPPER.get(color_key, "gray") - label_text = f"{region} (n={sum(df_filtered['injection region']==region)})" # noqa: E225 + label_text = f"{region} (n={sum(df_filtered['injection region'] == region)})" # noqa: E225 legend_elements.append( Line2D( [0], @@ -583,7 +594,11 @@ def generate_ccf_plot( # NoQA: C901 def generate_violin_plot( - df_to_use: pd.DataFrame, y_col: str, color_col: str, color_palette_dict: dict, ax=None + df_to_use: pd.DataFrame, + y_col: str, + color_col: str, + color_palette_dict: dict, + ax=None, ): """ Create a violin plot to compare data distributions across groups using matplotlib/seaborn. @@ -734,7 +749,10 @@ def save_figure( List of saved file paths in the same order as formats. """ if output_dir is None: - output_dir = os.path.dirname(os.path.abspath(__file__)) + "/../../../results/figures" + if os.getenv("CO_CAPSULE_ID"): + output_dir = "/results" + else: + output_dir = os.path.dirname(os.path.abspath(__file__)) + "/../../../results/figures" os.makedirs(output_dir, exist_ok=True) diff --git a/src/LCNE_patchseq_analysis/pipeline_util/s3.py b/src/LCNE_patchseq_analysis/pipeline_util/s3.py index 792b8d9..3666d31 100644 --- a/src/LCNE_patchseq_analysis/pipeline_util/s3.py +++ b/src/LCNE_patchseq_analysis/pipeline_util/s3.py @@ -15,9 +15,7 @@ s3 = s3fs.S3FileSystem(anon=True) # All on public bucket -S3_PUBLIC_URL_BASE = ( - "https://aind-scratch-data.s3.us-west-2.amazonaws.com/aind-patchseq-data" -) +S3_PUBLIC_URL_BASE = "https://aind-scratch-data.s3.us-west-2.amazonaws.com/aind-patchseq-data" S3_PATH_BASE = "aind-scratch-data/aind-patchseq-data" @@ -124,9 +122,7 @@ def get_public_representative_spikes( filename = spike_map.get(spike_type) if filename is None: valid_types = ", ".join(sorted(spike_map)) - raise ValueError( - f"Unknown spike_type '{spike_type}'. Valid options: {valid_types}." - ) + raise ValueError(f"Unknown spike_type '{spike_type}'. Valid options: {valid_types}.") s3_url = f"{S3_PUBLIC_URL_BASE}/efel/cell_stats/{filename}" if check_s3_public_url_exists(s3_url): return pd.read_pickle(s3_url) @@ -163,9 +159,7 @@ def get_public_mapmycells(filename="mapmycells_20250618.csv"): # Add a new column "subclass_category" based on if "subclass_name" == "251 NTS Dbh Glut" df["subclass_category"] = df["subclass_name"].apply( - lambda x: "251 NTS Dbh Glut" - if x == "251 NTS Dbh Glut" - else "Non-Dbh cells" + lambda x: "251 NTS Dbh Glut" if x == "251 NTS Dbh Glut" else "Non-Dbh cells" ) return df except Exception as e: