"""
Name: func.py
Author: Yuxiang LI (li.yuxiang.nj@gmail.com)
Data: 25/03/2016

Description: GPU version of neural network (fast)
"""

from keras.models import model_from_json
from func import int_sqrt


def load_model(model, weights):
    """
    Load pre-trained Keras model

    :param model: json file for model description
    :param weights: h5 file for model weights
    :return: Keras model
    """
    mod = model_from_json(open(model).read())
    mod.load_weights(weights)
    return {'patch_in': int_sqrt(mod.get_config()['layers'][0]['input_shape'][0]),
            'patch_out': int_sqrt(mod.get_config()['layers'][-1]['output_dim']),
            'predict': mod.predict}



