10
10
import gdown
11
11
import numpy as np
12
12
import torch
13
+ from dig .xgraph .dataset import SentiGraphDataset
13
14
from munch import Munch
14
- from rdkit import Chem
15
- from rdkit .Chem .Scaffolds import MurckoScaffold
16
15
from torch_geometric .data import InMemoryDataset , extract_zip , Data
17
- from torch_geometric .datasets import MoleculeNet
18
16
from tqdm import tqdm
19
- from dig .xgraph .dataset import SentiGraphDataset
20
17
21
18
22
19
class DomainGetter ():
23
20
r"""
24
21
A class containing methods for data domain extraction.
25
22
"""
23
+
26
24
def __init__ (self ):
27
25
pass
28
26
@@ -60,7 +58,7 @@ def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str
60
58
self .domain = domain
61
59
self .metric = 'Accuracy'
62
60
self .task = 'Binary classification'
63
- self .url = 'https://drive.google.com/file/d/1e2GmmeN-mN6X5KL6t8CosBujS1kfjeNS /view?usp=sharing'
61
+ self .url = 'https://drive.google.com/file/d/1lGNMbQebKIbS-NnbPxmY4_uDGI7EWXBP /view?usp=sharing'
64
62
65
63
self .generate = generate
66
64
@@ -140,12 +138,6 @@ def get_covariate_shift_list(self, sorted_data_list):
140
138
141
139
train_list , ood_val_list , ood_test_list = train_val_test_list
142
140
143
- id_test_ratio = 0.15
144
- num_id_test = int (len (train_list ) * id_test_ratio )
145
- random .shuffle (train_list )
146
- train_list , id_val_list , id_test_list = train_list [: - 2 * num_id_test ], train_list [
147
- - 2 * num_id_test : - num_id_test ], \
148
- train_list [- num_id_test :]
149
141
# Compose domains to environments
150
142
num_env_train = 10
151
143
num_per_env = len (train_list ) // num_env_train
@@ -157,6 +149,14 @@ def get_covariate_shift_list(self, sorted_data_list):
157
149
cur_env_id += 1
158
150
cur_domain_id = data .domain_id
159
151
data .env_id = cur_env_id
152
+
153
+ id_test_ratio = 0.15
154
+ num_id_test = int (len (train_list ) * id_test_ratio )
155
+ random .shuffle (train_list )
156
+ train_list , id_val_list , id_test_list = train_list [: - 2 * num_id_test ], train_list [
157
+ - 2 * num_id_test : - num_id_test ], \
158
+ train_list [- num_id_test :]
159
+
160
160
all_env_list = [train_list , ood_val_list , ood_test_list , id_val_list , id_test_list ]
161
161
162
162
return all_env_list
@@ -354,15 +354,17 @@ def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool
354
354
meta_info .model_level = 'graph'
355
355
356
356
train_dataset = GOODSST2 (root = dataset_root ,
357
- domain = domain , shift = shift , subset = 'train' , generate = generate )
357
+ domain = domain , shift = shift , subset = 'train' , generate = generate )
358
358
id_val_dataset = GOODSST2 (root = dataset_root ,
359
- domain = domain , shift = shift , subset = 'id_val' , generate = generate ) if shift != 'no_shift' else None
359
+ domain = domain , shift = shift , subset = 'id_val' ,
360
+ generate = generate ) if shift != 'no_shift' else None
360
361
id_test_dataset = GOODSST2 (root = dataset_root ,
361
- domain = domain , shift = shift , subset = 'id_test' , generate = generate ) if shift != 'no_shift' else None
362
+ domain = domain , shift = shift , subset = 'id_test' ,
363
+ generate = generate ) if shift != 'no_shift' else None
362
364
val_dataset = GOODSST2 (root = dataset_root ,
363
- domain = domain , shift = shift , subset = 'val' , generate = generate )
365
+ domain = domain , shift = shift , subset = 'val' , generate = generate )
364
366
test_dataset = GOODSST2 (root = dataset_root ,
365
- domain = domain , shift = shift , subset = 'test' , generate = generate )
367
+ domain = domain , shift = shift , subset = 'test' , generate = generate )
366
368
367
369
meta_info .dim_node = train_dataset .num_node_features
368
370
meta_info .dim_edge = train_dataset .num_edge_features
0 commit comments