如何調整參數

在上一章最後提到,更改gamma值導致overfitting,那就竟gamma要調成多少呢?我們來看看

from sklearn.learning_curve import  validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

digits = load_digits()
X = digits.data
y = digits.target

param_range = np.logspace(-6, -2.3, 5)
train_loss, test_loss = validation_curve(
        SVC(), X, y, param_name='gamma', param_range=param_range, cv=10,
        scoring='mean_squared_error')

train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")


plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

跟上一章的程式碼很像,只是觀察的變數面變成gamma,我們將取-6到-2.3之間的5個值當作測試範圍並將這些結果繪製成圖從圖中可以知道,gamma在0.0005之後會有overfitting的問題,所以在選擇gamma時,因該在0.0003到0.00005左右

results matching ""

    No results matching ""