classImportGraph(): """ Importing and running isolated TF graph """ def__init__(self, loc): # Create local graph and use it in the session self.graph = tf.Graph() config = tf.ConfigProto(log_device_placement=False) config.gpu_options.allow_growth = True self.sess = tf.Session(graph=self.graph, config=config) with self.graph.as_default(): # Import saved model from location 'loc' into local graph saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True) saver.restore(self.sess, loc) # There are TWO options how to get activation operation: # FROM SAVED COLLECTION: self.logits = self.graph.get_operation_by_name('proj/Reshape_1').outputs[0] # self.activation = tf.get_collection('activation')[0] # BY NAME: # self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]
defbuild_graph(self): with self.graph.as_default(): ... def__del__(self): # explicitly collect resources by closing and deleting session and graph self.sess.close() del self.sess del self.graph del self.param
# train models and return the test accuracy deftrain_test(self,train_data,train_label,test_data,test_label): ...