Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment.
The .setWeights() method is used to set the weights of the stated layer, from the given tensors.
Syntax:
setWeights(weights)
Parameters:
- weights: It is the stated list of input tensors. It is of type tf.Tensor[]. Where, the count of arrays as well as their shape should be equivalent to the count of the dimensions of the stated weights of the layer used. In other words, it must be equal to the result of the getWeights() method.
Return Value: It returns void.
Example 1:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Creating a model const model = tf.sequential(); // Adding a layer model.add(tf.layers.dense({units: 2, inputShape: [11]})); // Calling setWeights() method model.layers[0].setWeights([tf.truncatedNormal([11, 2]), tf.zeros([2])]); // Compiling the model model.compile({loss: 'categoricalCrossentropy' , optimizer: 'sgd' }); // Printing output using getWeights() method model.layers[0].getWeights()[0].print(); |
Output:
Tensor [[-0.5969906, -0.1883931], [0.8569255 , -0.49416 ], [0.1157023 , 0.1150239 ], [-0.4052143, 1.9936075 ], [0.3090054 , 0.7212474 ], [0.4626641 , -0.7287846], [0.4352857 , -0.5195332], [0.4626429 , 0.0216295 ], [-0.1110666, -0.5997615], [-0.5083916, -0.3582681], [-0.2847465, 1.184485 ]]
Here, truncatedNormal() method is used to create a tf.Tensor along with values that are sampled from a truncated normal distribution, zeros() method is used to create a tf.Tensor along with all the elements that are set to 0 and getWeights() method is used to print the weights that were set using setWeights() method.
Example 2:
Javascript
// Importing the tensorflow.js library import * as tf from "@tensorflow/tfjs" // Creating a model const model = tf.sequential(); // Adding layers model.add(tf.layers.dense({units: 1, inputShape: [5], batchSize: 1, dtype: 'int32' })); model.add(tf.layers.dense({units: 2, inputShape: [6], batchSize: 5})); model.add(tf.layers.dense({units: 3, inputShape: [7], batchSize: 8})); model.add(tf.layers.dense({units: 4, inputShape: [8], batchSize: 12})); // Calling setWeights() method model.layers[0].setWeights([tf.ones([5, 1]), tf.zeros([1])]); model.layers[1].setWeights([tf.ones([1, 2]), tf.zeros([2])]); // Printing output using getWeights() method model.layers[0].getWeights()[0].print(); model.layers[0].getWeights()[1].print(); model.layers[1].getWeights()[0].print(); model.layers[1].getWeights()[1].print(); |
Output:
Tensor [[1], [1], [1], [1], [1]] Tensor [0] Tensor [[1, 1],] Tensor [0, 0]
Here, ones() method is used to create a tf.Tensor along with all the elements that are set to 1.
Reference: https://js.tensorflow.org/api/latest/#tf.layers.Layer.setWeights