1
+ import numpy as np
2
+ import matplotlib .pyplot as plt
3
+ from mpl_toolkits .mplot3d import Axes3D
4
+ import random
5
+ # 本代码是一个最简单的线形回归问题,优化函数为经典的 SGD
6
+ rate = 0.2 # learning rate
7
+ def da (y ,y_p ,x ):
8
+ return (y - y_p )* (- x )
9
+
10
+ def db (y ,y_p ):
11
+ return (y - y_p )* (- 1 )
12
+ def calc_loss (a ,b ,x ,y ):
13
+ tmp = y - (a * x + b )
14
+ tmp = tmp ** 2 # 对矩阵内的每一个元素平方
15
+ SSE = sum (tmp ) / (2 * len (x ))
16
+ return SSE
17
+ def draw_hill (x ,y ):
18
+ a = np .linspace (- 20 ,20 ,100 )
19
+ print (a )
20
+ b = np .linspace (- 20 ,20 ,100 )
21
+ x = np .array (x )
22
+ y = np .array (y )
23
+
24
+ allSSE = np .zeros (shape = (len (a ),len (b )))
25
+ for ai in range (0 ,len (a )):
26
+ for bi in range (0 ,len (b )):
27
+ a0 = a [ai ]
28
+ b0 = b [bi ]
29
+ SSE = calc_loss (a = a0 ,b = b0 ,x = x ,y = y )
30
+ allSSE [ai ][bi ] = SSE
31
+
32
+ a ,b = np .meshgrid (a , b )
33
+
34
+ return [a ,b ,allSSE ]
35
+
36
+ def shuffle_data (x ,y ):
37
+ # 随机打乱x,y的数据,并且保持x和y一一对应
38
+ seed = random .random ()
39
+ random .seed (seed )
40
+ random .shuffle (x )
41
+ random .seed (seed )
42
+ random .shuffle (y )
43
+ # 模拟数据
44
+ x = [30 ,35 ,37 , 59 , 70 , 76 , 88 , 100 ]
45
+ y = [1100 , 1423 , 1377 , 1800 , 2304 , 2588 , 3495 , 4839 ]
46
+
47
+ # 数据归一化
48
+ x_max = max (x )
49
+ x_min = min (x )
50
+ y_max = max (y )
51
+ y_min = min (y )
52
+
53
+ for i in range (0 ,len (x )):
54
+ x [i ] = (x [i ] - x_min )/ (x_max - x_min )
55
+ y [i ] = (y [i ] - y_min )/ (y_max - y_min )
56
+
57
+ [ha ,hb ,hallSSE ] = draw_hill (x ,y )
58
+ hallSSE = hallSSE .T # 重要,将所有的losses做一个转置。原因是矩阵是以左上角至右下角顺序排列元素,而绘图是以左下角为原点。
59
+ # 初始化a,b值
60
+ a = 10.0
61
+ b = - 20.0
62
+ fig = plt .figure (1 , figsize = (12 , 8 ))
63
+
64
+ # 绘制图1的曲面
65
+ ax = fig .add_subplot (2 , 2 , 1 , projection = '3d' )
66
+ ax .set_top_view ()
67
+ ax .plot_surface (ha , hb , hallSSE , rstride = 2 , cstride = 2 , cmap = 'rainbow' )
68
+
69
+ # 绘制图2的等高线图
70
+ plt .subplot (2 ,2 ,2 )
71
+ ta = np .linspace (- 20 , 20 , 100 )
72
+ tb = np .linspace (- 20 , 20 , 100 )
73
+ plt .contourf (ha ,hb ,hallSSE ,15 ,alpha = 0.5 ,cmap = plt .cm .hot )
74
+ C = plt .contour (ha ,hb ,hallSSE ,15 ,colors = 'black' )
75
+ plt .clabel (C ,inline = True )
76
+ plt .xlabel ('a' )
77
+ plt .ylabel ('b' )
78
+
79
+ plt .ion () # iteration on
80
+
81
+ all_loss = []
82
+ all_step = []
83
+ last_a = a
84
+ last_b = b
85
+ step = 1
86
+ while step <= 500 :
87
+ loss = 0
88
+ all_da = 0
89
+ all_db = 0
90
+ shuffle_data (x ,y )
91
+ for i in range (0 ,len (x )):
92
+ y_p = a * x [i ] + b
93
+ loss = (y [i ] - y_p )* (y [i ] - y_p )/ 2
94
+ all_da = da (y [i ],y_p ,x [i ])
95
+ all_db = db (y [i ],y_p )
96
+ #loss_ = calc_loss(a = a,b=b,x=np.array(x),y=np.array(y))
97
+ #loss = loss/len(x)
98
+
99
+ # 绘制图1中的loss点
100
+ ax .scatter (a , b , loss , color = 'black' )
101
+ # 绘制图2中的loss点
102
+ plt .subplot (2 , 2 , 2 )
103
+ plt .scatter (a ,b ,s = 5 ,color = 'blue' )
104
+ plt .plot ([last_a ,a ],[last_b ,b ],color = 'aqua' )
105
+ # 绘制图3中的回归直线
106
+ plt .subplot (2 , 2 , 3 )
107
+ plt .plot (x , y )
108
+ plt .plot (x , y , 'o' )
109
+ x_ = np .linspace (0 , 1 , 2 )
110
+ y_draw = a * x_ + b
111
+ plt .plot (x_ , y_draw )
112
+ # 绘制图4的loss更新曲线
113
+ all_loss .append (loss )
114
+ all_step .append (step )
115
+ plt .subplot (2 ,2 ,4 )
116
+ plt .plot (all_step ,all_loss ,color = 'orange' )
117
+ plt .xlabel ("step" )
118
+ plt .ylabel ("loss" )
119
+
120
+ last_a = a
121
+ last_b = b
122
+
123
+ # 更新参数
124
+ a = a - rate * all_da
125
+ b = b - rate * all_db
126
+
127
+ if step % 1 == 0 :
128
+ print ("step: " , step , " loss: " , loss )
129
+ plt .show ()
130
+ plt .pause (0.01 )
131
+ step = step + 1
132
+ plt .show ()
133
+ plt .pause (99999999999 )
0 commit comments