diff --git a/src/tmap/visualization/templates/base.html.j2 b/src/tmap/visualization/templates/base.html.j2 index f667b6d..12e3ac3 100644 --- a/src/tmap/visualization/templates/base.html.j2 +++ b/src/tmap/visualization/templates/base.html.j2 @@ -929,6 +929,7 @@ function decodeGzipBuffer(buffer, dtype) { switch (dtype) { case 'uint16': return new Uint16Array(aligned); case 'uint32': return new Uint32Array(aligned); + case 'uint8': return new Uint8Array(aligned); case 'float32': return new Float32Array(aligned); case 'int32': return new Int32Array(aligned); default: throw new Error('Unknown dtype: ' + dtype); @@ -995,6 +996,8 @@ const hasEdges = nEdges > 0; const parsedEdgeWidth = Number(metadata.edgeWidth); const edgeWidth = Number.isFinite(parsedEdgeWidth) && parsedEdgeWidth > 0 ? parsedEdgeWidth : 2; const edgeStrokeStyle = metadata.edgeStrokeStyle || 'rgba(0, 0, 0, 0.5)'; +const edgeColorMode = metadata.edgeColorMode || 'single'; +const hasPerEdgeColors = edgeColorMode === 'per-edge'; updateProgress(40, 'Loading columns...'); @@ -1189,18 +1192,17 @@ async function loadEdges() { const eS = raw.subarray(0, nEdges); const eT = raw.subarray(nEdges, 2 * nEdges); - // Interleaved position buffer from point coordinates - const positions = new Float32Array(n * 2); - for (let i = 0; i < n; i++) { - positions[i * 2] = x[i]; - positions[i * 2 + 1] = y[i]; - } - - // Element index buffer: pairs of vertex indices per edge - const elements = new Uint32Array(nEdges * 2); - for (let i = 0; i < nEdges; i++) { - elements[i * 2] = eS[i]; - elements[i * 2 + 1] = eT[i]; + let edgeRgb = null; + if (hasPerEdgeColors) { +{% if inline_data %} + const edgeColorRaw = Uint8Array.from(atob("{{ inline_edge_colors }}"), c => c.charCodeAt(0)); + edgeRgb = decodeGzipBuffer(edgeColorRaw.buffer, 'uint8'); +{% else %} + edgeRgb = decodeGzipBuffer(await fetch('./edge_colors.bin').then(r => r.arrayBuffer()), 'uint8'); +{% endif %} + if (edgeRgb.length !== nEdges * 3) { + throw new Error('Expected ' + (nEdges * 3) + ' edge color bytes, got ' + edgeRgb.length); + } } edgeRegl = createREGL({ @@ -1209,27 +1211,91 @@ async function loadEdges() { attributes: { alpha: true, premultipliedAlpha: false, antialias: true }, }); - const vertSrc = 'precision highp float; attribute vec2 position; uniform vec2 u_scale; uniform vec2 u_offset; void main() { gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }'; - const fragSrc = 'precision mediump float; uniform vec4 u_color; void main() { gl_FragColor = u_color; }'; - - drawEdgesCmd = edgeRegl({ - vert: vertSrc, - frag: fragSrc, - attributes: { - position: { buffer: edgeRegl.buffer(positions), size: 2 }, - }, - elements: edgeRegl.elements({ data: elements, primitive: 'lines', type: 'uint32' }), - uniforms: { - u_scale: edgeRegl.prop('scale'), - u_offset: edgeRegl.prop('offset'), - u_color: edgeRegl.prop('color'), - }, - blend: { - enable: true, - func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 }, - }, - depth: { enable: false }, - }); + if (edgeRgb) { + const positions = new Float32Array(nEdges * 4); + const colors = new Float32Array(nEdges * 8); + for (let i = 0; i < nEdges; i++) { + const s = eS[i]; + const t = eT[i]; + const p = i * 4; + positions[p] = x[s]; + positions[p + 1] = y[s]; + positions[p + 2] = x[t]; + positions[p + 3] = y[t]; + + const c = i * 3; + const r = edgeRgb[c] / 255; + const g = edgeRgb[c + 1] / 255; + const b = edgeRgb[c + 2] / 255; + const a = edgeColor[3]; + const out = i * 8; + colors[out] = r; + colors[out + 1] = g; + colors[out + 2] = b; + colors[out + 3] = a; + colors[out + 4] = r; + colors[out + 5] = g; + colors[out + 6] = b; + colors[out + 7] = a; + } + + const vertSrc = 'precision highp float; attribute vec2 position; attribute vec4 edge_color; varying vec4 v_color; uniform vec2 u_scale; uniform vec2 u_offset; void main() { v_color = edge_color; gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }'; + const fragSrc = 'precision mediump float; varying vec4 v_color; void main() { gl_FragColor = v_color; }'; + + drawEdgesCmd = edgeRegl({ + vert: vertSrc, + frag: fragSrc, + attributes: { + position: { buffer: edgeRegl.buffer(positions), size: 2 }, + edge_color: { buffer: edgeRegl.buffer(colors), size: 4 }, + }, + primitive: 'lines', + count: nEdges * 2, + uniforms: { + u_scale: edgeRegl.prop('scale'), + u_offset: edgeRegl.prop('offset'), + }, + blend: { + enable: true, + func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 }, + }, + depth: { enable: false }, + }); + } else { + const positions = new Float32Array(n * 2); + for (let i = 0; i < n; i++) { + positions[i * 2] = x[i]; + positions[i * 2 + 1] = y[i]; + } + + const elements = new Uint32Array(nEdges * 2); + for (let i = 0; i < nEdges; i++) { + elements[i * 2] = eS[i]; + elements[i * 2 + 1] = eT[i]; + } + + const vertSrc = 'precision highp float; attribute vec2 position; uniform vec2 u_scale; uniform vec2 u_offset; void main() { gl_Position = vec4(position * u_scale + u_offset, 0.0, 1.0); }'; + const fragSrc = 'precision mediump float; uniform vec4 u_color; void main() { gl_FragColor = u_color; }'; + + drawEdgesCmd = edgeRegl({ + vert: vertSrc, + frag: fragSrc, + attributes: { + position: { buffer: edgeRegl.buffer(positions), size: 2 }, + }, + elements: edgeRegl.elements({ data: elements, primitive: 'lines', type: 'uint32' }), + uniforms: { + u_scale: edgeRegl.prop('scale'), + u_offset: edgeRegl.prop('offset'), + u_color: edgeRegl.prop('color'), + }, + blend: { + enable: true, + func: { srcRGB: 'src alpha', srcAlpha: 1, dstRGB: 'one minus src alpha', dstAlpha: 1 }, + }, + depth: { enable: false }, + }); + } edgesLoaded = true; edgeNeedsRedraw = true; @@ -1246,11 +1312,12 @@ function renderEdges() { const xRange = xd[1] - xd[0] || 1; const yRange = yd[1] - yd[0] || 1; - drawEdgesCmd({ + const drawProps = { scale: [2 / xRange, 2 / yRange], offset: [-(xd[1] + xd[0]) / xRange, -(yd[1] + yd[0]) / yRange], - color: edgeColor, - }); + }; + if (!hasPerEdgeColors) drawProps.color = edgeColor; + drawEdgesCmd(drawProps); edgeNeedsRedraw = false; } @@ -2509,7 +2576,9 @@ tbTheme.addEventListener('click', () => { isDarkTheme = !isDarkTheme; document.documentElement.dataset.theme = isDarkTheme ? '' : 'light'; scatterplot.set({ backgroundColor: isDarkTheme ? darkBg : lightBg }); - edgeColor = isDarkTheme ? parseEdgeColor(edgeStrokeStyle) : [1, 1, 1, edgeColor[3]]; + if (!hasPerEdgeColors) { + edgeColor = isDarkTheme ? parseEdgeColor(edgeStrokeStyle) : [1, 1, 1, edgeColor[3]]; + } edgeNeedsRedraw = true; if (typeof renderEdges === 'function') renderEdges(); tbTheme.innerHTML = isDarkTheme diff --git a/src/tmap/visualization/tmapviz.py b/src/tmap/visualization/tmapviz.py index adadf85..dc5ae5c 100644 --- a/src/tmap/visualization/tmapviz.py +++ b/src/tmap/visualization/tmapviz.py @@ -217,6 +217,17 @@ def _hex_to_css_rgba(hex_color: str, alpha: float = 1.0) -> str: return f"rgba({r}, {g}, {b}, {alpha_str})" +def _hex_colors_to_rgb_uint8(colors: Sequence[str]) -> NDArray[np.uint8]: + """Convert hex colors to an ``(n, 3)`` uint8 RGB array.""" + arr = np.empty((len(colors), 3), dtype=np.uint8) + for idx, color in enumerate(colors): + normalized = _normalize_hex_color(color).lstrip("#") + arr[idx, 0] = int(normalized[0:2], 16) + arr[idx, 1] = int(normalized[2:4], 16) + arr[idx, 2] = int(normalized[4:6], 16) + return arr + + # TODO(ISS-014): Implement categorical=True preserves listed colors when available def _colormap_to_hex(name: str) -> list[str]: """ @@ -411,6 +422,7 @@ def __init__(self) -> None: self._points_array: np.ndarray | None = None # Shape: (n, 2) self._edges_s: np.ndarray | None = None self._edges_t: np.ndarray | None = None + self._edge_colors: np.ndarray | None = None self._layout_keys: list[str] = [] self._labels_keys: list[str] = [] self._smiles_column: str | None = None @@ -1159,6 +1171,28 @@ def set_edges( self._edges_s = s_arr self._edges_t = t_arr + self._edge_colors = None + + def set_edge_colors(self, colors: Sequence[str]) -> None: + """Set one hex color per edge for visualization templates. + + Args: + colors: Hex color strings for each edge in the current edge order. + + Raises: + ValueError: If edges are not set or the color count differs from + the edge count. + """ + if self._edges_s is None or self._edges_t is None: + raise ValueError("set_edges must be called before set_edge_colors") + + edge_colors = _hex_colors_to_rgb_uint8(colors) + if len(edge_colors) != len(self._edges_s): + raise ValueError( + f"Edge color count must match edge count. " + f"Got {len(edge_colors)} colors for {len(self._edges_s)} edges." + ) + self._edge_colors = edge_colors def set_edge_style( self, @@ -1608,12 +1642,18 @@ def to_html(self, template_name: str = "base.html.j2") -> str: # Pack edges if present edges_b64 = "" + edge_colors_b64 = "" + edge_color_mode = "single" n_edges = 0 if self._edges_s is not None and self._edges_t is not None: n_edges = len(self._edges_s) edges_combined = np.concatenate([self._edges_s, self._edges_t]).astype(np.uint32) edges_compressed = gzip.compress(edges_combined.tobytes(), compresslevel=6) edges_b64 = base64.b64encode(edges_compressed).decode("ascii") + if self._edge_colors is not None: + edge_color_mode = "per-edge" + edge_colors_compressed = gzip.compress(self._edge_colors.tobytes(), compresslevel=6) + edge_colors_b64 = base64.b64encode(edge_colors_compressed).decode("ascii") # Build metadata (same flat structure as write_static) layout_options = list(self._layout_keys) @@ -1642,6 +1682,8 @@ def to_html(self, template_name: str = "base.html.j2") -> str: "opacity": self.opacity, "edgeStrokeStyle": _hex_to_css_rgba(self.edge_color, self.edge_opacity), "edgeWidth": self.edge_width, + "edgeColorMode": edge_color_mode, + "edgeColorDtype": "uint8" if edge_color_mode == "per-edge" else None, "backgroundColor": _hex_to_rgba(self.background_color), "layoutOptions": layout_options, "labelOptions": label_options, @@ -1683,6 +1725,7 @@ def to_html(self, template_name: str = "base.html.j2") -> str: inline_coords=coords_b64, inline_columns=columns_b64, inline_edges=edges_b64, + inline_edge_colors=edge_colors_b64, ) def write_html( @@ -1765,11 +1808,16 @@ def write_static( # Edges n_edges = 0 + edge_color_mode = "single" if self._edges_s is not None and self._edges_t is not None: n_edges = len(self._edges_s) edges_combined = np.concatenate([self._edges_s, self._edges_t]).astype(np.uint32) edges_compressed = gzip.compress(edges_combined.tobytes(), compresslevel=6) (output_dir / "edges.bin").write_bytes(edges_compressed) + if self._edge_colors is not None: + edge_color_mode = "per-edge" + edge_colors_compressed = gzip.compress(self._edge_colors.tobytes(), compresslevel=6) + (output_dir / "edge_colors.bin").write_bytes(edge_colors_compressed) # Columns columns_meta: dict[str, dict[str, Any]] = {} @@ -1847,6 +1895,8 @@ def write_static( "opacity": self.opacity, "edgeStrokeStyle": _hex_to_css_rgba(self.edge_color, self.edge_opacity), "edgeWidth": self.edge_width, + "edgeColorMode": edge_color_mode, + "edgeColorDtype": "uint8" if edge_color_mode == "per-edge" else None, "backgroundColor": _hex_to_rgba(self.background_color), "layoutOptions": layout_options, "labelOptions": label_options, @@ -1890,6 +1940,7 @@ def write_static( inline_coords="", inline_columns={}, inline_edges="", + inline_edge_colors="", ) (output_dir / "index.html").write_text(html, encoding="utf-8") diff --git a/tests/test_visualization.py b/tests/test_visualization.py index cb31f1e..b96512e 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -912,6 +912,67 @@ def test_custom_edge_style_in_header(self, viz_with_data): assert meta["edgeStrokeStyle"] == "rgba(255, 0, 51, 0.35)" assert meta["edgeWidth"] == 4.5 + def test_set_edge_colors_updates_values(self, viz_with_data): + """Per-edge colors should be normalized to packed RGB bytes.""" + viz, data = viz_with_data + viz.set_edges([0, 1], [1, 2]) + viz.set_edge_colors(["#f03", "#00aa11"]) + + assert viz._edge_colors is not None + assert viz._edge_colors.tolist() == [[255, 0, 51], [0, 170, 17]] + + def test_set_edge_colors_requires_edges(self, viz_with_data): + """Per-edge colors require an existing edge array.""" + viz, data = viz_with_data + + with pytest.raises(ValueError, match="set_edges must be called"): + viz.set_edge_colors(["#ffffff"]) + + def test_set_edge_colors_mismatched_length(self, viz_with_data): + """One color must be provided for each edge.""" + viz, data = viz_with_data + viz.set_edges([0, 1, 2], [1, 2, 3]) + + with pytest.raises(ValueError, match="must match edge count"): + viz.set_edge_colors(["#ffffff"]) + + def test_per_edge_colors_serialized_to_html(self, viz_with_data): + """Inline HTML should advertise the per-edge color payload.""" + viz, data = viz_with_data + viz.add_color_layout("value", data["continuous"]) + viz.set_edges([0, 1, 2], [1, 2, 3]) + viz.set_edge_colors(["#f03", "#00aa11", "#0000ff"]) + + html = viz.to_html() + + match = re.search( + r"const metadata = ({.*?});", + html, + re.DOTALL, + ) + assert match is not None + meta = json.loads(match.group(1)) + assert meta["edgeColorMode"] == "per-edge" + assert meta["edgeColorDtype"] == "uint8" + assert "edgeColorRaw" in html + + def test_per_edge_colors_serialized_to_static(self, viz_with_data, tmp_path): + """Static output should write edge color bytes beside edge indices.""" + viz, data = viz_with_data + viz.set_edges([0, 1], [1, 2]) + viz.set_edge_colors(["#f03", "#00aa11"]) + + out = viz.write_static(tmp_path / "out") + meta = json.loads((out / "metadata.json").read_text()) + colors = np.frombuffer( + gzip.decompress((out / "edge_colors.bin").read_bytes()), + dtype=np.uint8, + ).reshape(-1, 3) + + assert meta["edgeColorMode"] == "per-edge" + assert meta["edgeColorDtype"] == "uint8" + assert colors.tolist() == [[255, 0, 51], [0, 170, 17]] + class TestEdgeStyle: """Tests for edge style configuration."""