我有一组数据,我用kind =’cubic’进行插值.
我想找到这个三次插值函数的最大值.
目前我所做的只是找到插值数据数组中的最大值,但我想知道作为对象的插值函数是否可以区分以找到它的极值?
码:
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
x_axis = np.array([ 2.14414414,2.15270826,2.16127238,2.1698365,2.17840062,2.18696474,2.19552886,2.20409298,2.2126571,2.22122122])
y_axis = np.array([ 0.67958442,0.89628424,0.78904004,3.93404167,6.46422317,6.40459954,3.80216674,0.69641825,0.89675386,0.64274198])
f = interp1d(x_axis,y_axis,kind = 'cubic')
x_new = np.linspace(x_axis[0],x_axis[-1],100)
fig = plt.subplots()
plt.plot(x_new,f(x_new))
>使用4度样条进行插值,以便可以轻松找到其导数的根.
>使用三次样条(通常更可取),并为其派生的根编写自定义函数.
我在下面介绍两种解决方案
4度样条
使用InterpolatedUnivariateSpline.which有.derivative方法返回一个三次样条,可以应用.roots方法.
from scipy.interpolate import InterpolatedUnivariateSpline
f = InterpolatedUnivariateSpline(x_axis,k=4)
cr_pts = f.derivative().roots()
cr_pts = np.append(cr_pts,(x_axis[0],x_axis[-1])) # also check the endpoints of the interval
cr_vals = f(cr_pts)
min_index = np.argmin(cr_vals)
max_index = np.argmax(cr_vals)
print("Maximum value {} at {}\nMinimum value {} at {}".format(cr_vals[max_index],cr_pts[max_index],cr_vals[min_index],cr_pts[min_index]))
输出:
Maximum value 6.779687224066201 at 2.1824928509277037
Minimum value 0.34588448400295346 at 2.2075868177297036
立方样条
我们需要一个二次样条曲线根的自定义函数.这是(在下面解释).
def quadratic_spline_roots(spl):
roots = []
knots = spl.get_knots()
for a,b in zip(knots[:-1],knots[1:]):
u,v,w = spl(a),spl((a+b)/2),spl(b)
t = np.roots([u+w-2*v,w-u,2*v])
t = t[np.isreal(t) & (np.abs(t) <= 1)]
roots.extend(t*(b-a)/2 + (b+a)/2)
return np.array(roots)
现在完全如上所述,除了使用自定义解算器.
from scipy.interpolate import InterpolatedUnivariateSpline
f = InterpolatedUnivariateSpline(x_axis,k=4)
cr_pts = quadratic_spline_roots(f.derivative())
cr_pts = np.append(cr_pts,cr_pts[min_index]))
输出:
Maximum value 6.782781181150518 at 2.1824928579767167
Minimum value 0.45017143148176136 at 2.2070746522580795
第一种方法中输出的轻微差异不是错误; 4度样条和3度样条有点不同.
quadratic_spline_roots的说明
假设我们知道-1,1处的二次多项式的值是u,w.它在[-1,1]区间的根源是什么?通过一些代数,我们可以发现多项式是
((u+w-2*v) * x**2 + (w-u) * x + 2*v) / 2
现在可以使用二次公式,但最好使用np.roots,因为它还将处理前导系数为零的情况.然后将根过滤到-1到1之间的实数.最后,如果间隔是[a,b]而不是[-1,1],则进行线性变换.
额外:中间范围的三次样条宽度
假设我们想要找到样条曲线取值等于其最大值和最小值(即其中间值)的平均值的位置.那么我们肯定应该使用三次样条进行插值,因为现在需要根方法.人们不能只做(f – mid_range).roots(),因为在SciPy中不支持向样条添加常量.而是从y_axis – mid_range构建一个向下移动的样条曲线.
mid_range = (cr_vals[max_index] + cr_vals[min_index])/2
f_shifted = InterpolatedUnivariateSpline(x_axis,y_axis - mid_range,k=3)
roots = f_shifted.roots()
print("Mid-range attained from {} to {}".format(roots.min(),roots.max()))
Mid-range attained from 2.169076230034363 to 2.195974299834667