Weight initialization:- Any deep neural network (with similar structure as below) has “weights” associated with every connection . “Weights” can be viewed as how important , (with either positive or negative correlation ), any input feature holds while producing the final output .
Of course these weights are what a network learns in course of its training phase . But in order to update these numbers eventually , we must start somewhere . So, the question is , How to initialize the weights of a network?
There seems to be various ways in which one can do so , for eg ,
- Symmetrically assigning all weights
- Randomly assigning all weights
- Assigning weights by sampling from a distribution
Lets look at each of the above cases and finally conclude some techniques that are used widely to initialize weights in deep networks.
So, the first case is Wkij=c for all i,j ,k , that is weights of all layers and in between any two nodes are zero is same and has the value c . Since the value is same for all neurons , all the neurons would be symmetric(between neuron and all its subsequent connections) and will receive same updates . We want each neuron to learn a certain feature and this initialization technique wont let that happen .
To demonstrate the above see the below visualization:
- using tensor flow playground initialize all weights to zero
- let the model train
- Notice how all the weights originating from a neuron are same .
- you will see that every neuron in the first has unique but different weight value for all its connections, but after that between any layer L-1 and L all values are same as they are symmetric.
Of course assigning random values can help remove the problem of symmetricity but there is one drawback of random weight assignments . Absurdly large and small values of weights will result in a problem of vanishing gradients as at extremely large and small values a sigmoid functions derivative would be extremely close to zero , hence it will hinder weight updates after iteration .
Lets try to visualize the same in tensor flow playground :
- lets assign some large positive weights and some large negative weights
- notice how the training seems to halt even after numerous epochs
A need for better and improved methods led to development of new activation function specific weight initialization techniques. Xavier initialization is used for tanh activations and its logic is as follows:
Xavier initialization tries to initialize the weights of the network so that the neuron activation functions are not starting out in saturated or dead regions. We want to initialize the weights with random values which are not “too small or large.”
NOTE: there are 2 variations of Xavier initialization: based on the sampling technique used (uniform or normal) .
Where for any neuron input=number of incoming connections , and output=number of output connections
But why does this work?
An interviewer can seemingly put numerous questions as to how this weird looking formulae came into existance . The answer lies in the following explanation.
- Xa. weight initialization is trying to ensure that the dot product w.x that is fed to tanh is neither too large or low.
- to restrict the overall quantity we can always have a control on x values ( normalization, also ensures that the data has zero mean and unit variance)
- Now all that xavier in. does is that it ensures the variance of net product =Var(netj) of w.x is 1 .
In the formula above , n is the number of incoming connections , the research paper in order to deal with outgoing connections also takes in account the number of outgoing connections and hence the final formula.
In He weight initialization, (used for ReLu activations) we multiply random initialization with the below term (size l-1 is same as number of incoming connections)
As we saw in Xavier initialization , He initialization is also of 2 types , namely he normal and he uniform
A few other weight initialization techniques like “Kaiming initialization” are also there ,although i have rarely seen questions on it .
More posts :
discover and learn