Source code for bsv.plot_connectivity_matrix

"""Connectivity matrix heatmap visualization."""

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

from .atlas_utils import load_atlas, find_structure_indices


[docs] def plot_connectivity_matrix( projection_data, allen_atlas_path, target_regions, metric='mean', atlas_type='allen', atlas_resolution=10, normalize_rows=False, normalize_cols=False, cmap='viridis', annotate=True, figsize=None, title=None, ): """Create an N x M connectivity matrix heatmap. Build a matrix showing connectivity strength from N source regions to M target regions and display it as a heatmap. Parameters ---------- projection_data : dict Dictionary mapping source region acronyms to tuples returned by ``fetch_connectivity_data``: ``(combined_projection, combined_info, individual_projections, experiment_info)``. Only the first element (the projection array) is used. allen_atlas_path : str Path to the directory containing atlas annotation volume and structure tree files. target_regions : list of str List of target region acronyms (e.g., ``['CP', 'ACB', 'SNr']``). metric : str, optional Connectivity metric to compute: - ``'mean'`` (default): Mean projection density in target region. - ``'max'``: Maximum projection density in target region. - ``'volume'``: Total projection volume (sum of density * voxel volume). atlas_type : str, optional Atlas type (default ``'allen'``). atlas_resolution : int, optional Atlas resolution in micrometres (default 10). normalize_rows : bool, optional If True, normalize each row (source) to [0, 1] by dividing by row max. normalize_cols : bool, optional If True, normalize each column (target) to [0, 1] by dividing by col max. cmap : str, optional Matplotlib colormap name (default ``'viridis'``). annotate : bool, optional If True, display numeric values in each cell (default True). figsize : tuple of float, optional Figure size ``(width, height)``. If None, auto-calculated from matrix size. title : str, optional Plot title. If None, no title is displayed. Returns ------- connectivity_matrix : numpy.ndarray 2D array of shape ``(n_sources, n_targets)`` with connectivity values. fig : matplotlib.figure.Figure The matplotlib Figure object. Examples -------- >>> from bsv import fetch_connectivity_data, find_connectivity_experiments >>> from bsv import plot_connectivity_matrix >>> source_regions = ['VISam', 'VISp'] >>> projection_data = {} >>> for src in source_regions: ... exp_ids = find_connectivity_experiments(injection_structure=src) ... data = fetch_connectivity_data( ... experiment_ids=exp_ids, ... save_location='./cache', ... file_name=f'{src}_proj' ... ) ... projection_data[src] = data >>> target_regions = ['CP', 'ACB'] >>> matrix, fig = plot_connectivity_matrix( ... projection_data=projection_data, ... allen_atlas_path='./atlas', ... target_regions=target_regions, ... metric='mean' ... ) """ if metric not in ('mean', 'max', 'volume'): raise ValueError(f"metric must be 'mean', 'max', or 'volume', got '{metric}'") # Load atlas av, st = load_atlas(allen_atlas_path, atlas_type=atlas_type, atlas_resolution=atlas_resolution) # Projection grid is 100 um resolution: shape (132, 80, 114) # Atlas is typically 10 um: need to downsample atlas to match proj_resolution = 100 # um scale_factor = atlas_resolution / proj_resolution # Downsample annotation volume to projection grid resolution if scale_factor != 1.0: av_downsampled = zoom(av, scale_factor, order=0) # nearest-neighbor else: av_downsampled = av # Build masks for each target region target_masks = {} for region in target_regions: indices = find_structure_indices(st, region) if not indices: print(f"Warning: No structures found matching '{region}'") target_masks[region] = np.zeros(av_downsampled.shape, dtype=bool) else: mask = np.isin(av_downsampled, indices) target_masks[region] = mask # Get source regions in consistent order source_regions = list(projection_data.keys()) n_sources = len(source_regions) n_targets = len(target_regions) # Voxel volume in mm^3 (100 um = 0.1 mm per side) voxel_volume = (proj_resolution / 1000.0) ** 3 # Build connectivity matrix connectivity_matrix = np.zeros((n_sources, n_targets)) for i, src in enumerate(source_regions): data = projection_data[src] # Handle both tuple returns and direct array if isinstance(data, tuple): proj = data[0] # combined_projection else: proj = data # Average across groups if multiple groups exist if proj.ndim == 4: proj = np.mean(proj, axis=3) for j, tgt in enumerate(target_regions): mask = target_masks[tgt] # Handle shape mismatch between projection and mask if proj.shape != mask.shape: # Resize mask to match projection zoom_factors = [p / m for p, m in zip(proj.shape, mask.shape)] mask_resized = zoom(mask.astype(float), zoom_factors, order=0) > 0.5 else: mask_resized = mask masked_values = proj[mask_resized] if len(masked_values) == 0 or np.all(np.isnan(masked_values)): connectivity_matrix[i, j] = 0.0 elif metric == 'mean': connectivity_matrix[i, j] = np.nanmean(masked_values) elif metric == 'max': connectivity_matrix[i, j] = np.nanmax(masked_values) elif metric == 'volume': connectivity_matrix[i, j] = np.nansum(masked_values) * voxel_volume # Apply normalization if normalize_rows: row_max = connectivity_matrix.max(axis=1, keepdims=True) row_max[row_max == 0] = 1 # Avoid division by zero connectivity_matrix = connectivity_matrix / row_max if normalize_cols: col_max = connectivity_matrix.max(axis=0, keepdims=True) col_max[col_max == 0] = 1 connectivity_matrix = connectivity_matrix / col_max # Create figure if figsize is None: figsize = (max(6, n_targets * 1.2), max(4, n_sources * 0.8)) fig, ax = plt.subplots(figsize=figsize) # Plot heatmap im = ax.imshow(connectivity_matrix, cmap=cmap, aspect='auto') # Set ticks and labels ax.set_xticks(np.arange(n_targets)) ax.set_yticks(np.arange(n_sources)) ax.set_xticklabels(target_regions) ax.set_yticklabels(source_regions) # Rotate x labels for readability plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor') # Add colorbar metric_labels = { 'mean': 'Mean projection density', 'max': 'Max projection density', 'volume': 'Total projection volume (mm³)', } cbar = fig.colorbar(im, ax=ax) cbar.set_label(metric_labels[metric]) # Annotate cells with values if annotate: # Determine text color based on background norm_matrix = (connectivity_matrix - connectivity_matrix.min()) if connectivity_matrix.max() > connectivity_matrix.min(): norm_matrix = norm_matrix / (connectivity_matrix.max() - connectivity_matrix.min()) for i in range(n_sources): for j in range(n_targets): value = connectivity_matrix[i, j] # Use white text on dark cells, black on light text_color = 'white' if norm_matrix[i, j] < 0.5 else 'black' # Format based on magnitude if metric == 'volume': text = f'{value:.2e}' if value < 0.01 else f'{value:.3f}' else: text = f'{value:.2e}' if value < 0.001 else f'{value:.4f}' ax.text(j, i, text, ha='center', va='center', color=text_color, fontsize=8) # Labels and title ax.set_xlabel('Target Region') ax.set_ylabel('Source Region') if title: ax.set_title(title) plt.tight_layout() return connectivity_matrix, fig