Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/alphajudge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def main() -> None:
action="store_true",
help="Do not write per-model PAE heatmap PNG files",
)
p.add_argument(
"--skip_biophysical_scores",
action="store_true",
help="Skip expensive biophysical calculations (hydrogen bonds, salt bridges, disulfides, shape complementarity, buried surface area, solvation energy) to save time",
)
p.add_argument(
"--per_run_csv_name",
default="interfaces.csv",
Expand All @@ -51,6 +56,7 @@ def main() -> None:
force_recompute=args.force_recompute,
per_run_csv_name=args.per_run_csv_name,
skip_pae_png=args.skip_pae_png,
skip_biophysical_scores=args.skip_biophysical_scores,
)
else:
p.error("Provide PATHS")
Expand Down
4 changes: 3 additions & 1 deletion src/alphajudge/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from dataclasses import dataclass

import numpy as np


@dataclass(frozen=True)
class Confidence:
pae_matrix: list[list[float]]
pae_matrix: np.ndarray # Keep as numpy array for memory efficiency (7-10x reduction)
max_pae: float
iptm: float | None
ptm: float | None
Expand Down
6 changes: 4 additions & 2 deletions src/alphajudge/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, chain1, chain2, complex_ctx: Complex):
self.chain2 = list(chain2)

if not self.chain1 or not self.chain2:
self._pae = np.asarray(self.c.conf.pae_matrix)
# pae_matrix is already a numpy array for memory efficiency
self._pae = self.c.conf.pae_matrix
self._rim = self.c._res_index_map
self._cid = self.c._chain_indices_by_id
self._cid1_id = ""
Expand All @@ -48,7 +49,8 @@ def __init__(self, chain1, chain2, complex_ctx: Complex):
self._avg_pae = 0.0
return

self._pae = np.asarray(self.c.conf.pae_matrix)
# pae_matrix is already a numpy array for memory efficiency
self._pae = self.c.conf.pae_matrix
self._rim = self.c._res_index_map
self._cid = self.c._chain_indices_by_id

Expand Down
2 changes: 1 addition & 1 deletion src/alphajudge/parsers/af2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load_model(model: str):

plddt = self._plddt(chains, rim)
return struct, Confidence(
pae_matrix=pae.tolist(), max_pae=max_pae,
pae_matrix=pae, max_pae=max_pae,
iptm=iptm, ptm=ptm, iptm_ptm=iptm_ptm, confidence_score=conf,
plddt_residue=plddt,
)
Expand Down
4 changes: 2 additions & 2 deletions src/alphajudge/parsers/af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def pf(x: str | None) -> float:
return [f"seed-{r['seed']}_sample-{r['sample']}" for r in rows if 'seed' in r and 'sample' in r]

@staticmethod
def _normalize_pae_af3(matrix: dict, chains, cid) -> tuple[list[list[float]], float]:
def _normalize_pae_af3(matrix: dict, chains, cid) -> tuple[np.ndarray, float]:
total = sum(len(cid[c.id]) for c in chains)
pae = np.full((total, total), 100.0, dtype=float)
max_pae = float('nan')
Expand Down Expand Up @@ -109,4 +109,4 @@ def _normalize_pae_af3(matrix: dict, chains, cid) -> tuple[list[list[float]], fl
else:
raise ValueError("unknown AF3 confidences schema")

return pae.tolist(), float(max_pae)
return pae, float(max_pae)
73 changes: 43 additions & 30 deletions src/alphajudge/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _save_pae_heatmap(
provided in `chain_boundaries` (typically between different chains).
"""
try:
mtx = np.array(pae_matrix, dtype=float)
# pae_matrix is already a numpy array for memory efficiency
mtx = pae_matrix if isinstance(pae_matrix, np.ndarray) else np.array(pae_matrix, dtype=float)
if mtx.size == 0:
logger.warning(f"empty PAE matrix; skipping heatmap for {out_file}")
return
Expand Down Expand Up @@ -70,6 +71,7 @@ def process(
*,
per_run_csv_name: str = "interfaces.csv",
skip_pae_png: bool = False,
skip_biophysical_scores: bool = False,
) -> Path | None:
d = Path(directory)
parser = pick_parser(d)
Expand Down Expand Up @@ -99,41 +101,47 @@ def process(
f"{iface.chain1[0].get_parent().id}_{iface.chain2[0].get_parent().id}"
)
iptm_val = iface.iptm_chainpair if iface.iptm_chainpair is not None else confidence.iptm
rows.append(
{
"jobs": job,
"model_used": m,
"interface": label,
"iptm_ptm": float(confidence.iptm_ptm)
if confidence.iptm_ptm is not None
else float("nan"),
"iptm": float(iptm_val) if iptm_val is not None else float("nan"),
"ptm": float(confidence.ptm)
if confidence.ptm is not None
else float("nan"),
"confidence_score": float(confidence.confidence_score)
if confidence.confidence_score is not None
else float("nan"),
"pDockQ/mpDockQ": global_score,
"average_interface_pae": iface.average_interface_pae,
"interface_average_plddt": iface.average_interface_plddt,
"interface_num_intf_residues": iface.num_intf_residues,
"interface_polar": iface.polar,
"interface_hydrophobic": iface.hydrophobic,
"interface_charged": iface.charged,
"interface_contact_pairs": iface.contact_pairs,
"interface_score": iface.score_complex,
"interface_pDockQ2": pd2,
"interface_ipSAE": iface.ipsae(),
"interface_LIS": iface.lis(),

row_data = {
"jobs": job,
"model_used": m,
"interface": label,
"iptm_ptm": float(confidence.iptm_ptm)
if confidence.iptm_ptm is not None
else float("nan"),
"iptm": float(iptm_val) if iptm_val is not None else float("nan"),
"ptm": float(confidence.ptm)
if confidence.ptm is not None
else float("nan"),
"confidence_score": float(confidence.confidence_score)
if confidence.confidence_score is not None
else float("nan"),
"pDockQ/mpDockQ": global_score,
"average_interface_pae": iface.average_interface_pae,
"interface_average_plddt": iface.average_interface_plddt,
"interface_num_intf_residues": iface.num_intf_residues,
"interface_polar": iface.polar,
"interface_hydrophobic": iface.hydrophobic,
"interface_charged": iface.charged,
"interface_contact_pairs": iface.contact_pairs,
"interface_score": iface.score_complex,
"interface_pDockQ2": pd2,
"interface_ipSAE": iface.ipsae(),
"interface_LIS": iface.lis(),
}

# Add expensive metrics only if not skipped
if not skip_biophysical_scores:
row_data.update({
"interface_hb": iface.hb,
"interface_sb": iface.sb,
"interface_ss": iface.ss,
"interface_sc": iface.sc,
"interface_area": iface.int_area,
"interface_solv_en": iface.int_solv_en,
}
)
})

rows.append(row_data)

# Compute chain boundaries for separator lines on PAE heatmap
chain_boundaries: list[float] = []
Expand Down Expand Up @@ -208,6 +216,7 @@ def _process_one_run(
force_recompute: bool,
per_run_csv_name: str,
skip_pae_png: bool,
skip_biophysical_scores: bool,
) -> tuple[str, list[dict]]:
"""
Worker: process a single run dir (or reuse interfaces.csv) and optionally return rows for aggregation.
Expand Down Expand Up @@ -242,6 +251,7 @@ def _process_one_run(
ipsae_pae_cutoff,
per_run_csv_name=per_run_csv_name,
skip_pae_png=skip_pae_png,
skip_biophysical_scores=skip_biophysical_scores,
)

if want_summary and out_path is not None:
Expand All @@ -265,6 +275,7 @@ def process_many(
force_recompute: bool = False,
per_run_csv_name: str = "interfaces.csv",
skip_pae_png: bool = False,
skip_biophysical_scores: bool = False,
) -> Path | None:
"""
Process one or more directories. Optionally recurse into nested directories
Expand Down Expand Up @@ -336,6 +347,7 @@ def process_many(
force_recompute,
per_run_csv_name,
skip_pae_png,
skip_biophysical_scores,
)
if summary_csv and rows:
aggregated_rows.extend(rows)
Expand All @@ -358,6 +370,7 @@ def process_many(
force_recompute,
per_run_csv_name,
skip_pae_png,
skip_biophysical_scores,
)
for d in unique_run_dirs
]
Expand Down