From c33f89b92e446a8d9ad46b0166f3c8ec26337657 Mon Sep 17 00:00:00 2001 From: Eric Peterson Date: Mon, 2 Sep 2024 22:13:05 -0700 Subject: [PATCH 1/3] Initial commit for a multiprocess fitting script --- WrapImage/nifti_wrapper.py | 217 +++++++++++------- conftest.py | 46 ++++ tests/utilities/__init__.py | 0 tests/utilities/unit_tests/__init__.py | 0 .../unit_tests/test_diffusion_utils.py | 54 +++++ tests/utilities/unit_tests/test_file_io.py | 23 ++ utilities/process/__init__.py | 0 utilities/process/diffusion_utils.py | 202 ++++++++++++++++ utilities/process/file_io.py | 59 +++++ 9 files changed, 522 insertions(+), 79 deletions(-) create mode 100644 tests/utilities/__init__.py create mode 100644 tests/utilities/unit_tests/__init__.py create mode 100644 tests/utilities/unit_tests/test_diffusion_utils.py create mode 100644 tests/utilities/unit_tests/test_file_io.py create mode 100644 utilities/process/__init__.py create mode 100644 utilities/process/diffusion_utils.py create mode 100644 utilities/process/file_io.py diff --git a/WrapImage/nifti_wrapper.py b/WrapImage/nifti_wrapper.py index 5eae67ee..404e0caa 100644 --- a/WrapImage/nifti_wrapper.py +++ b/WrapImage/nifti_wrapper.py @@ -1,64 +1,13 @@ import argparse -import json -import os -import nibabel as nib +from utilities.process.file_io import read_nifti_file, read_bval_file, read_bvec_file, save_nifti_file +from utilities.process.diffusion_utils import find_directions, find_shells, normalize_series from src.wrappers.OsipiBase import OsipiBase import numpy as np from tqdm import tqdm +from tqdm.contrib.concurrent import process_map +from functools import partial -def read_nifti_file(input_file): - """ - For reading the 4d nifti image - """ - nifti_img = nib.load(input_file) - return nifti_img.get_fdata(), nifti_img.header - -def read_json_file(json_file): - """ - For reading the json file - """ - - if not os.path.exists(json_file): - raise FileNotFoundError(f"File '{json_file}' not found.") - - with open(json_file, "r") as f: - try: - json_data = json.load(f) - except json.JSONDecodeError as e: - raise ValueError(f"Error decoding JSON in file '{json_file}': {e}") - - return json_data - -def read_bval_file(bval_file): - """ - For reading the bval file - """ - if not os.path.exists(bval_file): - raise FileNotFoundError(f"File '{bval_file}' not found.") - - bval_data = np.genfromtxt(bval_file, dtype=float) - return bval_data - -def read_bvec_file(bvec_file): - """ - For reading the bvec file - """ - if not os.path.exists(bvec_file): - raise FileNotFoundError(f"File '{bvec_file}' not found.") - - bvec_data = np.genfromtxt(bvec_file) - bvec_data = np.transpose(bvec_data) # Transpose the array - return bvec_data - -def save_nifti_file(data, output_file, affine=None, **kwargs): - """ - For saving the 3d nifti images of the output of the algorithm - """ - if affine is None: - affine = np.eye(data.ndim + 1) - output_img = nib.nifti1.Nifti1Image(data, affine , **kwargs) - nib.save(output_img, output_file) def loop_over_first_n_minus_1_dimensions(arr): """ @@ -75,6 +24,28 @@ def loop_over_first_n_minus_1_dimensions(arr): flat_view = arr[idx].flatten() yield idx, flat_view +def generate_data(data, bvals, b0_indices, groups, total_iteration): + num_directions = groups.shape[1] + data = data.reshape(total_iteration, -1) + for idx in range(total_iteration): + for dir in range(num_directions): + # print('yielding') + yield (data[idx, groups[:, dir]].flatten(), bvals[:, groups[:, dir]].ravel(), b0_indices[:, groups[:, dir]].ravel()) + +def osipi_fit(fitfunc, data_bvals): + data, bvals, b0_indices = data_bvals + data = normalize_series(data, b0_indices) + # print(f'data.shape {data.shape} data {data} bvals {bvals}') + return fitfunc(data, bvals) + +# def osipi_fit(fitfunc, bvals, data, f_image, Dp_image, D_image, index): +# bval_index = len(f_image) % len(bvals) +# print(f'data.shape {data.shape} index {index} data[index] {data[index]} bvals.shape {bvals.shape} bval_index {bval_index} bvals {bvals[:, bval_index]}') +# [f_fit, Dp_fit, D_fit] = fitfunc(data[index], bvals[:, bval_index]) +# f_image[index] = f_fit +# Dp_image[index] = Dp_fit +# D_image[index] = D_fit + if __name__ == "__main__": @@ -82,45 +53,133 @@ def loop_over_first_n_minus_1_dimensions(arr): parser.add_argument("input_file", type=str, help="Path to the input 4D NIfTI file.") parser.add_argument("bvec_file", type=str, help="Path to the b-vector file.") parser.add_argument("bval_file", type=str, help="Path to the b-value file.") - parser.add_argument("--affine", type=float, nargs="+", help="Affine matrix for NIfTI image.") + parser.add_argument("--nproc", type=int, default=0, help="Number of processes to use, -1 disabled multprocessing, 0 automatically determines number, >0 uses that number.") + parser.add_argument("--group_directions", default=False, action="store_true", help="Fit all directions together") + parser.add_argument("--affine", type=float, default=None, nargs="+", help="Affine matrix for NIfTI image.") parser.add_argument("--algorithm", type=str, default="OJ_GU_seg", help="Select the algorithm to use.") - parser.add_argument("--algorithm_args", nargs=argparse.REMAINDER, help="Additional arguments for the algorithm.") + parser.add_argument("--algorithm_args", default={}, nargs=argparse.REMAINDER, help="Additional arguments for the algorithm.") + args = parser.parse_args() try: # Read the 4D NIfTI file data, _ = read_nifti_file(args.input_file) + data = data[0::4, 0::4, 0::2, :] + print(f'data.shape {data.shape}') # Read the b-vector, and b-value files bvecs = read_bvec_file(args.bvec_file) bvals = read_bval_file(args.bval_file) + # print(f'bvals.size {bvals.shape} bvecs.size {bvecs.shape}') + print(bvals) + print(bvecs) + shells, bval_indices, b0_indices = find_shells(bvals) + num_b0 = np.count_nonzero(b0_indices) + print(shells) + print(bval_indices) + # print(b0_indices) + # print(bvecs) + + # print('vectors') + vectors, bvec_indices, groups = find_directions(bvecs, b0_indices) + print(vectors) + print(bvec_indices) + print(f'groups {groups}') + + # split_bval_bvec(bvec_indices, num_vectors) + # quit() + # Pass additional arguments to the algorithm - - fit = OsipiBase(algorithm=args.algorithm) - f_image = [] - Dp_image = [] - D_image = [] + fit = OsipiBase(algorithm=args.algorithm, **args.algorithm_args) + + # n = data.ndim + output_shape = list(data.shape[:-1]) + # if args.group_directions: + # input_data = data + # input_bvals = np.atleast_2d(bvals) + # else: + # num_directions = groups.shape[1] + # measurements = np.count_nonzero(groups[:, 0]) + # print(f"group_length {num_directions}") + # input_shape = output_shape.copy() + + # print(f"groups[:, 0] {groups[:,0]} {np.count_nonzero(groups[:, 0])}") + # input_shape.append(num_directions) + # input_shape.append(measurements) + # output_shape.append(num_directions) + # print(f"input_shape {input_shape}") + # input_data = np.zeros(input_shape) + # input_bvals = np.zeros([measurements, num_directions]) + # for group_idx in range(num_directions): + # print(f"group {group_idx} {groups[:, group_idx]}") + # input_data[..., group_idx, :] = data[..., groups[:, group_idx]] + # input_bvals[:, group_idx] = bvals[groups[:, group_idx]] + # if args.group_directions: + # input_data = data + # input_bvals = np.atleast_2d(bvals) + # else: + input_data = data + input_bvals = np.atleast_2d(bvals) + b0_indices = np.atleast_2d(b0_indices) + print(f"data.shape {data.shape}") + print(f"input_data.shape {input_data.shape}") + + + + + voxel_iteration = np.prod(output_shape) + group_iteration = groups.shape[1] + total_iteration = voxel_iteration * group_iteration + output_shape.append(group_iteration) + f_image = np.zeros(output_shape) + Dp_image = np.zeros(output_shape) + D_image = np.zeros(output_shape) + print(f_image.shape) # This is necessary for the tqdm to display progress bar. - n = data.ndim - total_iteration = np.prod(data.shape[:n-1]) - for idx, view in tqdm(loop_over_first_n_minus_1_dimensions(data), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration): - [f_fit, Dp_fit, D_fit] = fit.osipi_fit(view, bvals) - f_image.append(f_fit) - Dp_image.append(Dp_fit) - D_image.append(D_fit) - - # Convert lists to NumPy arrays - f_image = np.array(f_image) - Dp_image = np.array(Dp_image) - D_image = np.array(D_image) - - # Reshape arrays if needed - f_image = f_image.reshape(data.shape[:data.ndim-1]) - Dp_image = Dp_image.reshape(data.shape[:data.ndim-1]) - D_image = D_image.reshape(data.shape[:data.ndim-1]) + + # total_iteration = np.prod(data.shape[:n-1]) + print(f'input_bvals {input_bvals}') + print(f'voxel_iteration {voxel_iteration} input_data.shape {input_data.shape}') + # print(f'input_data[5000] {input_data.reshape(total_iteration, -1)[5000]}') + + + # fit_partial = partial(osipi_fit, fit.osipi_fit, input_bvals, input_data.reshape(total_iteration, -1), f_image.reshape(total_iteration), Dp_image.reshape(total_iteration), D_image.reshape(total_iteration)) + fit_partial = partial(osipi_fit, fit.osipi_fit) + + + if args.nproc >= 0: + print('multiprocess fitting') + gd = generate_data(input_data, input_bvals, b0_indices, groups, voxel_iteration) + map_args = [fit_partial, gd] + chunksize = round(total_iteration / args.nproc) if args.nproc > 0 else round(total_iteration / 128) + print(f'chunksize {chunksize}') + map_kwargs = {'desc':f"{args.algorithm} is fitting", 'dynamic_ncols':True, 'total':total_iteration, 'chunksize':chunksize} + if args.nproc > 0: + map_kwargs['max_workers'] = args.nproc + result = process_map(*map_args, **map_kwargs) + output = np.asarray(result) + print(f'output.shape {output.shape}') + output = output.reshape([*output_shape, 3]) + f_image = output[..., 0] + print(f'f_img.shape {f_image.shape}') + Dp_image = output[..., 1] + D_image = output[..., 2] + # print(result) + # if args.nproc == 0: # TODO: can this be done more elegantly, I just want to omit a single parameter here + # process_map(fit_partial, range(total_iteration), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration) + # else: + # process_map(fit_partial, range(total_iteration), max_workers=args.nproc, desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration) + else: + for idx, view in tqdm(loop_over_first_n_minus_1_dimensions(data), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration): + [f_fit, Dp_fit, D_fit] = fit.osipi_fit(view, bvals) + f_image[idx] = f_fit + Dp_image[idx] = Dp_fit + D_image[idx] = D_fit + + print("finished fitting") save_nifti_file(f_image, "f.nii.gz", args.affine) save_nifti_file(Dp_image, "dp.nii.gz", args.affine) diff --git a/conftest.py b/conftest.py index 086585c2..f6405515 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,10 @@ import pathlib import json import csv +import tempfile +import os +import random +import numpy as np # import datetime @@ -178,3 +182,45 @@ def data_list(filename): bvals = bvals['bvalues'] for name, data in all_data.items(): yield name, bvals, data + +@pytest.fixture +def bval_bvec_info(): + shells = [0, 10, 20, 50, 100, 200, 500, 1000] + # random.shuffle(shells) + bvals = np.concatenate((shells, random.choices(shells, k=10)), axis=0) + + vecs = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0.707, 0.707, 0], [0.5, 0.5, 0.5], [0, 0.707, 0.707], [0.707, 0, 0.707]] + for idx in range(len(vecs)): + if np.linalg.norm(vecs[idx]) != 0: + vecs[idx] = vecs[idx]/np.linalg.norm(vecs[idx]) + bvecs = [] + vecs_idx = 1 # the first index is needed for the true output, but not needed here + for idx in range(len(bvals)): + if bvals[idx] == 0: + bvecs.append(np.asarray([0, 0, 0])) + elif vecs_idx < len(vecs): + bvecs.append(vecs[vecs_idx]) + vecs_idx += 1 + else: + bvecs.append(random.choice(vecs[1:])) # don't put a b0 in where it shouldn't be + print(f'raw bvals {bvals}') + print(f'raw bvecs {bvecs}') + + + with tempfile.NamedTemporaryFile(mode='wt', delete=False) as fp_val, tempfile.NamedTemporaryFile(mode='wt', delete=False) as fp_vec: + writer = csv.writer(fp_val, delimiter=' ') + for bval in bvals: + writer.writerow((bval,)) + fp_val.close() + writer = csv.writer(fp_vec, delimiter=' ') + for bvec in bvecs: + writer.writerow(bvec) + fp_vec.close() + yield (fp_val.name, np.asarray(shells), bvals, fp_vec.name, np.asarray(vecs), np.asarray(bvecs)) + os.unlink(fp_val.name) # use NamedTemporaryFile with delete_on_close with later python versions + os.path.exists(fp_val.name) + os.unlink(fp_vec.name) # use NamedTemporaryFile with delete_on_close with later python versions + os.path.exists(fp_vec.name) + + + diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/unit_tests/__init__.py b/tests/utilities/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utilities/unit_tests/test_diffusion_utils.py b/tests/utilities/unit_tests/test_diffusion_utils.py new file mode 100644 index 00000000..7b9c8e46 --- /dev/null +++ b/tests/utilities/unit_tests/test_diffusion_utils.py @@ -0,0 +1,54 @@ +import numpy as np +import numpy.testing as npt +from utilities.process.file_io import read_bval_file, read_bvec_file +from utilities.process.diffusion_utils import find_shells, find_directions, normalize_series + + +#TODO: test without b0 +#TODO: test symmetry +def test_read_bval_bvec(bval_bvec_info): + bval_name, shells, bvals, bvec_name, directions, bvecs = bval_bvec_info + saved_bvals = read_bval_file(bval_name) + npt.assert_equal(bvals, np.asarray(saved_bvals)) + saved_shells, bval_indices, b0 = find_shells(saved_bvals) + npt.assert_equal(shells, saved_shells, "Shells do not match") + npt.assert_equal(saved_bvals, [saved_shells[index] for index in bval_indices], "Bvalue indices are incorrect") + + saved_bvecs = read_bvec_file(bvec_name) + npt.assert_allclose(np.asarray(bvecs), np.asarray(saved_bvecs), err_msg="Incorrectly saved bvectors") + vectors, bvec_indices, groups = find_directions(saved_bvecs, b0) + assert vectors.shape[0] == groups.shape[1] + 1, "Number of vectors is correct" + assert vectors.shape == np.asarray(directions).shape, "Number of elements in directions does not match" + directions_set = set() + for direction in directions: + directions_set.add(tuple(direction)) + vectors_set = set() + for vector in vectors: + vectors_set.add(tuple(vector)) + assert directions_set == vectors_set, "Elements in directions does not match" + npt.assert_equal(saved_bvecs, [vectors[index] for index in bvec_indices], "Bvector indices are incorrect") + +def test_normalization(): + original = np.atleast_2d([[10, 10], [10, 10], [5, 5], [5, 5]]).T + + indices = [True, False, False, False] + updated = normalize_series(original.copy(), indices) + npt.assert_allclose(original / 10, updated, err_msg="Normalization with 1 point failed") + + indices = [True, True, False, False] + updated = normalize_series(original.copy(), indices) + npt.assert_allclose(original / 10, updated, err_msg="Normalization with 2 points failed") + + indices = [False, True, True, False] + updated = normalize_series(original.copy(), indices) + npt.assert_allclose(original / 7.5, updated, err_msg="Normalization with 2 different points failed") + + indices = [False, False, False, True] + updated = normalize_series(original.copy(), indices) + npt.assert_allclose(original / 5, updated, err_msg="Normalization with 1 final point failed") + + original = np.asarray([10, 5]) + indices = [True, False] + + updated = normalize_series(original.copy(), indices) + npt.assert_allclose(original / 10, updated, err_msg="Normalization of 1D failed") diff --git a/tests/utilities/unit_tests/test_file_io.py b/tests/utilities/unit_tests/test_file_io.py new file mode 100644 index 00000000..27b5fe0c --- /dev/null +++ b/tests/utilities/unit_tests/test_file_io.py @@ -0,0 +1,23 @@ +import tempfile +import os +import numpy as np +import numpy.testing as npt +from utilities.process.file_io import save_nifti_file, read_nifti_file, read_bval_file, read_bvec_file + + +def test_nifti_read_write(): + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, 'my_nifti.nii.gz') + data = np.random.rand(7, 8, 9) + save_nifti_file(data, path) + assert os.path.exists(path), "Nifti file does not exist" + saved_data, saved_hdr = read_nifti_file(path) + npt.assert_equal(data, saved_data, "Nifti data does not match") + +def test_read_bval_bvec(bval_bvec_info): + bval_name, shells, bvals, bvec_name, directions, bvecs = bval_bvec_info + assert bvecs.shape[1] == 3, "Bvec input is not Nx3" + saved_bvals = read_bval_file(bval_name) + npt.assert_equal(bvals, np.asarray(saved_bvals), "Bvalues do not match") + saved_bvecs = read_bvec_file(bvec_name) + npt.assert_allclose(bvecs, np.asarray(saved_bvecs), err_msg="Bvectors do not match") diff --git a/utilities/process/__init__.py b/utilities/process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utilities/process/diffusion_utils.py b/utilities/process/diffusion_utils.py new file mode 100644 index 00000000..23f3d78f --- /dev/null +++ b/utilities/process/diffusion_utils.py @@ -0,0 +1,202 @@ +import numpy as np + + +def angle_between(v1, v2): + """ Returns the angle in radians between vectors 'v1' and 'v2':: + + >>> angle_between((1, 0, 0), (0, 1, 0)) + 1.5707963267948966 + >>> angle_between((1, 0, 0), (1, 0, 0)) + 0.0 + >>> angle_between((1, 0, 0), (-1, 0, 0)) + 3.141592653589793 + """ + nv1 = np.linalg.norm(v1) + nv2 = np.linalg.norm(v2) + if nv1 == 0 or nv2 == 0: + return 0 + v1_u = v1 / nv1 + v2_u = v2 / nv2 + return np.degrees(np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))) + +def find_shells(bvals_raw, rtol=0.01, atol=1): + """ Automatically finds shells from a bvalue file. + + Given a raw list of bvalues and optional relative and absolute tolerances, + return the unique shells, shell indices, and b0 image indices. + + Parameters + ------------ + bvals_raw : list + List of bvalues as acquired + rtol : float + Relative tolerance of bvalues + atol : float + Absolute tolerance of bvalues + + Returns + ---------- + shells : list + The mean values of the shells + indices : list + The indices of the shells in the original bvalue list + b0 : list + A boolean array of locations of b0 values + """ + # if we know the number of values in each shell we could use np.partition + # but I don't think that's very useful. + + bvals_raw = np.asarray(bvals_raw) + assert len(bvals_raw.shape) == 1, "Must be a 1D array" + + # group the bvalues based on tolerances + bvals_indices = np.argsort(bvals_raw) # sort order + bvals = bvals_raw[bvals_indices] + index = np.zeros_like(bvals, dtype=int) + idx_lower = 0 + for idx_higher in range(1, len(bvals)): + if np.allclose(bvals[idx_lower], bvals[idx_higher], rtol, atol): + index[idx_higher] = index[idx_lower] + else: + index[idx_higher] = index[idx_lower] + 1 + idx_lower = idx_higher + + # invert the sort to return the shell indices of the unsorted bvalues + bvals_inverse_sort = np.empty_like(bvals_indices) + bvals_inverse_sort[bvals_indices] = np.arange(bvals_indices.size) + + # find the mean b-values of the shells + shells = [np.mean(bvals[index==idx]) for idx in set(index)] + indices = index[bvals_inverse_sort] + return shells, indices, indices Date: Tue, 3 Sep 2024 09:32:03 -0700 Subject: [PATCH 2/3] Fix doctest failure --- utilities/process/diffusion_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utilities/process/diffusion_utils.py b/utilities/process/diffusion_utils.py index 23f3d78f..953b49bb 100644 --- a/utilities/process/diffusion_utils.py +++ b/utilities/process/diffusion_utils.py @@ -2,14 +2,14 @@ def angle_between(v1, v2): - """ Returns the angle in radians between vectors 'v1' and 'v2':: + """ Returns the angle in degrees between vectors 'v1' and 'v2':: >>> angle_between((1, 0, 0), (0, 1, 0)) - 1.5707963267948966 + 90.0 >>> angle_between((1, 0, 0), (1, 0, 0)) 0.0 >>> angle_between((1, 0, 0), (-1, 0, 0)) - 3.141592653589793 + 180.0 """ nv1 = np.linalg.norm(v1) nv2 = np.linalg.norm(v2) From 2a4690954779cedec16ee071c90a71c5ad8627e8 Mon Sep 17 00:00:00 2001 From: Eric Peterson Date: Thu, 17 Oct 2024 20:52:10 -0700 Subject: [PATCH 3/3] Added docstrings --- WrapImage/nifti_wrapper.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/WrapImage/nifti_wrapper.py b/WrapImage/nifti_wrapper.py index 404e0caa..daab0eae 100644 --- a/WrapImage/nifti_wrapper.py +++ b/WrapImage/nifti_wrapper.py @@ -25,6 +25,19 @@ def loop_over_first_n_minus_1_dimensions(arr): yield idx, flat_view def generate_data(data, bvals, b0_indices, groups, total_iteration): + """ + Generates data samples for a multiprocess fitting + + Args: + data: The raw data to be sampled + bvals: The bvalues + b0_indices: The b0 indices in the data + groups: The group indices in the data + total_iterations: The total number of iterations to generate + + Yields: + A tuple containing matching: data, bvalues, and b0_indices + """ num_directions = groups.shape[1] data = data.reshape(total_iteration, -1) for idx in range(total_iteration): @@ -33,6 +46,16 @@ def generate_data(data, bvals, b0_indices, groups, total_iteration): yield (data[idx, groups[:, dir]].flatten(), bvals[:, groups[:, dir]].ravel(), b0_indices[:, groups[:, dir]].ravel()) def osipi_fit(fitfunc, data_bvals): + """ + Fit the data using the provided fit function + + Args: + fitfunc: The fit function + data_bvals: The tuple of data, bvals, and b0_indices + + Returns: + The fitted values + """ data, bvals, b0_indices = data_bvals data = normalize_series(data, b0_indices) # print(f'data.shape {data.shape} data {data} bvals {bvals}')