-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdataloaders.py
74 lines (60 loc) · 3.18 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as nd
import random
import SimpleITK as sitk
import torch
import torch.utils as utils
def collate_to_list_unsupervised(batch):
sources = [item[0].view(item[0].size(0), item[0].size(1)) for item in batch]
targets = [item[1].view(item[1].size(0), item[1].size(1)) for item in batch]
return sources, targets
def collate_to_list_segmentation(batch):
sources = [item[0].view(item[0].size(0), item[0].size(1)) for item in batch]
targets = [item[1].view(item[1].size(0), item[1].size(1)) for item in batch]
source_masks = [item[2].view(item[2].size(0), item[2].size(1)) for item in batch]
target_masks = [item[3].view(item[3].size(0), item[3].size(1)) for item in batch]
return sources, targets, source_masks, target_masks
class UnsupervisedLoader(utils.data.Dataset):
def __init__(self, data_path, transforms=None, randomly_swap=False):
self.data_path = data_path
self.all_ids = os.listdir(self.data_path)
self.transforms = transforms
self.randomly_swap = randomly_swap
def __len__(self):
return len(self.all_ids)
def __getitem__(self, idx):
case_id = self.all_ids[idx]
source_path = os.path.join(self.data_path, str(case_id), "source.mha")
target_path = os.path.join(self.data_path, str(case_id), "target.mha")
source = sitk.GetArrayFromImage(sitk.ReadImage(source_path))
target = sitk.GetArrayFromImage(sitk.ReadImage(target_path))
if self.transforms is not None:
source, target, _ = self.transforms(source, target)
if self.randomly_swap:
if random.random() > 0.5:
pass
else:
source, target = target, source
source_tensor, target_tensor = torch.from_numpy(source.astype(np.float32)), torch.from_numpy(target.astype(np.float32))
return source_tensor, target_tensor
class SegmentationLoader(utils.data.Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.all_ids = os.listdir(self.data_path)
def __len__(self):
return len(self.all_ids)
def __getitem__(self, idx):
case_id = self.all_ids[idx]
source_path = os.path.join(self.data_path, str(case_id), "source.mha")
target_path = os.path.join(self.data_path, str(case_id), "target.mha")
source_mask_path = os.path.join(self.data_path, str(case_id), "source_mask.mha")
target_mask_path = os.path.join(self.data_path, str(case_id), "target_mask.mha")
source = sitk.GetArrayFromImage(sitk.ReadImage(source_path))
target = sitk.GetArrayFromImage(sitk.ReadImage(target_path))
source_mask = sitk.GetArrayFromImage(sitk.ReadImage(source_mask_path))
target_mask = sitk.GetArrayFromImage(sitk.ReadImage(target_mask_path))
source_tensor, target_tensor = torch.from_numpy(source.astype(np.float32)), torch.from_numpy(target.astype(np.float32))
source_mask_tensor, target_mask_tensor = torch.from_numpy(source_mask.astype(np.float32)), torch.from_numpy(target_mask.astype(np.float32))
return source_tensor, target_tensor, source_mask_tensor, target_mask_tensor