tf.get_variable() 函數的使用

tf.get_variable(name, shape, initializer): name 就是變數的名稱,shape 是變數的維度,initializer是變數初始化的方式,初始化的方式有以下幾種:

tf.constant_initializer:常量初始化函數

tf.random_normal_initializer:正態分佈

tf.truncated_normal_initializer:截取的正態分佈

tf.random_uniform_initializer:均勻分佈

tf.zeros_initializer:全部是0

tf.ones_initializer:全是1

tf.uniform_unit_scaling_initializer:滿足均勻分佈,但不影響輸出數量級的隨機值

import  tensorflow as tf;    
import  numpy as np;    
import  matplotlib.pyplot as plt;    

a1 = tf.get_variable(name= 'a1' , shape=[ 2 , 3 ], initializer=tf.random_normal_initializer(mean= 0 , stddev= 1 ))  
a2 = tf.get_variable(name= 'a2' , shape=[ 1 ], initializer=tf.constant_initializer( 1 ))  
a3 = tf.get_variable(name= 'a3' , shape=[ 2 , 3 ], initializer=tf.ones_initializer())  

with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())  
    print  sess.run(a1)  
    print  sess.run(a2)  
    print  sess.run(a3)

輸出:

[[ 0.42299312 -0.25459203 -0.88605702]
[ 0.22410156 1.34326422 -0.39722782]]
[ 1.]
[[ 1. 1. 1.]
[ 1. 1. 1.]]

注意:不同的變量之間不能有相同的名字。


tf.variable 與 tf.get_variable 的差別

Variable

tensorflow 中有兩種關於變數的操作,tf.Variable() 與 tf.get_variable()

tf.Variable() 與 tf.get_variable()

tf.Variable()

tf.Variable(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None)

tf.get_variable()

tf.get_variable(name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None, partitioner=None, validate_shape=True, custom_getter=None)

使用 tf.Variable() 時,如果命名(name)衝突,系統會自己處理;但若使用 tf.get_variable(),系統會發生 error。

以下是 tf.Variable() 例子

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#輸出
#w_1:0
#w_1_1:0

以下是 tf.get_variable() 例子

import tensorflow as tf

w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#錯誤訊息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?

get_variable() 與 Variable() 的實質區別

import tensorflow as tf

with tf.variable_scope("scope1"):
    w1 = tf.get_variable("w1", shape=[])
    w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
    w1_p = tf.get_variable("w1", shape=[])
    w2_p = tf.Variable(1.0, name="w2")

print(w1 is w1_p, w2 is w2_p)
#輸出
#True  False

由於 tf.Variable() 每次都在創建新對象,所以 reuse=True 和它並沒有什麼關係。對於 get_variable(),來說,如果已經創建的變量對象,就把那個對象返回,如果沒有創建變量對象的話,就創建一個新的。


Reference

[0] https://blog.csdn.net/UESTC_C2_403/article/details/72327321

[1] https://blog.csdn.net/u012436149/article/details/53696970

results matching ""

    No results matching ""