diff --git a/stroo.py b/stroo.py index 835bca0..8b66a99 100644 --- a/stroo.py +++ b/stroo.py @@ -69,23 +69,24 @@ def train_model(data,n=3): ngrams=multigramm(data,n=n) data=ngramtrafo(data,ngrams,n=n) + pm=0.0 + while pm**2<0.0001: + #tensorflow stuff (not at all optimised) + inp=keras.Input(data.shape[1:]) + q=inp + q=keras.layers.Dense(10,activation="relu",use_bias=False)(q) + q=keras.layers.Dense(4,activation="relu",use_bias=False)(q) + q=keras.layers.Dense(1,activation="relu",use_bias=False)(q) - #tensorflow stuff (not at all optimised) - inp=keras.Input(data.shape[1:]) - q=inp - q=keras.layers.Dense(10,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(4,activation="relu",use_bias=False)(q) - q=keras.layers.Dense(1,activation="relu",use_bias=False)(q) + model=keras.models.Model(inp,q) - model=keras.models.Model(inp,q) - - model.compile("adam","mse") - model.fit(data,np.ones(len(data),dtype="float"), + model.compile("adam","mse") + model.fit(data,np.ones(len(data),dtype="float"), batch_size=100, epochs=50, validation_split=0.1) - pm=np.mean(model.predict(data)) + pm=np.mean(model.predict(data)) return stroo(model,ngrams,n=n,m=pm)