Source code for bsv.plot_upstream_projectome

import os
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

from .atlas_utils import load_atlas, find_structure_indices, load_projection_info
from .fetch_connectivity_images import fetch_connectivity_images, PROJECTION_GRID_SIZE
from .fetch_connectivity_summary import fetch_connectivity_summary, load_injection_summary


[docs] def plot_upstream_projectome(experiment_ids, source_regions, target_region, save_location, allen_atlas_path, atlas_resolution=10, atlas_type='allen', slice_thickness=2, density_percentile=99, static_ap=None, save_path=None, save_gif=None, gif_ap_step=2, gif_fps=8, gif_width=720): """Interactive coronal slice viewer of upstream projections into a target region. Displays a Jupyter widget with a coronal AP slider showing: - Atlas structure boundaries rendered at full atlas resolution - Per-source-region coloured heatmaps of projection density in the target region voxels (intensity ∝ projection density, colour = source region) - Injection site scatter dots sized by injection volume Parameters ---------- experiment_ids : list of int Experiment IDs, as returned by :func:`find_connectivity_experiments` with *target_regions* set. source_regions : list of str Source region acronyms to display (subset of upstream regions). target_region : str Target brain region acronym (e.g. ``'CP'``). save_location : str Directory where ``density.raw`` and injection summary files are cached. allen_atlas_path : str Path to the Allen CCF atlas directory. atlas_resolution : int, optional Atlas resolution in micrometres (default 10). atlas_type : str, optional Atlas type (default ``'allen'``). slice_thickness : int, optional Number of 100 µm grid slices averaged either side of the displayed AP index when computing the projection heatmap (default 2). density_percentile : float, optional Percentile of non-zero target-voxel density values used as the colour saturation ceiling (default 99). Lower values make dim signals brighter; raise toward 100 to avoid saturation on bright spots. static_ap : int or None, optional If given, render this single AP slice (100 µm grid index) as a static figure instead of launching the interactive widget — used to produce documentation figures headlessly. save_path : str or None, optional When *static_ap* is set, save the rendered figure to this path. save_gif : str or None, optional If given, render an animated GIF scrolling through the AP slices that contain the target region and write it to this path (requires Pillow). Skips the interactive widget. gif_ap_step : int, optional Step between AP slices when building the GIF (default 2). gif_fps : int, optional Frames per second of the GIF (default 8). gif_width : int, optional Width in pixels to downscale GIF frames to (default 720). """ AP, DV, ML = PROJECTION_GRID_SIZE # Number of full-res atlas voxels per 100 µm projection grid voxel factor = 100 // atlas_resolution # ------------------------------------------------------------------ atlas print("Loading atlas...") av, st = load_atlas(allen_atlas_path, atlas_type, atlas_resolution) # Full-resolution atlas dimensions (DV x ML for display) DV_full = av.shape[1] ML_full = av.shape[2] target_idx = find_structure_indices(st, target_region) # Target mask at 100 µm grid resolution (for density computation) av_grid = av[::factor, ::factor, ::factor][:AP, :DV, :ML] target_mask = np.isin(av_grid, target_idx) # (AP, DV, ML) bool del av_grid # free memory — only needed for target_mask # ---------------------------------------------------------- experiment map projection_info = load_projection_info() exp_to_region = {} for exp_id in experiment_ids: row = projection_info[projection_info['id'] == exp_id] if len(row) > 0: abbrev = row.iloc[0]['structure_abbrev'] # Match exact or parent prefix so 'ACAd' maps to source region 'ACA' for src in source_regions: if abbrev == src or abbrev.startswith(src): exp_to_region[exp_id] = src break valid_ids = [e for e in experiment_ids if e in exp_to_region] # ------------------------------------------------------- colours (tab10) # Index tab10's discrete palette directly (blue, orange, green, ...) so the # source regions get maximally-distinct hues — this matters where two # regions overlap the same target voxel and their colours alpha-blend. cmap = plt.cm.get_cmap('tab10') region_colors = {reg: np.array(cmap(i % 10)[:3]) for i, reg in enumerate(source_regions)} # ----------------------------------------- load density volumes & metadata print(f"Loading {len(valid_ids)} experiment volumes...") densities = {} # exp_id → (AP, DV, ML) float32 inj_ap_coord = {} # in atlas voxel units (atlas_resolution µm each) inj_dv_coord = {} inj_ml_coord = {} inj_vol = {} for exp_id in valid_ids: exp_dir = os.path.join(save_location, str(exp_id)) raw_path = os.path.join(exp_dir, 'density.raw') if not os.path.exists(raw_path): fetch_connectivity_images(exp_id, exp_dir) if os.path.exists(raw_path): vol = np.fromfile(raw_path, dtype='<f4').reshape(PROJECTION_GRID_SIZE, order='F') densities[exp_id] = vol row = projection_info[projection_info['id'] == exp_id] if len(row) > 0: r = row.iloc[0] try: coords_str = str(r['injection_coordinates']) coords = [float(x) for x in coords_str.strip('[]').split(',')] # CSV coordinates are in µm; convert to atlas voxel indices inj_ap_coord[exp_id] = coords[0] / atlas_resolution inj_dv_coord[exp_id] = coords[1] / atlas_resolution inj_ml_coord[exp_id] = coords[2] / atlas_resolution except Exception: inj_ap_coord[exp_id] = inj_dv_coord[exp_id] = inj_ml_coord[exp_id] = -1 try: inj_vol[exp_id] = float(r['injection_volume']) except Exception: inj_vol[exp_id] = 0.1 # -------------------------------- pre-compute mean density per source region region_density = {} for reg in source_regions: reg_ids = [e for e in valid_ids if exp_to_region.get(e) == reg and e in densities] if reg_ids: stack = np.stack([densities[e] for e in reg_ids], axis=0) region_density[reg] = np.mean(stack, axis=0) else: region_density[reg] = np.zeros(PROJECTION_GRID_SIZE, dtype=np.float32) # Normalisation ceiling: 95th percentile of non-zero target-voxel densities # across all source regions, so dim signals are still clearly visible. all_target_vals = np.concatenate([ region_density[reg][target_mask] for reg in source_regions if target_mask.any() ]) if source_regions else np.array([]) nonzero = all_target_vals[all_target_vals > 0] global_max = float(np.percentile(nonzero, density_percentile)) if nonzero.size else 1.0 if global_max == 0: global_max = 1.0 # ---------------------------------------------- dot-size scaling (20-200) vols_arr = np.array([inj_vol.get(e, 0.1) for e in valid_ids]) v_min, v_max = vols_arr.min(), vols_arr.max() if v_max == v_min: v_max = v_min + 1e-9 dot_sizes = 20 + 180 * (vols_arr - v_min) / (v_max - v_min) # ------------------------------- full-resolution atlas background helper def _atlas_bg(ap_idx): """(DV_full, ML_full) grayscale at full atlas resolution.""" # Use the middle atlas slice of the 100 µm window ap_atlas = min(ap_idx * factor + factor // 2, av.shape[0] - 1) sl = av[ap_atlas] # (DV_full, ML_full) structure indices brain = sl > 0 target = np.isin(sl, target_idx) h = sl[:, :-1] != sl[:, 1:] v = sl[:-1, :] != sl[1:, :] bnd = np.zeros(sl.shape, dtype=bool) bnd[:, :-1] |= h; bnd[:, 1:] |= h bnd[:-1, :] |= v; bnd[1:, :] |= v bnd &= brain # White background + brain, light-gray target region, black boundaries img = np.ones(sl.shape, dtype=float) # white outside brain img[brain] = 1.0 # white brain tissue img[target] = 0.85 # light gray target region img[bnd] = 0.0 # black structure boundaries return img # ----------------------------------------------------------- draw function def _draw(ap_idx): fig, ax = plt.subplots(figsize=(9, 5)) fig.patch.set_facecolor('white') ax.set_facecolor('white') # Full-resolution atlas background (DV_full x ML_full) ax.imshow(_atlas_bg(ap_idx), cmap='gray', vmin=0, vmax=1, origin='upper', aspect='equal', interpolation='nearest', extent=[0, ML_full, DV_full, 0]) # Coloured heatmaps — density at 100 µm, upsampled to full atlas res ap_lo = max(0, ap_idx - slice_thickness) ap_hi = min(AP - 1, ap_idx + slice_thickness) tm = target_mask[ap_idx] # (DV, ML) at 100 µm sorted_regions = sorted( source_regions, key=lambda r: float(np.mean(region_density[r][ap_idx][tm])) if tm.any() else 0.0, ) for reg in sorted_regions: color = region_colors[reg] density_slice = np.mean(region_density[reg][ap_lo:ap_hi + 1], axis=0) # (DV, ML) alpha = np.clip(density_slice / global_max, 0, 1) alpha[~tm] = 0.0 # Upsample to full atlas resolution then trim to exact atlas shape alpha_up = np.repeat(np.repeat(alpha, factor, axis=0), factor, axis=1) alpha_up = alpha_up[:DV_full, :ML_full] rgba = np.zeros((DV_full, ML_full, 4), dtype=float) rgba[:, :, :3] = color rgba[:, :, 3] = alpha_up ax.imshow(rgba, origin='upper', aspect='equal', interpolation='nearest', extent=[0, ML_full, DV_full, 0]) # Injection site scatter (coords in atlas voxel units) ap_center = ap_idx * factor + factor / 2.0 ap_window = slice_thickness * factor + factor # half-window in atlas voxels for i, exp_id in enumerate(valid_ids): ap_val = inj_ap_coord.get(exp_id, -1) if abs(ap_val - ap_center) > ap_window: continue reg = exp_to_region.get(exp_id) if reg is None: continue ml_val = inj_ml_coord.get(exp_id, -1) dv_val = inj_dv_coord.get(exp_id, -1) if ml_val < 0 or dv_val < 0: continue ax.scatter(ml_val, dv_val, s=dot_sizes[i], c=[region_colors[reg]], edgecolors='black', linewidths=0.5, zorder=10, alpha=0.9) # Legend (input / source regions, colour-coded) handles = [ plt.scatter([], [], s=80, c=[region_colors[r]], edgecolors='black', linewidths=0.5, label=r) for r in source_regions ] if handles: leg = ax.legend(handles=handles, loc='upper right', title='Input regions', framealpha=0.6, labelcolor='black', fontsize=9) leg.get_frame().set_facecolor('white') leg.get_title().set_color('black') ax.set_xlim(0, ML_full) ax.set_ylim(DV_full, 0) ax.set_xticks([]); ax.set_yticks([]) ax.set_title( f'Target region: {target_region}' f' AP = {ap_idx * 100} µm (CCF space)', color='black', fontsize=11, pad=6) plt.tight_layout() plt.show() # ------------------------------------- animated GIF (scroll through AP) if save_gif is not None: try: from PIL import Image except ImportError as exc: raise ImportError("Saving a GIF requires Pillow (pip install pillow).") from exc present = np.where(target_mask.any(axis=(1, 2)))[0] ap_lo, ap_hi = (int(present.min()), int(present.max())) if present.size else (0, AP - 1) ap_indices = list(range(ap_lo, ap_hi + 1, max(1, int(gif_ap_step)))) frames = [] for ap in ap_indices: _draw(ap) fig = plt.gcf() fig.canvas.draw() w, h = fig.canvas.get_width_height() rgba = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).reshape(h, w, 4) img = Image.fromarray(rgba[:, :, :3].copy()) if gif_width and img.width > gif_width: img = img.resize((int(gif_width), round(img.height * gif_width / img.width))) frames.append(img.convert('P', palette=Image.ADAPTIVE, colors=256)) plt.close(fig) if frames: frames[0].save(save_gif, save_all=True, append_images=frames[1:], duration=int(1000 / max(gif_fps, 1)), loop=0, optimize=True) print(f'Saved GIF: {save_gif} ({len(frames)} frames)') return # ------------------------------------- static render (docs / headless use) if static_ap is not None: _draw(int(static_ap)) if save_path: plt.savefig(save_path, facecolor='white', bbox_inches='tight') print(f'Saved: {save_path}') return # --------------------------------------------------------- widget assembly out = widgets.Output() slider = widgets.IntSlider( min=0, max=AP - 1, value=AP // 2, description='AP (100 µm)', continuous_update=False, style={'description_width': 'initial'}, layout=widgets.Layout(width='520px'), ) def _on_change(change): with out: clear_output(wait=True) _draw(slider.value) slider.observe(_on_change, names='value') display(widgets.VBox([slider, out])) with out: _draw(slider.value)