嘘~ 正在从服务器偷取页面 . . .

房价预测(kaggle系列)


一、下载和缓存数据集

首先,先维护字典DATA_HUB,将数据集名称的字符映射到数据集相关的二元组上,这个二元组包含了数据集的URL和验证文件完整性的sha-1密钥。所有这样的数据集都托管在地址为DATA_URL的站点上。

1
2
3
4
5
6
7
8
9
import os
import hashlib
import tarfile
import zipfile
import requests

#save
DATA_HUB = dict()
DATA_URL = 'https://d2l-data.s3_accelerate.amazonaws.com/'

定义一个downlowd函数用来下载数据集,将数据集缓存在本地目录中,并返回下载问文件的名称。如果缓存目录中已经存在这个数据集文件,并且sha-1与存储在DATA_HUB中的匹配时,则将使用缓存文件,以避免重复下载。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def download(name,cache_dir=os.path.join('..','data')):
#下载一个DATA_HUB中的文件,返回本地文件名
assert name in DATA_HUB,f"{name}不存在于{DATA_HUB}."
url, sha1_hash = DATA_HUB[name]
os.makedirs(cache_dir, exist_ok=True)
fname = os.path.join(cache_dir, url.split('/')[-1])
if os.path.exists(fname):
sha1 = hashlib.sha1()
with open(fname,'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.upgrade(data)
if sha1.hexdigest() == sha1_hash:
return fname #hit cache
print(f'正在从{url}中下载{fname}...')
r = requests.get(url, stream=True, verify=True)
with open(fname,'wb') as f:
f.write(r.content)
return fname

二、访问和读取数据集

可以观察数据集的每条记录都包括了房屋的属性值和属性,如街道类型、施工年份、屋顶类型、地下室状况等。这些特征由各种数据类型组成。例如,建筑年份由整数表示,屋顶类型由离散类别表示,其他特征由浮点数表示,并且一些数据完全丢失了,缺失值被简单地标记为“NA”。

我们首先使用pandas读入并处理数据。

1
2
3
4
5
import numpy as np
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l

使用定义的脚本下载并缓存Kaggle房屋数据集。

1
2
3
4
5
6
7
DATA_HUB['kaggle_house_train'] = (
DATA_URL + 'kaggle_house_pred_train.csv',
'585e9cc93e70b39160e7921475f9bcd7d31219ce')

DATA_HUB['kaggle_house_test'] = (
DATA_URL + 'kaggle_house_pred_test.csv',
'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')

使用pandas加载包含训练数据和测试数据的两个csv文件。

1
2
train_data = pd.read_csv(download('kaggle_house_train'))
train_data = pd.read_csv(download('kaggle_house_test'))

运行结束以后可以发现训练集数据包含1460个样本,每个样本80个特征和1个标签,而测试数据包含1459个样本,每个样本80个特征。


文章作者: Jeremy Yang
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Jeremy Yang !
  目录