Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

SAGNIKMJR/MetaQNN_ImageClassification_PyTorch

Open more actions menu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MetaQNN_ImageClassification_PyTorch

Implementation of MetaQNN (https://arxiv.org/abs/1611.02167, https://github.com/bowenbaker/metaqnn.git) with Additions and Modifications in PyTorch for Image Classification.

Basic Search Space Specs:

i) Minimum no. of Conv./Wrn layers
ii) Maximum no. of Conv./Wrn layers
iii) Maximum 1 FC layer (classifier not counted)

Additions/Modifications:

i) Optional Greedy version of Q-learning update rule added for shorter search schedules

def __update_q_value_sequence(self, states, termination_reward):
    self.__update_q_value(states[-2], states[-1], termination_reward)
    for i in reversed(range(len(states) - 2)):
        
        # NOTE: q-learning update (set proper q-learning rate in cmdparser.py)
        self.__update_q_value(states[i], states[i+1], 0)

        # NOTE: modified update for shorter search schedules (doesn't use q-learning rate in computation)
        # self.__update_q_value(states[i], states[i+1], termination_reward)

def __update_q_value(self, start_state, to_state, reward):
    if start_state.as_tuple() not in self.qstore.q:
        self.enum.enumerate_state(start_state, self.qstore.q)
    if to_state.as_tuple() not in self.qstore.q:
        self.enum.enumerate_state(to_state, self.qstore.q)

    actions = self.qstore.q[start_state.as_tuple()]['actions']
    values = self.qstore.q[start_state.as_tuple()]['utilities']

    max_over_next_states = max(self.qstore.q[to_state.as_tuple()]['utilities']) if to_state.terminate != 1 else 0
    action_between_states = self.enum.transition_to_action(start_state, to_state).as_tuple()

    # NOTE: q-learning update (set proper q-learning rate in cmdparser.py)
    values[actions.index(action_between_states)] = values[actions.index(action_between_states)] + \
                                                   self.args.q_learning_rate * \
                                                   (reward + self.args.q_discount_factor *
                                                    max_over_next_states -
                                                    values[actions.index(action_between_states)])

    # NOTE: modified update for shorter search schedules (doesn't use q-learning rate in computation)
    # values[actions.index(action_between_states)] = values[actions.index(action_between_states)] + \
    #                                                (max(reward, values[actions.index(action_between_states)]) -
    #                                                 values[actions.index(action_between_states)])

    self.qstore.q[start_state.as_tuple()] = {'actions': actions, 'utilities': values}

ii) Skip connections with WideResNet blocks, minimum and maximum conv layer limit and made some other search space changes for better performace
iii) Continuing from the previous Q-learning iteration if code crashes while running
iv) Running over single or multiple GPUs
iv) Automatic calculation of available GPU space and skipping of architecture if it doesn't fit

NOTE:

code for MNIST, CIFAR10 and CIFAR100; for other datasets dataloader has to be added to lib/Datasets/datasets.py

Intalling Code Dependencies -

pip install -r requirements.txt

Running Search:

Use python 2.7 and torch 0.4.0
Look at lib/cmdparser.py for the available command line options or just run

$ python main.py --help

Finally, run main.py

Running script to plot rolling mean from replay dictionary

Look at plot/plotReplayDictRollingMean.py for the available command line options or just run

$ python plot/plotReplayDictRollingMean.py --help

Finally, run plot/plotReplayDictRollingMean.py with the necessary command line arguments

About

Implementation of MetaQNN (https://arxiv.org/abs/1611.02167, https://github.com/bowenbaker/metaqnn.git) with Additions and Modifications in PyTorch for Image Classification

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

Morty Proxy This is a proxified and sanitized view of the page, visit original site.