前言
首先简要回顾一下SVM算法:如下图所示,寻找一个超平面划分数据,使得两类数据到超平面的距离均大余等于1/||w||。
其数学描述为:
正则化系数C
实际使用时,并不能保证所有数据被完美划分,例如在-例中混杂了一个+例,标准的SVM就无法求解,如下图所示。
另外,即使数据可以完美划分开,还需要考虑过拟合的问题,即如下图所示,如果完全拟合训练数据,非要将所有正例和反例分开,则分界线为橙色,当需要识别新样本时(紫色x),明显它更符合负例,但是会被划分为正例,即发生过拟合了。
有关机器学习中过拟合及解决方法的问题可以参考:机器学习中的偏差和方差
当ξ=0时,限制条件与以前相同,即划分正确;当0<ξ<1时,样本距离超平面距离小于1/||w||,但此时会在代价函数中有惩罚;当ξ>1时,代表样本越过超平面,划分错误,惩罚会更大。此时模型(参数w)由边界上的样本点及ξ!=0的样本点共同决定。
本实例绘制了在不同正则化系数C下,SVM选择最佳分界线。随机生成160个数据,80个正例,80个反例。选择两个C,为10和0.05,分别代表非正则化/正则化。实际上标准SVM对应C=无穷。划分结果如下,画圈的代表支持向量:
从划分结果可以看出,当C=0.05时,此时允许训练误差更大,支持向量更多,同时正负例的间隔也更大。
import numpy as np import matplotlib.pyplot as plt from sklearn import svm # we create 160 separable points np.random.seed(0) X = np.r_[np.random.randn(80,2) - [2,2],np.random.randn(80,2) + [2,2]] Y = [0] * 80 + [1] * 80 # figure number fignum = 1 # fit the model for name,penalty in (('unreg',10),('reg',0.05)): clf = svm.SVC(kernel='linear',C=penalty) #线性核函数 clf.fit(X,Y) # get the separating hyperplane w = clf.coef_[0] a = -w[0] / w[1] xx = np.linspace(-5,5) yy = a * xx - (clf.intercept_[0]) / w[1] #超平面 # plot the parallels to the separating hyperplane that pass through the # support vectors margin = 1 / np.sqrt(np.sum(clf.coef_ ** 2)) yy_down = yy + a * margin #下平面 yy_up = yy - a * margin #上平面 # plot the line,the points,and the nearest vectors to the plane plt.figure(fignum,figsize=(4,3)) plt.clf() plt.plot(xx,yy,'k-') #实线 plt.plot(xx,yy_down,'k--') #虚线 plt.plot(xx,yy_up,'k--') plt.scatter(clf.support_vectors_[:,0],clf.support_vectors_[:,1],s=80,facecolors='none',zorder=10) #标注支持向量 plt.scatter(X[:,X[:,c=Y,zorder=10,cmap=plt.cm.Paired) #绘点 plt.axis('tight') x_min = -4.8 x_max = 4.2 y_min = -6 y_max = 6 XX,YY = np.mgrid[x_min:x_max:200j,y_min:y_max:200j] Z = clf.predict(np.c_[XX.ravel(),YY.ravel()]) # Put the result into a color plot Z = Z.reshape(XX.shape) plt.figure(fignum,3)) plt.pcolormesh(XX,YY,Z,cmap=plt.cm.Paired) plt.xlim(x_min,x_max) plt.ylim(y_min,y_max) title = 'C = ' + str(penalty) plt.title(title) plt.xticks(()) plt.yticks(()) fignum = fignum + 1 plt.show()