Technology Sharing

Python28-11 CatBoost gradient boosting algorithm

2024-07-12

한어Русский языкEnglishFrançaisIndonesianSanskrit日本語DeutschPortuguêsΕλληνικάespañolItalianoSuomalainenLatina

picture

CatBoostCatBoost is a machine learning algorithm based on gradient boosting developed by Yandex (a Russian Internet company whose search engine once had more than 60% market share in Russia and also provides other Internet products and services). CatBoost is particularly good at processing categorical features and can effectively avoid overfitting and data leakage problems. The full name of CatBoost is "Categorical Boosting", and it was originally designed to perform better when processing data containing a large number of categorical features.

Features of CatBoost

  1. Handling categorical features: CatBoost can directly process categorical features without additional encoding (such as one-hot encoding).

  2. Avoid overfitting:CatBoost adopts a new method to process category features, which effectively reduces overfitting.

  3. Efficiency: CatBoost performs well in both training speed and prediction speed.

  4. Supports CPU and GPU training: CatBoost can run on the CPU or use the GPU for accelerated training.

  5. Automatically handle missing values: CatBoost can automatically handle missing values ​​without the need for additional preprocessing steps.

The core principle of CatBoost

The core principle of CatBoost is based on the gradient boosted decision tree (GBDT), but it has made innovations in handling categorical features and avoiding overfitting. The following are some key technical points:

  1. Category feature processing

    • CatBoost introduces a method called “mean encoding” to compute new features based on the mean of the categories.

    • Using a technique called “target encoding”, when converting categorical features into numerical features, the risk of data leakage is reduced by using the average of the target values.

    • During training, the data is processed using statistical information to avoid direct encoding using the target variable.

  2. Ordered Boosting

    • To prevent data leakage and overfitting, CatBoost processes the data in an orderly manner during training.

    • Ordered boosting works by randomly shuffling the data during training and ensuring that the model only sees past data at a given moment and does not use future information to make decisions.

  3. Computational Optimization

    • CatBoost speeds up the feature calculation process by pre-computation and caching.

    • It supports CPU and GPU training and can perform well on large-scale datasets.

Basic usage of CatBoost

The following is a basic example of using CatBoost for classification tasks. We use the Auto MPG (Miles Per Gallon) dataset, which is a classic regression problem dataset commonly used in machine learning and statistical analysis. This dataset records the fuel efficiency (i.e., miles per gallon) of different models of cars and multiple other related features.

Dataset characteristics:

  • mpg: Miles per gallon (target variable).

  • cylinders: Number of cylinders, indicating the number of cylinders in the engine.

  • displacement: Engine displacement (cubic inches).

  • horsepower: Engine power (horsepower).

  • weight: Vehicle weight in pounds.

  • acceleration: Acceleration time from 0 to 60 mph (seconds).

  • model_year: The year the vehicle was manufactured.

  • origin: Vehicle origin (1=United States, 2=Europe, 3=Japan).

The first few rows of the dataset:

  1.     mpg  cylinders  displacement  horsepower  weight  acceleration  model_year  origin
  2. 0  18.0          8         307.0       130.0  3504.0          12.0          70       1
  3. 1  15.0          8         350.0       165.0  3693.0          11.5          70       1
  4. 2  18.0          8         318.0       150.0  3436.0          11.0          70       1
  5. 3  16.0          8         304.0       150.0  3433.0          12.0          70       1
  6. 4  17.0          8         302.0       140.0  3449.0          10.5          70       1

Code example:

  1. import pandas as pd  # 导入Pandas库,用于数据处理
  2. import numpy as np  # 导入Numpy库,用于数值计算
  3. from sklearn.model_selection import train_test_split  # 从sklearn库导入train_test_split,用于划分数据集
  4. from sklearn.metrics import mean_squared_error, mean_absolute_error  # 导入均方误差和平均绝对误差,用于评估模型性能
  5. from catboost import CatBoostRegressor  # 导入CatBoost库中的CatBoostRegressor,用于回归任务
  6. import matplotlib.pyplot as plt  # 导入Matplotlib库,用于绘图
  7. import seaborn as sns  # 导入Seaborn库,用于绘制统计图
  8. # 设置随机种子以便结果复现
  9. np.random.seed(42)
  10. # 从UCI机器学习库加载Auto MPG数据集
  11. url = "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data"
  12. column_names = ['mpg''cylinders''displacement''horsepower''weight''acceleration''model_year''origin']
  13. data = pd.read_csv(url, names=column_names, na_values='?', comment='t', sep=' ', skipinitialspace=True)
  14. # 查看数据集的前几行
  15. print(data.head())
  16. # 处理缺失值
  17. data = data.dropna()
  18. # 特征和目标变量
  19. = data.drop('mpg', axis=1)  # 特征变量
  20. = data['mpg']  # 目标变量
  21. # 将类别特征转换为字符串类型(CatBoost可以直接处理类别特征)
  22. X['cylinders'= X['cylinders'].astype(str)
  23. X['model_year'= X['model_year'].astype(str)
  24. X['origin'= X['origin'].astype(str)
  25. # 将数据集划分为训练集和测试集
  26. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2random_state=42)
  27. # 定义CatBoost回归器
  28. model = CatBoostRegressor(
  29.     iterations=1000,  # 迭代次数
  30.     learning_rate=0.1,  # 学习率
  31.     depth=6,  # 决策树深度
  32.     loss_function='RMSE',  # 损失函数
  33.     verbose=100  # 输出训练过程信息
  34. )
  35. # 训练模型
  36. model.fit(X_train, y_train, eval_set=(X_test, y_test), early_stopping_rounds=50)
  37. # 进行预测
  38. y_pred = model.predict(X_test)
  39. # 评估模型性能
  40. mse = mean_squared_error(y_test, y_pred)  # 计算均方误差
  41. mae = mean_absolute_error(y_test, y_pred)  # 计算平均绝对误差
  42. # 打印模型的评估结果
  43. print(f'Mean Squared Error (MSE): {mse:.4f}')
  44. print(f'Mean Absolute Error (MAE): {mae:.4f}')
  45. # 绘制真实值与预测值的对比图
  46. plt.figure(figsize=(106))
  47. plt.scatter(y_test, y_pred, alpha=0.5)  # 绘制散点图
  48. plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], '--k')  # 绘制对角线
  49. plt.xlabel('True Values')  # X轴标签
  50. plt.ylabel('Predictions')  # Y轴标签
  51. plt.title('True Values vs Predictions')  # 图标题
  52. plt.show()
  53. # 特征重要性可视化
  54. feature_importances = model.get_feature_importance()  # 获取特征重要性
  55. feature_names = X.columns  # 获取特征名称
  56. plt.figure(figsize=(106))
  57. sns.barplot(x=feature_importances, y=feature_names)  # 绘制特征重要性条形图
  58. plt.title('Feature Importances')  # 图标题
  59. plt.show()
  60. # 输出
  61. '''
  62. mpg  cylinders  displacement  horsepower  weight  acceleration  
  63. 0  18.0          8         307.0       130.0  3504.0          12.0   
  64. 1  15.0          8         350.0       165.0  3693.0          11.5   
  65. 2  18.0          8         318.0       150.0  3436.0          11.0   
  66. 3  16.0          8         304.0       150.0  3433.0          12.0   
  67. 4  17.0          8         302.0       140.0  3449.0          10.5   
  68.    model_year  origin  
  69. 0          70       1  
  70. 1          70       1  
  71. 2          70       1  
  72. 3          70       1  
  73. 4          70       1  
  74. 0: learn: 7.3598113 test: 6.6405869 best: 6.6405869 (0) total: 1.7ms remaining: 1.69s
  75. 100: learn: 1.5990203 test: 2.3207830 best: 2.3207666 (94) total: 132ms remaining: 1.17s
  76. 200: learn: 1.0613606 test: 2.2319632 best: 2.2284239 (183) total: 272ms remaining: 1.08s
  77. Stopped by overfitting detector  (50 iterations wait)
  78. bestTest = 2.21453232
  79. bestIteration = 238
  80. Shrink model to first 239 iterations.
  81. Mean Squared Error (MSE): 4.9042
  82. Mean Absolute Error (MAE): 1.6381
  83. <Figure size 1000x600 with 1 Axes>
  84. <Figure size 1000x600 with 1 Axes>
  85. '''

Mean Squared Error (MSE): Mean square error, which represents the average square difference between the predicted value and the actual value. The smaller the value, the better the model performance. Here the value of MSE is 4.9042.

Mean Absolute Error (MAE): Mean absolute error, which represents the average absolute difference between the predicted value and the actual value. The smaller the value, the better the model performance. Here the value of MAE is 1.6381.

picture

  1. Scatter plot:Each point in the figure represents a test sample. The horizontal axis represents the true value of the sample (MPG), and the vertical axis represents the predicted value of the model (MPG).

  2. diagonal: The black dotted line in the figure is a 45-degree diagonal line, which represents the prediction result under ideal circumstances, that is, the predicted value is equal to the true value.

  3. Distribution of points:

    • Close to the diagonal: Indicates that the predicted value of the model is very close to the true value and the prediction is accurate.

    • Stay away from diagonals: Indicates that there is a large gap between the predicted value and the true value, and the prediction is inaccurate.

From the points in the figure, we can see that most of the points are concentrated near the diagonal line, which indicates that the model's prediction performance is good, but there are also some points far away from the diagonal line, indicating that there is a certain gap between the predicted values ​​of these samples and the true values.

picture

  1. Bar Chart: Each bar represents the importance of a feature in the model. The longer the bar, the greater the contribution of the feature to the model prediction.

  2. Feature Name: Lists the names of all features on the Y-axis.

  3. Feature importance value: The relative importance value of each feature is shown on the X-axis.

From the figure we can see:

  1. model_year: The most important of all the features, indicating that the year of manufacture of the car has a great influence on predicting fuel efficiency.

  2. weight: The weight of a car is the second most important characteristic and also has a significant impact on fuel efficiency.

  3. displacementandhorsepower: The engine's displacement and power also contribute greatly to fuel efficiency.

In this example, we use CatBoost on the Auto MPG dataset, where the main goal is to build a regression model to predict the fuel efficiency (i.e., miles per gallon, MPG) of a car.

The above content is summarized from the Internet. If it is helpful, please forward it. See you next time!