Skip to content

Commit 3bd83c7

Browse files
committed
datasets update
1 parent 5cf3aa5 commit 3bd83c7

File tree

9 files changed

+78
-62
lines changed

9 files changed

+78
-62
lines changed

GOOD/data/good_datasets/good_arxiv.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
import torch
1515
from munch import Munch
16+
from ogb.nodeproppred import PygNodePropPredDataset
1617
from torch_geometric.data import Data
1718
from torch_geometric.data import InMemoryDataset, extract_zip
1819
from torch_geometric.utils import degree, to_undirected
@@ -97,7 +98,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', transform=No
9798
self.domain = domain
9899
self.metric = 'Accuracy'
99100
self.task = 'Multi-label classification'
100-
self.url = 'https://drive.google.com/file/d/1-Wq7PoHTAiLsos20bLlq_xNvrV5AHSWu/view?usp=sharing'
101+
self.url = 'https://drive.google.com/file/d/1r1OTQJ5YxQAAYJiYfyDmCknmpVmiUksi/view?usp=sharing'
101102

102103
self.generate = generate
103104

@@ -243,11 +244,6 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
243244

244245
train_list, ood_val_list, ood_test_list = train_val_test_list
245246

246-
num_id_test = int(num_data * id_test_ratio)
247-
random.shuffle(train_list)
248-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
249-
-2 * num_id_test: - num_id_test], \
250-
train_list[- num_id_test:]
251247
# Compose domains to environments
252248
num_env_train = 10
253249
num_per_env = len(train_list) // num_env_train
@@ -260,6 +256,12 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
260256
cur_domain_id = data.domain_id
261257
data.env_id = cur_env_id
262258

259+
num_id_test = int(num_data * id_test_ratio)
260+
random.shuffle(train_list)
261+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
262+
-2 * num_id_test: - num_id_test], \
263+
train_list[- num_id_test:]
264+
263265
return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph)
264266

265267
def get_concept_shift_graph(self, sorted_domain_split_data_list, graph):
@@ -408,7 +410,7 @@ def get_domain_sorted_indices(self, graph, domain='degree'):
408410
return sorted_data_list, sorted_domain_split_data_list
409411

410412
def process(self):
411-
from ogb.nodeproppred import PygNodePropPredDataset
413+
412414
dataset = PygNodePropPredDataset(root=self.root, name='ogbn-arxiv')
413415
graph = dataset[0]
414416
graph.edge_index = to_undirected(graph.edge_index, graph.num_nodes)

GOOD/data/good_datasets/good_cora.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', transform=No
9494
self.domain = domain
9595
self.metric = 'Accuracy'
9696
self.task = 'Multi-label classification'
97-
self.url = 'https://drive.google.com/file/d/1VD1nGDvLBn2xpYAp12irBLkTRRZ282Qm/view?usp=sharing'
97+
self.url = 'https://drive.google.com/file/d/1OyMOwT4bn_4fLdpl5B3ie18OmGsUNQxS/view?usp=sharing'
9898

9999
self.generate = generate
100100

@@ -230,11 +230,6 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
230230

231231
train_list, ood_val_list, ood_test_list = train_val_test_list
232232

233-
num_id_test = int(num_data * id_test_ratio)
234-
random.shuffle(train_list)
235-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
236-
-2 * num_id_test: - num_id_test], \
237-
train_list[- num_id_test:]
238233
# Compose domains to environments
239234
num_env_train = 10
240235
num_per_env = len(train_list) // num_env_train
@@ -247,6 +242,12 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
247242
cur_domain_id = data.domain_id
248243
data.env_id = cur_env_id
249244

245+
num_id_test = int(num_data * id_test_ratio)
246+
random.shuffle(train_list)
247+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
248+
-2 * num_id_test: - num_id_test], \
249+
train_list[- num_id_test:]
250+
250251
return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph)
251252

252253
def get_concept_shift_graph(self, sorted_domain_split_data_list, graph):

GOOD/data/good_datasets/good_hiv.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class DomainGetter():
2323
r"""
2424
A class containing methods for data domain extraction.
2525
"""
26+
2627
def __init__(self):
2728
pass
2829

@@ -81,7 +82,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str
8182
self.domain = domain
8283
self.metric = 'ROC-AUC'
8384
self.task = 'Binary classification'
84-
self.url = 'https://drive.google.com/file/d/1GNc0HUee5YQH4Vtlk8ZbDjyJBYTEyabo/view?usp=sharing'
85+
self.url = 'https://drive.google.com/file/d/1CoOqYCuLObnG5M0D8a2P2NyL61WjbCzo/view?usp=sharing'
8586

8687
self.generate = generate
8788

@@ -163,11 +164,6 @@ def get_covariate_shift_list(self, sorted_data_list):
163164

164165
train_list, ood_val_list, ood_test_list = train_val_test_list
165166

166-
num_id_test = int(num_data * test_ratio)
167-
random.shuffle(train_list)
168-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
169-
-2 * num_id_test: - num_id_test], \
170-
train_list[- num_id_test:]
171167
# Compose domains to environments
172168
num_env_train = 10
173169
num_per_env = len(train_list) // num_env_train
@@ -179,6 +175,13 @@ def get_covariate_shift_list(self, sorted_data_list):
179175
cur_env_id += 1
180176
cur_domain_id = data.domain_id
181177
data.env_id = cur_env_id
178+
179+
num_id_test = int(num_data * test_ratio)
180+
random.shuffle(train_list)
181+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
182+
-2 * num_id_test: - num_id_test], \
183+
train_list[- num_id_test:]
184+
182185
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
183186

184187
return all_env_list
@@ -379,9 +382,11 @@ def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool
379382
train_dataset = GOODHIV(root=dataset_root,
380383
domain=domain, shift=shift, subset='train', generate=generate)
381384
id_val_dataset = GOODHIV(root=dataset_root,
382-
domain=domain, shift=shift, subset='id_val', generate=generate) if shift != 'no_shift' else None
385+
domain=domain, shift=shift, subset='id_val',
386+
generate=generate) if shift != 'no_shift' else None
383387
id_test_dataset = GOODHIV(root=dataset_root,
384-
domain=domain, shift=shift, subset='id_test', generate=generate) if shift != 'no_shift' else None
388+
domain=domain, shift=shift, subset='id_test',
389+
generate=generate) if shift != 'no_shift' else None
385390
val_dataset = GOODHIV(root=dataset_root,
386391
domain=domain, shift=shift, subset='val', generate=generate)
387392
test_dataset = GOODHIV(root=dataset_root,

GOOD/data/good_datasets/good_pcba.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str
8181
self.domain = domain
8282
self.metric = 'Average Precision'
8383
self.task = 'Binary classification'
84-
self.url = 'https://drive.google.com/file/d/1BGhI153CcJ1wuNAp7nQOhR9jGkmF-jwb/view?usp=sharing'
84+
self.url = 'https://drive.google.com/file/d/1WGieOjtgNXtGoO6o1EGhKrZj0zWU7AJl/view?usp=sharing'
8585

8686
self.generate = generate
8787

@@ -161,11 +161,6 @@ def get_covariate_shift_list(self, sorted_data_list):
161161

162162
train_list, ood_val_list, ood_test_list = train_val_test_list
163163

164-
num_id_test = int(num_data * test_ratio)
165-
random.shuffle(train_list)
166-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
167-
-2 * num_id_test: - num_id_test], \
168-
train_list[- num_id_test:]
169164
# Compose domains to environments
170165
num_env_train = 10
171166
num_per_env = len(train_list) // num_env_train
@@ -177,6 +172,13 @@ def get_covariate_shift_list(self, sorted_data_list):
177172
cur_env_id += 1
178173
cur_domain_id = data.domain_id
179174
data.env_id = cur_env_id
175+
176+
num_id_test = int(num_data * test_ratio)
177+
random.shuffle(train_list)
178+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
179+
-2 * num_id_test: - num_id_test], \
180+
train_list[- num_id_test:]
181+
180182
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
181183

182184
return all_env_list

GOOD/data/good_datasets/good_sst2.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,17 @@
1010
import gdown
1111
import numpy as np
1212
import torch
13+
from dig.xgraph.dataset import SentiGraphDataset
1314
from munch import Munch
14-
from rdkit import Chem
15-
from rdkit.Chem.Scaffolds import MurckoScaffold
1615
from torch_geometric.data import InMemoryDataset, extract_zip, Data
17-
from torch_geometric.datasets import MoleculeNet
1816
from tqdm import tqdm
19-
from dig.xgraph.dataset import SentiGraphDataset
2017

2118

2219
class DomainGetter():
2320
r"""
2421
A class containing methods for data domain extraction.
2522
"""
23+
2624
def __init__(self):
2725
pass
2826

@@ -60,7 +58,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str
6058
self.domain = domain
6159
self.metric = 'Accuracy'
6260
self.task = 'Binary classification'
63-
self.url = 'https://drive.google.com/file/d/1e2GmmeN-mN6X5KL6t8CosBujS1kfjeNS/view?usp=sharing'
61+
self.url = 'https://drive.google.com/file/d/1lGNMbQebKIbS-NnbPxmY4_uDGI7EWXBP/view?usp=sharing'
6462

6563
self.generate = generate
6664

@@ -140,12 +138,6 @@ def get_covariate_shift_list(self, sorted_data_list):
140138

141139
train_list, ood_val_list, ood_test_list = train_val_test_list
142140

143-
id_test_ratio = 0.15
144-
num_id_test = int(len(train_list) * id_test_ratio)
145-
random.shuffle(train_list)
146-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
147-
-2 * num_id_test: - num_id_test], \
148-
train_list[- num_id_test:]
149141
# Compose domains to environments
150142
num_env_train = 10
151143
num_per_env = len(train_list) // num_env_train
@@ -157,6 +149,14 @@ def get_covariate_shift_list(self, sorted_data_list):
157149
cur_env_id += 1
158150
cur_domain_id = data.domain_id
159151
data.env_id = cur_env_id
152+
153+
id_test_ratio = 0.15
154+
num_id_test = int(len(train_list) * id_test_ratio)
155+
random.shuffle(train_list)
156+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
157+
-2 * num_id_test: - num_id_test], \
158+
train_list[- num_id_test:]
159+
160160
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
161161

162162
return all_env_list
@@ -354,15 +354,17 @@ def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool
354354
meta_info.model_level = 'graph'
355355

356356
train_dataset = GOODSST2(root=dataset_root,
357-
domain=domain, shift=shift, subset='train', generate=generate)
357+
domain=domain, shift=shift, subset='train', generate=generate)
358358
id_val_dataset = GOODSST2(root=dataset_root,
359-
domain=domain, shift=shift, subset='id_val', generate=generate) if shift != 'no_shift' else None
359+
domain=domain, shift=shift, subset='id_val',
360+
generate=generate) if shift != 'no_shift' else None
360361
id_test_dataset = GOODSST2(root=dataset_root,
361-
domain=domain, shift=shift, subset='id_test', generate=generate) if shift != 'no_shift' else None
362+
domain=domain, shift=shift, subset='id_test',
363+
generate=generate) if shift != 'no_shift' else None
362364
val_dataset = GOODSST2(root=dataset_root,
363-
domain=domain, shift=shift, subset='val', generate=generate)
365+
domain=domain, shift=shift, subset='val', generate=generate)
364366
test_dataset = GOODSST2(root=dataset_root,
365-
domain=domain, shift=shift, subset='test', generate=generate)
367+
domain=domain, shift=shift, subset='test', generate=generate)
366368

367369
meta_info.dim_node = train_dataset.num_node_features
368370
meta_info.dim_edge = train_dataset.num_edge_features

GOOD/data/good_datasets/good_twitch.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', transform=No
8282
assert domain in ['language']
8383
self.metric = 'ROC-AUC'
8484
self.task = 'Binary classification'
85-
self.url = 'https://drive.google.com/file/d/1PuO-pWsVFfCwiXx7TzKP12-QXAx6vz4O/view?usp=sharing'
85+
self.url = 'https://drive.google.com/file/d/1wii9CWmtTAUofNTgg-GkpRz_iECcbQzK/view?usp=sharing'
8686

8787
self.generate = generate
8888

@@ -212,11 +212,6 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
212212

213213
train_list, ood_val_list, ood_test_list = train_val_test_list
214214

215-
num_id_test = int(num_data * id_test_ratio)
216-
random.shuffle(train_list)
217-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
218-
-2 * num_id_test: - num_id_test], \
219-
train_list[- num_id_test:]
220215
# Compose domains to environments
221216
num_env_train = 10
222217
num_per_env = len(train_list) // num_env_train
@@ -229,6 +224,12 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
229224
cur_domain_id = data.domain_id
230225
data.env_id = cur_env_id
231226

227+
num_id_test = int(num_data * id_test_ratio)
228+
random.shuffle(train_list)
229+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], train_list[
230+
-2 * num_id_test: - num_id_test], \
231+
train_list[- num_id_test:]
232+
232233
return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph)
233234

234235
def get_concept_shift_graph(self, sorted_domain_split_data_list, graph):

GOOD/data/good_datasets/good_webkb.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', transform=No
8181
assert domain in ['university']
8282
self.metric = 'Accuracy'
8383
self.task = 'Multi-label classification'
84-
self.url = 'https://drive.google.com/file/d/1tatdDrcwZAS2iUZujB4AEsTvPF-3LYoX/view?usp=sharing'
84+
self.url = 'https://drive.google.com/file/d/1DOdUOzAMBtcHXTphrWrKhNWPxzMDNvnb/view?usp=sharing'
8585

8686
self.generate = generate
8787

@@ -215,11 +215,6 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
215215
ood_val_list = ood_test_list[: len(ood_test_list) // 2]
216216
ood_test_list = ood_test_list[len(ood_test_list) // 2:]
217217

218-
num_id_test = int(num_data * id_test_ratio)
219-
random.shuffle(train_list)
220-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
221-
train_list[-2 * num_id_test: - num_id_test], \
222-
train_list[- num_id_test:]
223218
# Compose domains to environments
224219
num_env_train = 2
225220
num_per_env = len(train_list) // num_env_train
@@ -232,6 +227,12 @@ def get_covariate_shift_graph(self, sorted_data_list, graph):
232227
cur_domain_id = data.domain_id
233228
data.env_id = cur_env_id
234229

230+
num_id_test = int(num_data * id_test_ratio)
231+
random.shuffle(train_list)
232+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
233+
train_list[-2 * num_id_test: - num_id_test], \
234+
train_list[- num_id_test:]
235+
235236
return self.assign_masks(train_list, ood_val_list, ood_test_list, id_val_list, id_test_list, graph)
236237

237238
def get_concept_shift_graph(self, sorted_domain_split_data_list, graph):

GOOD/data/good_datasets/good_zinc.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str
8383
self.domain = domain
8484
self.metric = 'MAE'
8585
self.task = 'Regression'
86-
self.url = 'https://drive.google.com/file/d/1IDxJdFJXPngH1vK06jZqzTvmy865BHn2/view?usp=sharing'
86+
self.url = 'https://drive.google.com/file/d/1CHR0I1JcNoBqrqFicAZVKU3213hbsEPZ/view?usp=sharing'
8787

8888
self.generate = generate
8989

@@ -164,11 +164,6 @@ def get_covariate_shift_list(self, sorted_data_list):
164164

165165
train_list, ood_val_list, ood_test_list = train_val_test_list
166166

167-
num_id_test = int(num_data * test_ratio)
168-
random.shuffle(train_list)
169-
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
170-
train_list[-2 * num_id_test: - num_id_test], \
171-
train_list[- num_id_test:]
172167
# Compose domains to environments
173168
num_env_train = 10
174169
num_per_env = len(train_list) // num_env_train
@@ -180,6 +175,13 @@ def get_covariate_shift_list(self, sorted_data_list):
180175
cur_env_id += 1
181176
cur_domain_id = data.domain_id
182177
data.env_id = cur_env_id
178+
179+
num_id_test = int(num_data * test_ratio)
180+
random.shuffle(train_list)
181+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
182+
train_list[-2 * num_id_test: - num_id_test], \
183+
train_list[- num_id_test:]
184+
183185
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
184186

185187
return all_env_list

test/test_reproduce_sample/test_regenerate_datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def regenerate_dataset(config_path):
100100
return regenerator.config.dataset.dataset_name
101101

102102
for dataset_path in dataset_paths:
103-
if 'GOODSST2' in dataset_path:
103+
if 'GOODSST2' in dataset_path or 'GOODArxiv' in dataset_path:
104104
return
105105
dataset_name = regenerate_dataset(dataset_path)
106106
# release regenerate datasets space

0 commit comments

Comments
 (0)