When using torch.nn.Linear, suppose the input for the layer is X, where X is a matrix, and the weight matrix for this layer is
W, which is a matrix, then the output for the layer is
If X is a matrix with shape , all the dimensions except
will be treated as part of the batch dimension.