Skip to content

Commit a434006

Browse files
committedJul 16, 2017
更新所有实验程序到正确绘图版本
1 parent 31e3df4 commit a434006

7 files changed

+323
-461
lines changed
 

‎p2 SGD.py

-126
This file was deleted.

‎p2 origin SGD.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from mpl_toolkits.mplot3d import Axes3D
4+
import random
5+
# 本代码是一个最简单的线形回归问题,优化函数为经典的 SGD
6+
rate = 0.2 # learning rate
7+
def da(y,y_p,x):
8+
return (y-y_p)*(-x)
9+
10+
def db(y,y_p):
11+
return (y-y_p)*(-1)
12+
def calc_loss(a,b,x,y):
13+
tmp = y - (a * x + b)
14+
tmp = tmp ** 2 # 对矩阵内的每一个元素平方
15+
SSE = sum(tmp) / (2 * len(x))
16+
return SSE
17+
def draw_hill(x,y):
18+
a = np.linspace(-20,20,100)
19+
print(a)
20+
b = np.linspace(-20,20,100)
21+
x = np.array(x)
22+
y = np.array(y)
23+
24+
allSSE = np.zeros(shape=(len(a),len(b)))
25+
for ai in range(0,len(a)):
26+
for bi in range(0,len(b)):
27+
a0 = a[ai]
28+
b0 = b[bi]
29+
SSE = calc_loss(a=a0,b=b0,x=x,y=y)
30+
allSSE[ai][bi] = SSE
31+
32+
a,b = np.meshgrid(a, b)
33+
34+
return [a,b,allSSE]
35+
36+
def shuffle_data(x,y):
37+
# 随机打乱x,y的数据,并且保持x和y一一对应
38+
seed = random.random()
39+
random.seed(seed)
40+
random.shuffle(x)
41+
random.seed(seed)
42+
random.shuffle(y)
43+
# 模拟数据
44+
x = [30 ,35,37, 59, 70, 76, 88, 100]
45+
y = [1100, 1423, 1377, 1800, 2304, 2588, 3495, 4839]
46+
47+
# 数据归一化
48+
x_max = max(x)
49+
x_min = min(x)
50+
y_max = max(y)
51+
y_min = min(y)
52+
53+
for i in range(0,len(x)):
54+
x[i] = (x[i] - x_min)/(x_max - x_min)
55+
y[i] = (y[i] - y_min)/(y_max - y_min)
56+
57+
[ha,hb,hallSSE] = draw_hill(x,y)
58+
hallSSE = hallSSE.T# 重要,将所有的losses做一个转置。原因是矩阵是以左上角至右下角顺序排列元素,而绘图是以左下角为原点。
59+
# 初始化a,b值
60+
a = 10.0
61+
b = -20.0
62+
fig = plt.figure(1, figsize=(12, 8))
63+
64+
# 绘制图1的曲面
65+
ax = fig.add_subplot(2, 2, 1, projection='3d')
66+
ax.set_top_view()
67+
ax.plot_surface(ha, hb, hallSSE, rstride=2, cstride=2, cmap='rainbow')
68+
69+
# 绘制图2的等高线图
70+
plt.subplot(2,2,2)
71+
ta = np.linspace(-20, 20, 100)
72+
tb = np.linspace(-20, 20, 100)
73+
plt.contourf(ha,hb,hallSSE,15,alpha=0.5,cmap=plt.cm.hot)
74+
C = plt.contour(ha,hb,hallSSE,15,colors='black')
75+
plt.clabel(C,inline=True)
76+
plt.xlabel('a')
77+
plt.ylabel('b')
78+
79+
plt.ion() # iteration on
80+
81+
all_loss = []
82+
all_step = []
83+
last_a = a
84+
last_b = b
85+
step = 1
86+
while step <= 500:
87+
loss = 0
88+
all_da = 0
89+
all_db = 0
90+
shuffle_data(x,y)
91+
for i in range(0,len(x)):
92+
y_p = a*x[i] + b
93+
loss = (y[i] - y_p)*(y[i] - y_p)/2
94+
all_da = da(y[i],y_p,x[i])
95+
all_db = db(y[i],y_p)
96+
#loss_ = calc_loss(a = a,b=b,x=np.array(x),y=np.array(y))
97+
#loss = loss/len(x)
98+
99+
# 绘制图1中的loss点
100+
ax.scatter(a, b, loss, color='black')
101+
# 绘制图2中的loss点
102+
plt.subplot(2, 2, 2)
103+
plt.scatter(a,b,s=5,color='blue')
104+
plt.plot([last_a,a],[last_b,b],color='aqua')
105+
# 绘制图3中的回归直线
106+
plt.subplot(2, 2, 3)
107+
plt.plot(x, y)
108+
plt.plot(x, y, 'o')
109+
x_ = np.linspace(0, 1, 2)
110+
y_draw = a * x_ + b
111+
plt.plot(x_, y_draw)
112+
# 绘制图4的loss更新曲线
113+
all_loss.append(loss)
114+
all_step.append(step)
115+
plt.subplot(2,2,4)
116+
plt.plot(all_step,all_loss,color='orange')
117+
plt.xlabel("step")
118+
plt.ylabel("loss")
119+
120+
last_a = a
121+
last_b = b
122+
123+
# 更新参数
124+
a = a - rate*all_da
125+
b = b - rate*all_db
126+
127+
if step%1 == 0:
128+
print("step: ", step, " loss: ", loss)
129+
plt.show()
130+
plt.pause(0.01)
131+
step = step + 1
132+
plt.show()
133+
plt.pause(99999999999)

0 commit comments

Comments
 (0)
Please sign in to comment.