Source code for bsv.plot_connectivity

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from matplotlib.path import Path
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from scipy.ndimage import gaussian_filter

from .atlas_utils import load_atlas, find_structure_indices, get_structure_color


[docs] def plot_connectivity(experiment_data, allen_atlas_path, output_region, number_of_chunks, number_of_pixels, plane, region_only, smoothing, color_limits, color, normalization_info='unknown', input_regions=None, region_groups=None, experiment_region_info=None, normalize_by_group=False, custom_slices=None, slice_averaging=0, atlas_type='allen', atlas_resolution=10): """Plot 2D projection density slices for a target brain region. Parameters ---------- experiment_data : numpy.ndarray Projection density array from :func:`fetch_connectivity_data`. allen_atlas_path : str Path to the Allen CCF atlas directory. output_region : str Target region acronym to visualize (e.g. ``'CP'``). number_of_chunks : int Number of evenly spaced slices to display. number_of_pixels : int Number of 2D histogram bins per axis for each slice. The physical bin size adapts to the spatial extent of the region (bin_size ≈ region_extent_voxels / number_of_pixels × atlas_resolution µm). plane : str ``'coronal'`` or ``'sagittal'``. region_only : bool Mask the display to the target region boundary. smoothing : float Gaussian smoothing sigma in pixels (0 for none). color_limits : str or list ``'global'``, ``'per_slice'``, or ``[min, max]``. color : list or None RGB colour(s) for region groups, or None for default. normalization_info : str, optional Label for the normalization used during data fetch. input_regions : list of str, optional Source region acronyms for grouped display. region_groups : list of int, optional Group assignment per input region. experiment_region_info : dict, optional Per-experiment metadata from :func:`fetch_connectivity_data`. normalize_by_group : bool, optional Normalize each group independently. custom_slices : list of int, optional Specific slice indices instead of evenly spaced. slice_averaging : int, optional Number of adjacent slices to average around each custom slice. atlas_type : str, optional Atlas type (default ``'allen'``). atlas_resolution : int, optional Atlas resolution in micrometres (10 or 20). Returns ------- proj_array : numpy.ndarray Projection matrix for each slice. proj_coords : list Coordinate information for each slice. """ if input_regions is None: input_regions = [] if region_groups is None: region_groups = [] if custom_slices is None: custom_slices = [] av, st = load_atlas(allen_atlas_path, atlas_type, atlas_resolution) atlas_slice_spacing = atlas_resolution # Ensure output_region is a string if isinstance(output_region, (list, tuple)): output_region = output_region[0] # Find structure indices for output region curr_plot_structure_idx = find_structure_indices(st, output_region) plot_structure_color = get_structure_color(st, curr_plot_structure_idx[0]) # Handle region groups if input_regions: n_regions = len(input_regions) if region_groups: if len(region_groups) != n_regions: raise ValueError(f'Length of region_groups ({len(region_groups)}) must match input_regions ({n_regions})') unique_groups = sorted(set(region_groups)) n_region_groups = len(unique_groups) region_groups_cell = [[j for j, rg in enumerate(region_groups) if rg == g] for g in unique_groups] else: n_region_groups = n_regions region_groups_cell = [[i] for i in range(n_regions)] else: n_region_groups = 1 region_groups_cell = [[0]] # Get chunk limits # Only use left hemisphere of atlas for structure limits half_ml = av.shape[2] // 2 structure_mask = np.isin(av[:, :, :half_ml], curr_plot_structure_idx) ap_vals, _, ml_vals = np.where(structure_mask) if custom_slices: custom_indices = [s * atlas_slice_spacing for s in custom_slices] chunks_region = [] for ci in custom_indices: chunks_region.append(ci - slice_averaging * atlas_slice_spacing) chunks_region.append(ci + slice_averaging * atlas_slice_spacing) chunks_region = sorted(set(chunks_region)) number_of_chunks = len(custom_slices) else: if plane == 'coronal': limits = [ap_vals.min(), ap_vals.max()] else: limits = [ml_vals.min(), ml_vals.max()] step = (limits[1] - limits[0]) / number_of_chunks chunks_region = [limits[0] + i * step for i in range(number_of_chunks + 1)] # Second pass: get ML x DV (coronal) or AP x DV (sagittal) bins boundary_projection = [None] * number_of_chunks projection_view_bins = [None] * number_of_chunks for i_chunk in range(number_of_chunks): chunk_start = int(round(chunks_region[i_chunk])) chunk_end = int(round(chunks_region[i_chunk + 1])) if plane == 'coronal': region_area = np.isin(av[chunk_start:chunk_end + 1, :, :half_ml], curr_plot_structure_idx) # AP, DV, ML -> get ML, AP, DV coordinates ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) ap_loc = ap_loc + chunk_start x_coords = ml_loc y_coords = dv_loc else: region_area = np.isin(av[:, :, chunk_start:chunk_end + 1], curr_plot_structure_idx) ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) ml_loc = ml_loc + chunk_start x_coords = ap_loc y_coords = dv_loc if len(x_coords) < 3: projection_view_bins[i_chunk] = [np.array([0]), np.array([0])] boundary_projection[i_chunk] = np.array([]) continue # Compute boundary using convex hull from scipy.spatial import ConvexHull points = np.column_stack([x_coords, y_coords]) try: hull = ConvexHull(points) hull_indices = np.append(hull.vertices, hull.vertices[0]) except Exception: hull_indices = np.arange(len(x_coords)) boundary_projection[i_chunk] = hull_indices x_min, x_max = x_coords.min(), x_coords.max() y_min, y_max = y_coords.min(), y_coords.max() x_edges = np.linspace(x_min, x_max, number_of_pixels + 1) y_edges = np.linspace(y_min, y_max, number_of_pixels + 1) projection_view_bins[i_chunk] = [x_edges, y_edges] # Collapse hemispheres half_slices = experiment_data.shape[2] // 2 if experiment_data.ndim == 4: n_groups = experiment_data.shape[3] collapsed = experiment_data[:, :, :half_slices, :] + experiment_data[:, :, ::-1, :][:, :, :half_slices, :] else: n_groups = 1 collapsed = experiment_data[:, :, :half_slices] + experiment_data[:, :, ::-1][:, :, :half_slices] # Number of plot rows: one per group present in the data. Grouping by # injection AP/ML/DV (or by region) in fetch_connectivity_data adds groups # along axis 3, and each becomes its own row. Falls back to the region-group # count so region-based grouping and the ungrouped single-row case are unchanged. n_rows = max(n_region_groups, n_groups) # Conversion factor: atlas voxels per projection grid voxel # Atlas is at atlas_resolution um, projection grid is at 100 um atlas_to_grid = 100 / atlas_resolution # Extract projection data for each chunk projection_matrix = [None] * number_of_chunks for i_chunk in range(number_of_chunks): x_edges = projection_view_bins[i_chunk][0] y_edges = projection_view_bins[i_chunk][1] if 0 in x_edges or 0 in y_edges: if n_groups > 1: projection_matrix[i_chunk] = np.zeros((number_of_pixels + 1, number_of_pixels + 1, n_groups)) else: projection_matrix[i_chunk] = np.zeros((number_of_pixels + 1, number_of_pixels + 1)) continue this_diff = np.mean(np.diff(chunks_region)) if plane == 'coronal': ap_start = max(0, int(round((chunks_region[i_chunk] - this_diff) / atlas_to_grid))) ap_end = min(collapsed.shape[0] - 1, int(round((chunks_region[i_chunk] + this_diff) / atlas_to_grid))) y_idx = np.clip((y_edges / atlas_to_grid).astype(int), 0, collapsed.shape[1] - 1) x_idx = np.clip((x_edges / atlas_to_grid).astype(int), 0, collapsed.shape[2] - 1 if collapsed.ndim >= 3 else 0) if collapsed.ndim == 4: data_slice = collapsed[ap_start:ap_end + 1][:, y_idx][:, :, x_idx, :] mean_data = np.nanmean(data_slice, axis=0) # Average over AP projtemp = mean_data.transpose(1, 0, 2) # DV x ML x groups -> ML x DV x groups else: data_slice = collapsed[ap_start:ap_end + 1][:, y_idx][:, :, x_idx] mean_data = np.nanmean(data_slice, axis=0) projtemp = mean_data.T else: ml_start = max(0, int(round((chunks_region[i_chunk] - this_diff) / atlas_to_grid))) ml_end = min(collapsed.shape[2] - 1 if collapsed.ndim >= 3 else 0, int(round((chunks_region[i_chunk] + this_diff) / atlas_to_grid))) x_idx = np.clip((x_edges / atlas_to_grid).astype(int), 0, collapsed.shape[0] - 1) y_idx = np.clip((y_edges / atlas_to_grid).astype(int), 0, collapsed.shape[1] - 1) if collapsed.ndim == 4: data_slice = collapsed[x_idx][:, y_idx][:, :, ml_start:ml_end + 1, :] mean_data = np.nanmean(data_slice, axis=2) projtemp = mean_data.transpose(1, 0, 2) else: data_slice = collapsed[x_idx][:, y_idx][:, :, ml_start:ml_end + 1] mean_data = np.nanmean(data_slice, axis=2) projtemp = mean_data.T projection_matrix[i_chunk] = projtemp # Group-wise normalization if normalize_by_group and projection_matrix[0].ndim == 3 and projection_matrix[0].shape[2] > 1: n_g = projection_matrix[0].shape[2] group_max = np.zeros(n_g) for i_g in range(n_g): for i_chunk in range(number_of_chunks): group_max[i_g] = max(group_max[i_g], np.nanmax(projection_matrix[i_chunk][:, :, i_g])) for i_chunk in range(number_of_chunks): for i_g in range(n_g): if group_max[i_g] > 0: projection_matrix[i_chunk][:, :, i_g] /= group_max[i_g] # Compute global color limits from data global_vmax = 0 for i_chunk in range(number_of_chunks): vals = projection_matrix[i_chunk] if vals is not None: global_vmax = max(global_vmax, np.nanmax(vals)) if global_vmax == 0: global_vmax = 1 # Plot fig, axes = plt.subplots(n_rows, number_of_chunks, figsize=(3 * number_of_chunks, 3 * n_rows), squeeze=False) fig.patch.set_facecolor('white') fig.canvas.manager.set_window_title('Fluorescence intensity') slice_aras = np.zeros(number_of_chunks) for i_chunk in range(number_of_chunks): # Get boundary for region masking chunk_start = int(round(chunks_region[i_chunk])) chunk_end = int(round(chunks_region[i_chunk + 1])) if plane == 'coronal': region_area = np.isin(av[chunk_start:chunk_end + 1, :, :half_ml], curr_plot_structure_idx) ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) else: region_area = np.isin(av[:, :, chunk_start:chunk_end + 1], curr_plot_structure_idx) ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) x_edges = projection_view_bins[i_chunk][0] y_edges = projection_view_bins[i_chunk][1] # Build in-polygon mask if plane == 'coronal': bnd_x, bnd_y = ml_loc, dv_loc else: bnd_x, bnd_y = ap_loc, dv_loc is_in = _build_region_mask(x_edges, y_edges, bnd_x, bnd_y) for i_rg in range(n_rows): ax = axes[i_rg, i_chunk] if projection_matrix[i_chunk].ndim == 3 and projection_matrix[i_chunk].shape[2] > i_rg: avg_data = projection_matrix[i_chunk][:, :, i_rg] else: avg_data = projection_matrix[i_chunk] if projection_matrix[i_chunk].ndim == 2 else projection_matrix[i_chunk][:, :, 0] # Mask outside region masked_data = avg_data.copy() masked_data[~is_in] = np.nan # Apply Gaussian smoothing if requested if smoothing and smoothing > 0: # Replace NaN with 0 for filtering, then restore NaN mask temp = np.nan_to_num(masked_data, nan=0.0) temp = gaussian_filter(temp, sigma=smoothing) temp[np.isnan(masked_data)] = np.nan masked_data = temp ax.imshow(masked_data.T, origin='upper' if plane == 'coronal' else 'lower', extent=[x_edges[0], x_edges[-1], y_edges[-1], y_edges[0]] if plane == 'coronal' else [x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]], cmap='gray_r', vmin=0, vmax=global_vmax, aspect='equal') ax.set_facecolor('0.5') # Plot boundary if len(bnd_x) >= 3: from scipy.spatial import ConvexHull pts = np.column_stack([bnd_x, bnd_y]) try: hull = ConvexHull(pts) hull_pts = pts[np.append(hull.vertices, hull.vertices[0])] ax.plot(hull_pts[:, 0], hull_pts[:, 1], color=plot_structure_color, linewidth=3.5) except Exception: pass ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') # 1 mm scale bar, placed just *below* the bottom-right panel so it never # overlaps the data. Axes are in atlas voxels (atlas_resolution µm each) # with aspect='equal', so 1 mm = 1000/res voxels. if i_rg == n_rows - 1 and i_chunk == number_of_chunks - 1: bar_units = 1000.0 / atlas_resolution scalebar = AnchoredSizeBar( ax.transData, bar_units, '1 mm', loc='upper right', bbox_to_anchor=(1.0, -0.02), bbox_transform=ax.transAxes, pad=0.1, borderpad=0.0, sep=4, color='black', frameon=False, size_vertical=max(bar_units * 0.04, 0.5), fontproperties=fm.FontProperties(size=14)) ax.add_artist(scalebar) # ARA slice level if custom_slices: this_slice_ara = custom_slices[i_chunk] else: this_slice_ara = int(round(np.nanmean(chunks_region[i_chunk:i_chunk + 2]) / atlas_to_grid)) slice_aras[i_chunk] = this_slice_ara if i_rg == 0: if i_chunk == 0: prefix = 'ARA level (cor.): ' if plane == 'coronal' else 'ARA level (sag.): ' ax.set_title(f'{prefix}{this_slice_ara}', fontsize=16) else: ax.set_title(str(this_slice_ara), fontsize=16) if i_chunk == 0 and input_regions and i_rg < len(region_groups_cell): regions_in_group = region_groups_cell[i_rg] group_names = [input_regions[r] for r in regions_in_group] label = '+'.join(group_names) ax.text(-0.15, 0.5, label, transform=ax.transAxes, fontweight='bold', fontsize=18, ha='right', va='center', rotation=90) elif i_chunk == 0 and n_groups > 1: # Per-row sub-label: the exact injection-site coordinate of this group # (e.g. "AP 7560 µm"). The shared "Injection location (CCF)" header is # added once after the loop. Falls back to a generic group index. label = f'Group {i_rg + 1}' if experiment_region_info: centers = experiment_region_info.get('group_centers') axis = experiment_region_info.get('grouping_axis') if (centers is not None and axis and i_rg < len(centers) and np.isfinite(centers[i_rg])): label = f'{axis} {centers[i_rg]:.0f} µm' ax.text(-0.15, 0.5, label, transform=ax.transAxes, fontweight='bold', fontsize=13, ha='right', va='center', rotation=90) plt.tight_layout() # Single shared header for coordinate-grouped rows (AP/ML/DV); the per-row # sub-labels above give the actual coordinate of each level. if (not input_regions and n_groups > 1 and experiment_region_info and experiment_region_info.get('group_centers') is not None and experiment_region_info.get('grouping_axis')): fig.text(0.015, 0.5, 'Injection location (CCF)', rotation=90, va='center', ha='center', fontweight='bold', fontsize=16) plt.show(block=False) # Build return arrays if n_groups <= 1: proj_array = np.stack([projection_matrix[s] if projection_matrix[s].ndim == 2 else projection_matrix[s][:, :, 0] for s in range(number_of_chunks)], axis=-1) else: proj_array = np.stack(projection_matrix, axis=-1) proj_coords = [] for s in range(number_of_chunks): coords = list(projection_view_bins[s]) + [slice_aras[s] * 10] proj_coords.append(coords) return proj_array, proj_coords
def _build_region_mask(x_edges, y_edges, bnd_x, bnd_y): """Build boolean mask of pixels inside the region boundary.""" nx = len(x_edges) ny = len(y_edges) is_in = np.zeros((nx, ny), dtype=bool) if len(bnd_x) < 3: return is_in from scipy.spatial import ConvexHull pts = np.column_stack([bnd_x, bnd_y]) try: hull = ConvexHull(pts) hull_pts = pts[hull.vertices] path = Path(hull_pts) for ix in range(nx): for iy in range(ny): is_in[ix, iy] = path.contains_point([x_edges[ix], y_edges[iy]]) except Exception: is_in[:] = True return is_in