python – Tensorflow.非线性回归

前端之家收集整理的这篇文章主要介绍了python – Tensorflow.非线性回归前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

我有这些功能标签,它们不够线性,不能满足线性解决方案.我从sklearn训练了SVR(kernel =’rbf’)模型,但现在是时候用tensorflow来做了,很难说应该写什么来达到相同或更好的效果.

enter image description here

你看到那里的那条懒橙线吗?它并不能满足你的决心

代码本身:

import pandas as pd
import numpy as np
import tensorflow as tf
import tqdm
import matplotlib.pyplot as plt
from omnicomm_data.test_data import get_model,clean_df
import os
from sklearn import preprocessing

graph = tf.get_default_graph()

# tf variables
x_ = tf.placeholder(name="input",shape=[None,1],dtype=np.float32)
y_ = tf.placeholder(name="output",dtype=np.float32)
w = tf.Variable(tf.random_normal([]),name='weight')
b = tf.Variable(tf.random_normal([]),name='bias')
lin_model = tf.add(tf.multiply(x_,w),b)

#loss
loss = tf.reduce_mean(tf.pow(lin_model - y_,2),name='loss')
train_step = tf.train.GradientDescentOptimizer(0.000000025).minimize(loss)

#nonlinear part
nonlin_model = tf.tanh(tf.add(tf.multiply(x_,b))
nonlin_loss = tf.reduce_mean(tf.pow(nonlin_model - y_,name='cost')
train_step_nonlin = tf.train.GradientDescentOptimizer(0.000000025).minimize(nonlin_loss)       


# pandas data
df_train = pd.read_csv('me_rate.csv',header=None)

liters = df_train.iloc[:,0].values.reshape(-1,1)
parrots = df_train.iloc[:,1].values.reshape(-1,1)

#model for prediction
mms = preprocessing.MinMaxScaler()
rbf = get_model(path_to_model)


n_epochs = 200
train_errors = []
non_train_errors = []
test_errors = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in tqdm.tqdm(range(n_epochs)):

        _,train_err,summ = sess.run([train_step,loss,summaries],Feed_dict={x_: parrots,y_: liters})
        summary_writer.add_summary(summ,i)
        train_errors.append(train_err)

        _,non_train_err,= sess.run([train_step_nonlin,nonlin_loss],y_: liters})
        non_train_errors.append(non_train_err)


    plt.plot(list(range(n_epochs)),train_errors,label='train_lin')
    plt.plot(list(range(n_epochs)),non_train_errors,label='train_nonlin')
    plt.legend()
    print(train_errors[:10])
    print(non_train_errors[:10])
    plt.show()

    plt.scatter(parrots,liters,label='actual data')
    plt.plot(parrots,sess.run(lin_model,Feed_dict={x_: parrots}),label='linear (tf)')
    plt.plot(parrots,sess.run(nonlin_model,label='nonlinear (tf)')
    plt.plot(parrots,rbf.predict(mms.fit_transform(parrots)),label='rbf (sklearn)')
    plt.legend()
    plt.show()

如何激励橙色线?

之后的部分.

代码如下:

import pandas as pd
import numpy as np
import tensorflow as tf
import tqdm
import matplotlib.pyplot as plt
from omnicomm_data.test_data import get_model
import os
from sklearn import preprocessing

graph = tf.get_default_graph()

# tf variables
x_ = tf.placeholder(name="input",name='bias')

# nonlinear
nonlin_model = tf.add(tf.multiply(tf.tanh(x_),b)
nonlin_loss = tf.reduce_mean(tf.pow(nonlin_model - y_,name='cost')
train_step_nonlin = tf.train.GradientDescentOptimizer(0.01).minimize(nonlin_loss)


# pandas data
df_train = pd.read_csv('me_rate.csv',header=None)


liters = df_train.iloc[:,1)


#model for prediction
mms = preprocessing.MinMaxScaler()
rbf = get_model(path_to_model)


nz = preprocessing.MaxAbsScaler()  # normalization coz tanh
norm_parrots = nz.fit_transform(parrots)
print(norm_parrots)

n_epochs = 20000
train_errors = []
non_train_errors = []
test_errors = []
weights = []
biases = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in tqdm.tqdm(range(n_epochs)):

        _,weight,bias = sess.run([train_step_nonlin,nonlin_loss,w,b],Feed_dict={x_: norm_parrots,y_: liters})
        non_train_errors.append(non_train_err)
        weights.append(weight)
        biases.append(bias)


    plt.scatter(norm_parrots,label='actual data')

    plt.plot(norm_parrots,Feed_dict={x_: norm_parrots}),c='orange',label='nonlinear (tf)')
    plt.plot(norm_parrots,label='rbf (sklearn)')
    plt.legend()
    plt.show()

enter image description here


Asyoucan清楚地看到我们对橙色线有了一些改进(不如rbf好,但它只需要更多的工作).
最佳答案
您正在使用tf.tanh作为激活,这意味着您的输出限制在[-1,1]范围内.因此它永远不适合您的数据.

编辑:我删除了一个注意到已经修复的拼写错误的部分.

猜你在找的Python相关文章