XGBoost 鸢尾花分类 Hello World
目标
用 XGBoost 在经典 Iris 数据集上完成分类,展示训练、交叉验证、特征重要性可视化。
完整代码
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
# ─── 1. 加载数据 ───
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# ─── 2. 训练模型 ───
model = xgb.XGBClassifier(
n_estimators=100,
max_depth=4,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
objective="multi:softprob",
random_state=42,
eval_metric="mlogloss",
)
model.fit(
X_train, y_train,
eval_set=[(X_train, y_train), (X_test, y_test)],
verbose=False,
)
# ─── 3. 评估 ───
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"测试准确率: {acc:.4f}")
print("\n分类报告:\n", classification_report(y_test, y_pred, target_names=iris.target_names))
# ─── 4. 5 折交叉验证 ───
cv_scores = cross_val_score(model, X, y, cv=5)
print(f"5 折 CV 平均准确率: {cv_scores.mean():.4f} (±{cv_scores.std():.4f})")
# ─── 5. 特征重要性 ───
xgb.plot_importance(model, importance_type="gain")
plt.title("XGBoost Feature Importance (Gain)")
plt.tight_layout()
plt.savefig("feature_importance.png")
plt.show()
print("\n特征重要性 (Gain):")
for name, score in zip(iris.feature_names, model.feature_importances_):
print(f" {name}: {score:.4f}")
运行步骤
pip install xgboost scikit-learn matplotlib
python xgboost_iris.py
预期输出
测试准确率: 0.9667
分类报告:
precision recall f1-score support
setosa 1.00 1.00 1.00 10
versicolor 0.91 1.00 0.95 10
virginica 1.00 0.90 0.95 10
accuracy 0.97 30
5 折 CV 平均准确率: 0.9533 (±0.0400)
特征重要性 (Gain):
sepal length (cm): 0.1234
sepal width (cm): 0.0876
petal length (cm): 0.4567
petal width (cm): 0.3323