创新互联www.cdcxhl.cn八线动态BGP香港云服务器提供商,新人活动买多久送多久,划算不套路!
本篇文章为大家展示了tensorflow实现读取网络weight和bias的方法,代码简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。
(1) 获取参数的变量名。可以使用一下函数获取变量名:
def vars_generate1(self,scope_name_var): return [var for var in tf.global_variables() if scope_name_var in var.name ]
输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。
(2) 利用session读取变量的值:
def get_weight(self): full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ##一定要先初始化变量 print(sess.run(full_connect_variable[0]))