# API
# tf.add_to_collection()
# tf.get_collection()
# Functions:
# to preserve intermediate results to check network problem, such as same loss.
import tensorflow as tf
a = tf.get_variable(shape=[2, 3, 3, 3], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1), name='a')
b = tf.get_variable(shape=[2, 2], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1), name='b')
tf.add_to_collection('network', a)
tf.add_to_collection('network', b)
def fc(input, category):
batch = input.shape.as_list()[0]
height = input.shape.as_list()[1]
width = input.shape.as_list()[2]
channel = input.shape.as_list()[3]
input_size = height * width * channel
output_size = category
w_init = tf.truncated_normal([input_size, output_size], stddev=0.1)
w = tf.Variable(w_init, name='w')
b = tf.Variable(tf.ones([output_size]))
batch = input.shape.as_list()[0]
input = tf.reshape(input, [batch, -1])
output = tf.matmul(input, w) + b
return output
with tf.Session() as se:
output = fc(a, 2)
se.run(tf.global_variables_initializer())
for i in tf.get_collection('network'):
print('name: {}, shape: {}'.format(i.name, i.shape))
print(se.run(i))
print('--------------')
print(se.run(output))