全局种子
tf.random.set_seed(116)
针对程序重新运行或者tf.function(类似于re-run of a program),保证随机操作顺序相同
例如1
重新运行程序
tf.random.set_seed(1234)
print(tf.random.uniform([1])) # generates 'A1'
print(tf.random.uniform([1])) # generates 'A2'
(now close the program and run it again)
tf.random.set_seed(1234)
print(tf.random.uniform([1])) # generates 'A1'
print(tf.random.uniform([1])) # generates 'A2'
例如2
定义函数tf.function
tf.random.set_seed(1234)
@tf.function
def f():
a = tf.random.uniform([1])
b = tf.random.uniform([1])
return a, b
@tf.function
def g():
a = tf.random.uniform([1])
b = tf.random.uniform([1])
return a, b
print(f()) # prints '(A1, A2)'
print(g()) # prints '(A1, A2)'
output
(<tf.Tensor: id=20, shape=(1,), dtype=float32, numpy=array([0.96046877], dtype=float32)>, <tf.Tensor: id=21, shape=(1,), dtype=float32, numpy=array([0.85591054], dtype=float32)>)
(<tf.Tensor: id=41, shape=(1,), dtype=float32, numpy=array([0.96046877], dtype=float32)>, <tf.Tensor: id=42, shape=(1,), dtype=float32, numpy=array([0.85591054], dtype=float32)>)
操作种子
tf.random.truncated_normal([4,3], stddev=0.1, seed=1)
例如1
内部计数器,每次执行时会增加,产生不同的结果
print(tf.random.uniform([1], seed=1)) # generates 'A1'
print(tf.random.uniform([1], seed=1)) # generates 'A2'
(now close the program and run it again)
print(tf.random.uniform([1], seed=1)) # generates 'A1'
print(tf.random.uniform([1], seed=1)) # generates 'A2'
例如2
多个相同操作种子包含在tf.funtion中,因操作时间不长,共享相同的计数器
@tf.function
def foo():
a = tf.random.uniform([1], seed=1)
b = tf.random.uniform([1], seed=1)
return a, b
print(foo()) # prints '(A1, A1)'
print(foo()) # prints '(A2, A2)'
output
(<tf.Tensor: id=20, shape=(1,), dtype=float32, numpy=array([0.2390374], dtype=float32)>, <tf.Tensor: id=21, shape=(1,), dtype=float32, numpy=array([0.2390374], dtype=float32)>)
(<tf.Tensor: id=22, shape=(1,), dtype=float32, numpy=array([0.22267115], dtype=float32)>, <tf.Tensor: id=23, shape=(1,), dtype=float32, numpy=array([0.22267115], dtype=float32)>)
@tf.function
def bar():
a = tf.random.uniform([1])#不设置操作种子
b = tf.random.uniform([1])
return a, b
print(bar()) # prints '(A1, A2)'
print(bar()) # prints '(A3, A4)'
全局种子+操作种子
全局种子会重置计数器tf.random.set_seed()
tf.random.set_seed(1234)
print(tf.random.uniform([1], seed=1)) # generates 'A1'
print(tf.random.uniform([1], seed=1)) # generates 'A2'
tf.random.set_seed(1234)
print(tf.random.uniform([1], seed=1)) # generates 'A1'
print(tf.random.uniform([1], seed=1)) # generates 'A2'
相当于关闭了程序re-run
附注
以下三种随机操作顺序不同:
1全局+操作
tf.random.set_seed(1234)
print(tf.random.uniform([1], seed=1))
output
tf.Tensor([0.1689806], shape=(1,), dtype=float32)
2全局
tf.random.set_seed(1234)
print(tf.random.uniform([1))
output
tf.Tensor([0.5380393], shape=(1,), dtype=float32)
3操作
print(tf.random.uniform([1], seed=1))
output
tf.Tensor([0.2390374], shape=(1,), dtype=float32)