Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 106 additions & 37 deletions src/tmap/visualization/templates/base.html.j2
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ function decodeGzipBuffer(buffer, dtype) {
switch (dtype) {
case 'uint16': return new Uint16Array(aligned);
case 'uint32': return new Uint32Array(aligned);
case 'uint8': return new Uint8Array(aligned);
case 'float32': return new Float32Array(aligned);
case 'int32': return new Int32Array(aligned);
default: throw new Error('Unknown dtype: ' + dtype);
Expand Down Expand Up @@ -995,6 +996,8 @@ const hasEdges = nEdges > 0;
const parsedEdgeWidth = Number(metadata.edgeWidth);
const edgeWidth = Number.isFinite(parsedEdgeWidth) && parsedEdgeWidth > 0 ? parsedEdgeWidth : 2;
const edgeStrokeStyle = metadata.edgeStrokeStyle || 'rgba(0, 0, 0, 0.5)';
const edgeColorMode = metadata.edgeColorMode || 'single';
const hasPerEdgeColors = edgeColorMode === 'per-edge';

updateProgress(40, 'Loading columns...');

Expand Down Expand Up @@ -1189,18 +1192,17 @@ async function loadEdges() {
const eS = raw.subarray(0, nEdges);
const eT = raw.subarray(nEdges, 2 * nEdges);

// Interleaved position buffer from point coordinates
const positions = new Float32Array(n * 2);
for (let i = 0; i < n; i++) {
positions[i * 2] = x[i];
positions[i * 2 + 1] = y[i];
}

// Element index buffer: pairs of vertex indices per edge
const elements = new Uint32Array(nEdges * 2);
for (let i = 0; i < nEdges; i++) {
elements[i * 2] = eS[i];
elements[i * 2 + 1] = eT[i];
let edgeRgb = null;
if (hasPerEdgeColors) {
{% if inline_data %}
const edgeColorRaw = Uint8Array.from(atob("{{ inline_edge_colors }}"), c => c.charCodeAt(0));
edgeRgb = decodeGzipBuffer(edgeColorRaw.buffer, 'uint8');
{% else %}
edgeRgb = decodeGzipBuffer(await fetch('./edge_colors.bin').then(r => r.arrayBuffer()), 'uint8');
{% endif %}
if (edgeRgb.length !== nEdges * 3) {
throw new Error('Expected ' + (nEdges * 3) + ' edge color bytes, got ' + edgeRgb.length);
}
}

edgeRegl = createREGL({
Expand All @@ -1209,27 +1211,91 @@ async function loadEdges() {
attributes: { alpha: true, premultipliedAlpha: false, antialias: true },
});

const vertSrc = 'precision highp float; attribute vec2 position; uniform vec2 u_scale; uniform vec2 u_offset; void main() { gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }';
const fragSrc = 'precision mediump float; uniform vec4 u_color; void main() { gl_FragColor = u_color; }';

drawEdgesCmd = edgeRegl({
vert: vertSrc,
frag: fragSrc,
attributes: {
position: { buffer: edgeRegl.buffer(positions), size: 2 },
},
elements: edgeRegl.elements({ data: elements, primitive: 'lines', type: 'uint32' }),
uniforms: {
u_scale: edgeRegl.prop('scale'),
u_offset: edgeRegl.prop('offset'),
u_color: edgeRegl.prop('color'),
},
blend: {
enable: true,
func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 },
},
depth: { enable: false },
});
if (edgeRgb) {
const positions = new Float32Array(nEdges * 4);
const colors = new Float32Array(nEdges * 8);
for (let i = 0; i < nEdges; i++) {
const s = eS[i];
const t = eT[i];
const p = i * 4;
positions[p] = x[s];
positions[p + 1] = y[s];
positions[p + 2] = x[t];
positions[p + 3] = y[t];

const c = i * 3;
const r = edgeRgb[c] / 255;
const g = edgeRgb[c + 1] / 255;
const b = edgeRgb[c + 2] / 255;
const a = edgeColor[3];
const out = i * 8;
colors[out] = r;
colors[out + 1] = g;
colors[out + 2] = b;
colors[out + 3] = a;
colors[out + 4] = r;
colors[out + 5] = g;
colors[out + 6] = b;
colors[out + 7] = a;
}

const vertSrc = 'precision highp float; attribute vec2 position; attribute vec4 edge_color; varying vec4 v_color; uniform vec2 u_scale; uniform vec2 u_offset; void main() { v_color = edge_color; gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }';
const fragSrc = 'precision mediump float; varying vec4 v_color; void main() { gl_FragColor = v_color; }';

drawEdgesCmd = edgeRegl({
vert: vertSrc,
frag: fragSrc,
attributes: {
position: { buffer: edgeRegl.buffer(positions), size: 2 },
edge_color: { buffer: edgeRegl.buffer(colors), size: 4 },
},
primitive: 'lines',
count: nEdges * 2,
uniforms: {
u_scale: edgeRegl.prop('scale'),
u_offset: edgeRegl.prop('offset'),
},
blend: {
enable: true,
func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 },
},
depth: { enable: false },
});
} else {
const positions = new Float32Array(n * 2);
for (let i = 0; i < n; i++) {
positions[i * 2] = x[i];
positions[i * 2 + 1] = y[i];
}

const elements = new Uint32Array(nEdges * 2);
for (let i = 0; i < nEdges; i++) {
elements[i * 2] = eS[i];
elements[i * 2 + 1] = eT[i];
}

const vertSrc = 'precision highp float; attribute vec2 position; uniform vec2 u_scale; uniform vec2 u_offset; void main() { gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }';
const fragSrc = 'precision mediump float; uniform vec4 u_color; void main() { gl_FragColor = u_color; }';

drawEdgesCmd = edgeRegl({
vert: vertSrc,
frag: fragSrc,
attributes: {
position: { buffer: edgeRegl.buffer(positions), size: 2 },
},
elements: edgeRegl.elements({ data: elements, primitive: 'lines', type: 'uint32' }),
uniforms: {
u_scale: edgeRegl.prop('scale'),
u_offset: edgeRegl.prop('offset'),
u_color: edgeRegl.prop('color'),
},
blend: {
enable: true,
func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 },
},
depth: { enable: false },
});
}

edgesLoaded = true;
edgeNeedsRedraw = true;
Expand All @@ -1246,11 +1312,12 @@ function renderEdges() {
const xRange = xd[1] - xd[0] || 1;
const yRange = yd[1] - yd[0] || 1;

drawEdgesCmd({
const drawProps = {
scale: [2 / xRange, 2 / yRange],
offset: [-(xd[1] + xd[0]) / xRange, -(yd[1] + yd[0]) / yRange],
color: edgeColor,
});
};
if (!hasPerEdgeColors) drawProps.color = edgeColor;
drawEdgesCmd(drawProps);
edgeNeedsRedraw = false;
}

Expand Down Expand Up @@ -2509,7 +2576,9 @@ tbTheme.addEventListener('click', () => {
isDarkTheme = !isDarkTheme;
document.documentElement.dataset.theme = isDarkTheme ? '' : 'light';
scatterplot.set({ backgroundColor: isDarkTheme ? darkBg : lightBg });
edgeColor = isDarkTheme ? parseEdgeColor(edgeStrokeStyle) : [1, 1, 1, edgeColor[3]];
if (!hasPerEdgeColors) {
edgeColor = isDarkTheme ? parseEdgeColor(edgeStrokeStyle) : [1, 1, 1, edgeColor[3]];
}
edgeNeedsRedraw = true;
if (typeof renderEdges === 'function') renderEdges();
tbTheme.innerHTML = isDarkTheme
Expand Down
51 changes: 51 additions & 0 deletions src/tmap/visualization/tmapviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,17 @@ def _hex_to_css_rgba(hex_color: str, alpha: float = 1.0) -> str:
return f"rgba({r}, {g}, {b}, {alpha_str})"


def _hex_colors_to_rgb_uint8(colors: Sequence[str]) -> NDArray[np.uint8]:
"""Convert hex colors to an ``(n, 3)`` uint8 RGB array."""
arr = np.empty((len(colors), 3), dtype=np.uint8)
for idx, color in enumerate(colors):
normalized = _normalize_hex_color(color).lstrip("#")
arr[idx, 0] = int(normalized[0:2], 16)
arr[idx, 1] = int(normalized[2:4], 16)
arr[idx, 2] = int(normalized[4:6], 16)
return arr


# TODO(ISS-014): Implement categorical=True preserves listed colors when available
def _colormap_to_hex(name: str) -> list[str]:
"""
Expand Down Expand Up @@ -411,6 +422,7 @@ def __init__(self) -> None:
self._points_array: np.ndarray | None = None # Shape: (n, 2)
self._edges_s: np.ndarray | None = None
self._edges_t: np.ndarray | None = None
self._edge_colors: np.ndarray | None = None
self._layout_keys: list[str] = []
self._labels_keys: list[str] = []
self._smiles_column: str | None = None
Expand Down Expand Up @@ -1159,6 +1171,28 @@ def set_edges(

self._edges_s = s_arr
self._edges_t = t_arr
self._edge_colors = None

def set_edge_colors(self, colors: Sequence[str]) -> None:
"""Set one hex color per edge for visualization templates.

Args:
colors: Hex color strings for each edge in the current edge order.

Raises:
ValueError: If edges are not set or the color count differs from
the edge count.
"""
if self._edges_s is None or self._edges_t is None:
raise ValueError("set_edges must be called before set_edge_colors")

edge_colors = _hex_colors_to_rgb_uint8(colors)
if len(edge_colors) != len(self._edges_s):
raise ValueError(
f"Edge color count must match edge count. "
f"Got {len(edge_colors)} colors for {len(self._edges_s)} edges."
)
self._edge_colors = edge_colors

def set_edge_style(
self,
Expand Down Expand Up @@ -1608,12 +1642,18 @@ def to_html(self, template_name: str = "base.html.j2") -> str:

# Pack edges if present
edges_b64 = ""
edge_colors_b64 = ""
edge_color_mode = "single"
n_edges = 0
if self._edges_s is not None and self._edges_t is not None:
n_edges = len(self._edges_s)
edges_combined = np.concatenate([self._edges_s, self._edges_t]).astype(np.uint32)
edges_compressed = gzip.compress(edges_combined.tobytes(), compresslevel=6)
edges_b64 = base64.b64encode(edges_compressed).decode("ascii")
if self._edge_colors is not None:
edge_color_mode = "per-edge"
edge_colors_compressed = gzip.compress(self._edge_colors.tobytes(), compresslevel=6)
edge_colors_b64 = base64.b64encode(edge_colors_compressed).decode("ascii")

# Build metadata (same flat structure as write_static)
layout_options = list(self._layout_keys)
Expand Down Expand Up @@ -1642,6 +1682,8 @@ def to_html(self, template_name: str = "base.html.j2") -> str:
"opacity": self.opacity,
"edgeStrokeStyle": _hex_to_css_rgba(self.edge_color, self.edge_opacity),
"edgeWidth": self.edge_width,
"edgeColorMode": edge_color_mode,
"edgeColorDtype": "uint8" if edge_color_mode == "per-edge" else None,
"backgroundColor": _hex_to_rgba(self.background_color),
"layoutOptions": layout_options,
"labelOptions": label_options,
Expand Down Expand Up @@ -1683,6 +1725,7 @@ def to_html(self, template_name: str = "base.html.j2") -> str:
inline_coords=coords_b64,
inline_columns=columns_b64,
inline_edges=edges_b64,
inline_edge_colors=edge_colors_b64,
)

def write_html(
Expand Down Expand Up @@ -1765,11 +1808,16 @@ def write_static(

# Edges
n_edges = 0
edge_color_mode = "single"
if self._edges_s is not None and self._edges_t is not None:
n_edges = len(self._edges_s)
edges_combined = np.concatenate([self._edges_s, self._edges_t]).astype(np.uint32)
edges_compressed = gzip.compress(edges_combined.tobytes(), compresslevel=6)
(output_dir / "edges.bin").write_bytes(edges_compressed)
if self._edge_colors is not None:
edge_color_mode = "per-edge"
edge_colors_compressed = gzip.compress(self._edge_colors.tobytes(), compresslevel=6)
(output_dir / "edge_colors.bin").write_bytes(edge_colors_compressed)

# Columns
columns_meta: dict[str, dict[str, Any]] = {}
Expand Down Expand Up @@ -1847,6 +1895,8 @@ def write_static(
"opacity": self.opacity,
"edgeStrokeStyle": _hex_to_css_rgba(self.edge_color, self.edge_opacity),
"edgeWidth": self.edge_width,
"edgeColorMode": edge_color_mode,
"edgeColorDtype": "uint8" if edge_color_mode == "per-edge" else None,
"backgroundColor": _hex_to_rgba(self.background_color),
"layoutOptions": layout_options,
"labelOptions": label_options,
Expand Down Expand Up @@ -1890,6 +1940,7 @@ def write_static(
inline_coords="",
inline_columns={},
inline_edges="",
inline_edge_colors="",
)
(output_dir / "index.html").write_text(html, encoding="utf-8")

Expand Down
61 changes: 61 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,67 @@ def test_custom_edge_style_in_header(self, viz_with_data):
assert meta["edgeStrokeStyle"] == "rgba(255, 0, 51, 0.35)"
assert meta["edgeWidth"] == 4.5

def test_set_edge_colors_updates_values(self, viz_with_data):
"""Per-edge colors should be normalized to packed RGB bytes."""
viz, data = viz_with_data
viz.set_edges([0, 1], [1, 2])
viz.set_edge_colors(["#f03", "#00aa11"])

assert viz._edge_colors is not None
assert viz._edge_colors.tolist() == [[255, 0, 51], [0, 170, 17]]

def test_set_edge_colors_requires_edges(self, viz_with_data):
"""Per-edge colors require an existing edge array."""
viz, data = viz_with_data

with pytest.raises(ValueError, match="set_edges must be called"):
viz.set_edge_colors(["#ffffff"])

def test_set_edge_colors_mismatched_length(self, viz_with_data):
"""One color must be provided for each edge."""
viz, data = viz_with_data
viz.set_edges([0, 1, 2], [1, 2, 3])

with pytest.raises(ValueError, match="must match edge count"):
viz.set_edge_colors(["#ffffff"])

def test_per_edge_colors_serialized_to_html(self, viz_with_data):
"""Inline HTML should advertise the per-edge color payload."""
viz, data = viz_with_data
viz.add_color_layout("value", data["continuous"])
viz.set_edges([0, 1, 2], [1, 2, 3])
viz.set_edge_colors(["#f03", "#00aa11", "#0000ff"])

html = viz.to_html()

match = re.search(
r"const metadata = ({.*?});",
html,
re.DOTALL,
)
assert match is not None
meta = json.loads(match.group(1))
assert meta["edgeColorMode"] == "per-edge"
assert meta["edgeColorDtype"] == "uint8"
assert "edgeColorRaw" in html

def test_per_edge_colors_serialized_to_static(self, viz_with_data, tmp_path):
"""Static output should write edge color bytes beside edge indices."""
viz, data = viz_with_data
viz.set_edges([0, 1], [1, 2])
viz.set_edge_colors(["#f03", "#00aa11"])

out = viz.write_static(tmp_path / "out")
meta = json.loads((out / "metadata.json").read_text())
colors = np.frombuffer(
gzip.decompress((out / "edge_colors.bin").read_bytes()),
dtype=np.uint8,
).reshape(-1, 3)

assert meta["edgeColorMode"] == "per-edge"
assert meta["edgeColorDtype"] == "uint8"
assert colors.tolist() == [[255, 0, 51], [0, 170, 17]]


class TestEdgeStyle:
"""Tests for edge style configuration."""
Expand Down