博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
70行代码实现的神经网络算法
阅读量:6601 次
发布时间:2019-06-24

本文共 2335 字,大约阅读时间需要 7 分钟。

hot3.png

1、算法实现

public class BpDeep{    public double[][] layer;//神经网络各层节点    public double[][] layerErr;//神经网络各节点误差    public double[][][] layer_weight;//各层节点权重    public double[][][] layer_weight_delta;//各层节点权重动量    public double mobp;//动量系数    public double rate;//学习系数    public BpDeep(int[] layernum, double rate, double mobp){        this.mobp = mobp;        this.rate = rate;        layer = new double[layernum.length][];        layerErr = new double[layernum.length][];        layer_weight = new double[layernum.length][][];        layer_weight_delta = new double[layernum.length][][];        Random random = new Random();        for(int l=0;l
0){ for(int j=0;j
0?layerErr[l+1][i]*layer_weight[l][j][i]:0; layer_weight_delta[l][j][i]= mobp*layer_weight_delta[l][j][i]+rate*layerErr[l+1][i]*layer[l][j];//隐含层动量调整 layer_weight[l][j][i]+=layer_weight_delta[l][j][i];//隐含层权重调整 if(j==layerErr[l].length-1){ layer_weight_delta[l][j+1][i]= mobp*layer_weight_delta[l][j+1][i]+rate*layerErr[l+1][i];//截距动量调整 layer_weight[l][j+1][i]+=layer_weight_delta[l][j+1][i];//截距权重调整 } } layerErr[l][j]=z*layer[l][j]*(1-layer[l][j]);//记录误差 } } } public void train(double[] in, double[] tar){ double[] out = computeOut(in); updateWeight(tar); }}

2、算法测试

public class BpDeepTest{    public static void main(String[] args){        //初始化神经网络的基本配置        //第一个参数是一个整型数组,表示神经网络的层数和每层节点数,比如{3,10,10,10,10,2}表示输入层是3个节点,输出层是2个节点,中间有4层隐含层,每层10个节点        //第二个参数是学习步长,第三个参数是动量系数        BpDeep bp = new BpDeep(new int[]{2,10,2}, 0.15, 0.8);        //设置样本数据,对应上面的4个二维坐标数据        double[][] data = new double[][]{
{1,2},{2,2},{1,1},{2,1}}; //设置目标数据,对应4个坐标数据的分类 double[][] target = new double[][]{
{1,0},{0,1},{0,1},{1,0}}; //迭代训练5000次 for(int n=0;n<5000;n++) for(int i=0;i

3、执行结果

[1.0, 2.0]:[0.9782137336790337, 0.021683706747676907][2.0, 2.0]:[0.02140104439139772, 0.9785416755641893][1.0, 1.0]:[0.016850236680035113, 0.9835668738330479][2.0, 1.0]:[0.9809725214354169, 0.018824324694218176][3.0, 1.0]:[0.9985448434744455, 0.0013163425493131222]

 

转载于:https://my.oschina.net/u/2391658/blog/700073

你可能感兴趣的文章
如何查看已委派控制的用户及具体权限
查看>>
Kotlin从入门到放弃(四)——协程下
查看>>
You should be here !
查看>>
WKWebView捕获HTML弹出的Alert和Confirm
查看>>
Hyper-V Server NUMA
查看>>
NT/2000下删日志的方法
查看>>
Gradle 1.12用户指南翻译——第四十八章. Wrapper 插件
查看>>
Oracle 10.2.0.4(5)EM不能启动的解决方案
查看>>
【我们都爱Paul Hegarty】斯坦福IOS8公开课个人笔记10 Property List
查看>>
windows 2003系统上部署Exchange 2007邮件服务器 (一)
查看>>
开启Mysql远程访问权限
查看>>
saltstack学习四:自定义modules
查看>>
如何设置RHEL6 ADSL(pppoe)实现接入宽带网络
查看>>
Extjs 无法decode 带有 \n 的字符串
查看>>
SQL Server 审核(Audit)-- 使用审核的注意事项
查看>>
偷梁换柱 暗渡陈仓 一招搞定360安全卫士无法启动
查看>>
Win7下启用Telnet方法
查看>>
Linux简单的DNS服务器配置
查看>>
LD算法获取字符串相似度
查看>>
kvm虚拟化学习笔记(十六)之kvm虚拟化存储池配置
查看>>