environment.yml中列出了以下环境
# run: conda env create --file environment.yml
name: attention_tsp
channels: // 指的是conda的channel
- pytorch
dependencies:
- python>=3.8 // 最好使用3.8 或 3.9 太高版本会引起依赖冲突 例如 pytorch tensorflow不支持3.9以后的
- anaconda // 类似Maven 因为py各种包的兼容性太差 canda可以同时构建多套环境 类似docker
- tqdm //进度条显示工具 https://github.com/tqdm/tqdm https://blog.csdn.net/langb2014/article/d etails/54798823
- pytorch //做机器学习的核心框架 官网:https://pytorch.org/docs/stable/optim.html 中文版(不全):https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-optim/#torchoptim
//pytorch 如果安装不成功需要看一下cuda版本 当然显卡不行最好上cpu 我1070Ti的显卡cuda是12.5 需要自己去源码打包 官网只有10.1 windows的
- torchvision //包含了目前流行的数据集,模型结构和常用的图片转换工具。
- cuda91 //显卡核心 cpu的用不着
- pip //装包工具 刚需 可以最新版
- pip:
- tensorboard_logger //数据可视化框架 tenserflow的轻量级版本,直接使用conda估计装不上,需要在项目路径pip install一下
这个是Jupyter 一个note工具 可以在文档里跑代码那种感觉 同时支持markdown
在run里可以看到大体执行逻辑
其中涉及到的专业名词在此做简单介绍和资料补充
# Set the device
opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")
这个地方是加载cpu 还是gpu驱动的
# Initialize model
model_class = {
'attention': AttentionModel,
'pointer': PointerNetwork
}.get(opts.model, None)
# 这里是做了一个字典 去加载不同的model,这个model都是继承了torch.nn.Model 这个get方法就是字典取值
# AttentionModel [注意力机制](https://easyai.tech/ai-definition/attention/)
# PointerNetwork [指针网络](https://blog.csdn.net/qq_38556984/article/details/107574587)
assert model_class is not None, "Unknown model: {}".format(model_class)
model = model_class(
opts.embedding_dim,
opts.hidden_dim,
problem,
n_encode_layers=opts.n_encode_layers,
mask_inner=True,
mask_logits=True,
normalization=opts.normalization,
tanh_clipping=opts.tanh_clipping,
checkpoint_encoder=opts.checkpoint_encoder,
shrink_size=opts.shrink_size
).to(opts.device)
这一步主要是初始化模型
# Initialize optimizer
# 优化器 https://www.cnblogs.com/guoyaohua/p/8542554.html
# 本文使用的是Adam优化器
# lr-> learning rate 学习率 https://blog.csdn.net/u012526436/article/details/90486021
# epoch 代
optimizer = optim.Adam(
[{'params': model.parameters(), 'lr': opts.lr_model}]
+ (
[{'params': baseline.get_learnable_parameters(), 'lr': opts.lr_critic}]
if len(baseline.get_learnable_parameters()) > 0
else []
)
)