ENAS的代码详解——cifar10\controller.py

定义控制器的类成员变量,并进行控制器和采样器的初始化。

class ConvController(Controller):
  def __init__(self,
               num_branches=6,
               num_layers=4,
               num_blocks_per_branch=8,
               lstm_size=32,
               lstm_num_layers=2,
               lstm_keep_prob=1.0,
               tanh_constant=None,
               temperature=None,
               lr_init=1e-3,
               lr_dec_start=0,
               lr_dec_every=100,
               lr_dec_rate=0.9,
               l2_reg=0,
               clip_mode=None,
               grad_bound=None,
               use_critic=False,
               bl_dec=0.999,
               optim_algo="adam",
               sync_replicas=False,
               num_aggregate=None,
               num_replicas=None,
               name="controller"):

    print "-" * 80
    print "Building ConvController"

    self.num_branches = num_branches
    self.num_layers = num_layers
    self.num_blocks_per_branch = num_blocks_per_branch
    self.lstm_size = lstm_size
    self.lstm_num_layers = lstm_num_layers 
    self.lstm_keep_prob = lstm_keep_prob
    self.tanh_constant = tanh_constant
    self.temperature = temperature
    self.lr_init = lr_init
    self.lr_dec_start = lr_dec_start
    self.lr_dec_every = lr_dec_every
    self.lr_dec_rate = lr_dec_rate
    self.l2_reg = l2_reg
    self.clip_mode = clip_mode
    self.grad_bound = grad_bound
    self.use_critic = use_critic
    self.bl_dec = bl_dec
    self.optim_algo = optim_algo
    self.sync_replicas = sync_replicas
    self.num_aggregate = num_aggregate
    self.num_replicas = num_replicas
    self.name = name

    self._create_params()
    self._build_sampler()

新建控制器的变量。

  def _create_params(self):
    with tf.variable_scope(self.name):
      with tf.variable_scope("lstm"):
        self.w_lstm = []
        for layer_id in xrange(self.lstm_num_layers):
          with tf.variable_scope("layer_{}".format(layer_id)):
            w = tf.get_variable("w", [2 * self.lstm_size, 4 * self.lstm_size])
            self.w_lstm.append(w)

      self.num_configs = (2 ** self.num_blocks_per_branch) - 1
      with tf.variable_scope("embedding"):
        self.g_emb = tf.get_variable("g_emb", [1, self.lstm_size])
        self.w_emb = tf.get_variable("w", [self.num_blocks_per_branch,
                                           self.lstm_size])

      with tf.variable_scope("softmax"):
        self.w_soft = tf.get_variable("w", [self.lstm_size,
                                            self.num_blocks_per_branch])
      with tf.variable_scope("critic"):
        self.w_critic = tf.get_variable("w", [self.lstm_size, 1])

这一段实在不知道细节上是什么意思……大概就是初始化采样器。

  def _build_sampler(self):
    """Build the sampler ops and the log_prob ops."""

    arc_seq = []
    sample_log_probs = []
    all_h = []

    # sampler ops
    inputs = self.g_emb
    prev_c = [tf.zeros([1, self.lstm_size], dtype=tf.float32)
              for _ in xrange(self.lstm_num_layers)]
    prev_h = [tf.zeros([1, self.lstm_size], dtype=tf.float32)
              for _ in xrange(self.lstm_num_layers)]
    for layer_id in xrange(self.num_layers):
      for branch_id in xrange(self.num_branches):
        next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
        all_h.append(tf.stop_gradient(next_h[-1]))

        logits = tf.matmul(next_h[-1], self.w_soft)
        if self.temperature is not None:
          logits /= self.temperature
        if self.tanh_constant is not None:
          logits = self.tanh_constant * tf.tanh(logits)

        config_id = tf.multinomial(logits, 1)
        config_id = tf.to_int32(config_id)
        config_id = tf.reshape(config_id, [1])
        arc_seq.append(config_id)
        log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=config_id)
        sample_log_probs.append(log_prob)

        inputs = tf.nn.embedding_lookup(self.w_emb, config_id)
    arc_seq = tf.concat(arc_seq, axis=0)
    self.sample_arc = arc_seq

    self.sample_log_probs = tf.concat(sample_log_probs, axis=0)
    self.ppl = tf.exp(tf.reduce_sum(self.sample_log_probs) /
                      tf.to_float(self.num_layers * self.num_branches))
    self.all_h = all_h

建立训练器,基于Actor-Critic框架,具体可以百度。


  def build_trainer(self, child_model):
    # actor
    child_model.build_valid_rl()
    self.valid_acc = (tf.to_float(child_model.valid_shuffle_acc) /
                      tf.to_float(child_model.batch_size))
    self.reward = self.valid_acc

    if self.use_critic:
      # critic
      all_h = tf.concat(self.all_h, axis=0)
      value_function = tf.matmul(all_h, self.w_critic)
      advantage = value_function - self.reward
      critic_loss = tf.reduce_sum(advantage ** 2)
      self.baseline = tf.reduce_mean(value_function)
      self.loss = -tf.reduce_mean(self.sample_log_probs * advantage)

      critic_train_step = tf.Variable(
          0, dtype=tf.int32, trainable=False, name="critic_train_step")
      critic_train_op, _, _, _ = get_train_ops(
        critic_loss,
        [self.w_critic],
        critic_train_step,
        clip_mode=None,
        lr_init=1e-3,
        lr_dec_start=0,
        lr_dec_every=int(1e9),
        optim_algo="adam",
        sync_replicas=False)
    else:
      # or baseline
      self.sample_log_probs = tf.reduce_sum(self.sample_log_probs)
      self.baseline = tf.Variable(0.0, dtype=tf.float32, trainable=False)
      baseline_update = tf.assign_sub(
        self.baseline, (1 - self.bl_dec) * (self.baseline - self.reward))
      with tf.control_dependencies([baseline_update]):
        self.reward = tf.identity(self.reward)
      self.loss = self.sample_log_probs * (self.reward - self.baseline)

    self.train_step = tf.Variable(
        0, dtype=tf.int32, trainable=False, name="train_step")
    tf_variables = [var for var in tf.trainable_variables()
                    if var.name.startswith(self.name)
                      and "w_critic" not in var.name]
    print "-" * 80
    for var in tf_variables:
      print var
    self.train_op, self.lr, self.grad_norm, self.optimizer = get_train_ops(
      self.loss,
      tf_variables,
      self.train_step,
      clip_mode=self.clip_mode,
      grad_bound=self.grad_bound,
      l2_reg=self.l2_reg,
      lr_init=self.lr_init,
      lr_dec_start=self.lr_dec_start,
      lr_dec_every=self.lr_dec_every,
      lr_dec_rate=self.lr_dec_rate,
      optim_algo=self.optim_algo,
      sync_replicas=self.sync_replicas,
      num_aggregate=self.num_aggregate,
      num_replicas=self.num_replicas)

    if self.use_critic:
      self.train_op = tf.group(self.train_op, critic_train_op)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,294评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,493评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,790评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,595评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,718评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,906评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,053评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,797评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,250评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,570评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,711评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,388评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,018评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,796评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,023评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,461评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,595评论 2 350

推荐阅读更多精彩内容

  • Actor的创建&引用&声明周期 1.创建actor 定义一个Actor类 要定义自己的Actor类,需要继承Ac...
    这个该叫什么呢阅读 1,002评论 0 1
  • Actor系统的实体 在Actor系统中,actor之间具有树形的监管结构,并且actor可以跨多个网络节点进行透...
    JasonDing阅读 3,334评论 2 6
  • 目录 [TOC] 引言 量化交易是指以先进的数学模型替代人为的主观判断,利用计算机技术从庞大的历史数据中海选能带来...
    雷达熊阅读 969评论 0 2
  • 婆娑世界基本都是烟雨的画廊 一笔朦胧之美,一笔勾魂摄魄的感觉 谁能看得透彻,画得清晰流畅 自然的线条已经替我们勾勒...
    子墨的简书阅读 541评论 2 3
  • 曾经以为 我是一支红烛 点燃自己 照亮你的征程 后来发现 我不过是 温暖你书桌一角的台灯 而你需要的是 全部的阳光...
    火流苏_阅读 163评论 2 3