Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing scaling issue for GSPMM kernel operation on X86 machine #7845

Open
choudhary-devang opened this issue Dec 10, 2024 · 4 comments
Open

Comments

@choudhary-devang
Copy link

choudhary-devang commented Dec 10, 2024

ICELAKE

Input Shape | 1 Core | 2 Cores | 4 Cores | 8 Cores | 16 Cores | 32 Cores
[11658, 128] | 150.9 ms |113.41ms | 89.54ms | 81.23ms | 82.71ms | 166.92ms
[11410, 128] | 150.56ms | 109.93ms | 87.33ms | 80.08ms | 82.58ms | 157.84ms
[11070, 128] | 145.78ms | 109.69ms | 88.50ms | 81.51ms | 82.56ms | 165.82ms
[11172, 128] | 144.81ms | 109.81ms | 87.17ms | 79.12ms | 81.70ms | 162.82ms
[10938, 128] | 143.74ms | 109.54ms | 86.50ms | 79.06ms | 81.25ms | 160.59ms
[11096, 128] | 141.87ms | 109.44ms | 86.16ms | 78.73ms | 82.35ms | 164.38ms
[10740, 128] | 136.91ms | 105.59ms | 85.67ms | 78.73ms | 80.23ms | 165.82ms
[10464, 128] | 139.91ms | 105.08ms | 84.59ms | 77.68ms | 79.87ms | 159.14ms
[10168, 128] | 139.27ms | 102.07ms | 83.99ms | 76.57ms | 80.32ms | 166.89ms
[10066, 128] | 131.42ms | 101.85ms | 83.10ms | 76.48ms | 78.29ms | 162.82ms
[10092, 128] | 134.06ms | 102.55ms | 83.88ms | 75.25ms | 78.29ms | 164.52ms
[10482, 128] | 133.30ms | 104.64ms | 84.94ms | 77.78ms | 81.43ms | 165.23ms
[10206, 128] | 133.43ms | 103.74ms | 83.10ms | 76.26ms | 80.88ms | 159.14ms
[10398, 128] | 132.03ms | 102.55ms | 83.88ms | 76.85ms | 79.89ms | 157.77ms
[10344, 128] | 135.24ms | 103.36ms | 83.99ms | 76.64ms | 80.07ms | 160.32ms
[10266, 128] | 134.60ms | 102.71ms | 84.27ms | 76.64ms | 79.05ms | 157.83ms
[10590, 128] | 132.03ms | 104.56ms | 84.73ms | 77.95ms | 78.88ms | 165.23ms
[10144, 128] | 136.40ms | 100.83ms | 83.52ms | 76.44ms | 78.88ms | 160.32ms
[10194, 128] | 133.39ms | 102.65ms | 82.80ms | 76.33ms | 92.43ms | 164.38ms

this is the bench marking results which I took from MESHGNN model
as it can observed from the results that the spmm operation is not scaling properly after 8 cores these are the inference results

script which i used:-

image

setup details :-

dgl = 2.5 (build form source)
torch = 2.5 (pip installed)

@choudhary-devang
Copy link
Author

@itaraban, @jermainewang, @BarclayII can you please look into it!

Copy link

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you

@itaraban
Copy link
Collaborator

@choudhary-devang, could you please share more information about CPU model?

@choudhary-devang
Copy link
Author

choudhary-devang commented Jan 13, 2025

@itaraban thanks for responding.
i tried with different models for experiment you check with dagnn model example given in the location dgl/examples/pytorch/dagnn/main.py
i just added profiling in this script.
results for dagnn model on different cores:
image
experimented on 64 core machine (c6i.16xlarge) with exporting OMP_NUM_THREADS.
as we can observe from the results after 8 cores it is not scaling similar behaviors have been observed in serval other models.

updated part for dagnn inference in location of dgl/examples/pytorch/dagnn/main.py:-

image

I just have changed the main function in the script to infer the model and to observe the effect properly I just iterated 200 times.

if you require anything else or some more details, please let me know.
updated script:-

import argparse
import dgl.function as fn
import dgl
import numpy as np
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from torch import nn
from torch.nn import functional as F, Parameter
from tqdm import trange
from utils import evaluate, generate_random_seeds, set_random_state

class DAGNNConv(nn.Module):
    def __init__(self, in_dim, k):
        super(DAGNNConv, self).__init__()
        self.s = Parameter(torch.FloatTensor(in_dim, 1))
        self.k = k
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("sigmoid")
        nn.init.xavier_uniform_(self.s, gain=gain)

    def forward(self, graph, feats):
        with graph.local_scope():
            results = [feats]
            degs = graph.in_degrees().float()
            norm = torch.pow(degs, -0.5)
            norm = norm.to(feats.device).unsqueeze(1)
            for _ in range(self.k):
                feats = feats * norm
                graph.ndata["h"] = feats
                graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                feats = graph.ndata["h"]
                feats = feats * norm
                results.append(feats)
            H = torch.stack(results, dim=1)
            S = F.sigmoid(torch.matmul(H, self.s))
            S = S.permute(0, 2, 1)
            H = torch.matmul(S, H).squeeze()
            return H


class MLPLayer(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, activation=None, dropout=0):
        super(MLPLayer, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim, bias=bias)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
        gain = 1.0
        if self.activation is F.relu:
            gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.linear.weight, gain=gain)
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, feats):
        feats = self.dropout(feats)
        feats = self.linear(feats)
        if self.activation:
            feats = self.activation(feats)

        return feats


class DAGNN(nn.Module):
    def __init__(
        self,
        k,
        in_dim,
        hid_dim,
        out_dim,
        bias=True,
        activation=F.relu,
        dropout=0,
    ):
        super(DAGNN, self).__init__()
        self.mlp = nn.ModuleList()
        self.mlp.append(
            MLPLayer(
                in_dim=in_dim,
                out_dim=hid_dim,
                bias=bias,
                activation=activation,
                dropout=dropout,
            )
        )
        self.mlp.append(
            MLPLayer(
                in_dim=hid_dim,
                out_dim=out_dim,
                bias=bias,
                activation=None,
                dropout=dropout,
            )
        )
        self.dagnn = DAGNNConv(in_dim=out_dim, k=k)

    def forward(self, graph, feats):
        for layer in self.mlp:
            feats = layer(feats)
        feats = self.dagnn(graph, feats)
        return feats


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
    if args.dataset == "Cora":
        dataset = CoraGraphDataset()
    elif args.dataset == "Citeseer":
        dataset = CiteseerGraphDataset()
    elif args.dataset == "Pubmed":
        dataset = PubmedGraphDataset()
    else:
        raise ValueError("Dataset {} is invalid.".format(args.dataset))
    # dgl.use_libxsmm(False)
    graph = dataset[0]
    graph = graph.add_self_loop()

    # check cuda
    if args.gpu >= 0 and torch.cuda.is_available():
        device = "cuda:{}".format(args.gpu)
    else:
        device = "cpu"

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
    labels = graph.ndata.pop("label").to(device).long()

    # Extract node features
    feats = graph.ndata.pop("feat").to(device)
    n_features = feats.shape[-1]

    # retrieve masks for train/validation/test
    train_mask = graph.ndata.pop("train_mask")
    val_mask = graph.ndata.pop("val_mask")
    test_mask = graph.ndata.pop("test_mask")

    train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze().to(device)
    val_idx = torch.nonzero(val_mask, as_tuple=False).squeeze().to(device)
    test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze().to(device)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
    model = DAGNN(
        k=args.k,
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        dropout=args.dropout,
    )
    model = model.to(device)
    # Step 3: Infer model =================================================================== #
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU],
        record_shapes=True,
        with_stack=True
    ) as prof:

        for i in range(200):
            output = model(graph, feats)
    print(prof.key_averages(group_by_input_shape=True).table(row_limit=-1)) 
    return 


if __name__ == "__main__":
    """
    DAGNN Model Hyperparameters
    """
    parser = argparse.ArgumentParser(description="DAGNN")
    # data source params
    parser.add_argument(
        "--dataset",
        type=str,
        default="Cora",
        choices=["Cora", "Citeseer", "Pubmed"],
        help="Name of dataset.",
    )
    # cuda params
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
    # training params
    parser.add_argument("--runs", type=int, default=1, help="Training runs.")
    parser.add_argument(
        "--epochs", type=int, default=1500, help="Training epochs."
    )
    parser.add_argument(
        "--early-stopping",
        type=int,
        default=100,
        help="Patient epochs to wait before early stopping.",
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument("--lamb", type=float, default=0.005, help="L2 reg.")
    # model params
    parser.add_argument(
        "--k", type=int, default=12, help="Number of propagation layers."
    )
    parser.add_argument(
        "--hid-dim", type=int, default=64, help="Hidden layer dimensionalities."
    )
    parser.add_argument("--dropout", type=float, default=0.8, help="dropout")
    args = parser.parse_args()
    print(args)

    random_seeds = generate_random_seeds(seed=1222, nums=args.runs)

    for run in range(args.runs):
        set_random_state(random_seeds[run])
        main(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants