Skip to content

Commit 88f109f

Browse files
CfromBUUbuntuUbuntu
authored
[distGB]fix the problem when graph has few nodes or edges in distributed partition (#7824)
Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent d92c98d commit 88f109f

File tree

4 files changed

+272
-5
lines changed

4 files changed

+272
-5
lines changed

tests/tools/test_dist_part.py

+192-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44

55
import dgl
6+
import dgl.backend as F
67

78
import numpy as np
89
import pyarrow.parquet as pq
@@ -19,7 +20,8 @@
1920

2021
from distpartitioning import array_readwriter
2122
from distpartitioning.utils import generate_read_list
22-
from pytest_utils import create_chunked_dataset
23+
from pytest_utils import chunk_graph, create_chunked_dataset
24+
from scipy import sparse as spsp
2325

2426
from tools.verification_utils import (
2527
verify_graph_feats,
@@ -202,6 +204,103 @@ def test_chunk_graph_arbitrary_chunks(
202204
)
203205

204206

207+
def create_mini_chunked_dataset(
208+
root_dir,
209+
num_chunks,
210+
data_fmt,
211+
edges_fmt,
212+
vector_rows,
213+
few_entity="node",
214+
**kwargs,
215+
):
216+
num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
217+
etypes = [
218+
("n1", "r1", "n2"),
219+
("n2", "r1", "n1"),
220+
("n1", "r2", "n3"),
221+
("n2", "r3", "n3"),
222+
]
223+
node_items = ["n1", "n2", "n3"]
224+
edges_coo = {}
225+
for etype in etypes:
226+
src_ntype, _, dst_ntype = etype
227+
arr = spsp.random(
228+
num_nodes[src_ntype],
229+
num_nodes[dst_ntype],
230+
density=0.001,
231+
format="coo",
232+
random_state=100,
233+
)
234+
edges_coo[etype] = (arr.row, arr.col)
235+
edge_items = []
236+
if few_entity == "edge":
237+
edges_coo[("n1", "a0", "n2")] = (
238+
torch.tensor([0, 1]),
239+
torch.tensor([1, 0]),
240+
)
241+
edges_coo[("n1", "a1", "n3")] = (
242+
torch.tensor([0, 1]),
243+
torch.tensor([1, 0]),
244+
)
245+
edge_items.append(("n1", "a0", "n2"))
246+
edge_items.append(("n1", "a1", "n3"))
247+
elif few_entity == "node":
248+
edges_coo[("n1", "r_few", "n_few")] = (
249+
torch.tensor([0, 1]),
250+
torch.tensor([1, 0]),
251+
)
252+
edges_coo[("a0", "a01", "n_1")] = (
253+
torch.tensor([0, 1]),
254+
torch.tensor([1, 0]),
255+
)
256+
edge_items.append(("n1", "r_few", "n_few"))
257+
edge_items.append(("a0", "a01", "n_1"))
258+
node_items.append("n_few")
259+
node_items.append("n_1")
260+
num_nodes["n_few"] = 2
261+
num_nodes["n_1"] = 2
262+
g = dgl.heterograph(edges_coo)
263+
264+
node_data = {}
265+
edge_data = {}
266+
# save feature
267+
input_dir = os.path.join(root_dir, "data_test")
268+
269+
for ntype in node_items:
270+
os.makedirs(os.path.join(input_dir, ntype))
271+
feat = np.random.randn(num_nodes[ntype], 3)
272+
feat_path = os.path.join(input_dir, f"{ntype}/feat.npy")
273+
with open(feat_path, "wb") as f:
274+
np.save(f, feat)
275+
g.nodes[ntype].data["feat"] = torch.from_numpy(feat)
276+
node_data[ntype] = {"feat": feat_path}
277+
278+
for etype in set(edge_items):
279+
os.makedirs(os.path.join(input_dir, etype[1]))
280+
num_edge = len(edges_coo[etype][0])
281+
feat = np.random.randn(num_edge, 4)
282+
feat_path = os.path.join(input_dir, f"{etype[1]}/feat.npy")
283+
with open(feat_path, "wb") as f:
284+
np.save(f, feat)
285+
g.edges[etype].data["feat"] = torch.from_numpy(feat)
286+
edge_data[etype] = {"feat": feat_path}
287+
288+
output_dir = os.path.join(root_dir, "chunked-data")
289+
chunk_graph(
290+
g,
291+
"mag240m",
292+
node_data,
293+
edge_data,
294+
num_chunks=num_chunks,
295+
output_path=output_dir,
296+
data_fmt=data_fmt,
297+
edges_fmt=edges_fmt,
298+
vector_rows=vector_rows,
299+
**kwargs,
300+
)
301+
return g
302+
303+
205304
def _test_pipeline(
206305
num_chunks,
207306
num_parts,
@@ -373,6 +472,98 @@ def test_pipeline_feature_format(data_fmt):
373472
_test_pipeline(4, 4, 4, data_fmt=data_fmt)
374473

375474

475+
@pytest.mark.parametrize(
476+
"num_chunks, num_parts, world_size",
477+
[[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]],
478+
)
479+
@pytest.mark.parametrize("few_entity", ["node", "edge"])
480+
def test_partition_hetero_few_entity(
481+
num_chunks,
482+
num_parts,
483+
world_size,
484+
few_entity,
485+
graph_formats=None,
486+
data_fmt="numpy",
487+
edges_fmt="csv",
488+
vector_rows=False,
489+
num_chunks_nodes=None,
490+
num_chunks_edges=None,
491+
num_chunks_node_data=None,
492+
num_chunks_edge_data=None,
493+
):
494+
with tempfile.TemporaryDirectory() as root_dir:
495+
g = create_mini_chunked_dataset(
496+
root_dir,
497+
num_chunks,
498+
few_entity=few_entity,
499+
data_fmt=data_fmt,
500+
edges_fmt=edges_fmt,
501+
vector_rows=vector_rows,
502+
num_chunks_nodes=num_chunks_nodes,
503+
num_chunks_edges=num_chunks_edges,
504+
num_chunks_node_data=num_chunks_node_data,
505+
num_chunks_edge_data=num_chunks_edge_data,
506+
)
507+
508+
# Step1: graph partition
509+
in_dir = os.path.join(root_dir, "chunked-data")
510+
output_dir = os.path.join(root_dir, "parted_data")
511+
os.system(
512+
"python3 tools/partition_algo/random_partition.py "
513+
"--in_dir {} --out_dir {} --num_partitions {}".format(
514+
in_dir, output_dir, num_parts
515+
)
516+
)
517+
518+
# Step2: data dispatch
519+
partition_dir = os.path.join(root_dir, "parted_data")
520+
out_dir = os.path.join(root_dir, "partitioned")
521+
ip_config = os.path.join(root_dir, "ip_config.txt")
522+
with open(ip_config, "w") as f:
523+
for i in range(world_size):
524+
f.write(f"127.0.0.{i + 1}\n")
525+
526+
cmd = "python3 tools/dispatch_data.py"
527+
cmd += f" --in-dir {in_dir}"
528+
cmd += f" --partitions-dir {partition_dir}"
529+
cmd += f" --out-dir {out_dir}"
530+
cmd += f" --ip-config {ip_config}"
531+
cmd += " --ssh-port 22"
532+
cmd += " --process-group-timeout 60"
533+
cmd += " --save-orig-nids"
534+
cmd += " --save-orig-eids"
535+
cmd += f" --graph-formats {graph_formats}" if graph_formats else ""
536+
os.system(cmd)
537+
538+
# read original node/edge IDs
539+
def read_orig_ids(fname):
540+
orig_ids = {}
541+
for i in range(num_parts):
542+
ids_path = os.path.join(out_dir, f"part{i}", fname)
543+
part_ids = load_tensors(ids_path)
544+
for type, data in part_ids.items():
545+
if type not in orig_ids:
546+
orig_ids[type] = data
547+
else:
548+
orig_ids[type] = torch.cat((orig_ids[type], data))
549+
return orig_ids
550+
551+
orig_nids = read_orig_ids("orig_nids.dgl")
552+
orig_eids = read_orig_ids("orig_eids.dgl")
553+
554+
# load partitions and verify
555+
part_config = os.path.join(out_dir, "metadata.json")
556+
for i in range(num_parts):
557+
part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition(
558+
part_config, i
559+
)
560+
verify_partition_data_types(part_g)
561+
verify_partition_formats(part_g, graph_formats)
562+
verify_graph_feats(
563+
g, gpb, part_g, node_feats, edge_feats, orig_nids, orig_eids
564+
)
565+
566+
376567
def test_utils_generate_read_list():
377568
read_list = generate_read_list(10, 4)
378569
assert np.array_equal(read_list[0], np.array([0, 1, 2]))

tools/distpartitioning/convert_partition.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import dgl.graphbolt as gb
1010
import numpy as np
1111
import torch as th
12+
import torch.distributed as dist
1213
from dgl import EID, ETYPE, NID, NTYPE
1314

1415
from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID
@@ -355,6 +356,34 @@ def _process_partition_gb(
355356
return indptr, indices[sorted_idx], edge_ids[sorted_idx]
356357

357358

359+
def _update_node_map(node_map_val, end_ids_per_rank, id_ntypes, prev_last_id):
360+
"""this function is modified from the function '_update_node_edge_map' in dgl.distributed.partition"""
361+
# Update the node_map_val to be contiguous.
362+
rank = dist.get_rank()
363+
prev_end_id = (
364+
end_ids_per_rank[rank - 1].item() if rank > 0 else prev_last_id
365+
)
366+
ntype_ids = {ntype: ntype_id for ntype_id, ntype in enumerate(id_ntypes)}
367+
for ntype_id in list(ntype_ids.values()):
368+
ntype = id_ntypes[ntype_id]
369+
start_id = node_map_val[ntype][0][0]
370+
end_id = node_map_val[ntype][0][1]
371+
if not (start_id == -1 and end_id == -1):
372+
continue
373+
prev_ntype_id = (
374+
ntype_ids[ntype] - 1
375+
if ntype_ids[ntype] > 0
376+
else max(ntype_ids.values())
377+
)
378+
prev_ntype = id_ntypes[prev_ntype_id]
379+
if ntype_ids[ntype] == 0:
380+
node_map_val[ntype][0][0] = prev_end_id
381+
else:
382+
node_map_val[ntype][0][0] = node_map_val[prev_ntype][0][1]
383+
node_map_val[ntype][0][1] = node_map_val[ntype][0][0]
384+
return node_map_val[ntype][0][-1]
385+
386+
358387
def create_graph_object(
359388
tot_node_count,
360389
tot_edge_count,
@@ -368,6 +397,7 @@ def create_graph_object(
368397
edgeid_offset,
369398
node_typecounts,
370399
edge_typecounts,
400+
last_ids={},
371401
return_orig_nids=False,
372402
return_orig_eids=False,
373403
use_graphbolt=False,
@@ -512,12 +542,30 @@ def create_graph_object(
512542
shuffle_global_nid_range = (shuffle_global_nids[0], shuffle_global_nids[-1])
513543

514544
# Determine the node ID ranges of different node types.
545+
prev_last_id = last_ids.get(part_id - 1, 0)
515546
for ntype_name in global_nid_ranges:
516547
ntype_id = ntypes_map[ntype_name]
517548
type_nids = shuffle_global_nids[ntype_ids == ntype_id]
518-
node_map_val[ntype_name].append(
519-
[int(type_nids[0]), int(type_nids[-1]) + 1]
520-
)
549+
if len(type_nids) == 0:
550+
node_map_val[ntype_name].append([-1, -1])
551+
else:
552+
node_map_val[ntype_name].append(
553+
[int(type_nids[0]), int(type_nids[-1]) + 1]
554+
)
555+
last_id = th.tensor(
556+
[max(prev_last_id, int(type_nids[-1]) + 1)], dtype=th.int64
557+
)
558+
id_ntypes = list(global_nid_ranges.keys())
559+
560+
gather_last_ids = [
561+
th.zeros(1, dtype=th.int64) for _ in range(dist.get_world_size())
562+
]
563+
564+
dist.all_gather(gather_last_ids, last_id)
565+
prev_last_id = _update_node_map(
566+
node_map_val, gather_last_ids, id_ntypes, prev_last_id
567+
)
568+
last_ids[part_id] = prev_last_id
521569

522570
# process edges
523571
memory_snapshot("CreateDGLObj_AssignEdgeData: ", part_id)

tools/distpartitioning/data_shuffle.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,10 @@ def exchange_feature(
489489
feat_dims_dtype.append(DATA_TYPE_ID[torch.float32])
490490
feature_dimension = 0
491491

492+
feature_dimension_tensor = torch.tensor([feature_dimension])
493+
dist.all_reduce(feature_dimension_tensor, op=dist.ReduceOp.MAX)
494+
feature_dimension = feature_dimension_tensor.item()
495+
492496
logging.debug(f"Sending the feature shape information - {feat_dims_dtype}")
493497
all_dims_dtype = allgather_sizes(
494498
feat_dims_dtype, world_size, num_parts, return_sizes=True
@@ -553,7 +557,11 @@ def exchange_feature(
553557
else:
554558
cur_features[local_feat_key] = output_feat_list
555559
cur_global_ids[local_feat_key] = output_id_list
556-
560+
else:
561+
cur_features[local_feat_key] = torch.empty(
562+
(0, feature_dimension), dtype=torch.float32
563+
)
564+
cur_global_ids[local_feat_key] = torch.empty((0,), dtype=torch.int64)
557565
return cur_features, cur_global_ids
558566

559567

@@ -1301,6 +1309,7 @@ def prepare_local_data(src_data, local_part_id):
13011309
if params.graph_formats:
13021310
graph_formats = params.graph_formats.split(",")
13031311

1312+
prev_last_ids = {}
13041313
for local_part_id in range(params.num_parts // world_size):
13051314
# Synchronize for each local partition of the graph object.
13061315
dist.barrier()
@@ -1340,6 +1349,7 @@ def prepare_local_data(src_data, local_part_id):
13401349
schema_map[constants.STR_NUM_NODES_PER_TYPE],
13411350
),
13421351
edge_typecounts,
1352+
prev_last_ids,
13431353
return_orig_nids=params.save_orig_nids,
13441354
return_orig_eids=params.save_orig_eids,
13451355
use_graphbolt=params.use_graphbolt,
@@ -1390,6 +1400,19 @@ def prepare_local_data(src_data, local_part_id):
13901400
] = json_metadata
13911401
memory_snapshot("MetadataCreateComplete: ", rank)
13921402

1403+
last_id_tensor = torch.tensor(
1404+
[prev_last_ids[rank + (local_part_id * world_size)]],
1405+
dtype=torch.int64,
1406+
)
1407+
gather_list = [
1408+
torch.zeros(1, dtype=torch.int64) for _ in range(world_size)
1409+
]
1410+
dist.all_gather(gather_list, last_id_tensor)
1411+
for rank_id, last_id in enumerate(gather_list):
1412+
prev_last_ids[
1413+
rank_id + (local_part_id * world_size)
1414+
] = last_id.item()
1415+
13931416
if rank == 0:
13941417
# get meta-data from all partitions and merge them on rank-0
13951418
metadata_list = gather_metadata_json(output_meta_json, rank, world_size)

tools/distpartitioning/dataset_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,11 @@ def get_dataset(
547547
autogenerate_column_names=True,
548548
)
549549
parse_options = pyarrow.csv.ParseOptions(delimiter=" ")
550+
551+
if os.path.getsize(edge_file) == 0:
552+
# if getsize() == 0, the file is empty, indicating that the partition doesn't have this attribute.
553+
# The src_ids and dst_ids should remain empty.
554+
continue
550555
with pyarrow.csv.open_csv(
551556
edge_file,
552557
read_options=read_options,

0 commit comments

Comments
 (0)