Skip to content

Commit

Permalink
removal of dangerous assumption of maintenance of parallel array-like…
Browse files Browse the repository at this point in the history
…s of equal length; force the caller to pair up names and mask images
  • Loading branch information
vreuter committed Feb 19, 2025
1 parent 2b4b2f6 commit 7d0a14c
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions bin/cli/assign_spots_to_nucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

NUC_LABEL_COL = "nucleusNumber"

FieldOfViewName: TypeAlias = str


def _determine_labels(
rois: pd.DataFrame,
Expand Down Expand Up @@ -111,8 +113,7 @@ def spot_in_nuc(row: Union[pd.Series, dict], nuc_label_img: np.ndarray) -> int:
def add_nucleus_labels(
*,
rois_table: pd.DataFrame,
mask_images: Sequence[da.Array],
fov_names: Iterable[str],
mask_images: list[tuple[FieldOfViewName, da.Array]],
nuclei_drift_file: Path,
spots_drift_file: Path,
) -> pd.DataFrame:
Expand All @@ -132,7 +133,7 @@ def query_table_for_fov(table: pd.DataFrame) -> Callable[[str], pd.DataFrame]:

subtables: list[pd.DataFrame] = []

for i, pos in tqdm.tqdm(enumerate(fov_names)):
for pos, nuc_mask_image in tqdm.tqdm(mask_images):
rois = get_rois(pos)
if len(rois) == 0:
logging.warning(f"No ROIs for FOV: {pos}")
Expand All @@ -144,8 +145,8 @@ def query_table_for_fov(table: pd.DataFrame) -> Callable[[str], pd.DataFrame]:
filter_kwargs = {"nuc_drifts": nuc_drifts, "spot_drifts": spot_drifts}
# TODO: this array indexing is sensitive to whether the mask and class images have the dummy time and channel dimensions or not.
# See: https://github.com/gerlichlab/looptrace/issues/247
logging.info(f"Assigning nuclei labels for spots from FOV: {pos}")
rois = _determine_labels(rois, nuc_label_img=mask_images[i].compute(), new_col=NUC_LABEL_COL, **filter_kwargs)
logging.info("Assigning nuclei labels for spots from FOV: %s", pos)
rois = _determine_labels(rois, nuc_label_img=nuc_mask_image, new_col=NUC_LABEL_COL, **filter_kwargs)
subtables.append(rois.copy())

return pd.concat(subtables).sort_values([FIELD_OF_VIEW_COLUMN, "timepoint"])
Expand All @@ -154,10 +155,10 @@ def query_table_for_fov(table: pd.DataFrame) -> Callable[[str], pd.DataFrame]:
def run_labeling(*, rois: pd.DataFrame, image_handler: ImageHandler, nuc_detector: Optional[NucDetector] = None) -> pd.DataFrame:
if nuc_detector is None:
nuc_detector = NucDetector(image_handler)
fov_names: Iterable[str] = image_handler.image_lists[image_handler.spot_input_name]
return add_nucleus_labels(
rois_table=rois,
mask_images=nuc_detector.mask_images,
fov_names=image_handler.image_lists[image_handler.spot_input_name],
mask_images=[(pos, nuc_detector.mask_images[i]) for i, pos in enumerate(fov_names)],
nuclei_drift_file=nuc_detector.drift_correction_file__coarse,
spots_drift_file=image_handler.drift_correction_file__coarse,
)
Expand Down

0 comments on commit 7d0a14c

Please sign in to comment.