from __future__ import absolute_import from . import backend as K from .utils.generic_utils import get_from_module class Constraint(object): def __call__(self, p): return p def get_config(self): return {'name': self.__class__.__name__} class MaxNorm(Constraint): """MaxNorm weight constraint. Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. # Arguments m: the maximum norm for the incoming weights. axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Convolution2D` layer with `dim_ordering="tf"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. # References - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting Srivastava, Hinton, et al. 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) """ def __init__(self, m=2, axis=0): self.m = m self.axis = axis def __call__(self, p): norms = K.sqrt(K.sum(K.square(p), axis=self.axis, keepdims=True)) desired = K.clip(norms, 0, self.m) p *= (desired / (K.epsilon() + norms)) return p def get_config(self): return {'name': self.__class__.__name__, 'm': self.m, 'axis': self.axis} class NonNeg(Constraint): """Constrains the weights to be non-negative. """ def __call__(self, p): p *= K.cast(p >= 0., K.floatx()) return p class UnitNorm(Constraint): """Constrains the weights incident to each hidden unit to have unit norm. # Arguments axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Convolution2D` layer with `dim_ordering="tf"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. """ def __init__(self, axis=0): self.axis = axis def __call__(self, p): return p / (K.epsilon() + K.sqrt(K.sum(K.square(p), axis=self.axis, keepdims=True))) def get_config(self): return {'name': self.__class__.__name__, 'axis': self.axis} class MinMaxNorm(Constraint): """MinMaxNorm weight constraint. Constrains the weights incident to each hidden unit to have the norm between a lower bound and an upper bound. # Arguments low: the minimum norm for the incoming weights. high: the maximum norm for the incoming weights. rate: rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * norm.clip(low, high). Effectively, this means that rate=1.0 stands for strict enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step to slowly move towards a value inside the desired interval. axis: integer, axis along which to calculate weight norms. For instance, in a `Dense` layer the weight matrix has shape `(input_dim, output_dim)`, set `axis` to `0` to constrain each weight vector of length `(input_dim,)`. In a `Convolution2D` layer with `dim_ordering="tf"`, the weight tensor has shape `(rows, cols, input_depth, output_depth)`, set `axis` to `[0, 1, 2]` to constrain the weights of each filter tensor of size `(rows, cols, input_depth)`. """ def __init__(self, low=0.0, high=1.0, rate=1.0, axis=0): self.low = low self.high = high self.rate = rate self.axis = axis def __call__(self, p): norms = K.sqrt(K.sum(K.square(p), axis=self.axis, keepdims=True)) desired = self.rate * K.clip(norms, self.low, self.high) + (1 - self.rate) * norms p *= (desired / (K.epsilon() + norms)) return p def get_config(self): return {'name': self.__class__.__name__, 'low': self.low, 'high': self.high, 'rate': self.rate, 'axis': self.axis} # Aliases. maxnorm = MaxNorm nonneg = NonNeg unitnorm = UnitNorm def get(identifier, kwargs=None): return get_from_module(identifier, globals(), 'constraint', instantiate=True, kwargs=kwargs)