Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

前端之家收集整理的这篇文章主要介绍了Alink漫谈(十九) :源码解析 之 分位点离散化Quantile前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

Alink漫谈(十九) :源码解析 之 分位点离散化Quantile

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 Quantile 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

本文缘由是因为想分析GBDT,发现GBDT涉及到Quantile的使用,所以只能先分析Quantile 。

0x01 背景概念

1.1 离散化

离散化:就是把无限空间中有限的个体映射到有限的空间中(分箱处理)。数据离散化操作大多是针对连续数据进行的,处理之后的数据值域分布将从连续属性变为离散属性

离散化方式会影响后续数据建模和应用效果

  • 使用决策树往往倾向于少量的离散化区间,过多的离散化将使得规则过多受到碎片区间的影响。
  • 关联规则需要对所有特征一起离散化,关联规则关注的是所有特征的关联关系,如果对每个列单独离散化将失去整体规则性。

连续数据的离散化结果可以分为两类:

  • 一类是将连续数据划分为特定区间的集合,例如{(0,10],(10,20],(20,50],(50,100]};
  • 一类是将连续数据划分为特定类,例如类1、类2、类3;

1.2 分位数

分位数(Quantile),亦称分位点,是指将一个随机变量的概率分布范围分为几个等份的数值点,常用的有中位数(即二分位数)、四分位数、百分位数等。

假如有1000个数字(正数),这些数字的5%,30%,50%,70%,99%分位数分别是 [3.0,5.0,6.0,9.0,12.0],这表明

  • 有5%的数字分布在0-3.0之间
  • 有25%的数字分布在3.0-5.0之间
  • 有20%的数字分布在5.0-6.0之间
  • 有20%的数字分布在6.0-9.0之间
  • 有29%的数字分布在9.0-12.0之间
  • 有1%的数字大于12.0

这就是分位数的统计学理解。

因此求解某一组数字中某个数的分位数,只需要将该组数字进行排序,然后再统计小于等于该数的个数,除以总的数字个数即可。

确定p分位数位置的两种方法

  • position = (n+1)p
  • position = 1 + (n-1)p

1.3 四分位数

这里我们用四分位数做进一步说明。

四分位数 概念:把给定的乱序数值由小到大排列并分成四等份,处于三个分割点位置的数值就是四分位数。

第1四分位数 (Q1),又称“较小四分位数”,等于该样本中所有数值由小到大排列后第25%的数字。

第2四分位数 (Q2),又称“中位数”,等于该样本中所有数值由小到大排列后第50%的数字。

第3四分位数 (Q3),又称“较大四分位数”,等于该样本中所有数值由小到大排列后第75%的数字。

四分位距(InterQuartile Range,IQR)= 第3四分位数与第1四分位数的差距。

0x02 示例代码

Alink中完成分位数功能的是QuantileDiscretizerQuantileDiscretizer输入连续的特征列,输出分箱的类别特征。

  • 分位点离散可以计算选定列的分位点,然后使用这些分位点进行离散化。生成选中列对应的q-quantile,其中可以所有列指定一个,也可以每一列对应一个。
  • 分箱数(所需离散的数目,即分为几段)是通过参数numBuckets(桶数目)来指定的。 箱的范围是通过使用近似算法来得到的。

本文示例代码如下。

@H_301_164@public class QuantileDiscretizerExample { public static void main(String[] args) throws Exception { NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001,2000,"col0"); // 就是把1001 ~ 2000 这个连续数值分段 Pipeline pipeline = new Pipeline() .add(new QuantileDiscretizer() .setNumBuckets(6) // 指定分箱数数目 .setSelectedCols(new String[]{"col0"})); List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect(); System.out.println(result); } }

输出

@H_301_164@[0,1,2,3,4,5,..... 0,1 ..... 5,5]

0x03 总体逻辑

我们首先给出总体逻辑图例

@H_301_164@-------------------------------- 准备阶段 -------------------------------- │ │ │ ┌───────────────────┐ │ getSelectedCols │ 获取需要分位的列名字 └───────────────────┘ │ │ │ ┌─────────────────────┐ │ quantileNum │ 获取分箱数 └─────────────────────┘ │ │ │ ┌──────────────────────┐ │ Preprocessing.select │ 从输入中根据列名字select出数据 └──────────────────────┘ │ │ │ -------------------------------- 预处理阶段 -------------------------------- │ │ │ ┌──────────────────────┐ │ quantile │ 后续步骤 就是 计算分位数 └──────────────────────┘ │ │ │ ┌────────────────────────────────┐ │ countElementsPerPartition │ 在每一个partition中获取该分区的所有元素个数 └────────────────────────────────┘ │ <task id,count in this task> │ │ ┌──────────────────────┐ │ sum(1) │ 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数 └──────────────────────┘ │ │ │ ┌──────────────────────┐ │ map │ 取出所有元素个数,cnt在后续会使用 └──────────────────────┘ │ │ │ │ ┌──────────────────────┐ │ missingCount │ 分区查找应选的列中,有哪些数据没有被查到,比如zeroAsMissing,null,isNaN └──────────────────────┘ │ │ │ ┌────────────────┐ │ mapPartition │ 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来 └────────────────┘ │ <idx in row,item in row>,即<row中第几个元素,元素> │ │ ┌──────────────┐ │ pSort │ 将flatten数据进行排序 └──────────────┘ │ 返回的是二元组 │ f0: dataset which is indexed by partition id │ f1: dataset which has partition id and count │ │ -------------------------------- 计算阶段 -------------------------------- │ │ │ ┌─────────────────┐ │ MultiQuantile │ 后续都是具体计算步骤 └─────────────────┘ │ │ │ ┌─────────────────┐ │ open │ 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序) └─────────────────┘ │ │ │ ┌─────────────────┐ │ mapPartition │ 具体计算 └─────────────────┘ │ │ │ ┌─────────────────┐ │ groupBy(0) │ 依据 列idx 分组 └─────────────────┘ │ │ │ ┌─────────────────┐ │ reduceGroup │ 归并排序 └─────────────────┘ │set(Tuple2<column idx,真实数据值>) │ │ -------------------------------- 序列化模型 -------------------------------- │ │ │ ┌──────────────┐ │ reduceGroup │ 分组归并 └──────────────┘ │ │ │ ┌─────────────────┐ │ SerializeModel │ 序列化模型 └─────────────────┘

下面图片是为了在手机上缩放适配展示。

QuantileDiscretizerTrainBatchOp.linkFrom如下:

@H_301_164@public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); // 示例中设置了 .setSelectedCols(new String[]{"col0"}));, 所以这里 quantileColNames 的数值是"col0 String[] quantileColNames = getSelectedCols(); int[] quantileNum = null; // 示例中设置了 .setNumBuckets(6),所以这里 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6 if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) { quantileNum = new int[quantileColNames.length]; Arrays.fill(quantileNum,getNumBuckets()); } else { quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray(); } /* filter the selected column from input */ // 获取了 选择的列 "col0" DataSet<Row> input = Preprocessing.select(in,quantileColNames).getDataSet(); // 计算分位数 DataSet<Row> quantile = quantile( input,quantileNum,getParams().get(HasRoundMode.ROUND_MODE),getParams().get(Preprocessing.ZERO_AS_MISSING) ); // 序列化模型 quantile = quantile.reduceGroup( new SerializeModel( getParams(),quantileColNames,TableUtil.findColTypesWithAssertAndHint(in.getSchema(),quantileColNames),BinTypes.BinDivideType.QUANTILE ) ); /* set output */ setOutput(quantile,new QuantileDiscretizerModelDataConverter().getModelSchema()); return this; }

其总体逻辑如下:

  • 获取需要分位的列名字
  • 获取分箱数
  • 从输入中根据列名字select出数据
  • 调用 quantile 计算分位数
    • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id,count in this task>,然后 对于元素个数进行累积 sum(1) ,即"count in this task"进行累积,得出所有元素的个数 cnt;
    • 分区查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing,isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
    • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了, 返回flatten = <idx in row,即<row中第几个元素,元素>;
    • 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类。pSort返回的是二元组sortedData,f0: dataset which is indexed by partition id,f1: dataset which has partition id and count;
    • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数;具体是分区计算 mapPartition:
      • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
      • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
      • 把数据插入 allRows.add(value); value 可认为是 <partition id,真实数据>;
      • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
      • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168.
    • 依据 列idx 分组,得到 set(Tuple2<column idx,真实数据值>);
  • 序列化模型

0x04 训练

4.1 quantile

训练是通过 quantile 完成的,大致包含以下步骤。

  • 调用 countElementsPerPartition 在每一个partition中获取该分区的所有元素个数,返回<task id,isNaN这几种情况,然后依据 partition id 进行分组 groupBy(0) 累积求和,得到 missingCount;
  • 把输入数据Row打散,对于Row中的子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了,返回flatten = <idx in row,f1: dataset which has partition id and count;
  • 调用 MultiQuantile ,对 sortedData.f0(f0: dataset which is indexed by partition id) 进行计算分位数。

具体如下

@H_301_164@public static DataSet<Row> quantile( DataSet<Row> input,final int[] quantileNum,final HasRoundMode.RoundMode roundMode,final boolean zeroAsMissing) { /* instance count of dataset */ // countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数,返回<task id,count in this task>。 DataSet<Long> cnt = DataSetUtils .countElementsPerPartition(input) .sum(1) // 这里对第二个参数,即"count in this task"进行累积,得出所有元素的个数。 .map(new MapFunction<Tuple2<Integer,Long>,Long>() { @Override public Long map(Tuple2<Integer,Long> value) throws Exception { return value.f1; // 取出所有元素个数 } }); // cnt在后续会使用 /* missing count of columns */ // 会查找应选的列中,有哪些数据没有被查到,从代码看,是zeroAsMissing,isNaN这几种情况 DataSet<Tuple2<Integer,Long>> missingCount = input .mapPartition(new RichMapPartitionFunction<Row,Tuple2<Integer,Long>>() { public void mapPartition(Iterable<Row> values,Collector<Tuple2<Integer,Long>> out) { StreamSupport.stream(values.spliterator(),false) .flatMap(x -> { long[] counts = new long[x.getArity()]; Arrays.fill(counts,0L); // 如果发现有数据没有查到,就增加counts for (int i = 0; i < x.getArity(); ++i) { if (x.getField(i) == null || (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0) || Double.isNaN(((Number)x.getField(i)).doubleValue())) { counts[i]++; } } return IntStream.range(0,x.getArity()) .mapToObj(y -> Tuple2.of(y,counts[y])); }) .collect(Collectors.groupingBy( x -> x.f0,Collectors.mapping(x -> x.f1,Collectors.reducing((a,b) -> a + b)) ) ) .entrySet() .stream() .map(x -> Tuple2.of(x.getKey(),x.getValue().get())) .forEach(out::collect); } }) .groupBy(0) //按第一个元素分组 .reduce(new RichReduceFunction<Tuple2<Integer,Long>>() { @Override public Tuple2<Integer,Long> reduce(Tuple2<Integer,Long> value1,Long> value2) { return Tuple2.of(value1.f0,value1.f1 + value2.f1); //累积求和 } }); /* flatten dataset to 1d */ // 把输入数据打散。 DataSet<PairComparable> flatten = input .mapPartition(new RichMapPartitionFunction<Row,PairComparable>() { PairComparable pairBuff; public void mapPartition(Iterable<Row> values,Collector<PairComparable> out) { for (Row value : values) { // 遍历分区内所有输入元素 for (int i = 0; i < value.getArity(); ++i) { // 如果输入元素Row本身包含多个子元素 pairBuff.first = i; // 则对于这些子元素按照Row内顺序一一发送出来,这就做到了把Row类型给flatten了 if (value.getField(i) == null || (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0) || Double.isNaN(((Number)value.getField(i)).doubleValue())) { pairBuff.second = null; } else { pairBuff.second = (Number) value.getField(i); } out.collect(pairBuff); // 返回<idx in row,即<row中第几个元素,元素> } } } }); /* sort data */ // 将flatten数据进行排序,pSort是大规模分区排序,此时还没有分类 // pSort返回的是二元组,f0: dataset which is indexed by partition id,f1: dataset which has partition id and count. Tuple2<DataSet<PairComparable>,DataSet<Tuple2<Integer,Long>>> sortedData = SortUtilsNext.pSort(flatten); /* calculate quantile */ return sortedData.f0 //f0: dataset which is indexed by partition id .mapPartition(new MultiQuantile(quantileNum,roundMode)) .withBroadcastSet(sortedData.f1,"counts") //f1: dataset which has partition id and count .withBroadcastSet(cnt,"totalCnt") .withBroadcastSet(missingCount,"missingCounts") .groupBy(0) // 依据 列idx 分组 .reduceGroup(new RichGroupReduceFunction<Tuple2<Integer,Number>,Row>() { @Override public void reduce(Iterable<Tuple2<Integer,Number>> values,Collector<Row> out) { TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() { @Override public int compare(Number o1,Number o2) { return SortUtils.OBJECT_COMPARATOR.compare(o1,o2); } }); int id = -1; for (Tuple2<Integer,Number> val : values) { // Tuple2<column idx,数据> id = val.f0; set.add(val.f1); } // runtime变量 set = {TreeSet@9379} size = 5 0 = {Long@9389} 167 // 就是第 0 列的第一段 idx 1 = {Long@9392} 333 // 就是第 0 列的第二段 idx 2 = {Long@9393} 500 3 = {Long@9394} 667 4 = {Long@9382} 833 out.collect(Row.of(id,set.toArray(new Number[0]))); } }); }

下面会对几个重点函数做说明。

4.2 countElementsPerPartition

countElementsPerPartition 的作用是:在每一个partition中获取该分区的所有元素个数。

@H_301_164@public static <T> DataSet<Tuple2<Integer,Long>> countElementsPerPartition(DataSet<T> input) { return input.mapPartition(new RichMapPartitionFunction<T,Long>>() { @Override public void mapPartition(Iterable<T> values,Long>> out) throws Exception { long counter = 0; for (T value : values) { counter++; // 在每一个partition中获取该分区的所有元素个数 } out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(),counter)); } }); }

4.3 MultiQuantile

MultiQuantile用来计算具体的分位点。

open函数中会从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)等等。

mapPartition函数则做具体计算,大致步骤如下:

  • 累积,得到当前 task 的起始位置,即 n 个输入数据中从哪个数据开始计算;
  • 根据 taskId 从 counts 中得到了本 task 应该处理哪些数据,即数据的start,end位置;
  • 把数据插入 allRows.add(value); value 可认为是 <partition id,真实数据>;
  • 调用 QIndex 计算分位数元数据;quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6;
  • 遍历一直到分箱数,每次循环 调用 qIndex.genIndex(j) 获取每个分箱的index。然后依据这个分箱的index从输入数据中获取真实数据值,这个 真实数据值 就是 真实数据的index。比如连续区域是 1001 ~ 2000,分成 6 份,则第一份调用 qIndex.genIndex(j) 得到 167,则根据167,获取真实数据是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一个分位index 是 1168;

具体代码是:

@H_301_164@public static class MultiQuantile extends RichMapPartitionFunction<PairComparable,Number>> { private List<Tuple2<Integer,Long>> counts; private List<Tuple2<Integer,Long>> missingCounts; private long totalCnt = 0; private int[] quantileNum; private HasRoundMode.RoundMode roundType; private int taskId; @Override public void open(Configuration parameters) throws Exception { // 从广播中获取变量,初步处理counts(排序),totalCnt,missingCounts(排序)。 // 之前设置广播变量.withBroadcastSet(sortedData.f1,"counts"),其中 f1 的格式是: dataset which has partition id and count,所以就是用 partition id来排序 this.counts = getRuntimeContext().getBroadcastVariableWithInitializer( "counts",new BroadcastVariableInitializer<Tuple2<Integer,List<Tuple2<Integer,Long>>>() { @Override public List<Tuple2<Integer,Long>> initializeBroadcastVariable( Iterable<Tuple2<Integer,Long>> data) { ArrayList<Tuple2<Integer,Long>> sortedData = new ArrayList<>(); for (Tuple2<Integer,Long> datum : data) { sortedData.add(datum); } //排序 sortedData.sort(Comparator.comparing(o -> o.f0)); // runtime的数据如下,本机有4核,所以数据分为4个 partition,每个partition的数据分别为251,250,250,250 sortedData = {ArrayList@9347} size = 4 0 = {Tuple2@9350} "(0,251)" // partition 0,数据个数是251 1 = {Tuple2@9351} "(1,250)" 2 = {Tuple2@9352} "(2,250)" 3 = {Tuple2@9353} "(3,250)" return sortedData; } }); this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",new BroadcastVariableInitializer<Long,Long>() { @Override public Long initializeBroadcastVariable(Iterable<Long> data) { return data.iterator().next(); } }); this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer( "missingCounts",Long>> data) { return StreamSupport.stream(data.spliterator(),false) .sorted(Comparator.comparing(o -> o.f0)) .collect(Collectors.toList()); } } ); taskId = getRuntimeContext().getIndexOfThisSubtask(); // runtime的数据如下 this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348} counts = {ArrayList@9347} size = 4 0 = {Tuple2@9350} "(0,251)" 1 = {Tuple2@9351} "(1,250)" 2 = {Tuple2@9352} "(2,250)" 3 = {Tuple2@9353} "(3,250)" missingCounts = {ArrayList@9375} size = 1 0 = {Tuple2@9381} "(0,0)" totalCnt = 1001 quantileNum = {int[1]@9376} 0 = 6 roundType = {HasRoundMode$RoundMode@9377} "ROUND" taskId = 2 } @Override public void mapPartition(Iterable<PairComparable> values,Number>> out) throws Exception { long start = 0; long end; int curListIndex = -1; int size = counts.size(); // 分成4份,所以这里是4 for (int i = 0; i < size; ++i) { int curId = counts.get(i).f0; // 取出输入元素中的 partition id if (curId == taskId) { curListIndex = i; // 当前 task 对应哪个 partition id break; // 到了当前task,就可以跳出了 } start += counts.get(i).f1; // 累积,得到当前 task 的起始位置,即1000个数据中从哪个数据开始计算 } // 根据 taskId 从counts中得到了本 task 应该处理哪些数据,即数据的start,end位置 // 本 partition 是 0,其中有251个数据 end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的数据个数 ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start)); for (PairComparable value : values) { allRows.add(value); // value 可认为是 <partition id,真实数据> } allRows.sort(Comparator.naturalOrder()); // runtime变量 start = 0 curListIndex = 0 size = 4 end = 251 allRows = {ArrayList@9406} size = 251 0 = {PairComparable@9408} first = {Integer@9397} 0 second = {Long@9434} 0 1 = {PairComparable@9409} first = {Integer@9397} 0 second = {Long@9435} 1 2 = {PairComparable@9410} first = {Integer@9397} 0 second = {Long@9439} 2 ...... // size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1 size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1; int localStart = 0; for (int i = 0; i < size; ++i) { int fIdx = (int) (start / totalCnt + i); int subStart = 0; int subEnd = (int) totalCnt; if (i == 0) { subStart = (int) (start % totalCnt); // 0 } if (i == size - 1) { subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251 } if (totalCnt - missingCounts.get(fIdx).f1 == 0) { localStart += subEnd - subStart; continue; } QIndex qIndex = new QIndex( totalCnt - missingCounts.get(fIdx).f1,quantileNum[fIdx],roundType); // runtime变量 qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548} totalCount = 1001.0 q1 = 0.16666666666666666 roundMode = {HasRoundMode$RoundMode@9377} "ROUND" // 遍历,一直到分箱数。 for (int j = 1; j < quantileNum[fIdx]; ++j) { // 获取每个分箱的index long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 个分为6段,第一段终点是167 //对应本 task = 0,subStart = 0,subEnd = 251。则index = 167,直接从allRows获取第167个,数值是 1168。因为连续区域是 1001 ~ 2000,所以第167个对应数值就是1168 //如果本 task = 1,subStart = 251,subEnd = 501。则index = 333,直接从allRows获取第 (333 + 0 - 251)= 第 82 个,获取其中的数值。这里因为数值区域是 1001 ~ 2000,所以数值是1334。 if (index >= subStart && index < subEnd) { // idx刚刚好在本分区的数据中 PairComparable pairComparable = allRows.get( (int) (index + localStart - subStart)); // // runtime变量 pairComparable = {PairComparable@9581} first = {Integer@9507} 0 // first是column idx second = {Long@9584} 167 // 真实数据 out.collect(Tuple2.of(pairComparable.first,pairComparable.second)); } } localStart += subEnd - subStart; } } }

4.4 QIndex

其中 QIndex 是本文关键所在,就是具体计算分位数。

  • 构造函数中会得倒所有元素个数,每段大小;
  • genIndex函数中会具体计算,比如假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167
@H_301_164@public static class QIndex { private double totalCount; private double q1; private HasRoundMode.RoundMode roundMode; public QIndex(double totalCount,int quantileNum,HasRoundMode.RoundMode type) { this.totalCount = totalCount; // 1001,所有元素的个数 this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分成几段,q1就是每一段的大小。如果分成6段,则每一段的大小是1/6 this.roundMode = type; } public long genIndex(int k) { // 假设还是6段,则如果取第一段,则k=1,其index为 (1/6 * (1001 - 1) * 1) = 167 return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k); } }

0x05 输出模型

输出模型是通过 reduceGroup 调用 SerializeModel 来完成。

具体逻辑是:

  • 先构建分箱点元数据信息;
  • 然后序列化成模型;
@H_301_164@// 序列化模型 quantile = quantile.reduceGroup( new SerializeModel( getParams(),BinTypes.BinDivideType.QUANTILE ) );

SerializeModel 的具体实现是:

@H_301_164@public static class SerializeModel implements GroupReduceFunction<Row,Row> { private Params Meta; private String[] colNames; private TypeInformation<?>[] colTypes; private BinTypes.BinDivideType binDivideType; @Override public void reduce(Iterable<Row> values,Collector<Row> out) throws Exception { Map<String,FeatureBorder> m = new HashMap<>(); for (Row val : values) { int index = (int) val.getField(0); Number[] splits = (Number[]) val.getField(1); m.put( colNames[index],QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder( colNames[index],colTypes[index],splits,Meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),binDivideType ) ); } for (int i = 0; i < colNames.length; ++i) { if (m.containsKey(colNames[i])) { continue; } m.put( colNames[i],QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder( colNames[i],colTypes[i],binDivideType ) ); } QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m,Meta); model.save(model,out); } }

这里用到了 FeatureBorder 类。

数据分箱是按照某种规则将数据进行分类。就像可以将水果按照大小进行分类,售卖不同的价格一样。

FeatureBorder 就是专门为了 Featureborder for binning,discrete Featureborder and continuous Featureborder。

我们能够看出来,该分箱对应的列名,index,各个分割点。

@H_301_164@m = {HashMap@9380} size = 1 "col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"

0x06 预测

预测是在 QuantileDiscretizerModelMapper 中完成的。

6.1 加载模型

模型数据是

@H_301_164@model = {QuantileDiscretizerModelDataConverter@9582} Meta = {Params@9670} "Params {selectedCols=["col0"],version="v2",numBuckets=6}" data = {HashMap@9584} size = 1 "col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","binCount":6}"

loadModel会完成加载。

@H_301_164@@Override public void loadModel(List<Row> modelRows) { QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(); model.load(modelRows); for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) { FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]); List<Bin.BaseBin> norm = border.bin.normBins; int size = norm.size(); Long maxIndex = norm.get(0).getIndex(); Long lastIndex = norm.get(size - 1).getIndex(); for (int j = 0; j < norm.size(); ++j) { if (maxIndex < norm.get(j).getIndex()) { maxIndex = norm.get(j).getIndex(); } } long maxIndexWithNull = Math.max(maxIndex,border.bin.nullBin.getIndex()); switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) { case KEEP: mapperBuilder.vectorSize.put(i,maxIndexWithNull + 1); break; case SKIP: case ERROR: mapperBuilder.vectorSize.put(i,maxIndex + 1); break; default: throw new UnsupportedOperationException("Unsupported now."); } if (mapperBuilder.paramsBuilder.dropLast) { mapperBuilder.dropIndex.put(i,lastIndex); } mapperBuilder.discretizers[i] = createQuantileDiscretizer(border,model.Meta); } mapperBuilder.setAssembledVectorSize(); }

加载中,最后调用 createQuantileDiscretizer 生成 LongQuantileDiscretizer。这就是针对Long类型的离散器。

@H_301_164@public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer { long[] bounds; boolean isLeftOpen; int[] boundIndex; int nullIndex; boolean zeroAsMissing; @Override public int findIndex(Object number) { if (number == null) { return nullIndex; } long lVal = ((Number) number).longValue(); if (isMissing(lVal,zeroAsMissing)) { return nullIndex; } int hit = Arrays.binarySearch(bounds,lVal); if (isLeftOpen) { hit = hit >= 0 ? hit - 1 : -hit - 2; } else { hit = hit >= 0 ? hit : -hit - 2; } return boundIndex[hit]; } }

其数值如下:

@H_301_164@this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} bounds = {long[7]@9757} 0 = -9223372036854775807 1 = 1168 2 = 1334 3 = 1501 4 = 1667 5 = 1834 6 = 9223372036854775807 isLeftOpen = true boundIndex = {int[7]@9743} 0 = 0 // -9223372036854775807 ~ 1168 之间对应的最终分箱离散值是 0 1 = 1 2 = 2 3 = 3 4 = 4 5 = 5 6 = 5 // 1834 ~ 9223372036854775807 之间对应的最终分箱离散值是 5 nullIndex = 6 zeroAsMissing = false

6.2 预测

预测 QuantileDiscretizerModelMapper 的 DiscretizerMapperBuilder 完成。

@H_301_164@Row map(Row row){ // 这里的 row 举例是: row = {Row@9743} "1003" for (int i = 0; i < paramsBuilder.selectedCols.length; i++) { int colIdxInData = selectedColIndicesInData[i]; Object val = row.getField(colIdxInData); int foundIndex = discretizers[i].findIndex(val); // 找到 1003对应的index,就是调用Discretizer完成,这里找到 foundIndex 是0 predictIndices[i] = (long) foundIndex; } return paramsBuilder.outputColsHelper.getResultRow( row,setResultRow( predictIndices,paramsBuilder.encode,dropIndex,vectorSize,paramsBuilder.dropLast,assembledVectorSize) // 最后返回离散值是0 ); } this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744} paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752} selectedColIndicesInData = {int[1]@9754} vectorSize = {HashMap@9758} size = 1 dropIndex = {HashMap@9759} size = 1 assembledVectorSize = {Integer@9760} 6 discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761} 0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768} bounds = {long[7]@9776} isLeftOpen = true boundIndex = {int[7]@9777} nullIndex = 6 zeroAsMissing = false predictIndices = {Long[1]@9763}

0xFF 参考

QuantileDiscretizer的用法

Spark QuantileDiscretizer 分位数离散器

机器学习——数据离散化(时间离散,多值离散化,分位数,聚类法,频率区间,二值化)

如何通俗地理解分位数?

分位数通俗理解

Python解释数学系列——分位数Quantile

spark之QuantileDiscretizer源码解析

猜你在找的大数据相关文章