Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

不同数据加载方式对训练速度的影响? #48

Open
Tobyzai opened this issue Jun 5, 2024 · 2 comments
Open

不同数据加载方式对训练速度的影响? #48

Tobyzai opened this issue Jun 5, 2024 · 2 comments

Comments

@Tobyzai
Copy link

Tobyzai commented Jun 5, 2024

我将数据的加载方式改成一般使用的根据路径索引批量读取,
而不是先将所有数据先转换为.npy,并一次性读取到变量里。
源代码的方式需要占用大量的内存,训练速度快。
而改写后的数据加载方式,内存占用小,但训练速度很慢,一卡一卡的。

我想问,这是什么原因造成的呢?是mamba只适合这种数据加载方式吗?如果只能先预加载数据才能提升速度,那这在实际的实时推理应用中似乎就没有意义了。
以下是改写的数据加载代码(一般的模型都会使用这种方法):

class isic_loader(Dataset):
""" dataset class for Brats datasets
"""

def __init__(self, path_Data, root, train=True, Test=False):
    super(isic_loader, self)
    self.train = train
    self.path_Data = path_Data
    self.root = root
    if train:
        self.data_list = self.load_anno()
    else:
        if Test:
            self.data_list = self.load_anno()
        else:
            self.data_list = self.load_anno()

def __getitem__(self, indx):
    item = self.data_list[indx]
    img, seg = self.item_loader(item)

    return img, seg

def random_rot_flip(self, image, label):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label

def random_rotate(self, image, label):
    angle = np.random.randint(20, 80)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label

def __len__(self):
    return len(self.data_list)

def load_anno(self):
    data_list = []
    with open(self.path_Data, 'r') as fp:
        data_list.extend([x.strip().split(' ') for x in fp.readlines()])
    return data_list

def item_loader(self, item):
    full_paths = [os.path.join(self.root, x) for x in item]
    img_path, gt_path = full_paths
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    seg = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

    if self.train:
        img = img / 255.

        seg = seg / 255.

        if random.random() > 0.5:
            img, seg = self.random_rot_flip(img, seg)
        if random.random() > 0.5:
            img, seg = self.random_rotate(img, seg)

        img = cv2.resize(img, (960, 544), interpolation=cv2.INTER_LINEAR)
        seg = cv2.resize(seg, (960, 544), interpolation=cv2.INTER_NEAREST)
        seg = np.expand_dims(seg, axis=2)

    else:
        img = img / 255.
        seg = seg / 255.
        img = cv2.resize(img, (1280, 1024), interpolation=cv2.INTER_LINEAR)
        seg = cv2.resize(seg, (1280, 1024), interpolation=cv2.INTER_NEAREST)
        seg = np.expand_dims(seg, axis=2)

    seg = torch.tensor(seg.copy())
    img = torch.tensor(img.copy())
    img = img.permute(2, 0, 1)
    seg = seg.permute(2, 0, 1)

    return img, seg
@wurenkai
Copy link
Owner

wurenkai commented Jun 6, 2024

Hi, different data loading methods will definitely have no effect on the model training speed. You need to check the size of the input model training image and your data preprocessing code. In particular, an oversize of the image scale will increase its training duration for any model.

@Tobyzai
Copy link
Author

Tobyzai commented Jun 6, 2024

I tried to change the image size to (256, 256), but the same phenomenon occurs with the training process: very fast all at once, then a long pause, at which point checking the GPU doesn't work. After the pause, the GPU works only for a short while, when the training progress is updated.
But using the same data loading method on other models(CNN or transformer based) does not show this phenomenon. What are the possible reasons for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants