OpenCV python sklearn随机超参数搜索的实现

yipeiwu_com6年前Python基础

本文介绍了OpenCV python sklearn随机超参数搜索的实现,分享给大家,具体如下:

"""
房价预测数据集 使用sklearn执行超参数搜索
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import tensorflow as tf
from tensorflow_core.python.keras.api._v2 import keras # 不能使用 python
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from scipy.stats import reciprocal

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 0.打印导入模块的版本
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
  print("%s version:%s" % (module.__name__, module.__version__))


# 显示学习曲线
def plot_learning_curves(his):
  pd.DataFrame(his.history).plot(figsize=(8, 5))
  plt.grid(True)
  plt.gca().set_ylim(0, 1)
  plt.show()


# 1.加载数据集 california 房价
housing = fetch_california_housing()

print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

# 2.拆分数据集 训练集 验证集 测试集
x_train_all, x_test, y_train_all, y_test = train_test_split(
  housing.data, housing.target, random_state=7)
x_train, x_valid, y_train, y_valid = train_test_split(
  x_train_all, y_train_all, random_state=11)

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

# 3.数据集归一化
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.fit_transform(x_valid)
x_test_scaled = scaler.fit_transform(x_test)


# 创建keras模型
def build_model(hidden_layers=1, # 中间层的参数
        layer_size=30,
        learning_rate=3e-3):
  # 创建网络层
  model = keras.models.Sequential()
  model.add(keras.layers.Dense(layer_size, activation="relu",
                 input_shape=x_train.shape[1:]))
 # 隐藏层设置
  for _ in range(hidden_layers - 1):
    model.add(keras.layers.Dense(layer_size,
                   activation="relu"))
  model.add(keras.layers.Dense(1))

  # 优化器学习率
  optimizer = keras.optimizers.SGD(lr=learning_rate)
  model.compile(loss="mse", optimizer=optimizer)

  return model


def main():
  # RandomizedSearchCV

  # 1.转化为sklearn的model
  sk_learn_model = keras.wrappers.scikit_learn.KerasRegressor(build_model)

  callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]

  history = sk_learn_model.fit(x_train_scaled, y_train, epochs=100,
                 validation_data=(x_valid_scaled, y_valid),
                 callbacks=callbacks)
  # 2.定义超参数集合
  # f(x) = 1/(x*log(b/a)) a <= x <= b
  param_distribution = {
    "hidden_layers": [1, 2, 3, 4],
    "layer_size": np.arange(1, 100),
    "learning_rate": reciprocal(1e-4, 1e-2),
  }

  # 3.执行超搜索参数
  # cross_validation:训练集分成n份, n-1训练, 最后一份验证.
  random_search_cv = RandomizedSearchCV(sk_learn_model, param_distribution,
                     n_iter=10,
                     cv=3,
                     n_jobs=1)
  random_search_cv.fit(x_train_scaled, y_train, epochs=100,
             validation_data=(x_valid_scaled, y_valid),
             callbacks=callbacks)
  # 4.显示超参数
  print(random_search_cv.best_params_)
  print(random_search_cv.best_score_)
  print(random_search_cv.best_estimator_)

  model = random_search_cv.best_estimator_.model
  print(model.evaluate(x_test_scaled, y_test))

  # 5.打印模型训练过程
  plot_learning_curves(history)


if __name__ == '__main__':
  main()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Django ORM 常用字段与不常用字段汇总

Django ORM 常用字段与不常用字段汇总

常用字段 AutoField: int 自增列,必须填入参数 primary_key=True 如果没有写 AutoField,则会自动创建一个列名为 id 的列 from dja...

pygame实现贪吃蛇游戏(上)

pygame实现贪吃蛇游戏(上)

本文实例为大家分享了pygame贪吃蛇游戏的具体代码,供大家参考,具体内容如下 1.准备工作 我们已经初始化了一个400*400的界面,为方便看我们的游戏,我们先在界面上画40*40的格...

python处理文本文件并生成指定格式的文件

import os import sys import string #以指定模式打开指定文件,获取文件句柄 def getFileIns(filePath,model)...

Python随机生成彩票号码的方法

本文实例讲述了Python随机生成彩票号码的方法。分享给大家供大家参考。具体如下: 前些日子在淘宝上买了一阵子彩票,每次都是使用淘宝的机选,每次一注。后来觉得不如自己写一个机选的程序有意...

Python通过future处理并发问题

Python通过future处理并发问题

future初识 通过下面脚本来对future进行一个初步了解: 例子1:普通通过循环的方式 import os import time import sys import re...