% ======================================== % % hidteach.m (Matlab code) updated 3-10-03 % % Purpose: %---------- % * This code demonstrates our hidden-node teaching algorithm that % can render the 7-4-1 MLP (only four hidden nodes) ``insensitive'' % to starting parameters randomly initialized in a small range % for solving the 7-bit parity problem (128 data). % % % Authors: %---------- % Eiji Mizutani (eiji@wayne.cs.nthu.edu.tw/eiji@biosys2.me.berkeley.edu) % Dept. of Computer Science, National Tsing Hua University, Taiwan % % and % % Stuart E. Dreyfus (dreyfus@ieor.berkeley.edu) % Dept. of IEOR, University of California at Berkeley % % % NOTE: %------ %* The MLP has only four hidden nodes (totally 37 weights) to solve the % seven-bit parity due to our finding (concerning the N-bit parity): % % ``For N odd, (N+1)/2 hidden nodes suffice for solution.'' % % See ref.[3] listed below for its proof. % %* Our hidden-node teaching algorithm employs a steepest descent-type % pattern-by-pattern learning mode (i.e., incremental gradient method) % with a fixed momentum term (0.8) derived from the Kelley-Bryson's % Optimal Control Gradient Formula (in late 1950s); see ref.[2]. % %* Because of the pattern-by-pattern learning mode, the data should be % randomly sorted: Our data set (randomly-sorted 128 data), pari7D, % is attached at the end of this file; store it into ``pari7D.dat.'' % %* Hyperbolic tangent (tanh) functions are used at both hidden % and output layers. % %* Hidden-node saturations are checked during the MLP-learning using % two variables: % ``NumOfSatuNodes'' that counts the # of saturated hidden nodes, % and ``NumSatuPtns'' that counts the # of patterns that make % all hidden nodes ``saturated.'' % %* Once RMSE (root mean squared error) becomes less than 0.6, % what we call ``Switching'' occurs usually before epoch 2,000. % %* The RMSE is computed in the batch mode at the end of epoch; so, the % data set is touched twice per epoch at the expense of inefficiency. % %* By default, our criteria for correctness, MyCriterion, is set equal % to 0.01 (i.e., MyCriterion=0.01). In this case, % % the posed problem could be SOLVED ALWAYS at epoch 3,200 or so % % when the initial weight range is [-0.25,+0.25]. % If MyCriterion is set equal to 0.8, then the required epoch would be % nearly 19,200 (see Table 1 in ref.[1] below) with a smaller RMSE. % In any event, the required # of epochs won't differ greatly, which % is the sign of ``insensitivity to initial weights'' developed in % the 7-4-1 MLP. To make the RMSE virtually zero with no additional % epoch, use the scaling-up method (see Section 4-D in ref.[1]). % %----------------------------------------------------------------------- % % For details, refer to the following references, all available at % % http://www.ieor.berkeley.edu/people/dreyfus.html. % % % % ------------ %[1] ``MLP's hidden-node saturation and insensitivity to initial weights % in two classification benchmark problems: parity and two-spirals % Eiji Mizutani and Stuart E. Dreyfus % In Proceedings of IEEE Int'l Conf. on NN (IJCNN'2002), Honolulu USA % %[2] ``On derivation of MLP backpropagation from the Kelley-Bryson % optimal control gradient formula and its application'' % Eiji Mizutani, Stuart E. Dreyfus, and Kenichi Nishio % In Proceedings of IEEE Int'l Conf. on NN (IJCNN'2000), Como Italy. % %[3] ``On dynamic programming-like recursive gradient formula for % alleviating hidden-node saturation in the parity problem'' % Eiji Mizutani, Stuart E. Dreyfus, and J.-S. Roger Jang % In Proceedings of the 8th Bellman Continuum, 2000, Taiwan % %[4] ``On complexity analysis of supervised MLP-learning for % algorithmic comparisons'' % Eiji Mizutani and Stuart E. Dreyfus % In Proc. of IEEE Int'l Conf. on NN (IJCNN'2001), Washington DC USA % % % (Revision note) % 1st revision: June, 2000 % 2nd revision: Nov., 2001 % 3rd revision: May, 2002 % 4th revision: March, 2003 (a new data set attached at the end of file) % % < Exercises > %--------------- %(1) Observe changes of ``hidden-node saturation pattern'' just % by looking at ``NumOfSatuNodes:'' printed by this code, especially % before and after ``Switching.'' % %(2) Modify the code to monitor the behavior without the ``Switching'' % mechanism, and explain why the MLP may not solve the problem; pay % attention to ``NumSatuPtns'' printed by this code. % %======================================================================= clear all; close all; %%%%%%%%%%%%%%%%%%%%%%%%%%% % % % Setups for MLP-learning % % % %%%%%%%%%%%%%%%%%%%%%%%%%%% %% keep_init_wts=0; %% 1: To store the initial weights into a file; %% 0: Not to store them. weight_range = 0.25; %% Range for initial weights EtaH = 0.01; %% Learning rate between input & hidden layers EtaO = 0.001; %% Learning rate between hidden & output layers Momentum = 0.8; Report = 50; mlp_config = [7 4 1]; %% Use 7-4-1 MLP!!! RMSE_for_Switching = 0.6; %% Switching criteria in RMSE MyCriterion = 0.01; %% Criteria for correctness: %% see Eqn(4) in ref.[1] EtaReduceRateAtSwitching = 0.2; %% Upon switching, %% EtaH is 80% reduced. Signal_for_hidden_target = 0.6; %% Target for a subset of hidden %% nodes: see Eqn(8) in ref.[1]. SaturationCriterion = 0.999; max_epoch = 25000; error_goal = 0.01; %% in terms of RMSE %======================================================================= load pari7D.dat %%% Load the 128 (7-bit parity) randomly sorted data in_n = mlp_config(1); % Number of input nodes = 7 hidden_n = mlp_config(2); % Number of hidden nodes = 4 out_n = mlp_config(3); % Number of output node = 1 [data_n, col_n] = size( pari7D ); if in_n + out_n ~= col_n, error('The # of columns of the data matrix is not for your MLP!'); end INPUT = pari7D(:, 1:in_n); % input data TARGET = pari7D(:, in_n+1:in_n+out_n); % target data %%% Count the # of saturated patterns per hidden node i NumOfSatuNodes = zeros(hidden_n,1); fprintf('=============== MLP setup ================================\n'); fprintf('MLP structure : %d x %d x %d\n',in_n,hidden_n,out_n); fprintf('Goal=%.3f LimitEpoch=%d Report=%d\n',error_goal,max_epoch, Report); fprintf('EtaH=%f EtaO=%f Momentum=%f\n',EtaH, EtaO, Momentum); fprintf('SaturationCriterion=%f WtsRange=%f \n',SaturationCriterion, weight_range); fprintf('RMSE_for_Switching=%f EtaReduceRateAtSwitching=%f \n',RMSE_for_Switching,EtaReduceRateAtSwitching); if keep_init_wts == 1 & exist('weight.mat'), %%% When the initial weight file exists, then use it. load weight else %%% Initialize the weight matrices, W1 and W2: %%% W1 = 8 x 4 matrix <=== (# of input nodes + 1) x (# of hidden nodes) %%% W2 = 5 x 1 vector <=== (# of hidden nodes + 1) x (# of output node) %%% The last row contains the weights connected to the bias node. W1 = weight_range*2*(rand(in_n+1,hidden_n) - 0.5); W2 = weight_range*2*(rand(hidden_n+1,out_n) - 0.5); if keep_init_wts == 1, save weight W1 W2; end end dW1_old = zeros(size(W1)); dW2_old = zeros(size(W2)); RMSE = -ones(max_epoch, 1); %% Root Mean Squared Error CNT = -ones(max_epoch, 1); %% # of ``incorrect'' patterns one = 1.0; %%% The bias node is the constant function. %======================================================================= % Compute the initial BATCH RMSE and check hidden-node saturation %======================================================================= NumWrongPtns = 0; % Reset the # of incorrect patterns. NumOfSatuNodes = 0 * NumOfSatuNodes; % Reset # of saturated hidden nodes NumSatuPtns = 0; % Reset the # of patterns that render % ALL the hidden nodes saturated; flag_reduce_etaH = 1; for ptn = 1:data_n, A0 = INPUT(ptn,:); T = TARGET(ptn,:); %%%%%%%% Forward pass A1 = tanh([A0 one]*W1); % hidden-node activations A2 = tanh([A1 one]*W2); % terminal-node activations %%%%%%%% Hidden-node saturation check flag = 0; for hid = 1:hidden_n, if A1(hid) > SaturationCriterion, NumOfSatuNodes(hid) = NumOfSatuNodes(hid) + 1; flag = flag + 1; else if A1(hid) < -SaturationCriterion, NumOfSatuNodes(hid) = NumOfSatuNodes(hid) + 1; flag = flag + 1; end end end if flag == hidden_n, NumSatuPtns = NumSatuPtns + 1; end residu(ptn) = T - A2; if T > 0.1, %%% If the target is ON (=1.0) if A2 < MyCriterion, NumWrongPtns=NumWrongPtns + 1; end else if A2 > -MyCriterion, NumWrongPtns=NumWrongPtns + 1; end end end rmse = sqrt(sum(residu.^2)/data_n); RMSE(1) = rmse; fprintf('Start! Epoch 0: RMSE= %.6f WrongPtns= %d NumSatuPtns=%d\n',rmse,NumWrongPtns,NumSatuPtns); fprintf('NumOfSatuNodes: '); for hid = 1:hidden_n, fprintf('%d ',NumOfSatuNodes(hid)); end fprintf('\n'); %======================================================================= fprintf('********************************************************* \n'); fprintf('**************** Start MLP-learning ********************* \n'); fprintf('********************************************************* \n'); for i = 2:max_epoch, for ptn = 1:data_n, A0 = INPUT(ptn,:); %%% A0 is a row vector of inputs. T = TARGET(ptn,:); %%% T is a scalar of the desired output. %======================================================================= % Forward pass to compute node activations %======================================================================= A1 = tanh([A0 one]*W1); %% Activations at the hidden layer; A2 = tanh([A1 one]*W2); %% Activations at the output layer; %======================================================================= % Backward pass to propagate sensitivity %======================================================================= %%% After-node sensitivity, Xi_2, at the terminal layer Xi_2 = A2 - T; %%% See Eqn(5) in ref.[4]. %%% Change from the after-node sensitivity, Xi_2, %%% to the before-node sensitivity, Delta_2; see Eqn(7) in ref.[4]. Delta_2 = Xi_2.*(1+A2).*(1-A2); %%% Gradients for the weights between terminal and hidden layers Grad_2 = [A1 one]'* Delta_2; %%% See Eqn(8) in ref.[4]. %%% Backward Sensitivity Propagation: See Eqn(6) in ref.[4]. Xi_1= W2(1:hidden_n,:) * Delta_2; %%% Xi_1= Delta_2 * W2(1:hidden_n,:)'; %%% Xi_1= Xi_2.*(1+A2).*(1-A2).*W2(1:hidden_n,:)'; %======================================================================= % *** Hidden-node teaching: see Eqn(8) in ref.[1]. % if(!((ptn+epoch)%2)) hidden-target = -0.6 * target; % else hidden-target = +0.6 * target; %======================================================================= if RMSE(i-1) < RMSE_for_Switching, if flag_reduce_etaH == 1, EtaH = EtaH * EtaReduceRateAtSwitching; fprintf('****** Switching occurs at Epoch= %d\n',i); fprintf('****** Reduce EtaH down to %.9f\n',EtaH); flag_reduce_etaH = 0; end for hid = 3:hidden_n, flag=0; if A1(hid) > SaturationCriterion, flag = 1; end if A1(hid) < -SaturationCriterion, flag = 1; end if flag == 1, if mod((ptn+hid),2) == 0, hid_target = - Signal_for_hidden_target * T; else hid_target = Signal_for_hidden_target * T; end Xi_1(hid) = Xi_1(hid) + A1(hid) - hid_target; end end else for hid = 1:hidden_n-2, if mod((ptn+hid),2) == 0, hid_target = - Signal_for_hidden_target * T; else hid_target = Signal_for_hidden_target * T; end Xi_1(hid) = Xi_1(hid) + A1(hid) - hid_target; end end %%%%%%%% End for hidden-node teaching %%% Change from the after-node sensitivity, Xi_1, %%% to the before-node sensitivity, Delta_1; see Eqn(7) in ref.[4]. Delta_1 = (1+A1).*(1-A1).*Xi_1'; %%% The outer product yields gradients (in matrix form) %%% for the weights between hidden and input layers Grad_1 = [A0 one]' * Delta_1; %%% see Eqn(8) in ref.[4] %======================================================================= % Update parameters %======================================================================= dW2 = -EtaO * Grad_2 + Momentum * dW2_old; %%% output wts dW1 = -EtaH * Grad_1 + Momentum * dW1_old; %%% hidden wts W2 = W2 + dW2; W1 = W1 + dW1; dW2_old = dW2; dW1_old = dW1; end %%%%%%%%% End_for ptn-loop %======================================================================= % After parameter updates, % compute the batch RMSE and check hidden-node saturation. %======================================================================= NumWrongPtns=0; % # of incorrect patterns NumOfSatuNodes= 0 * NumOfSatuNodes; %# of saturated hidden nodes NumSatuPtns = 0; % # of saturated patterns for ptn = 1:data_n, A0 = INPUT(ptn,:); T = TARGET(ptn,:); %%%%%%%% Forward pass A1 = tanh([A0 one]*W1); % hidden activations A2 = tanh([A1 one]*W2); % output activations %%%%%%%% Saturation check flag = 0; for hid = 1:hidden_n, if A1(hid) > SaturationCriterion, NumOfSatuNodes(hid) = NumOfSatuNodes(hid) + 1; flag = flag + 1; else if A1(hid) < -SaturationCriterion, NumOfSatuNodes(hid) = NumOfSatuNodes(hid) + 1; flag = flag + 1; end end end if flag == hidden_n, NumSatuPtns = NumSatuPtns + 1; end residu(ptn) = T - A2; if T > 0.1, %%% If the target is ON (=1.0) if A2 < MyCriterion, NumWrongPtns=NumWrongPtns+1; end else if A2 > -MyCriterion, NumWrongPtns=NumWrongPtns+1; end end end %% End for ptn-loop RMSE(i) = sqrt(sum(residu.^2)/data_n); CNT(i) = NumWrongPtns; if mod(i,Report) == 0, fprintf('Epoch %.0f: RMSE= %.6f WrongPtns= %d NumSatuPtns=%d\n',i, RMSE(i),NumWrongPtns,NumSatuPtns); fprintf('NumOfSatuNodes: '); for hid = 1:hidden_n, fprintf('%d ',NumOfSatuNodes(hid)); end fprintf('\n'); end %%%% Check if any stopping criteria is satisfied: if RMSE(i)