When using torch.nn.Linear
, suppose the input for the layer is X
, where X
is a matrix, and the weight matrix for this layer is W
, which is a matrix, then the output for the layer is
If X
is a matrix with shape , all the dimensions except will be treated as part of the batch dimension.