-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample4.py
68 lines (54 loc) · 1.98 KB
/
example4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from tinygrad import nn, Tensor, TinyJit
from tinygrad.helpers import getenv, trange
NUMBERS = 1000
class Model:
def __init__(self):
self.l1 = nn.Linear(10, 16)
self.l2 = nn.Linear(16, 1)
def __call__(self, x: Tensor):
x = self.l1(x).relu()
x = self.l2(x).sigmoid()
return x
if __name__ == "__main__":
num_array_train = np.arange(1, NUMBERS+1)
X_train = np.eye(10)[num_array_train % 10]
Y_train = ((num_array_train % 2) == 0).astype(np.int32)
X_train_tensor = Tensor(X_train.astype(np.int32))
Y_train_tensor = Tensor(Y_train.reshape(-1, 1))
num_array_test = np.arange(NUMBERS, NUMBERS+201)
X_test = np.eye(10)[num_array_test % 10]
Y_test = ((num_array_test % 2) == 0).astype(np.int32)
X_test_tensor = Tensor(X_test.astype(np.int32))
Y_test_tensor = Tensor(Y_test.reshape(-1, 1))
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=0.001)
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
outputs = model(X_train_tensor)
loss = outputs.binary_crossentropy(Y_train_tensor).mean()
opt.zero_grad()
loss.backward()
opt.step()
return loss
@TinyJit
@Tensor.test()
def get_test_acc() -> Tensor:
preds = (np.array(model(X_test_tensor).data()) > 0.5).astype(int)
return (preds == Y_test).mean() * 100
test_acc = float("nan")
for i in (t := trange(getenv("STEPS", NUMBERS))):
loss = train_step()
if i % 10 == 9:
test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
def predict(n):
last_digit = n % 10
vector = np.eye(10)[last_digit].astype(np.int32)
x = Tensor(vector)
prob = model(x).item()
return "even" if prob > 0.5 else "odd"
test_numbers = [17, 44, 62, 503, 12321]
for num in test_numbers:
print(f"Number {num} - {predict(num)}")