Alink漫谈(二十) :卡方检验源码解析
0x00 摘要
Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析 Alink 中 卡方检验 的实现。
因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。
0x01 背景概念
问题:在南方的小明要去北京读书,小明的父母都担心起来,北方的生活能习惯吗?是不是只能吃面?他们的脑子里面提出了很多很多的假设,是不是都要验证一下呢。
1.1 假设检验
统计假设是指我们对总体的猜测或判断,比如中国广为流行的地域划分,北方人是不是都吃面食?江浙是不是都喜欢吃甜食?江南的妹子是不是脾气都很好?北方的男生是不是都大条?
为了证明,或者说绝对确定某个假设是正确或错误的,我们需要绝对知识,也就是说我们需要检查所有的样本,比如问问所有的北方人是不是都吃面食,或者说问问江浙一带的女生是不是都不吵架?
但是我们无法把所有样本都统计一遍。
假设检验就是为了解决这个问题而诞生的,它使用随机的样本来判断假设是否有理。
这来源于一个最基本的思路,那就是统计是从有限样本推断总体。既然我们不可能知道所有样本,换句话说我们不可能知道总体到底是什么统计性质,那么我们只能拿有限的样本做文章。
假设检验依据的原理是小概率事件原理。小概率事件是一个事件的发生概率,由于概率小,它在一次试验中是几乎不可能发生的。例如,拿我们生活的经验来说,只买一次彩票就中大奖的几率是很小的,所以还没看到过偶尔买彩票就中大奖的报道。
1.2 H0和H1是什么?
小明父母的种种假设到底对不对呢?对是一种结果,不对是另一种结果。在假设检验中,我们也有两部分,一部分叫做H0,一部分是H1。
H是英文单词hypothesis的第一个字母,H0代表了原假设(null hypothesis),H1是备择假设(alternative hypothesis)。换句话说,我们想在两个假设中选一个,尽管这个“选”不是通常意义下的“二选一”。
- H0 原假设又叫零假设,一般来说,我们把认为想收集证据反对的假设称为0假设,比如太阳绕着地球转,鸟是不会飞的。小明这里的 H0 就是 “北方不吃米饭”,是小明父母不希望,想拒绝的。
- H1又叫备选假设,一般来说,我们都希望备选假设,也就是H1为真,比如小明的父母希望北方也是吃饭的。
H0和H1并不是不能互换,但是选取的H0原则是H0必须是一个可以被拒绝的假设。对于一个假设检验问题,如果H1是一个不能被拒绝的假设,那么H0和H1不能互换。
什么样的H1是不能被拒绝的假设?比如下面这个问题:总体是一个班级的所有同学一次的考试成绩,原假设H0是全班同学的平均成绩为80分,H1是全班同学的平均成绩不是80分。在这个框架中,H0是一个可以被拒绝的假设,H1是一个不能被拒绝的假设。
为什么这么说?如果H0选做全班同学的平均成绩不是80分,那么哪怕你知道全班几乎所有人的成绩,不妨认为平均是80分,只要有一个人的成绩你不知道,那么知道其他人的成绩对于你拒绝H0没有任何帮助。只有在这个人恰好80分的情况下,原假设是不成立的,这个人是81或者79,原假设都成立。从另一个角度,也可以认为当H0是一个不能被拒绝的假设的时候,H0太过宽泛,和H1区分不了。
一般来说,假设检验的结果有下面两个:
- H0错,H1对(用专业的话来说,就是接受H1,拒绝H0,因为有足够多的样本支持H1,比如说小明爸妈问了5,6个到北方读书的人,都说食堂面和饭都有。他们大可以说,北方也是吃饭的)。
- H1错(用专业的话说,不拒绝H0,因为证据不够。小明爸妈问了5个到北方读书的人,1个说食堂面和饭都有,那食堂到底有没有呢?心里可是直打鼓了。)
有一点要注意,无法拒绝H0并不是说H0为真,而是说我们的证据不足,无法证明H1。我们经常在法庭上听到,证据不足,无罪释放,但是,这个人到底有没有罪,还是要打个问号的。
1.3 P值 (P-value)
P值,也就是常见到的 P-value。P 值是一种概率,指的是在 H0 假设为真的前提下,样本结果出现的概率。即 p值是在原假设成立的基础上计算的。
如果 P-value 很小,则说明在原假设为真的前提下,样本结果出现的概率很小,甚至很极端,这就反过来说明了原假设很大概率是错误的。
另外一个角度想,p值也是弃真错误的概率。也就是这个假设成立的情况下,出现这个糟糕结果的概率,也当然就是如果此时把h0拒绝,出现错误的概率。只要p值足够小,我们就认为此时拒绝h0,出错的概率很小,那就干脆把h0拒绝好了。
通常,会设置一个显著性水平(significance level) alpha 与 P-value 进行比较,如果 P-value < alpha ,则说明在显著性水平 alpha 下拒绝原假设,alpha 通常情况下设置为0.05。
假如我们比较某地区男、女性的饮食口味是否存在差异,则 H0 是 "男女的饮食口味相同,不存在差异"。
最后得出 P=0.283 > 0.05,在α=0.05水平上不拒绝零假设,即不能认为该地区男女的饮食口味不同。
1.4 交叉表
在统计学中,交叉表是矩阵格式的一种表格,显示变量的(多变量)频率分布。
“交叉表”对象是一个网格,用来根据指定的条件返回值。数据显示在压缩行和列中。这种格式易于比较数据并辨别其趋势。它由三个元素组成:行 / 列 / 摘要字段。
让我们举例说明:
马军头领 | 步兵头领 | |
---|---|---|
二龙山 | 1 | 6 |
少华山 | 3 | 0 |
- “交叉表”中的行沿水平方向延伸(从一侧到另一侧)。在上面的示例中,”二龙山” 是一行。
- “交叉表”中的列沿垂直方向延伸(上下)。在上面的示例中,“马军头领” 是一列。
- 汇总字段位于行和列的交叉处。每个交叉处的值代表对既满足行条件又满足列条件的记录的汇总(求和、计数等)。在上面的示例中,“二龙山”和“马军头领”交叉处的值是1,这是在二龙山上马军头领的数目。
上文中交叉表是按两个变量交叉分类的,该列联表称为两维列联表,若按3个变量交叉分类,所得的列联表称为3维列联表,依次类推。3维及以上的列联表通常称为“多维列联表”或“高维列联表”,而一维列联表就是频数分布表。
1.5 卡方
交叉分类所得的表格称为“列联表”,统计推断(检验)则要使用列联表分析的方法------卡方检验。
卡方检验,主要用于检验统计样本的实际观测值与理论推断值之间的偏离程度,或者是检验一批数据是否与某种理论分布相符合。
Alink 文档中给出的是:卡方独立性检验是检验两个因素(各有两项或以上的分类)之间是否相互影响的问题,其零假设是两因素之间相互独立。
1.5.1 公式
卡方值是卡方检验时用到的检验统计量,卡方值越大,说明观测值与理论值之间的偏离就越大;反之,二者偏差越小。实际应用时,可以根据卡方值计算 P-value,从而选择拒绝或者接受原假设。
公式如下:
1.5.2 基本思想
卡方检验最基本的思想就是通过观察实际值与理论值的偏差来确定理论的正确与否。
具体做的时候常常先假设两个变量确实是独立的(行话就叫做“原假设”),然后观察实际值(也可以叫做观察值)与理论值(这个理论值是指“如果两者确实独立”的情况下应该有的值)的偏差程度。
- 如果偏差足够小,我们就认为误差是很自然的样本误差,是测量手段不够精确导致或者偶然发生的,两者确确实实是独立的,此时就接受原假设。
- 如果偏差大到一定程度,使得这样的误差不太可能是偶然产生或者测量不精确所致,我们就认为两者实际上是相关的,即否定原假设,而接受备择假设。
1.5.3 实现过程
卡方分析的方法:
- 假设两个变量是相互独立,互不关联的。这在统计上称为原假设;
- 对于调查中得到的两个变量的数据,用一个表格的形式来表示它们的分布(频数和百分数),这里的频数叫观测频数,这种表格叫列联表;
- 如果原假设成立,在这个前提下,可以计算出上面列联表中每个格子里的频数应该是多少,这叫期望频数;
- 比较观测频数与期望频数的差,如果两者的差越大,表明实际情况与原假设相去甚远;差越小,表明实际情况与原假设越相近。这种差值用一个卡方统计量来表示;
- 对卡方值进行检验,如果卡方检验的结果不显著,则不能拒绝原假设,即两变量是相互独立、互不关联的,如果卡方检验的结果显著,则拒绝原假设,即两变量间存在某种关联,至于是如何关联的,这要看列联表中数据的分布形态。
具体实现过程:
- 按照假设检验的步骤,首先我们需要确定原假设 H0(null hypothesis):原假设是变量独立的,实际观测频率和理论频率一致。
- 其次我们根据实际观测的联连表,去求理论的联连表;卡方统计值:X2,记为Statistic;
- 然后选取适合的置信度(一般为95%)同自由度一起确定临界值Critical Value,比较卡方统计值和临界值大小:
- If Statistic >= Critical Value: 认为变量对结果有影响,则拒绝原假设,变量不独立
- If Statistic < Critical Value: 认为变量对结果没有影响,接受原假设,变量独立
1.6 自由度
自由度:取值不受限制的变量的个数。
如何理解这句简单的话呢?给定一组数据,我们来计算不同的统计量,看看自由度的变化。这些数据分别为 1 2 4 6 8. 5个数。
先来求平均值,这几个数据都可以任意变化成其它数据,而我们仍然可以对它们求平均值,它们的平均值也跟着变化。这时自由度为5,也就是说有几个数据自由度就是几。
卡方检验的自由度:
1)如果是独立性检验,那么自由度就等于(a-1)*(b-1),a b表示这两个检验条件的对应的分类数。
2)适合性检验,类别数减去1。此处相当于约束条件只有一个。
卡方检验只有在用笔算查表时使用自由度,软件计算不用担心这个问题,但是最好明白自由度代表着总的变量数目减去约束条件的数目。
0x02 示例代码
本文示例代码如下,这里需要注意的是:
- "col1","col2"是所选择的列;
- "col4"是Label;
@H_103_301@public class ChiSquareTestBatchOpExample { public static void main(String[] args) throws Exception { Row[] testArray = new Row[]{ Row.of("a",1.1,1.2,1),Row.of("b",0.9,1.0,-2),Row.of("c",-0.01,100),Row.of("d",100.9,0.1,-99),Row.of("a",0.2,0.3,-99) }; String[] colNames = new String[]{"col1","col2","col3","col4"}; MemSourceBatchOp source = new MemSourceBatchOp(Arrays.asList(testArray),colNames); ChiSquareTestBatchOp test = new ChiSquareTestBatchOp() .setSelectedCols("col1","col2") .setLabelCol("col4"); test.linkFrom(source).print(); } }
输出如下:
@H_103_301@col|chisquare_test ---|-------------- col1|{"comment":"chi-square test","df":9.0,"p":0.004301310843500827,"value":24.0} col2|{"comment":"chi-square test","value":24.0}
转换为图表更好理解:
col | chisquare_test |
---|---|
col1 | {"comment":"chi-square test","value":24.0} |
col2 | {"comment":"chi-square test","value":24.0} |
df是自由度,p就是p-value, value就是我们前面说的卡方值,即
0x03 总体逻辑
训练总体逻辑如下:
- 使用 flatMap 做 flatting data to triple。遍历输入Row,然后把Row给flat了,得到三元组<idx in row,value in row,y-label>。比如 对应输入 Row.of("b",-2),则row = {Row@9419} "b,9,-2",因为col1,col2是特征,col4是 label,则发送两个三元组是 <0,b,-2>,<1,-2>;
- 使用 toTable 把前面处理的dataSet再进行转换,生成一张表 data。{"col","feature","label"} 就对应着我们之前的三元组;
- 对 data 进行 计算交叉表 和 卡方校验;
- groupBy("col,feature,label") 进行分类排序;
- select("col,label,count(1) as count2")) 得出 feature 的个数作为count2;
- groupBy("col").reduceGroup 再根据col排序,归并;
- 得到 <feature,y-label> : "count of feature" 这个map;
- Crosstab.convert(map) 利用map来做交叉表;
- map(new ChiSquareTestFromCrossTable()) 利用交叉表来构建卡方检验;
- test(crossTabWithId) 这里进行计算,其中会调用 org.apache.commons.math3.distribution.GammaDistribution.cumulativeProbability 进行Gamma计算;
0x04 训练
还是老套路,直奔ChiSquareTestBatchOp的linkFrom函数。
代码是缩减版,但原本就非常简单,获取“选择的列”和“Y列”,然后用输入数据进行训练检验。
深入看下去却很有难度。
@H_103_301@public ChiSquareTestBatchOp linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); String[] selectedColNames = getSelectedCols(); String labelColName = getLabelCol(); this.setOutputTable(ChiSquareTestUtil.buildResult( ChiSquareTestUtil.test(in,selectedColNames,labelColName),getMLEnvironmentId())); return this; }
最后会辗转进入到 ChiSquareTest.test,这里才是真章。
@H_103_301@public static DataSet<Row> test(BatchOperator in,String[] selectedColNames,String labelColName) { in = in.select(ArrayUtils.add(selectedColNames,labelColName)); return ChiSquareTest.test(in.getDataSet(),in.getMLEnvironmentId()); }
4.1 ChiSquareTest
- 输入:in 的最后一列是label,其余列是所选择的特征列;
- 输出:有三列,分别是 1th is colId,2th is pValue,3th is chi-square value;
这里的总体逻辑是:
- 使用 flatMap 做 flatting data to triple。遍历输入Row,然后把Row给flat了,得到三元组<idx in row,y-label>;
- 使用 toTable 把前面处理的dataSet再进行转换,生成一张表 data。{"col","label"} 就对应着我们之前的三元组;
- 对 data 进行 计算交叉表 和 卡方校验;
具体代码如下:
@H_103_301@protected static DataSet<Row> test(DataSet<Row> in,Long sessionId) { //flatting data to triple. //这里就是遍历输入Row,然后把Row给flat了,得到三元组<idx in row,y-label> //比如 对应输入 Row.of("b",-2> DataSet<Row> dataSet = in .flatMap(new FlatMapFunction<Row,Row>() { @Override public void flatMap(Row row,Collector<Row> result) { int n = row.getArity() - 1; String nStr = String.valueOf(row.getField(n)); for (int i = 0; i < n; i++) { Row out = new Row(3); out.setField(0,i); out.setField(1,String.valueOf(row.getField(i))); out.setField(2,nStr); result.collect(out); } } }); // 把前面处理的dataSet再进行转换,生成一张表。{"col","label"} 就对应着我们之前的三元组 Table data = DataSetConversionUtil.toTable( sessionId,dataSet,new String[]{"col","label"},new TypeInformation[]{Types.INT,Types.STRING,Types.STRING}); // 对 data 进行 计算交叉表 和 卡方校验 //calculate cross table and chiSquare test. return DataSetConversionUtil.fromTable(sessionId,data .groupBy("col,label") //分类排序 .select("col,count(1) as count2")) // 为了得出 feature 的个数作为count2 .groupBy("col").reduceGroup( // 再根据col排序 new GroupReduceFunction<Row,Tuple2<Integer,Crosstab>>() { @Override public void reduce(Iterable<Row> iterable,Collector<Tuple2<Integer,Crosstab>> collector) { Map<Tuple2<String,String>,Long> map = new HashMap<>(); int colIdx = -1; for (Row row : iterable) { // 假如有如下,row = {Row@9684} "0,a,1,2",他对应了两个 Row.of("a",就是 <col,count(1)>,就是 <'a'是第0列,'a',对应 y-label是 1, 'a' 有两个> map.put(Tuple2.of(row.getField(1).toString(),row.getField(2).toString()),(long) row.getField(3)); colIdx = (Integer) row.getField(0); } // 得到 <feature,y-label> : "count of feature" 这个map map = {HashMap@9676} size = 4 {Tuple2@9688} "(a,1)" -> {Long@9689} 2 {Tuple2@9690} "(b,-2)" -> {Long@9689} 2 {Tuple2@9691} "(d,-99)" -> {Long@9689} 2 {Tuple2@9692} "(c,100)" -> {Long@9689} 2 // 利用map来做交叉表 collector.collect(new Tuple2<>(colIdx,Crosstab.convert(map))); } }) .map(new ChiSquareTestFromCrossTable()); // 构建卡方检验 }
4.2 Crosstab
上面代码中,使用 collector.collect(new Tuple2<>(colIdx,Crosstab.convert(map)));
来构建交叉表。
Crosstab 就是 Cross Tabulations reflects the relationship between two variables。即以map key为横轴,纵轴,value作为数值,就是feature和label之间的交叉。
@H_103_301@public static Crosstab convert(Map<Tuple2<String,Long> maps) { Crosstab crosstab = new Crosstab(); //get row tags and col tags Set<Tuple2<String,String>> sets = maps.keySet(); Set<String> rowTags = new HashSet<>(); // 拿到行,列 Set<String> colTags = new HashSet<>(); for (Tuple2<String,String> tuple2 : sets) { rowTags.add(tuple2.f0); colTags.add(tuple2.f1); } crosstab.rowTags = new ArrayList<>(rowTags); crosstab.colTags = new ArrayList<>(colTags); int rowLen = crosstab.rowTags.size(); int colLen = crosstab.colTags.size(); //compute value crosstab.data = new long[rowLen][colLen]; for (Map.Entry<Tuple2<String,Long> entry : maps.entrySet()) { int rowIdx = crosstab.rowTags.indexOf(entry.getKey().f0); int colIdx = crosstab.colTags.indexOf(entry.getKey().f1); crosstab.data[rowIdx][colIdx] = entry.getValue(); } return crosstab; }
这里输入输出如下
@H_103_301@// 输入如下 maps = {HashMap@9676} size = 4 {Tuple2@9688} "(a,100)" -> {Long@9689} 2 // 交叉表如下 crosstab = {Crosstab@9703} colTags = {ArrayList@9720} size = 4 0 = "1" 1 = "100" 2 = "-2" 3 = "-99" rowTags = {ArrayList@9721} size = 4 0 = "a" 1 = "b" 2 = "c" 3 = "d" data = {long[4][]@9713} 0 = {long[4]@9722} 0 = 2 1 = 0 2 = 0 3 = 0 1 = {long[4]@9723} 0 = 0 1 = 0 2 = 2 3 = 0 2 = {long[4]@9724} 0 = 0 1 = 2 2 = 0 3 = 0 3 = {long[4]@9725} 0 = 0 1 = 0 2 = 0 3 = 2
构造出来交叉表如下:
1 | 100 | -2 | -99 | |
---|---|---|---|---|
a | 2 | |||
b | 2 | |||
c | 2 | |||
d | 2 |
4.3 构建卡方检验
4.1中,有 .map(new ChiSquareTestFromCrossTable());
,这里就是根据collector.collect(new Tuple2<>(colIdx,Crosstab.convert(map)));
交叉表构建卡方检验。
@H_103_301@/** * calculate chi-square test value from cross table. */ public static class ChiSquareTestFromCrossTable implements MapFunction<Tuple2<Integer,Crosstab>,Row> { @Override public Row map(Tuple2<Integer,Crosstab> crossTabWithId) throws Exception { Tuple4 tuple4 = test(crossTabWithId); // f0 is id of cross table,f1 is pValue,f2 is chi-square Value,f3 is df Row row = new Row(4); row.setField(0,tuple4.f0); row.setField(1,tuple4.f1); row.setField(2,tuple4.f2); row.setField(3,tuple4.f3); return row; } }
test(crossTabWithId)是关键点,其中 distribution.cumulativeProbability 最后调用到 org.apache.commons.math3.distribution.GammaDistribution.cumulativeProbability。
这里能够看到
- df 的定义就是 (double)(rowLen - 1) * (colLen - 1),即(行 - 1)*(列 - 1)。
- 卡方值就是严格按照定义来构建的。
- p-value是 调用到 org.apache.commons.math3.distribution.GammaDistribution.cumulativeProbability。
@H_103_301@/** * @param crossTabWithId: f0 is id,f1 is cross table * @return tuple4: f0 is id which is id of cross table,f3 is df */ protected static Tuple4<Integer,Double,Double> test(Tuple2<Integer,Crosstab> crossTabWithId) { int colIdx = crossTabWithId.f0; Crosstab crosstab = crossTabWithId.f1; int rowLen = crosstab.rowTags.size(); int colLen = crosstab.colTags.size(); //compute row sum and col sum 计算出列的数值和,行的数值和 double[] rowSum = crosstab.rowSum(); double[] colSum = crosstab.colSum(); double n = crosstab.sum(); //compute statistic value 计算统计值 double chiSq = 0; for (int i = 0; i < rowLen; i++) { for (int j = 0; j < colLen; j++) { double nij = rowSum[i] * colSum[j] / n; double temp = crosstab.data[i][j] - nij; chiSq += temp * temp / nij; // 就是按照定义来构建卡方值 } } //set result double p; if (rowLen <= 1 || colLen <= 1) { p = 1; } else { ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,(rowLen - 1) * (colLen - 1)); p = 1.0 - distribution.cumulativeProbability(Math.abs(chiSq)); } // return tuple4: f0 is id which is id of cross table,f3 is df return Tuple4.of(colIdx,p,chiSq,(double)(rowLen - 1) * (colLen - 1)); } // runtime是 tuple4 = {Tuple4@9842} "(0,0.004301310843500827,24.0,9.0)" f0 = {Integer@9843} 0 f1 = {Double@9844} 0.004301310843500827 f2 = {Double@9847} 24.0 f3 = {Double@9848} 9.0
0xFF 参考
卡方检验(Chi_square_test): 原理及python实现
Spark MLlib基本算法【相关性分析、卡方检验、总结器】