200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > Python 实现 周志华 《机器学习》 BP算法(高级版)

Python 实现 周志华 《机器学习》 BP算法(高级版)

时间:2021-04-21 03:21:41

相关推荐

Python 实现 周志华 《机器学习》 BP算法(高级版)

习题5.5: 试编程实现标准BP算法和累积BP算法,在西瓜数据集3.0上分别用这两个算法训练一个单隐层网络,并进行比较

算法的主要思想来自周志华《机器学习》上讲BP算法的部分,实现了书上介绍的标准BP算法和累积BP算法,对于西瓜数据集3.0,已经把文字部分的取值变为离散的数字了

如果要求解亦或问题,把下面的代码注释取消即可

x = np.mat( '1,1,2,2;\1,2,1,2\').Tx = np.array(x)y=np.mat('0,1,1,0')y = np.array(y).T

之前写过一版(戳这里查看初级版),全是通过for循环自己慢慢修改参数,这一版借助numpy矩阵运算的操作,使得代码量大大简化,并且运行的时间也比之前的版本快不少。

#!/usr/bin/python #-*- coding:utf-8 -*- ############################ #File Name: bp-watermelon3.py#Author: No One #E-mail: 1130395634@ #Created Time: -02-23 13:30:35############################import numpy as npimport mathfrom sys import argvx = np.mat( '2,3,3,2,1,2,3,3,3,2,1,1,2,1,3,1,2;\1,1,1,1,1,2,2,2,2,3,3,1,2,2,2,1,1;\2,3,2,3,2,2,2,2,3,1,1,2,2,3,2,2,3;\3,3,3,3,3,3,2,3,2,3,1,1,2,2,3,1,2;\1,1,1,1,1,2,2,2,2,3,3,3,1,1,2,3,2;\1,1,1,1,1,2,2,1,1,2,1,2,1,1,2,1,1;\0.697,0.774,0.634,0.668,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719;\0.460,0.376,0.264,0.318,0.215,0.237,0.149,0.211,0.091,0.267,0.057,0.099,0.161,0.198,0.370,0.042,0.103\').Tx = np.array(x)y = np.mat('1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0')y = np.array(y).T'''x = np.mat( '1,1,2,2;\1,2,1,2\').Tx = np.array(x)y=np.mat('0,1,1,0')y = np.array(y).T'''xrow, xcol = x.shapeyrow, ycol = y.shapenp.random.seed(0)print 'x: \n', xprint 'y: \n', ydef sigmoid(x):return 1.0 / (1.0 + np.exp(-x))def printParam(v, w, t0, t1):print 'v:', vprint 'w: ', wprint 't0: ', t0print 't1: ', t1def bpa(x, y, n_hidden_layer, r, error, n_max_train):printprint 'all bp algorithm'print '------------------------------------'print 'init param'[xrow, xcol] = x.shape[yrow, ycol] = y.shapev = np.random.random((xcol, n_hidden_layer))w = np.random.random((n_hidden_layer, ycol))t0 = np.random.random((1, n_hidden_layer))t1 = np.random.random((1, ycol))print '---------- train begins ----------'n_train = 0yo = 0loss = 0while 1:b = sigmoid(x.dot(v) - t0)yo = sigmoid(b.dot(w) - t1)loss = sum((yo - y)**2) / xrowif loss < error or n_train > n_max_train:breakn_train += 1# update paramg = yo * (1 - yo) * (y - yo)w += r * b.T.dot(g)t1 -= r * g.sum(axis = 0) e = b * (1 - b) * g * w.Tv += r * x.T.dot(e)t0 -= r * e.sum(axis = 0)if n_train % 10000 == 0:print 'train count: ', n_trainprint np.hstack((y, yo))printprint '---------- train ends ----------'print 'train count = ', n_trainyo = yo.tolist()print '---------- learned param: ----------'printParam(v, w, t0, t1)print '---------- result: ----------'print np.hstack((y, yo))print 'loss: ', loss def bps(x, y, n_hidden_layer, r, error, n_max_train):printprint 'standard bp algorithm'print '------------------------------------'print 'init param'[xrow, xcol] = x.shape[yrow, ycol] = y.shapev = np.random.random((xcol, n_hidden_layer))w = np.random.random((n_hidden_layer, ycol))t0 = np.random.random((1, n_hidden_layer))t1 = np.random.random((1, ycol))print '---------- train begins ----------'n_train = 0tag = 0yo = 0loss = 0while 1:for k in range(len(x)):b = sigmoid(x.dot(v) - t0)yo = sigmoid(b.dot(w) - t1)loss = sum((yo - y)**2) / xrowif loss < error or n_train > n_max_train:tag = 1breakb = b[k]b = b.reshape(1,b.size)n_train += 1g = yo[k] * (1 - yo[k]) * (y[k] - yo[k])g = g.reshape(1,g.size)w += r * b.T.dot(g)t1 -= r * ge = b * (1 - b) * g * w.Tv += r * x[k].reshape(1, x[k].size).T.dot(e)t0 -= r * eif n_train % 10000 == 0:print 'train count: ', n_trainprint np.hstack((y, yo))if tag:breakprintprint '---------- train ends ----------'print 'train count = ', n_trainyo = yo.tolist()print '---------- learned param: ----------'printParam(v, w, t0, t1)print '---------- result: ----------'print np.hstack((y, yo))print 'loss: ', loss r = 0.1error = 0.001n_max_train = 1000000n_hidden_layer = 5n = int(argv[1])if n == 1:bpa(x, y, n_hidden_layer, r, error, n_max_train)elif n == 2:bps(x, y, n_hidden_layer, r, error, n_max_train)else:print '命令行参数错误'

命令行输入: python test.py 1 # 1表示运行累积bp算法,2表示标准bp算法

结果如下

---------- train ends ----------train count = 10472---------- learned param: ----------v: [[ 0.73242941 3.65170127 0.59713105 0.53589607 4.26680198][-0.4797 0.38050143 0.88684761 0.96043754 -5.04922845][-3.53478658 -2.43632002 0.5617708 0.91791984 -3.99160595][ 2.72776748 -3.03747142 0.82596831 0.7629904 3.58719733][-0.49817982 0.11022257 0.45451013 0.77793538 -2.06655661][ 1.31898792 2.91731759 0.94006976 0.51654301 4.99262637][-2.87599092 -1.20602034 0.4544875 0.56614856 -2.14762434][ 3.31012315 2.37538414 0.61649596 0.94252131 0.76600351]]w: [[ -7.57093041][ -4.87555553][ -0.60132992][ -1.24255911][ 11.7515]]t0: [[ 2.86694039 1.63790548 0.13229702 0.31939894 1.83144759]]t1: [[ 1.86987471]]---------- result: ----------[[ 1.00000000e+00 9.93190538e-01][ 1.00000000e+00 9.99558269e-01][ 1.00000000e+00 9.73273387e-01][ 1.00000000e+00 9.98817906e-01][ 1.00000000e+00 9.95520603e-01][ 1.00000000e+00 9.58776391e-01][ 1.00000000e+00 9.26738291e-01][ 1.00000000e+00 9.78479082e-01][ 0.00000000e+00 5.84289232e-03][ 0.00000000e+00 6.31392712e-03][ 0.00000000e+00 8.31158755e-04][ 0.00000000e+00 1.51786116e-03][ 0.00000000e+00 2.72394938e-02][ 0.00000000e+00 2.37542259e-02][ 0.00000000e+00 7.79277689e-02][ 0.00000000e+00 1.85295127e-02][ 0.00000000e+00 2.97535714e-02]]loss: [ 0.00099981]

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。