Tensorflow Project Template
A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.
So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)
Table Of Contents
- In a Nutshell
- In Details
- Project architecture
- Folder structure
- Main Components
- Models
- Trainer
- Data Loader
- Logger
- Configuration
- Main
- Future Work
- Contributing
- Acknowledgments
In a Nutshell
In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:
- In models folder create a class named VGG that inherit the "base_model" class
def __init__(self, config):
super(VGGModel, self).__init__(config)
#call the build_model and init_saver functions.
self.build_model()
self.init_saver()
- Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
# here you build the tensorflow graph of any model you want and also define the loss.
pass
def init_saver(self):
# here you initalize the tensorflow saver that will be used in saving the checkpoints.
self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
- In trainers folder create a VGG trainer that inherit from "base_train" class
def __init__(self, sess, model, data, config, logger):
super(VGGTrainer, self).__init__(sess, model, data, config, logger)
- Override these two functions "train_step", "train_epoch" where you write the logic of the training process
"""
implement the logic of epoch:
-loop on the number of iterations in the config and call the train step
-add any summaries you want using the summary
"""
pass
def train_step(self):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
pass
- In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
# create instance of the model you want
model = VGGModel(config)
# create your data generator
data = DataGenerator(config)
# create tensorboard logger
logger = Logger(sess, config)
- Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
# here you train your model
trainer.train()
You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.
In Details
Project architecture
Folder structure
+-- base
| +-- base_model.py - this file contains the abstract class of the model.
| +-- base_train.py - this file contains the abstract class of the trainer.
|
|
+-- model - this folder contains any model of your project.
| +-- example_model.py
|
|
+-- trainer - this folder contains trainers of your project.
| +-- example_trainer.py
|
+-- mains - here's the main(s) of your project (you may need more than one main).
| +-- example_main.py - here's an example of main that is responsible for the whole pipeline.
|
+-- data _loader
| +-- data_generator.py - here's the data_generator that is responsible for all data handling.
|
+-- utils
+-- logger.py
+-- any_other_utils_you_need