OpenCV python sklearn随机超参数搜索的实现

yipeiwu_com5年前Python基础

本文介绍了OpenCV python sklearn随机超参数搜索的实现,分享给大家,具体如下:

"""
房价预测数据集 使用sklearn执行超参数搜索
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import tensorflow as tf
from tensorflow_core.python.keras.api._v2 import keras # 不能使用 python
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from scipy.stats import reciprocal

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 0.打印导入模块的版本
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
  print("%s version:%s" % (module.__name__, module.__version__))


# 显示学习曲线
def plot_learning_curves(his):
  pd.DataFrame(his.history).plot(figsize=(8, 5))
  plt.grid(True)
  plt.gca().set_ylim(0, 1)
  plt.show()


# 1.加载数据集 california 房价
housing = fetch_california_housing()

print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

# 2.拆分数据集 训练集 验证集 测试集
x_train_all, x_test, y_train_all, y_test = train_test_split(
  housing.data, housing.target, random_state=7)
x_train, x_valid, y_train, y_valid = train_test_split(
  x_train_all, y_train_all, random_state=11)

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

# 3.数据集归一化
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.fit_transform(x_valid)
x_test_scaled = scaler.fit_transform(x_test)


# 创建keras模型
def build_model(hidden_layers=1, # 中间层的参数
        layer_size=30,
        learning_rate=3e-3):
  # 创建网络层
  model = keras.models.Sequential()
  model.add(keras.layers.Dense(layer_size, activation="relu",
                 input_shape=x_train.shape[1:]))
 # 隐藏层设置
  for _ in range(hidden_layers - 1):
    model.add(keras.layers.Dense(layer_size,
                   activation="relu"))
  model.add(keras.layers.Dense(1))

  # 优化器学习率
  optimizer = keras.optimizers.SGD(lr=learning_rate)
  model.compile(loss="mse", optimizer=optimizer)

  return model


def main():
  # RandomizedSearchCV

  # 1.转化为sklearn的model
  sk_learn_model = keras.wrappers.scikit_learn.KerasRegressor(build_model)

  callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]

  history = sk_learn_model.fit(x_train_scaled, y_train, epochs=100,
                 validation_data=(x_valid_scaled, y_valid),
                 callbacks=callbacks)
  # 2.定义超参数集合
  # f(x) = 1/(x*log(b/a)) a <= x <= b
  param_distribution = {
    "hidden_layers": [1, 2, 3, 4],
    "layer_size": np.arange(1, 100),
    "learning_rate": reciprocal(1e-4, 1e-2),
  }

  # 3.执行超搜索参数
  # cross_validation:训练集分成n份, n-1训练, 最后一份验证.
  random_search_cv = RandomizedSearchCV(sk_learn_model, param_distribution,
                     n_iter=10,
                     cv=3,
                     n_jobs=1)
  random_search_cv.fit(x_train_scaled, y_train, epochs=100,
             validation_data=(x_valid_scaled, y_valid),
             callbacks=callbacks)
  # 4.显示超参数
  print(random_search_cv.best_params_)
  print(random_search_cv.best_score_)
  print(random_search_cv.best_estimator_)

  model = random_search_cv.best_estimator_.model
  print(model.evaluate(x_test_scaled, y_test))

  # 5.打印模型训练过程
  plot_learning_curves(history)


if __name__ == '__main__':
  main()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python while 循环使用的简单实例

while循环是在Python中的循环结构之一。 while循环继续,直到表达式变为假。表达的是一个逻辑表达式,必须返回一个true或false值,本文章向码农介绍Python whil...

解决PyCharm的Python.exe已经停止工作的问题

今天遇到一个问题,就是用pycharm运行python程序,老是会出现Python.exe已停止的对话框。后来我到处在网上搜原因,网上给出的解决办法也有很多种。最终帮我解决问题的就是:打...

python 常用的基础函数

Python: 1. print()函数:打印字符串 2. raw_input()函数:从用户键盘捕获字符 3. len()函数:计算字符长度 4. format(12.3654,'6....

Python运维自动化之nginx配置文件对比操作示例

Python运维自动化之nginx配置文件对比操作示例

本文实例讲述了Python运维自动化之nginx配置文件对比操作。分享给大家供大家参考,具体如下: 文件差异对比diff.py #!/usr/bin/env python # imp...

给Python的Django框架下搭建的BLOG添加RSS功能的教程

前些天有位网友建议我在博客中添加RSS订阅功能,觉得挺好,所以自己抽空看了一下如何在Django中添加RSS功能,发现使用Django中的syndication feed framewo...