45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
|
import numpy as np
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tensorflow import keras
|
||
|
from tensorflow.keras import backend as K
|
||
|
|
||
|
|
||
|
from disttf import GaussianRealisation,BiasedRealisation,BoxRealisation
|
||
|
from disttf import MixtureLayer, MotioLayer, ScaleLayer, SplitLayer
|
||
|
from disttf import NonLinearityLayer, SeperateLayer, RecombineLayer
|
||
|
|
||
|
|
||
|
def gen_model(inputs, splits=10, realisation="gauss", mixture=0, nonlin=False):
|
||
|
|
||
|
i=keras.layers.Input(shape=(inputs,))
|
||
|
q=i
|
||
|
v=tf.constant(1.0,dtype=tf.float32)
|
||
|
|
||
|
if splits>1:
|
||
|
q,v=SplitLayer(splits)([q,v])
|
||
|
q,v=MotioLayer()([q,v])
|
||
|
q,v=ScaleLayer()([q,v])
|
||
|
for j in range(mixture):
|
||
|
q=MixtureLayer()(q)
|
||
|
if nonlin:
|
||
|
q=NonLinearityLayer()(q)
|
||
|
if realisation=="gauss":
|
||
|
q=GaussianRealisation()([q,v])
|
||
|
elif realisation=="biased":
|
||
|
q=BiasedRealisation()([q,v])
|
||
|
elif realisation=="box":
|
||
|
q=BoxRealisation()([q,v])
|
||
|
else:
|
||
|
raise ValueError("Unknown realisation type: "+realisation)
|
||
|
|
||
|
model=keras.Model(inputs=i,outputs=q)
|
||
|
|
||
|
loss=K.log(K.abs(q)+1e-6)
|
||
|
loss=-K.mean(loss)
|
||
|
model.add_loss(loss)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|