|
| 1 | +# 百度点石 |
| 2 | + |
| 3 | +比赛链接地址:http://dianshi.baidu.com/dianshi/pc/competition/22/submit |
| 4 | + |
| 5 | +# 1. 数据简单处理 |
| 6 | + |
| 7 | +## PathFilter |
| 8 | + |
| 9 | +可以用于对我们的数据进行采样 |
| 10 | + |
| 11 | +1. RandomPathFilter:随机采样数据,起到了一个 shuffle 的作用 |
| 12 | +2. BalancedPathFilter:随意采样数据,并且解决数据不平衡 |
| 13 | + |
| 14 | +# PipelineImageTransform |
| 15 | +```Java |
| 16 | +List<Pair<ImageTransform, Double>> pipeline = Arrays.asList(new Pair<>(cropTransform, 0.9), |
| 17 | + new Pair<>(filpTransform, 0.9), |
| 18 | + new Pair<>(rotateTransform0, 1.0), |
| 19 | + new Pair<>(rotateTransform30, 0.9), |
| 20 | + new Pair<>(rotateTransform90, 0.9), |
| 21 | + new Pair<>(rotateTransform120, 0.9), |
| 22 | + new Pair<>(warpTransform, 0.9)); |
| 23 | +``` |
| 24 | + |
| 25 | +后面的 Double 类型是用于标注前面的图像增强方法的执行概率。 |
| 26 | + |
| 27 | +```Java |
| 28 | +// {@link org.datavec.image.transform.PipelineImageTransform.doTransform} |
| 29 | + @Override |
| 30 | +protected ImageWritable doTransform(ImageWritable image, Random random) { |
| 31 | + if (shuffle) { |
| 32 | + Collections.shuffle(imageTransforms); |
| 33 | + } |
| 34 | + |
| 35 | + currentTransforms.clear(); |
| 36 | + |
| 37 | + // execute each item in the pipeline |
| 38 | + for (Pair<ImageTransform, Double> tuple : imageTransforms) { |
| 39 | + if (tuple.getSecond() == 1.0 || rng.nextDouble() < tuple.getSecond()) { // probability of execution |
| 40 | + currentTransforms.add(tuple.getFirst()); |
| 41 | + image = random != null ? tuple.getFirst().transform(image, random) |
| 42 | + : tuple.getFirst().transform(image); |
| 43 | + } |
| 44 | + } |
| 45 | + |
| 46 | + return image; |
| 47 | +} |
| 48 | +``` |
| 49 | + |
| 50 | +# 2. 内存管理 |
| 51 | + |
| 52 | +官方文档:https://deeplearning4j.org/docs/latest/deeplearning4j-config-memory |
| 53 | + |
| 54 | +```Java |
| 55 | +-Xms2G -Xmx2G -Dorg.bytedeco.javacpp.maxbytes=10G -Dorg.bytedeco.javacpp.maxphysicalbytes=10G |
| 56 | +``` |
| 57 | + |
| 58 | +# 3. 模型训练早停法 |
| 59 | + |
| 60 | +## 1. 创建 ModelSaver |
| 61 | + |
| 62 | +用于在模型训练过程中,指定最好模型保存的位置: |
| 63 | + |
| 64 | +1. InMemoryModelSaver:用于保存到内存中 |
| 65 | +2. LocalFileModelSaver:用于保存到本地目录中,只能保存 `MultiLayerNetwork` 类型的网络结果 |
| 66 | +3. LocalFileGraphSaver:用于保存到本地目录中,只能保存 `ComputationGraph` 类型的网络结果 |
| 67 | + |
| 68 | +## 2. 配置早停法训练配置项 |
| 69 | + |
| 70 | + 1. epochTerminationConditions:训练结束条件 |
| 71 | + 2. evaluateEveryNEpochs:训练多少个epoch 来进行一次模型评估 |
| 72 | + 3. scoreCalculator:模型评估分数的计算者 |
| 73 | + i. org.deeplearning4j.earlystopping.scorecalc.RegressionScoreCalculator 用于回归的分数计算 |
| 74 | + ii. ClassificationScoreCalculator 用于分类任务的分数计算 |
| 75 | + 4. modelSaver:模型的存储位置 |
| 76 | + 5. iterationTerminationConditions:在每一次迭代的时候用于控制 |
| 77 | + |
| 78 | +## 3. 获取早停法信息 |
| 79 | +```Java |
| 80 | +//Conduct early stopping training: |
| 81 | +EarlyStoppingResult result = trainer.fit(); |
| 82 | +System.out.println("Termination reason: " + result.getTerminationReason()); |
| 83 | +System.out.println("Termination details: " + result.getTerminationDetails()); |
| 84 | +System.out.println("Total epochs: " + result.getTotalEpochs()); |
| 85 | +System.out.println("Best epoch number: " + result.getBestModelEpoch()); |
| 86 | +System.out.println("Score at best epoch: " + result.getBestModelScore()); |
| 87 | + |
| 88 | +//Print score vs. epoch |
| 89 | +Map<Integer,Double> scoreVsEpoch = result.getScoreVsEpoch(); |
| 90 | +List<Integer> list = new ArrayList<>(scoreVsEpoch.keySet()); |
| 91 | +Collections.sort(list); |
| 92 | +System.out.println("Score vs. Epoch:"); |
| 93 | +for( Integer i : list){ |
| 94 | + System.out.println(i + "\t" + scoreVsEpoch.get(i)); |
| 95 | +} |
| 96 | +``` |
| 97 | + |
| 98 | +# 4. 迁移学习 |
| 99 | + |
| 100 | +## 1. 获取原有的网络结构 |
| 101 | + |
| 102 | +```Java |
| 103 | + // 构造数据模型 |
| 104 | +ZooModel zooModel = VGG16.builder().build(); |
| 105 | +ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(); |
| 106 | +``` |
| 107 | + |
| 108 | + |
| 109 | +## 2. 修改模型的训练部分超参数 |
| 110 | + |
| 111 | + 1. updater |
| 112 | + 2. 学习率 |
| 113 | + 3. 随机数种子:用于模型的复现 |
| 114 | + |
| 115 | +``` |
| 116 | + FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder() |
| 117 | + .updater(new Nesterovs(0.1, 0.9)) |
| 118 | + .seed(123) |
| 119 | + .build(); |
| 120 | +``` |
| 121 | + |
| 122 | +## 3. 修改网络架构 |
| 123 | + |
| 124 | +### 3.1 setFeatureExtractor |
| 125 | + |
| 126 | +用于指定那个层以下为非 frozen 层,非冻结层。 |
| 127 | + |
| 128 | + |
| 129 | +### 3.2 结构更改 |
| 130 | + |
| 131 | +1. 一般只有不同网络层之间才会出现 shape 异常:需要根据异常信息调整我们的网络层结构和参数 |
| 132 | +2. `removeVertexKeepConnections` 和 `addLayer` 或者是 `addVertex` 进行网络结构的更改 |
| 133 | + |
| 134 | +### 迁移学习思路 |
| 135 | + |
| 136 | +1. 抛弃全连接层 -> Global average Pooling -> 替代全连接层进行分类 |
| 137 | +2. 对部分卷积层进行非冻结训练 -> 优化模型本身的特征提取能力 |
| 138 | + |
| 139 | + |
| 140 | + |
0 commit comments