Tensorflow如何加载变量名和变量值
这篇文章主要为大家展示了“Tensorflow如何加载变量名和变量值”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Tensorflow如何加载变量名和变量值”这篇文章吧。
则从这些checkpoint文件中加载变量名和变量值代码如下:
model_dir='./ckpt-182802'importtensorflowastffromtensorflow.pythonimportpywrap_tensorflowreader=pywrap_tensorflow.NewCheckpointReader(model_dir)var_to_shape_map=reader.get_variable_to_shape_map()forkeyinvar_to_shape_map:print("tensor_name:",key)print(reader.get_tensor(key))#RemovethisisyouwanttoprintonlyvariablenamesMnist
下面将给出一个基于卷积神经网络的手写数字识别样例:
#-*-coding:utf-8-*-importtensorflowastffromtensorflow.examples.tutorials.mnistimportinput_datafromtensorflow.python.frameworkimportgraph_utillog_dir='./tensorboard'mnist=input_data.read_data_sets(train_dir="./mnist_data",one_hot=True)iftf.gfile.Exists(log_dir):tf.gfile.DeleteRecursively(log_dir)tf.gfile.MakeDirs(log_dir)#定义输入数据mnist图片大小28*28*1=784,None表示batch_sizex=tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")#定义标签数据,mnist共10类y_=tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")#将数据调整为二维数据,w*H*c--->28*28*1,-1表示N张image=tf.reshape(x,shape=[-1,28,28,1])#第一层,卷积核={5*5*1*32},池化核={2*2*1,1*2*2*1}w1=tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))b1=tf.Variable(initial_value=tf.zeros(shape=[32]))conv1=tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")relu1=tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")pool1=tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")#shape={None,14,14,32}#第二层,卷积核={5*5*32*64},池化核={2*2*1,1*2*2*1}w2=tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))b2=tf.Variable(initial_value=tf.zeros(shape=[64]))conv2=tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")relu2=tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")pool2=tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")#shape={None,7,7,64}#FC1w3=tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))b3=tf.Variable(initial_value=tf.zeros(shape=[1024]))#关键,进行reshapeinput3=tf.reshape(pool2,shape=[-1,7*7*64],name="input3")fc1=tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")#shape={None,1024}#FC2w4=tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))b4=tf.Variable(initial_value=tf.zeros(shape=[10]))fc2=tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4,name="logit")#shape={None,10}#定义交叉熵损失#使用softmax将NN计算输出值表示为概率y=tf.nn.softmax(fc2,name="out")#定义交叉熵损失函数cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)loss=tf.reduce_mean(cross_entropy)tf.summary.scalar('Cross_Entropy',loss)#定义solvertrain=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)forvarintf.trainable_variables():printvar#train=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)#定义正确值,判断二者下标index是否相等correct_predict=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#定义如何计算准确率accuracy=tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")tf.summary.scalar('Training_ACC',accuracy)#定义初始化opmerged=tf.summary.merge_all()init=tf.global_variables_initializer()saver=tf.train.Saver()#训练NNwithtf.Session()assession:session.run(fetches=init)writer=tf.summary.FileWriter(log_dir,session.graph)#定义记录日志的位置foriinrange(0,500):xs,ys=mnist.train.next_batch(100)session.run(fetches=train,feed_dict={x:xs,y_:ys})ifi%10==0:train_accuracy,summary=session.run(fetches=[accuracy,merged],feed_dict={x:xs,y_:ys})writer.add_summary(summary,i)print(i,"accuracy=",train_accuracy)'''#训练完成后,将网络中的权值转化为常量,形成常量graph,注意:需要x与labelconstant_graph=graph_util.convert_variables_to_constants(sess=session,input_graph_def=session.graph_def,output_node_names=['out','y_','input'])#将带权值的graph序列化,写成pb文件存储起来withtf.gfile.FastGFile("lenet.pb",mode='wb')asf:f.write(constant_graph.SerializeToString())'''saver.save(session,'./ckpt')
补充:查看tensorflow产生的checkpoint文件内容的方法
tensorflow在保存权重模型时多使用tf.train.Saver().save 函数进行权重保存,保存的ckpt文件无法直接打开,但tensorflow提供了相关函数 tf.train.NewCheckpointReader 可以对ckpt文件进行权重查看。
importosfromtensorflow.pythonimportpywrap_tensorflowcheckpoint_path=os.path.join('modelckpt',"fc_nn_model")#Readdatafromcheckpointfilereader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)var_to_shape_map=reader.get_variable_to_shape_map()#Printtensornameandvaluesforkeyinvar_to_shape_map:print("tensor_name:",key)print(reader.get_tensor(key))
其中‘modelckpt'是存放.ckpt文件的文件夹,"fc_nn_model"是文件名。
var_to_shape_map是一个字典,其中的键值是变量名,对应的值是该变量的形状,如
{‘LSTM_input/bias_LSTM/Adam_1':[128]}
想要查看某变量值时,需要调用get_tensor函数,即输入以下代码:
reader.get_tensor('LSTM_input/bias_LSTM/Adam_1')
以上是“Tensorflow如何加载变量名和变量值”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注亿速云行业资讯频道!
声明:本站所有文章资源内容,如无特殊说明或标注,均为采集网络资源。如若本站内容侵犯了原著者的合法权益,可联系本站删除。