本文作者:何通,SupStat Inc(总部在纽约,中国分部为北京数博思达信息科技有限公司)数据科学家,加拿大Simon Fraser University计算机学院研究生,研究兴趣为数据挖掘和生物信息学。
引言
在数据分析的过程中,我们经常需要对数据建模并做预测。在众多的选择中,randomForest, gbm和glmnet是三个尤其流行的R包,它们在Kaggle的各大数据挖掘竞赛中的出现频率独占鳌头,被坊间人称为R数据挖掘包中的三驾马车。根据我的个人经验,gbm包比同样是使用树模型的randomForest包占用的内存更少,同时训练速度较快,尤其受到大家的喜爱。在python的机器学习库sklearn里也有GradientBoostingClassifier的存在。
Boosting分类器属于集成学习模型,它基本思想是把成百上千个分类准确率较低的树模型组合起来,成为一个准确率很高的模型。这个模型会不断地迭代,每次迭代就生成一颗新的树。对于如何在每一步生成合理的树,大家提出了很多的方法,我们这里简要介绍由Friedman提出的Gradient Boosting Machine。它在生成每一棵树的时候采用梯度下降的思想,以之前生成的所有树为基础,向着最小化给定目标函数的方向多走一步。在合理的参数设置下,我们往往要生成一定数量的树才能达到令人满意的准确率。在数据集较大较复杂的时候,我们可能需要几千次迭代运算,如果生成一个树模型需要几秒钟,那么这么多迭代的运算耗时,应该能让你专心地想静静……
现在,我们希望能通过xgboost工具更好地解决这个问题。xgboost的全称是eXtreme Gradient Boosting。正如其名,它是Gradient Boosting Machine的一个c++实现,作者为正在华盛顿大学研究机器学习的大牛陈天奇。他在研究中深感自己受制于现有库的计算速度和精度,因此在一年前开始着手搭建xgboost项目,并在去年夏天逐渐成型。xgboost最大的特点在于,它能够自动利用CPU的多线程进行并行,同时在算法上加以改进提高了精度。它的处女秀是Kaggle的希格斯子信号识别竞赛,因为出众的效率与较高的预测准确度在比赛论坛中引起了参赛选手的广泛关注,在1700多支队伍的激烈竞争中占有一席之地。随着它在Kaggle社区知名度的提高,最近也有队伍借助xgboost在比赛中夺得第一。
为了方便大家使用,陈天奇将xgboost封装成了python库。我有幸和他合作,制作了xgboost工具的R语言接口,并将其提交到了CRAN上。也有用户将其封装成了julia库。python和R接口的功能一直在不断更新,大家可以通过下文了解大致的功能,然后选择自己最熟悉的语言进行学习。
功能介绍
一、基础功能
首先,我们从github上安装这个包:
devtools::install_github('dmlc/xgboost',subdir='R-package')
动手时间到!第一步,运行下面的代码载入样例数据:
require(xgboost)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
这份数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。数据以1或0来标记某个属性存在与否,所以样例数据为稀疏矩阵类型:
class(train$data)
[1] "dgCMatrix"
attr(,"package")
[1] "Matrix"
不用担心,xgboost支持稀疏矩阵作为输入。下面就是训练模型的命令
bst <- xgboost(data = train$data, label = train$label, max.depth = 2, eta = 1,nround = 2, objective = "binary:logistic")
[0] train-error:0.046522
[1] train-error:0.022263
我们迭代了两次,可以看到函数输出了每一次迭代模型的误差信息。这里的数据是稀疏矩阵,当然也支持普通的稠密矩阵。如果数据文件太大不希望读进R中,我们也可以通过设置参数data = 'path_to_file'
使其直接从硬盘读取数据并分析。目前支持直接从硬盘读取libsvm格式的文件。
做预测只需要一句话:
pred <- predict(bst, test$data)
做交叉验证的函数参数与训练函数基本一致,只需要在原有参数的基础上设置nfold
:
cv.res <- xgb.cv(data = train$data, label = train$label, max.depth = 2, eta = 1, nround = 2, objective = "binary:logistic",
nfold = 5)
[0] train-error:0.046522+0.001102 test-error:0.046523+0.004410
[1] train-error:0.022264+0.000864 test-error:0.022266+0.003450
cv.res
train.error.mean train.error.std test.error.mean test.error.std
1: 0.046522 0.001102 0.046523 0.004410
2: 0.022264 0.000864 0.022266 0.003450
交叉验证的函数会返回一个data.table类型的结果,方便我们监控训练集和测试集上的表现,从而确定最优的迭代步数。
二、高速准确
上面的几行代码只是一个入门,使用的样例数据没法表现出xgboost高效准确的能力。xgboost通过如下的优化使得效率大幅提高:
- xgboost借助OpenMP,能自动利用单机CPU的多核进行并行计算。需要注意的是,Mac上的Clang对OpenMP的支持较差,所以默认情况下只能单核运行。
- xgboost自定义了一个数据矩阵类DMatrix,会在训练开始时进行一遍预处理,从而提高之后每次迭代的效率。
在尽量保证所有参数都一致的情况下,我们使用希格斯子竞赛的数据做了对照实验。
Model and Parameter | gbm | xgboost(1 thread) | xgboost(2 threads) | xgboost(4 threads) | xgboost(8 threads) |
---|---|---|---|---|---|
Time (in secs) | 761.48 | 450.22 | 102.41 | 44.18 | 34.04 |
以上实验使用的CPU是i7-4700MQ。python的sklearn速度与gbm相仿。如果想要自己对这个结果进行测试,可以在比赛的官方网站下载数据,并参考这份demo中的代码。
除了明显的速度提升外,xgboost在比赛中的效果也非常好。在这个竞赛初期,大家惊讶地发现R和python中的gbm竟然难以突破组织者预设的benchmark。而xgboost横空出世,用不到一分钟的训练时间便打入当时的top 10,引起了大家的兴趣与关注。准确度提升的主要原因在于,xgboost的模型和传统的GBDT相比加入了对于模型复杂度的控制以及后期的剪枝处理,使得学习出来的模型更加不容易过拟合。更多算法上的细节可以参考这份陈天奇给出的介绍性讲义。
三、进阶特征
除了速度快精度高,xgboost还有一些很有用的进阶特性。下面的“demo”链接对应着相应功能的简单样例代码。
- 只要能够求出目标函数的梯度和Hessian矩阵,用户就可以自定义训练模型时的目标函数。demo
- 允许用户在交叉验证时自定义误差衡量方法,例如回归中使用RMSE还是RMSLE,分类中使用AUC,分类错误率或是F1-score。甚至是在希格斯子比赛中的“奇葩”衡量标准AMS。demo
- 交叉验证时可以返回模型在每一折作为预测集时的预测结果,方便构建ensemble模型。demo
- 允许用户先迭代1000次,查看此时模型的预测效果,然后继续迭代1000次,最后模型等价于一次性迭代2000次。demo
- 可以知道每棵树将样本分类到哪片叶子上,facebook介绍过如何利用这个信息提高模型的表现。demo
- 可以计算变量重要性并画出树状图。demo
- 可以选择使用线性模型替代树模型,从而得到带L1+L2惩罚的线性回归或者logistic回归。demo
这些丰富的功能来源于对日常使用场景的总结,数据挖掘比赛需求以及许多用户给出的精彩建议。
四、未来计划
现在,机器学习工具在实用中会不可避免地遇到“单机性能不够”的问题。目前,xgboost的多机分布式版本正在开发当中。基础设施搭建完成之日,便是新一轮R包开始设计与升级之时。
结语
我为xgboost制作R接口的目的就是希望引进好的工具,让大家使用R的时候心情更愉悦。总结下来,xgboost的特点有三个:速度快,效果好,功能多,希望它能受到大家的喜爱,成为一驾新的马车。
xgboost功能较多,参数设置比较繁杂,希望在上手之后有更全面了解的读者可以参考项目wiki。欢迎大家多多交流,在项目issue区提出疑问与建议。我们也邀请有兴趣的读者提交代码完善功能,让xgboost成为更好用的工具。
另外,在使用github开发的过程中,我深切地感受到了协作写代码带来的变化。一群人在一起的时候,可以写出更有效率的代码,在丰富的使用场景中发现新的需求,在极端情况发现隐藏很深的bug,甚至在主代码手拖延症较为忙碌的时候有人挺身而出拿下一片issue。这样的氛围,能让一个语言社区的交流丰富起来,从而充满生命力地活跃下去。
发表/查看评论