-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLoss.py
73 lines (59 loc) · 2.34 KB
/
Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
Loss.py用于实现损失函数
主要实现交叉熵Cross Entropy
'''
import numpy as np
from Module import Module
class NLLLoss(Module):
# 计算多分类任务中的负对数似然损失函数,传入logsoftmax([p1,p2,...,pk])
def __init__(self, size_average=True):
super(NLLLoss, self).__init__()
self.size_average = size_average
# 计算损失函数,先计算softmax值,再使用cross entropy计算损失
def cal_loss(self, prediction, labels):
'''
predict:output of predicted probability [batch, [p1,p2,...,pk]]
labels: labels of dataset [batch, 1] , example: [2,3,8,9]表示类别
labels可能有one-hot编码模式,[0,0,1,0]代表3
size_average:if the loss need to be averaged
'''
self.labels = labels
self.prediction = prediction
self.batchsize = self.prediction.shape[0]
self.loss = 0
# 判断是否使用one-hot编码
if labels.ndim >1: # one-hot [[p1,p2,...,pk],...]
for i in range(self.batchsize):
self.loss -= np.sum(self.prediction * self.labels)
elif labels.ndim == 1: # [class_num]
for i in range(self.batchsize):
self.loss -= prediction[i, labels[i]]
# 对所有样本的loss求平均,作为最终的loss输出
if self.size_average:
self.loss = self.loss/self.batchsize
return self.loss
def gradient(self):
self.eta = self.labels.copy()
# 求导结果为-yi
self.eta_next = -self.eta
return self.eta_next
class BECLoss(Module):
# 计算二分类任务的交叉熵损失函数
def __init__(self, size_average=True):
super(BECLoss, self).__init__()
self.size_average = size_average
def forward(self, prediction, labels):
self.prediction = prediction
self.batchsize = self.prediction.shape[0]
self.labels = labels
self.loss = self.labels*np.log(self.prediction)+(1-self.labels)*np.log(1-self.prediction)
self.loss = -np.sum(self.loss)
if self.size_average:
self.loss /= self.batchsize
return self.loss
def gradient(self):
return self.labels
def test_NLLLoss():
print('-----NLLLoss test-----')
if __name__ == '__main__':
test_NLLLoss()