Skip to content
This repository was archived by the owner on May 8, 2019. It is now read-only.

Commit a0efbcb

Browse files
committed
百度点石
1 parent 01c0dcd commit a0efbcb

2 files changed

+140
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# 百度点石
2+
3+
比赛链接地址:http://dianshi.baidu.com/dianshi/pc/competition/22/submit
4+
5+
# 1. 数据简单处理
6+
7+
## PathFilter
8+
9+
可以用于对我们的数据进行采样
10+
11+
1. RandomPathFilter:随机采样数据,起到了一个 shuffle 的作用
12+
2. BalancedPathFilter:随意采样数据,并且解决数据不平衡
13+
14+
# PipelineImageTransform
15+
```Java
16+
List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(cropTransform, 0.9),
17+
new Pair<>(filpTransform, 0.9),
18+
new Pair<>(rotateTransform0, 1.0),
19+
new Pair<>(rotateTransform30, 0.9),
20+
new Pair<>(rotateTransform90, 0.9),
21+
new Pair<>(rotateTransform120, 0.9),
22+
new Pair<>(warpTransform, 0.9));
23+
```
24+
25+
后面的 Double 类型是用于标注前面的图像增强方法的执行概率。
26+
27+
```Java
28+
// {@link org.datavec.image.transform.PipelineImageTransform.doTransform}
29+
@Override
30+
protected ImageWritable doTransform(ImageWritable image, Random random) {
31+
if (shuffle) {
32+
Collections.shuffle(imageTransforms);
33+
}
34+
35+
currentTransforms.clear();
36+
37+
// execute each item in the pipeline
38+
for (Pair<ImageTransform, Double> tuple : imageTransforms) {
39+
if (tuple.getSecond() == 1.0 || rng.nextDouble() < tuple.getSecond()) { // probability of execution
40+
currentTransforms.add(tuple.getFirst());
41+
image = random != null ? tuple.getFirst().transform(image, random)
42+
: tuple.getFirst().transform(image);
43+
}
44+
}
45+
46+
return image;
47+
}
48+
```
49+
50+
# 2. 内存管理
51+
52+
官方文档:https://deeplearning4j.org/docs/latest/deeplearning4j-config-memory
53+
54+
```Java
55+
-Xms2G -Xmx2G -Dorg.bytedeco.javacpp.maxbytes=10G -Dorg.bytedeco.javacpp.maxphysicalbytes=10G
56+
```
57+
58+
# 3. 模型训练早停法
59+
60+
## 1. 创建 ModelSaver
61+
62+
用于在模型训练过程中,指定最好模型保存的位置:
63+
64+
1. InMemoryModelSaver:用于保存到内存中
65+
2. LocalFileModelSaver:用于保存到本地目录中,只能保存 `MultiLayerNetwork` 类型的网络结果
66+
3. LocalFileGraphSaver:用于保存到本地目录中,只能保存 `ComputationGraph` 类型的网络结果
67+
68+
## 2. 配置早停法训练配置项
69+
70+
1. epochTerminationConditions:训练结束条件
71+
2. evaluateEveryNEpochs:训练多少个epoch 来进行一次模型评估
72+
3. scoreCalculator:模型评估分数的计算者
73+
i. org.deeplearning4j.earlystopping.scorecalc.RegressionScoreCalculator 用于回归的分数计算
74+
ii. ClassificationScoreCalculator 用于分类任务的分数计算
75+
4. modelSaver:模型的存储位置
76+
5. iterationTerminationConditions:在每一次迭代的时候用于控制
77+
78+
## 3. 获取早停法信息
79+
```Java
80+
//Conduct early stopping training:
81+
EarlyStoppingResult result = trainer.fit();
82+
System.out.println("Termination reason: " + result.getTerminationReason());
83+
System.out.println("Termination details: " + result.getTerminationDetails());
84+
System.out.println("Total epochs: " + result.getTotalEpochs());
85+
System.out.println("Best epoch number: " + result.getBestModelEpoch());
86+
System.out.println("Score at best epoch: " + result.getBestModelScore());
87+
88+
//Print score vs. epoch
89+
Map<Integer,Double> scoreVsEpoch = result.getScoreVsEpoch();
90+
List<Integer> list = new ArrayList<>(scoreVsEpoch.keySet());
91+
Collections.sort(list);
92+
System.out.println("Score vs. Epoch:");
93+
for( Integer i : list){
94+
System.out.println(i + "\t" + scoreVsEpoch.get(i));
95+
}
96+
```
97+
98+
# 4. 迁移学习
99+
100+
## 1. 获取原有的网络结构
101+
102+
```Java
103+
// 构造数据模型
104+
ZooModel zooModel = VGG16.builder().build();
105+
ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained();
106+
```
107+
108+
109+
## 2. 修改模型的训练部分超参数
110+
111+
1. updater
112+
2. 学习率
113+
3. 随机数种子:用于模型的复现
114+
115+
```
116+
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
117+
.updater(new Nesterovs(0.1, 0.9))
118+
.seed(123)
119+
.build();
120+
```
121+
122+
## 3. 修改网络架构
123+
124+
### 3.1 setFeatureExtractor
125+
126+
用于指定那个层以下为非 frozen 层,非冻结层。
127+
128+
129+
### 3.2 结构更改
130+
131+
1. 一般只有不同网络层之间才会出现 shape 异常:需要根据异常信息调整我们的网络层结构和参数
132+
2. `removeVertexKeepConnections``addLayer` 或者是 `addVertex` 进行网络结构的更改
133+
134+
### 迁移学习思路
135+
136+
1. 抛弃全连接层 -> Global average Pooling -> 替代全连接层进行分类
137+
2. 对部分卷积层进行非冻结训练 -> 优化模型本身的特征提取能力
138+
139+
140+
Binary file not shown.

0 commit comments

Comments
 (0)