Source code for bsv.plot_injections_combined

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from matplotlib.path import Path

from .atlas_utils import load_atlas, find_structure_indices, get_structure_color


[docs] def plot_injections_combined(experiment_imgs, allen_atlas_path, input_regions, number_of_slices, number_of_pixels, plane, region_only, smoothing, color_limits, color, normalization_method, experiment_region_info=None, atlas_type='allen', atlas_resolution=10): """Plot injection sites for multiple regions in a combined grid. Parameters ---------- experiment_imgs : numpy.ndarray Projection density array from :func:`fetch_connectivity_data`. allen_atlas_path : str Path to the Allen CCF atlas directory. input_regions : list of str Source region acronyms to display. number_of_slices : int Number of slices per region. number_of_pixels : int Number of 2D histogram bins per axis for each slice. The physical bin size adapts to the spatial extent of the region (bin_size ≈ region_extent_voxels / number_of_pixels × atlas_resolution µm). plane : str ``'coronal'`` or ``'sagittal'``. region_only : bool Mask to the region boundary. smoothing : float Gaussian smoothing sigma in pixels. color_limits : str or list Colour scale specification. color : list or None RGB colour(s). normalization_method : str Normalization label. experiment_region_info : dict, optional Per-experiment metadata. atlas_type : str, optional Atlas type (default ``'allen'``). atlas_resolution : int, optional Atlas resolution in micrometres (10 or 20). """ av, st = load_atlas(allen_atlas_path, atlas_type, atlas_resolution) n_regions = len(input_regions) fig, axes = plt.subplots(n_regions, number_of_slices, figsize=(3 * number_of_slices, 3 * n_regions), squeeze=False) fig.patch.set_facecolor('white') fig.suptitle('Injection Sites by Region', fontsize=14, fontweight='bold') # Collapse hemispheres half = experiment_imgs.shape[2] // 2 collapsed = experiment_imgs[:, :, :half] + experiment_imgs[:, :, ::-1][:, :, :half] global_vmax = np.nanmax(collapsed) if np.nanmax(collapsed) > 0 else 1 for i_reg in range(n_regions): region = input_regions[i_reg] curr_idx = find_structure_indices(st, region) if not curr_idx: print(f'Warning: No structure found for region {region}') continue region_color = get_structure_color(st, curr_idx[0]) # Get structure limits half_ml = av.shape[2] // 2 structure_mask = np.isin(av[:, :, :half_ml], curr_idx) ap_vals, _, ml_vals = np.where(structure_mask) if plane == 'coronal': limits = [ap_vals.min(), ap_vals.max()] else: limits = [ml_vals.min(), ml_vals.max()] step = (limits[1] - limits[0]) / number_of_slices chunks = [limits[0] + j * step for j in range(number_of_slices + 1)] for i_chunk in range(number_of_slices): ax = axes[i_reg, i_chunk] chunk_start = int(round(chunks[i_chunk])) chunk_end = int(round(chunks[i_chunk + 1])) if plane == 'coronal': region_area = np.isin(av[chunk_start:chunk_end + 1, :, :half_ml], curr_idx) ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) ap_loc += chunk_start bnd_x, bnd_y = ml_loc, dv_loc else: region_area = np.isin(av[:, :, chunk_start:chunk_end + 1], curr_idx) ml_loc, ap_loc, dv_loc = np.where(region_area.transpose(2, 0, 1)) ml_loc += chunk_start bnd_x, bnd_y = ap_loc, dv_loc if len(bnd_x) < 3: ax.axis('off') continue x_range = [bnd_x.min(), bnd_x.max()] y_range = [bnd_y.min(), bnd_y.max()] if x_range[1] - x_range[0] == 0 or y_range[1] - y_range[0] == 0: ax.axis('off') continue x_edges = np.linspace(x_range[0], x_range[1], number_of_pixels + 1) y_edges = np.linspace(y_range[0], y_range[1], number_of_pixels + 1) this_diff = np.mean(np.diff(chunks)) if plane == 'coronal': ap_s = max(0, int(round((chunks[i_chunk] - this_diff) / 10))) ap_e = min(collapsed.shape[0] - 1, int(round((chunks[i_chunk] + this_diff) / 10))) y_idx = np.clip((y_edges / 10).astype(int), 0, collapsed.shape[1] - 1) x_idx = np.clip((x_edges / 10).astype(int), 0, collapsed.shape[2] - 1) data_slice = collapsed[ap_s:ap_e + 1][:, y_idx][:, :, x_idx] mean_data = np.nanmean(data_slice, axis=0) projtemp = mean_data.T else: ax.axis('off') continue # Mask outside region try: hull = ConvexHull(np.column_stack([bnd_x, bnd_y])) hull_pts = np.column_stack([bnd_x, bnd_y])[hull.vertices] path = Path(hull_pts) mask = np.zeros(projtemp.shape[:2], dtype=bool) for ix in range(len(x_edges)): for iy in range(len(y_edges)): mask[ix, iy] = path.contains_point([x_edges[ix], y_edges[iy]]) projtemp[~mask] = np.nan except Exception: pass ax.imshow(projtemp.T, origin='upper' if plane == 'coronal' else 'lower', extent=[x_edges[0], x_edges[-1], y_edges[-1], y_edges[0]], cmap='gray_r', vmin=0, vmax=global_vmax, aspect='equal') ax.set_facecolor('0.5') # Plot boundary try: hull = ConvexHull(np.column_stack([bnd_x, bnd_y])) hull_pts_plot = np.column_stack([bnd_x, bnd_y])[np.append(hull.vertices, hull.vertices[0])] ax.plot(hull_pts_plot[:, 0], hull_pts_plot[:, 1], color=region_color, linewidth=2) except Exception: pass ax.set_xticks([]) ax.set_yticks([]) ax.axis('off') if i_chunk == 0: ax.set_ylabel(region, fontweight='bold', fontsize=10, rotation=90) ax.axis('on') ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) if i_reg == 0: slice_ara = int(round(np.mean(chunks[i_chunk:i_chunk + 2]) / 10)) ax.set_title(f'ARA {slice_ara}', fontsize=9) plt.tight_layout() plt.show(block=False)