深入理解XGBoost:高效机器学习算法与进阶
上QQ阅读APP看书,第一时间看更新

2.3 示例:XGBoost告诉你蘑菇是否有毒

本节通过一个简单的示例,介绍如何使用XGBoost解决机器学习问题。该示例使用的是XGBoost自带的数据集(位于/demo/data文件夹下),该数据集描述的是不同蘑菇的相关特征,如大小、颜色等,并且每一种蘑菇都会被标记为可食用的(标记为0)或有毒的(标记为1)。我们的任务是对蘑菇特征数据进行学习,训练相关模型,然后利用训练好的模型预测未知的蘑菇样本是否有毒。下面用XGBoost解决该问题,代码如下:

In [102]: 
import xgboost as xgb
# 数据读取
xgb_train = xgb.DMatrix('${XGBOOST_PATH}/demo/data/agaricus.txt.train ')
xgb_test = xgb.DMatrix('${XGBOOST_PATH}/demo/data/agaricus.txt.test ')

# 定义模型训练参数
params = {
          "objective": "binary:logistic",
          "booster": "gbtree",
          "max_depth": 3
         }
# 训练轮数
num_round = 5

# 训练过程中实时输出评估结果
watchlist = [(xgb_train, 'train'), (xgb_test, 'test')]

# 模型训练
model = xgb.train(params, xgb_train, num_round, watchlist)

首先读取训练集数据和测试集数据(其中${XGBOOST_PATH}代表XGBoost的根目录路径),XGBoost会将数据加载为自定义的矩阵DMatrix。数据加载完毕后,定义模型训练参数(后续章节会详细介绍这些参数表示的意义),然后对模型进行训练,训练过程的输出如图2-13所示。

图2-13 训练过程的输出

由图2-13可以看到,XGBoost训练过程中实时输出了训练集和测试集的错误率评估结果。随着训练的进行,训练集和测试集的错误率均在不断下降,说明模型对于特征数据的学习是十分有效的。最后,模型训练完毕后,即可通过训练好的模型对测试集数据进行预测。预测代码如下:

In [103]: 
# 模型预测
preds = model.predict(xgb_test)
preds
Out [103]: 
array([ 0.10455427,  0.80366629,  0.10455427, ...,  0.89609396,
        0.10285233,  0.89609396], dtype=float32)

可以看到,预测结果为一个浮点数的数组,数组的大小和测试集的样本数量是一致的。数组中的值均在0~1区间内,每个值对应一个样本。该值可以看作模型对该样本的预测概率,即模型认为该蘑菇是有毒蘑菇的概率。