forked from Soptq/LANA-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
102 lines (91 loc) · 5.73 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from torch.utils.data import Dataset, DataLoader
import torch
import config
import numpy as np
from tqdm import tqdm
# Dataset Define
class DKTDataset(Dataset):
def __init__(self, group, max_seq, min_seq, overlap_seq):
self.samples = group
self.max_seq = max_seq
self.min_seq = min_seq
self.overlap_seq = overlap_seq
self.data = []
for exercise, part, correctness, elapsed_time, lag_time_s, lag_time_m, lag_time_d, p_explanation in tqdm(self.samples, total=len(self.samples), desc="Loading Dataset"):
content_len = len(exercise)
if content_len < self.min_seq:
continue # skip sequence with too few contents
if content_len > self.max_seq:
initial = content_len % self.max_seq
if initial >= self.min_seq:
self.data.extend([(np.append([config.START], exercise[:initial]),
np.append([config.START], part[:initial]),
np.append([config.START], correctness[:initial]),
np.append([config.START], elapsed_time[:initial]),
np.append([config.START], lag_time_s[:initial]),
np.append([config.START], lag_time_m[:initial]),
np.append([config.START], lag_time_d[:initial]),
np.append([config.START], p_explanation[:initial]))])
for seq in range(content_len // self.max_seq):
start = initial + seq * self.max_seq
end = initial + (seq + 1) * self.max_seq
self.data.extend([(np.append([config.START], exercise[start: end]),
np.append([config.START], part[start: end]),
np.append([config.START], correctness[start: end]),
np.append([config.START], elapsed_time[start: end]),
np.append([config.START], lag_time_s[start: end]),
np.append([config.START], lag_time_m[start: end]),
np.append([config.START], lag_time_d[start: end]),
np.append([config.START], p_explanation[start: end]))])
else:
self.data.extend([(np.append([config.START], exercise),
np.append([config.START], part),
np.append([config.START], correctness),
np.append([config.START], elapsed_time),
np.append([config.START], lag_time_s),
np.append([config.START], lag_time_m),
np.append([config.START], lag_time_d),
np.append([config.START], p_explanation))])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
raw_content_ids, raw_part, raw_correctness, raw_elapsed_time, raw_lag_time_s, raw_lag_time_m, raw_lag_time_d, raw_p_explan = self.data[idx]
seq_len = len(raw_content_ids)
input_content_ids = np.zeros(self.max_seq, dtype=np.int64)
input_part = np.zeros(self.max_seq, dtype=np.int64)
input_correctness = np.zeros(self.max_seq, dtype=np.int64)
input_elapsed_time = np.zeros(self.max_seq, dtype=np.int64)
input_lag_time_s = np.zeros(self.max_seq, dtype=np.int64)
input_lag_time_m = np.zeros(self.max_seq, dtype=np.int64)
input_lag_time_d = np.zeros(self.max_seq, dtype=np.int64)
input_p_explan = np.zeros(self.max_seq, dtype=np.int64)
label = np.zeros(self.max_seq, dtype=np.int64)
if seq_len == self.max_seq + 1: # START token
input_content_ids[:] = raw_content_ids[1:]
input_part[:] = raw_part[1:]
input_p_explan[:] = raw_p_explan[1:]
input_correctness[:] = raw_correctness[:-1]
input_elapsed_time[:] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:])
input_lag_time_s[:] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:])
input_lag_time_m[:] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:])
input_lag_time_d[:] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:])
label[:] = raw_correctness[1:] - 2
else:
input_content_ids[-(seq_len - 1):] = raw_content_ids[1:] # Delete START token
input_part[-(seq_len - 1):] = raw_part[1:]
input_p_explan[-(seq_len - 1):] = raw_p_explan[1:]
input_correctness[-(seq_len - 1):] = raw_correctness[:-1]
input_elapsed_time[-(seq_len - 1):] = np.append(raw_elapsed_time[0], raw_elapsed_time[2:])
input_lag_time_s[-(seq_len - 1):] = np.append(raw_lag_time_s[0], raw_lag_time_s[2:])
input_lag_time_m[-(seq_len - 1):] = np.append(raw_lag_time_m[0], raw_lag_time_m[2:])
input_lag_time_d[-(seq_len - 1):] = np.append(raw_lag_time_d[0], raw_lag_time_d[2:])
label[-(seq_len - 1):] = raw_correctness[1:] - 2
_input = {"content_id": input_content_ids.astype(np.int64),
"part": input_part.astype(np.int64),
"correctness": input_correctness.astype(np.int64),
"elapsed_time": input_elapsed_time.astype(np.int64),
"lag_time_s": input_lag_time_s.astype(np.int64),
"lag_time_m": input_lag_time_m.astype(np.int64),
"lag_time_d": input_lag_time_d.astype(np.int64),
"prior_explan": input_p_explan.astype(np.int64)}
return _input, label