[专知-Java Deeplearning4j 深度学习教程 03] 使用多层神经网络分类 MNIST 数据集:图文-代码

【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视觉等)、大数据、编程语言、系统架构。使用请访问专知 进行主题搜索查看 - 桌面电脑访问http://www.zhuanzhi.ai, 手机端访问http://www.zhuanzhi.ai 或关注微信公众号后台回复" 专知"进入专知,搜索主题查看。继Pytorch教程后,我们推出面向Java程序员的深度学习教程DeepLearning4J。Deeplearning4j的案例和资料很少,官方的doc文件也非常简陋,基本上所有的类和函数的都没有解释。为此,我们推出来自中科院自动化所专知小组博士生Hujun与Sanglei创作的-分布式Java开源深度学习框架Deeplearning4j学习教程包括以下:

MNIST数据集

图片

MNIST由手写数字图片组成,包含0-9十种数字,常被用作测试机器学习算法性能的基准数据集。MNIST包含了一个有60000张图片的训练集和一个有10000张图片的测试集。深度学习在MNIST上可以达到99.7%的准确率。

Deeplearning4j中直接集成了MNIST数据集,例如可以直接用下面的代码加载训练集和测试集:


DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);

DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);

DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);

DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);

神经网络结构

图片

本教程使用具有1个隐藏层的MLP作为网络的结构,使用RELU作为隐藏层的激活函数,使用SOFTMAX作为输出层的激活函数。

从图中可以看出,网络具有输入层、隐藏层和输出层一共3层,但在代码编写时,会将该网络看作由2个层组成(2次变换):

  • Layer 0: 一个Dense Layer(全连接层),由输入层进行线性变换变为隐藏层,并使用RELU对变换结果进行激活。用公式表达形式为H= relu(XW_0 + b_0),其中:

    • X: 输入层,是形状为[batch_size, input_dim]的矩阵,矩阵的每行对应一个样本,每列对应一个特征(一个像素)
    • H: 隐藏层的输出,是形状为[batch_size, hidden_dim]的矩阵,矩阵的每行对应一个样本隐藏层的输出
    • relu: 使用RELU激活函数进行激活
    • W_0: 形状为[input_dim, hidden_dim]的矩阵,是全连接层线性变换的参数
    • b_0: 形状为[hidden_dim]的矩阵,是全连接层线性变换的参数(偏置)
  • Layer 1: 一个Dense Layer(全连接层),由隐藏层进行线性变换为输出层,并使用SOFTMAX对变换结果进行激活。用公式表达形式为:OUTPUT = softmax(HW_1 + b_1),其中:

    • OUTPUT: 输出层,是形状为[batch_size, output_dim]的矩阵,矩阵的每行对应一个样本,每列对应样本属于某类的概率。例如该例子中第0列表示输入手写数字为1的概率。
    • softmax: 使用SOFTMAX激活函数进行激活
    • W_1: 形状为[hidden_dim, output_dim]的矩阵,是全连接层线性变换的参数
    • b_1: 形状为[output_dim]的矩阵,是全连接层线性变换的参数(偏置)

神经网络的训练过程,即神经网络参数的调整过程。待参数能够很好地预测测试集中样本的类别(label),神经网络就训练成功了。

代码


import org.nd4j.linalg.activations.Activation;  
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;  
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;  
import org.deeplearning4j.eval.Evaluation;  
import org.deeplearning4j.nn.api.OptimizationAlgorithm;  
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;  
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;  
import org.deeplearning4j.nn.conf.Updater;  
import org.deeplearning4j.nn.conf.layers.DenseLayer;  
import org.deeplearning4j.nn.conf.layers.OutputLayer;  
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;  
import org.deeplearning4j.nn.weights.WeightInit;  
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;  
import org.nd4j.linalg.api.ndarray.INDArray;  
import org.nd4j.linalg.dataset.DataSet;  
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;  
import org.slf4j.Logger;import org.slf4j.LoggerFactory;/\*\*

\**
本示例使用Deeplearning4j构建了一个多层感知器(MLP)来进行手写数字(MNIST)的识别

该示例中的神经网络只有1个隐藏层
输入层的维度是numRows\*numColumns(图像像素行数\*图像像素列数),即每个手写数字图像的像素数量(28\*28)

隐藏层的大小为1000,使用RELU作为激活函数

输出层为SOFTMAX层,用于表示输入图像属于每个分类的概率(概率总和为1)

\**

public class MLPMnistSingleLayerExample {      
       private static Logger log =
LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);      
         
       public static void main(String[] args) throws Exception {          
           //number of rows and columns in the input pictures

           final int numRows = 28;          
           final int numColumns = 28;          
           int outputNum = 10; // 手写字符类别的数量

           int batchSize = 128; //
batch大小,一个batch中的输入使用相同的神经网络参数

           int rngSeed = 123; //
设置一个随机种子,使得每次跑程序获得的随机值相同

           int numEpochs = 15; // 训练时每扫描一遍数据集算一个Epoch

           //Deeplearning4j内置的MNIST数据集

           DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,
true, rngSeed);

           DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,
false, rngSeed);

           log.info("Build model....");

           MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(rngSeed) 
// 为模型设置随机种子

               // 使用随机梯度下降作为优化算法

             
 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

               .iterations(1)

               .learningRate(0.006) // 设置学习速率

               .updater(Updater.NESTEROVS)

               .regularization(true).l2(1e-4)
//设置L2正则系数,设置L2正则可以降低过拟合的程度

               .list() //开始构建MLP网络(多层感知器)

               .layer(0, new DenseLayer.Builder() //设置第一个Dense层

                       .nIn(numRows \* numColumns) //输入为28\*28

                       .nOut(1000) //输出为1000

                       .activation(Activation.RELU) //使用RELU激活

                       .weightInit(WeightInit.XAVIER) //设置初始化方法

                       .build())

               .layer(1, new
OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
//设置第二个Dense层,OutputLayer也是Dense层

                       .nIn(1000) //输入为1000

                       .nOut(outputNum) //输出为10,即手写数字的类别数量

                       .activation(Activation.SOFTMAX) //使用SOFTMAX激活

                       .weightInit(WeightInit.XAVIER)

                       .build())

               .pretrain(false).backprop(true) //进行反向传播,不进行预训练

               .build();

           MultiLayerNetwork model = new MultiLayerNetwork(conf);

           model.init();        //每隔1个iteration就输出一次score

           model.setListeners(new ScoreIterationListener(1));

           log.info("Train model....");          
           for( int i=0; i\<numEpochs; i++ ){

               model.fit(mnistTrain);

           }

           log.info("Evaluate model....");

           Evaluation eval = new Evaluation(outputNum); //创建一个评价器

           while(mnistTest.hasNext()){

               DataSet next = mnistTest.next();

               INDArray output = model.output(next.getFeatureMatrix());
//模型的预测结果

               eval.eval(next.getLabels(), output);
//根据真实的结果和模型的预测结果对模型进行评价

           }

           log.info(eval.stats());

           log.info("\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*Example
finished\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*\*");

   }

}

运行代码,输出如下:

图片

图片

请继续关注“DeepLearning4j”教程。

完整系列搜索查看,请PC登录

www.zhuanzhi.ai, 搜索“**DeepLearning4j**”即可得。


对DeepLearning4j教程感兴趣的同学,欢迎进入我们的专知DeepLearning4j主题群一起交流、学习、讨论,扫一扫如下群二维码即可进入(先加微信小助手weixinhao: Rancho_Fang,注明Deeplearning4j)。

展开全文
相关主题
Top
微信扫码咨询专知VIP会员