上QQ阅读APP看书,第一时间看更新
2.3 示例:XGBoost告诉你蘑菇是否有毒
本节通过一个简单的示例,介绍如何使用XGBoost解决机器学习问题。该示例使用的是XGBoost自带的数据集(位于/demo/data文件夹下),该数据集描述的是不同蘑菇的相关特征,如大小、颜色等,并且每一种蘑菇都会被标记为可食用的(标记为0)或有毒的(标记为1)。我们的任务是对蘑菇特征数据进行学习,训练相关模型,然后利用训练好的模型预测未知的蘑菇样本是否有毒。下面用XGBoost解决该问题,代码如下:
In [102]: import xgboost as xgb # 数据读取 xgb_train = xgb.DMatrix('${XGBOOST_PATH}/demo/data/agaricus.txt.train ') xgb_test = xgb.DMatrix('${XGBOOST_PATH}/demo/data/agaricus.txt.test ') # 定义模型训练参数 params = { "objective": "binary:logistic", "booster": "gbtree", "max_depth": 3 } # 训练轮数 num_round = 5 # 训练过程中实时输出评估结果 watchlist = [(xgb_train, 'train'), (xgb_test, 'test')] # 模型训练 model = xgb.train(params, xgb_train, num_round, watchlist)
首先读取训练集数据和测试集数据(其中${XGBOOST_PATH}代表XGBoost的根目录路径),XGBoost会将数据加载为自定义的矩阵DMatrix。数据加载完毕后,定义模型训练参数(后续章节会详细介绍这些参数表示的意义),然后对模型进行训练,训练过程的输出如图2-13所示。
图2-13 训练过程的输出
由图2-13可以看到,XGBoost训练过程中实时输出了训练集和测试集的错误率评估结果。随着训练的进行,训练集和测试集的错误率均在不断下降,说明模型对于特征数据的学习是十分有效的。最后,模型训练完毕后,即可通过训练好的模型对测试集数据进行预测。预测代码如下:
In [103]: # 模型预测 preds = model.predict(xgb_test) preds Out [103]: array([ 0.10455427, 0.80366629, 0.10455427, ..., 0.89609396, 0.10285233, 0.89609396], dtype=float32)
可以看到,预测结果为一个浮点数的数组,数组的大小和测试集的样本数量是一致的。数组中的值均在0~1区间内,每个值对应一个样本。该值可以看作模型对该样本的预测概率,即模型认为该蘑菇是有毒蘑菇的概率。