我在
python中有以下信息(数据帧)
product baskets scaling_factor 12345 475 95.5 12345 108 57.7 12345 2 1.4 12345 38 21.9 12345 320 88.8
我想运行以下非线性回归并估计参数.
a,b和c
我想要适合的等式:
scaling_factor = a - (b*np.exp(c*baskets))
在sas中我们通常运行以下模型:(使用高斯牛顿法)
proc nlin data=scaling_factors; parms a=100 b=100 c=-0.09; model scaling_factor = a - (b * (exp(c*baskets))); output out=scaling_equation_parms parms=a b c;
有没有类似的方法来估计Python中的参数使用非线性回归,我怎么能看到python中的情节.
解决方法
同意Chris Mueller,我也会使用scipy而不是
代码如下:
scipy.optimize.curve_fit
.
代码如下:
###the top two lines are required on my linux machine import matplotlib matplotlib.use('Qt4Agg') import matplotlib.pyplot as plt from matplotlib.pyplot import cm import numpy as np from scipy.optimize import curve_fit #we could import more,but this is what we need ###defining your fitfunction def func(x,a,b,c): return a - b* np.exp(c * x) ###OP's data baskets = np.array([475,108,2,38,320]) scaling_factor = np.array([95.5,57.7,1.4,21.9,88.8]) ###let us guess some start values initialGuess=[100,100,-.01] guessedFactors=[func(x,*initialGuess ) for x in baskets] ###making the actual fit popt,pcov = curve_fit(func,baskets,scaling_factor,initialGuess) #one may want to print popt print pcov ###preparing data for showing the fit basketCont=np.linspace(min(baskets),max(baskets),50) fittedData=[func(x,*popt) for x in basketCont] ###preparing the figure fig1 = plt.figure(1) ax=fig1.add_subplot(1,1,1) ###the three sets of data to plot ax.plot(baskets,linestyle='',marker='o',color='r',label="data") ax.plot(baskets,guessedFactors,marker='^',color='b',label="initial guess") ax.plot(basketCont,fittedData,linestyle='-',color='#900000',label="fit with ({0:0.2g},{1:0.2g},{2:0.2g})".format(*popt)) ###beautification ax.legend(loc=0,title="graphs",fontsize=12) ax.set_ylabel("factor") ax.set_xlabel("baskets") ax.grid() ax.set_title("$\mathrm{curve}_\mathrm{fit}$") ###putting the covariance matrix nicely tab= [['{:.2g}'.format(j) for j in i] for i in pcov] the_table = plt.table(cellText=tab,colWidths = [0.2]*3,loc='upper right',bBox=[0.483,0.35,0.5,0.25] ) plt.text(250,65,'covariance:',size=12) ###putting the plot plt.show() ###done