Mxnet-Model.fit()源码

 def fit(self, X, y=None, eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None,
            work_load_list=None, monitor=None, eval_end_callback=LogValidationMetricsCallback(),
            eval_batch_end_callback=None):

训练集 Training data.——X: 格式:DataIter, or numpy.ndarray/NDArray
If X is a DataIter, the name or (if name not available) the position of its outputs should match the corresponding variable names defined in the symbolic graph.

训练集标签 Training set label.——y : 格式:numpy.ndarray/NDArray, optional
If Xis numpy.ndarray or NDArray, y is required to be set.
While y can be 1D or 2D (with 2nd dimension as 1), its first dimension must be the same as X,
i.e. the number of data points and labels should be equal.

验证集——eval_data : 格式:DataIter or numpy.ndarray/list/NDArray pair
If eval_data is numpy.ndarray/list/NDArray pair, it should be (valid_data, valid_label).

验证标准The evaluation metric. ——eval_metric 格式 : metric.EvalMetric or str or callable
This could be the name of evaluation metric or a custom evaluation function that returns statistics based on a minibatch.

回调函数(epoch结束时执行)——** epoch_end_callback** 格式: callable(epoch, symbol, arg_params, aux_states)
可以用来每个epoch保存一下模型(checkpoint)

回调函数(batch结束时执行)——batch_end_callback 格式: callable(epoch)
A callback that is invoked at end of each batch for purposes of printing.

将参数存储到哪
kvstore: KVStore or str, optional
The KVStore or a string kvstore type: 'local', 'dist_sync', 'dist_async'
In default uses 'local', often no need to change for single machiine.
KVStore behavior
- 'local', multi-devices on a single machine, will automatically choose best type.
- 'dist_sync', multiple machines communicating via BSP.
- 'dist_async', multiple machines with asynchronous communication.

是否打印日志
logger : logging logger, optional
When not specified, default logger will be used.

    work_load_list : float or int, optional
        The list of work load for different devices,
        in the same order as `ctx`.
    Note

    """
        data = self._init_iter(X, y, is_train=True)
        eval_data = self._init_eval_iter(eval_data)
        if self.sym_gen:
            self.symbol = self.sym_gen(data.default_bucket_key) # pylint: disable=no-member
            self._check_arguments()
        self.kwargs["sym"] = self.symbol
        arg_names, param_names, aux_names = \
                self._init_params(data.provide_data+data.provide_label)

        # setup metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)
        # create kvstore
        (kvstore, update_on_kvstore) = _create_kvstore(
            kvstore, len(self.ctx), self.arg_params)
        param_idx2name = {}
        if update_on_kvstore:
            param_idx2name.update(enumerate(param_names))
        else:
            for i, n in enumerate(param_names):
                for k in range(len(self.ctx)):
                    param_idx2name[i*len(self.ctx)+k] = n
        self.kwargs["param_idx2name"] = param_idx2name

        # init optmizer
        if isinstance(self.optimizer, str):
            batch_size = data.batch_size
            if kvstore and 'dist' in kvstore.type and '_async' not in kvstore.type:
                batch_size *= kvstore.num_workers
            optimizer = opt.create(self.optimizer,
                                  rescale_grad=(1.0/batch_size),
                                  **(self.kwargs))
        elif isinstance(self.optimizer, opt.Optimizer):
            if not optimizer.idx2name:
                optimizer.idx2name = param_idx2name.copy()
            optimizer = self.optimizer

        # do training
        _train_multi_device(self.symbol, self.ctx, arg_names, param_names, aux_names,
                            self.arg_params, self.aux_params,
                            begin_epoch=self.begin_epoch, end_epoch=self.num_epoch,
                            epoch_size=self.epoch_size,
                            optimizer=optimizer,
                            train_data=data, eval_data=eval_data,
                            eval_metric=eval_metric,
                            epoch_end_callback=epoch_end_callback,
                            batch_end_callback=batch_end_callback,
                            kvstore=kvstore, update_on_kvstore=update_on_kvstore,
                            logger=logger, work_load_list=work_load_list, monitor=monitor,
                            eval_end_callback=eval_end_callback,
                            eval_batch_end_callback=eval_batch_end_callback,
                            sym_gen=self.sym_gen)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 知识点 抽象类 abstract 所有的类都是用来描绘对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象...
    一方通行不会慌阅读 235评论 0 0
  • 2009年,第一次见到你,班长,那会是多讨厌你啊,欺负别人,大声说话,那么傲,没有一点点我喜欢的人的样子,是什么时...
    不如初阅读 315评论 0 1
  • 12/9武丽娟感恩日志 感恩引领右脑的老师给孩子用心辅导,我在旁边观看也有很大感受,课后的沟通也是受益匪浅!孩子的...
    花布鱼阅读 249评论 0 0
  • Of open minds as open as a trap
    吴叔八道阅读 171评论 0 0