pytorch系列|3步入门篇

1. 定义网络

继承nn.Module类,实现init和forward方法.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  super(Net,self).__init__()
  self.fc1 = nn.Linear(240,120) #输入长度240,输出长度120
  self.fc2 = nn.Linear(120,60)
  self.fc3 = nn.Linear(60,10)

def forward(self,x):
  ##定义网络中的层该怎么连接
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x

net = Net()

As a rule of thumb, you can put inside the forward method all the layers that do not have any weights to be updated. On the other hand, you should put all the layers that have weights to be updated inside the __init__.

2. 加载数据

pytorch使用了DatasetDataLoader来接入数据.我们可自定义Dataset来接入自己的数据集.

import torch
import pandas as pd
from torch.utils.data import Dataset,DataLoader

class ExampleDataset(Dataset):
  def __init__(self,csv_file):
    self.data_frame = pd.read_csv(csv_file)
 def __len__(self):
  return len(self.data_frame)
def __getitem__(self,idx):
  return self.data_frame[idx]

example_dataset = ExampleDataset('my_datasets.csv')

example_data_loader = DataLoader(example_dataset,batch_size = 10,shuffle=True,num_workers=2) # num_workers: used to load the data in parallel


## Loop over data
## enumerate() return index and value.
for batch_index, batch in enumerate(example_data_loader):
  print(batch_index,batch)

3. 训练网络
import torch.optim as optim
import torch.nn as nn

## instantiate network
net = Net()
## optimizer
optimizer = optim.SGD(net.parameters(),lr=1e-3)

## define loss function
criterion = nn.MSELoss() ## nn.CrossEntropyLoss()

for epoch in range(10): ## epoches = 10
  for i,batch in enumerate(example_data_loader):
    # get the inputs
    data,targets = batch

    # zero the gradient buffes
    optimizer.zero_grad()
    ## passes the data through the network
    output = net.forward(data)
    ## calculates the loss
    loss = criterion(output,target)
    ## propagates the loss back.
    loss.backward()
    ## update all weights of the network
    optimizer.step()


更多学习资料
参考
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容