引用的包以及作用--内部使用

environment.yml中列出了以下环境

# run: conda env create --file environment.yml
name: attention_tsp
channels:   // 指的是conda的channel
- pytorch   
dependencies:
- python>=3.8 // 最好使用3.8 或 3.9  太高版本会引起依赖冲突  例如 pytorch tensorflow不支持3.9以后的
- anaconda  // 类似Maven 因为py各种包的兼容性太差 canda可以同时构建多套环境 类似docker
- tqdm  //进度条显示工具  https://github.com/tqdm/tqdm    https://blog.csdn.net/langb2014/article/d etails/54798823
- pytorch  //做机器学习的核心框架  官网:https://pytorch.org/docs/stable/optim.html 中文版(不全):https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-optim/#torchoptim
//pytorch 如果安装不成功需要看一下cuda版本 当然显卡不行最好上cpu  我1070Ti的显卡cuda是12.5 需要自己去源码打包 官网只有10.1 windows的
- torchvision  //包含了目前流行的数据集,模型结构和常用的图片转换工具。
- cuda91 //显卡核心 cpu的用不着
- pip //装包工具 刚需 可以最新版
- pip:
  - tensorboard_logger  //数据可视化框架 tenserflow的轻量级版本,直接使用conda估计装不上,需要在项目路径pip install一下
1720356080835.png

这个是Jupyter 一个note工具 可以在文档里跑代码那种感觉 同时支持markdown

在run里可以看到大体执行逻辑
其中涉及到的专业名词在此做简单介绍和资料补充

  # Set the device
    opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")

这个地方是加载cpu 还是gpu驱动的

 # Initialize model
    model_class = {
        'attention': AttentionModel,
        'pointer': PointerNetwork
   }.get(opts.model, None)  
# 这里是做了一个字典 去加载不同的model,这个model都是继承了torch.nn.Model  这个get方法就是字典取值
# AttentionModel   [注意力机制](https://easyai.tech/ai-definition/attention/)
# PointerNetwork  [指针网络](https://blog.csdn.net/qq_38556984/article/details/107574587)
    assert model_class is not None, "Unknown model: {}".format(model_class)
    model = model_class(
        opts.embedding_dim,
        opts.hidden_dim,
        problem,
        n_encode_layers=opts.n_encode_layers,
        mask_inner=True,
        mask_logits=True,
        normalization=opts.normalization,
        tanh_clipping=opts.tanh_clipping,
        checkpoint_encoder=opts.checkpoint_encoder,
        shrink_size=opts.shrink_size
    ).to(opts.device)

这一步主要是初始化模型

# Initialize optimizer  
# 优化器  https://www.cnblogs.com/guoyaohua/p/8542554.html
# 本文使用的是Adam优化器
# lr-> learning rate 学习率  https://blog.csdn.net/u012526436/article/details/90486021
# epoch 代 
    optimizer = optim.Adam(
        [{'params': model.parameters(), 'lr': opts.lr_model}]
        + (
            [{'params': baseline.get_learnable_parameters(), 'lr': opts.lr_critic}]
            if len(baseline.get_learnable_parameters()) > 0
            else []
        )
    )
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。