一、下载和缓存数据集
首先,先维护字典DATA_HUB,将数据集名称的字符映射到数据集相关的二元组上,这个二元组包含了数据集的URL和验证文件完整性的sha-1密钥。所有这样的数据集都托管在地址为DATA_URL的站点上。
1 2 3 4 5 6 7 8 9
| import os import hashlib import tarfile import zipfile import requests
DATA_HUB = dict() DATA_URL = 'https://d2l-data.s3_accelerate.amazonaws.com/'
|
定义一个downlowd函数用来下载数据集,将数据集缓存在本地目录中,并返回下载问文件的名称。如果缓存目录中已经存在这个数据集文件,并且sha-1与存储在DATA_HUB中的匹配时,则将使用缓存文件,以避免重复下载。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| def download(name,cache_dir=os.path.join('..','data')): assert name in DATA_HUB,f"{name}不存在于{DATA_HUB}." url, sha1_hash = DATA_HUB[name] os.makedirs(cache_dir, exist_ok=True) fname = os.path.join(cache_dir, url.split('/')[-1]) if os.path.exists(fname): sha1 = hashlib.sha1() with open(fname,'rb') as f: while True: data = f.read(1048576) if not data: break sha1.upgrade(data) if sha1.hexdigest() == sha1_hash: return fname print(f'正在从{url}中下载{fname}...') r = requests.get(url, stream=True, verify=True) with open(fname,'wb') as f: f.write(r.content) return fname
|
二、访问和读取数据集
可以观察数据集的每条记录都包括了房屋的属性值和属性,如街道类型、施工年份、屋顶类型、地下室状况等。这些特征由各种数据类型组成。例如,建筑年份由整数表示,屋顶类型由离散类别表示,其他特征由浮点数表示,并且一些数据完全丢失了,缺失值被简单地标记为“NA”。
我们首先使用pandas读入并处理数据。
1 2 3 4 5
| import numpy as np import pandas as pd import torch from torch import nn from d2l import torch as d2l
|
使用定义的脚本下载并缓存Kaggle房屋数据集。
1 2 3 4 5 6 7
| DATA_HUB['kaggle_house_train'] = ( DATA_URL + 'kaggle_house_pred_train.csv', '585e9cc93e70b39160e7921475f9bcd7d31219ce')
DATA_HUB['kaggle_house_test'] = ( DATA_URL + 'kaggle_house_pred_test.csv', 'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')
|
使用pandas加载包含训练数据和测试数据的两个csv文件。
1 2
| train_data = pd.read_csv(download('kaggle_house_train')) train_data = pd.read_csv(download('kaggle_house_test'))
|
运行结束以后可以发现训练集数据包含1460个样本,每个样本80个特征和1个标签,而测试数据包含1459个样本,每个样本80个特征。