Skip to content

Commit

Permalink
discarding beads in nuclei; close #400; related to #403
Browse files Browse the repository at this point in the history
  • Loading branch information
vreuter committed Feb 19, 2025
1 parent f9b382a commit b09a240
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 83 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project will adhere to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [v0.13.0] - Unreleased

### Added
* Discard detected beads which are in nuclear regions. See [Issue 401](https://github.com/gerlichlab/looptrace/issues/400) and [Issue 403](https://github.com/gerlichlab/looptrace/issues/403).
This is designed to prevent accidentally using a FISH spot as a bead, which is especially likely when doing single-channel tracing, in which beads and FISH signal are captured at the same wavelength.
* Add a proximity filter between beads and FISH spots. In other words, discard any "FISH spot" which is too close to a bead. Again, this is related to the idea of mixing up the two, especially during single-channel tracing.
See [Issue 400](https://github.com/gerlichlab/looptrace/issues/400) and [Issue 403](https://github.com/gerlichlab/looptrace/issues/403).

## [v0.12.1] - 2025-02-19

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ RUN apt-get update -y && \
RUN mkdir /looptrace
WORKDIR /looptrace
COPY . /looptrace
RUN mv /looptrace/target/scala-3.5.2/looptrace-assembly-0.12.1.jar /looptrace/looptrace
RUN mv /looptrace/target/scala-3.5.2/looptrace-assembly-0.13.0-SNAPSHOT.jar /looptrace/looptrace

# Install new-ish R and necessary packages.
RUN echo "Installing R..." && \
Expand Down
7 changes: 2 additions & 5 deletions bin/cli/analyse_bead_discard_reasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gertils import ExtantFolder

from looptrace.ImageHandler import BeadRoisFilenameSpecification
from looptrace.bead_roi_generation import FAIL_CODE_COLUMN_NAME, INTENSITY_COLUMN_NAME, BeadRoiParameters
from looptrace.bead_roi_generation import INDEX_COLUMN_NAME, FAIL_CODE_COLUMN_NAME, INTENSITY_COLUMN_NAME, BeadRoiParameters


HISTOGRAM_FILENAME = "bead_rois_discard_analysis.json"
Expand Down Expand Up @@ -69,10 +69,7 @@ def workflow(*, root_folder: Path, output_folder: Path, reference_timepoint: Opt
assert spec.purpose is None, f"Non-null purpose for bead ROIs spec: {spec.purpose}"

logging.debug("Parsing bead ROIs file %s (%s)...", f, str(spec))
# Here we leave the index_col=0 in place since these aren't like the FISH spot ROI indices.
# This is because these are generated by a simple .to_csv() call on a DataFrame.
# See: bead_roi_generation.generate_all_bead_rois_from_getter
data = pd.read_csv(f, index_col=0)
data = pd.read_csv(f, index_col=INDEX_COLUMN_NAME)

total_roi_count += data.shape[0]
data[FAIL_CODE_COLUMN_NAME] = data[FAIL_CODE_COLUMN_NAME].fillna("")
Expand Down
164 changes: 121 additions & 43 deletions bin/cli/assign_spots_to_nucs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Filter detected spots for overlap with detected nuclei."""

import argparse
from enum import Enum
import logging
from pathlib import Path
from typing import *

import dask.array as da
from expression import Option
import numpy as np
import pandas as pd
import tqdm
Expand All @@ -16,37 +16,54 @@
from looptrace import FIELD_OF_VIEW_COLUMN
from looptrace.ImageHandler import ImageHandler
from looptrace.NucDetector import NucDetector
from looptrace.image_processing_functions import (
X_CENTER_COLNAME,
Y_CENTER_COLNAME,
Z_CENTER_COLNAME,
)

__author__ = "Kai Sandvold Beckwith"
__credits__ = ["Kai Sandvold Beckwith", "Vince Reuter"]


NUC_LABEL_COL = "nucleusNumber"

FieldOfViewName: TypeAlias = str

def _filter_rois_in_nucs(

def _determine_labels(
*,
field_of_view: FieldOfViewName,
rois: pd.DataFrame,
nuc_label_img: np.ndarray,
new_col: str,
*,
nuc_drifts: pd.DataFrame,
nuc_drift: pd.DataFrame,
spot_drifts: pd.DataFrame,
) -> pd.DataFrame:
timepoint: Option[int],
) -> pd.DataFrame:
"""
Check if a spot is in inside a segmented nucleus.
Arguments
---------
field_of_fiew: FieldOfViewName
Name of the field of view from which the passed ROIs table comes
rois : pd.DataFrame
ROI table to check, usually FISH spots (from regional barcodes)
nuc_label_img : np.ndarray
2D/3D label images, where 0 is outside any nuclear region and >0 inside a nuclear region
new_col : str
The name of the new column in the ROI table
nuc_drifts : pd.DataFrame
Data table with information on drift correction for nuclei
nuc_drift : pd.DataFrame
Data table with information on drift correction for nuclei, for the given FOV;
since we only image nuclei at one time point, this should be a table with a
single row
spot_drifts : pd.DataFrame
Data table with information on drift correction for FISH spots
timepoint : Option[int]
Optionally, a single timepoint from which the given ROIs table comes;
this will be the case (single timepoint) when it's a bead ROIs subtable
which is being passed to this function
Returns
-------
Expand All @@ -56,14 +73,14 @@ def _filter_rois_in_nucs(
# We're only interested here in adding a nucelus (or other region) ID column, so keep all other data.
new_rois = rois.copy()

def spot_in_nuc(row: Union[pd.Series, dict], nuc_label_img: np.ndarray):
base_idx = (int(row["yc"]), int(row["xc"]))
def spot_in_nuc(row: Union[pd.Series, dict], nuc_label_img: np.ndarray) -> int:
base_idx = (int(row[Y_CENTER_COLNAME]), int(row[X_CENTER_COLNAME]))
num_dim: int = len(nuc_label_img.shape)
if num_dim == 2:
idx = base_idx
else:
try:
idx_px_z = 0 if nuc_label_img.shape[-3] == 1 else int(row["zc"]) # Flat in z dimension?
idx_px_z = 0 if nuc_label_img.shape[-3] == 1 else int(row[Z_CENTER_COLNAME]) # Flat in z dimension?
except IndexError as e:
logging.error(f"IndexError ({e}) trying to get z-axis length from images with shape {nuc_label_img}")
raise
Expand All @@ -84,83 +101,143 @@ def spot_in_nuc(row: Union[pd.Series, dict], nuc_label_img: np.ndarray):
spot_label = 0
return int(spot_label)

# Check the type and content of the spots table.
if not isinstance(rois, pd.DataFrame):
raise TypeError(f"Spots table is not a data frame, but {type(rois).__name__}")
if FIELD_OF_VIEW_COLUMN in rois.columns:
logging.debug("Checking field of view column (%s) against given value: %s", FIELD_OF_VIEW_COLUMN, field_of_view)
match list(rois[FIELD_OF_VIEW_COLUMN].unique()):
case [obs_spot_fov]:
# NB: here we do NOT .removesuffix(".zarr"), beacuse this should already have been done if this field is present for this table.
if obs_spot_fov != field_of_view:
raise ValueError(f"Given FOV is {field_of_view}, but FOV (from column {FIELD_OF_VIEW_COLUMN}) in ROIs table is different: {obs_spot_fov}")
case obs_fovs:
raise ValueError(f"Expected exactly 1 unique FOV for nucleus label assignment, but got {len(obs_fovs)} in ROIs table: {obs_fovs}")
else:
logging.debug("Field of view column (%s) is absent, so no FOV validation will be done for ROIs", FIELD_OF_VIEW_COLUMN)

# Check the type and content of the spots drifts table.
if not isinstance(spot_drifts, pd.DataFrame):
raise TypeError(f"Nuclear drift for FOV {field_of_view} is not a data frame, but {type(spot_drifts).__name__}")
match list(spot_drifts[FIELD_OF_VIEW_COLUMN].unique()):
case [raw_obs_spot_drift_fov]:
obs_spot_drift_fov: FieldOfViewName = raw_obs_spot_drift_fov.removesuffix(".zarr")
if obs_spot_drift_fov != field_of_view:
raise ValueError(f"Given FOV is {field_of_view}, but FOV (from column {FIELD_OF_VIEW_COLUMN}) in spot drifts table is different: {obs_spot_drift_fov}")
case obs_fovs:
raise ValueError(f"Expected exactly 1 unique FOV for nucleus label assignment, but got {len(obs_fovs)} in spot drifts table: {obs_fovs}")

# Check the type, shape, and content of the nuclei drift table.
if not isinstance(nuc_drift, pd.DataFrame):
raise TypeError(f"Nuclear drift for FOV {field_of_view} is not a data frame, but {type(nuc_drift).__name__}")
if nuc_drift.shape[0] != 1:
raise DimensionalityError(f"Nuclear drift for FOV {field_of_view} is not exactly 1 row, but {nuc_drift.shape[0]} rows!")
else:
obs_nuc_fov: FieldOfViewName = list(nuc_drift[FIELD_OF_VIEW_COLUMN])[0].removesuffix(".zarr")
if obs_nuc_fov != field_of_view:
raise ValueError(f"Given FOV is {field_of_view}, but FOV (from column {FIELD_OF_VIEW_COLUMN}) in nuclear drift table is different: {obs_nuc_fov}")

# Handle either a single, fixed timepoint passed as an argument, or to extract this value from each row (ROI).
get_roi_time: Callable[[pd.Series], int] = timepoint.map(lambda t: (lambda _: t)).default_value(lambda r: r["timepoint"])

# Remove the labels column if it already exists.
new_rois.drop(columns=[new_col], inplace=True, errors="ignore")

rois_shifted = new_rois.copy()
shifts = []
shift_column_names = ["z", "y", "x"]
center_column_names = [Z_CENTER_COLNAME, Y_CENTER_COLNAME, X_CENTER_COLNAME]
drift_column_names = ["zDriftCoarsePixels", "yDriftCoarsePixels", "xDriftCoarsePixels"]

for _, row in rois_shifted.iterrows():
curr_pos_name = row[FIELD_OF_VIEW_COLUMN]
raw_nuc_drift_match = nuc_drifts[nuc_drifts[FIELD_OF_VIEW_COLUMN] == curr_pos_name]
if not isinstance(raw_nuc_drift_match, pd.DataFrame):
raise TypeError(f"Nuclear drift for FOV {curr_pos_name} is not a data frame, but {type(raw_nuc_drift_match).__name__}")
if not raw_nuc_drift_match.shape[0] == 1:
raise DimensionalityError(f"Nuclear drift for FOV {curr_pos_name} is not exactly 1 row, but {raw_nuc_drift_match.shape[0]} rows!")
drift_target = raw_nuc_drift_match[["zDriftCoarsePixels", "yDriftCoarsePixels", "xDriftCoarsePixels"]].to_numpy()
drift_roi = spot_drifts[(spot_drifts[FIELD_OF_VIEW_COLUMN] == curr_pos_name) & (spot_drifts["timepoint"] == row["timepoint"])][["zDriftCoarsePixels", "yDriftCoarsePixels", "xDriftCoarsePixels"]].to_numpy()
drift_target = nuc_drift[drift_column_names].to_numpy()
drift_roi = spot_drifts[spot_drifts["timepoint"] == get_roi_time(row)][drift_column_names].to_numpy()
shift = drift_target - drift_roi
shifts.append(shift[0])
shifts = pd.DataFrame(shifts, columns=["z", "y", "x"])
rois_shifted[["zc", "yc", "xc"]] = rois_shifted[["zc", "yc", "xc"]].to_numpy() - shifts[["z","y","x"]].to_numpy()
shifts = pd.DataFrame(shifts, columns=shift_column_names)
rois_shifted[center_column_names] = rois_shifted[center_column_names].to_numpy() - shifts[shift_column_names].to_numpy()

# Store the vector of nucleus IDs in a new column on the original ROI table.
new_rois.loc[:, new_col] = rois_shifted.apply(spot_in_nuc, nuc_label_img=nuc_label_img, axis=1)

return new_rois


def _add_nucleus_labels(
def add_nucleus_labels(
*,
rois_table: pd.DataFrame,
mask_images: Sequence[da.Array],
fov_names: Iterable[str],
mask_images: list[tuple[FieldOfViewName, da.Array]],
nuclei_drift_file: Path,
spots_drift_file: Path,
timepoint: Option[int],
) -> pd.DataFrame:

def query_table_for_fov(table: pd.DataFrame) -> Callable[[str], pd.DataFrame]:
def query_table_for_fov(table: pd.DataFrame) -> Callable[[FieldOfViewName], pd.DataFrame]:
return (lambda fov: table.query('fieldOfView == @fov'))

get_rois: Callable[[str], pd.DataFrame] = query_table_for_fov(rois_table)
get_rois: Callable[[FieldOfViewName], pd.DataFrame]
combine_subtables: Callable[[list[pd.DataFrame]], pd.DataFrame]
if timepoint.is_none():
get_rois = query_table_for_fov(rois_table)
combine_subtables = lambda ts: pd.concat(ts).sort_values([FIELD_OF_VIEW_COLUMN, "timepoint"])
else:
get_rois = lambda _: rois_table
def combine_subtables(ts: list[pd.DataFrame]) -> pd.DataFrame:
match ts:
case [t]:
return t
case list():
raise ValueError(f"Expected exactly one subtable but got {len(ts)}")
case _:
raise TypeError(f"Expected to be combining a list of subtables, but got {type(ts).__name__}")

logging.info("Reading drift file for nuclei: %s", nuclei_drift_file)
drift_table_nuclei = pd.read_csv(nuclei_drift_file, index_col=False)
get_nuc_drifts: Callable[[str], pd.DataFrame] = query_table_for_fov(drift_table_nuclei)
get_nuc_drift: Callable[[FieldOfViewName], pd.DataFrame] = query_table_for_fov(drift_table_nuclei)

logging.info("Reading coarse-drift file for spots: %s", spots_drift_file)
drift_table_spots = pd.read_csv(spots_drift_file, index_col=False)
get_spot_drifts: Callable[[str], pd.DataFrame] = query_table_for_fov(drift_table_spots)
get_spot_drifts: Callable[[FieldOfViewName], pd.DataFrame] = query_table_for_fov(drift_table_spots)

subtables: list[pd.DataFrame] = []

for i, pos in tqdm.tqdm(enumerate(fov_names)):
for pos, nuc_mask_image in tqdm.tqdm(mask_images):
fov_name: FieldOfViewName = pos.removesuffix(".zarr")
rois = get_rois(pos)
if len(rois) == 0:
logging.warning(f"No ROIs for FOV: {pos}")
logging.warning("No ROIs for FOV: %s", fov_name)
continue

nuc_drifts: pd.DataFrame = get_nuc_drifts(pos)
spot_drifts = get_spot_drifts(pos)

filter_kwargs = {"nuc_drifts": nuc_drifts, "spot_drifts": spot_drifts}
# TODO: this array indexing is sensitive to whether the mask and class images have the dummy time and channel dimensions or not.
# See: https://github.com/gerlichlab/looptrace/issues/247
logging.info(f"Assigning nuclei labels for spots from FOV: {pos}")
rois = _filter_rois_in_nucs(rois, nuc_label_img=mask_images[i].compute(), new_col=NUC_LABEL_COL, **filter_kwargs)
logging.info("Assigning nuclei labels for spots from FOV: %s", fov_name)
rois = _determine_labels(
field_of_view=fov_name,
rois=rois,
nuc_label_img=nuc_mask_image,
new_col=NUC_LABEL_COL,
nuc_drift=get_nuc_drift(pos),
spot_drifts=get_spot_drifts(pos),
timepoint=timepoint,
)
subtables.append(rois.copy())

return pd.concat(subtables).sort_values([FIELD_OF_VIEW_COLUMN, "timepoint"])
return combine_subtables(subtables)


def run_labeling(*, rois: pd.DataFrame, image_handler: ImageHandler, nuc_detector: Optional[NucDetector] = None) -> pd.DataFrame:
def run_labeling(
*,
rois: pd.DataFrame,
image_handler: ImageHandler,
timepoint: Option[int],
nuc_detector: Optional[NucDetector] = None,
) -> pd.DataFrame:
if nuc_detector is None:
nuc_detector = NucDetector(image_handler)
return _add_nucleus_labels(
fov_names: Iterable[str] = image_handler.image_lists[image_handler.spot_input_name]
return add_nucleus_labels(
rois_table=rois,
mask_images=nuc_detector.mask_images,
fov_names=image_handler.image_lists[image_handler.spot_input_name],
mask_images=[(pos, nuc_detector.mask_images[i]) for i, pos in enumerate(fov_names)],
nuclei_drift_file=nuc_detector.drift_correction_file__coarse,
spots_drift_file=image_handler.drift_correction_file__coarse,
timepoint=timepoint,
)


Expand Down Expand Up @@ -191,6 +268,7 @@ def workflow(
rois=pd.read_csv(input_file, index_col=False),
image_handler=H,
nuc_detector=N,
timepoint=Option.Nothing(),
)

if all_rois.shape[0] == 0:
Expand Down
Loading

0 comments on commit b09a240

Please sign in to comment.