Preface
自1950年以来,世界人口基本呈现线性增长.拿世界人口做线性回归的范例是相当合适的.
Prepare
我们通过wiki上的世界人口统计为数据集做处理.但是我们目前一般无法直接访问wiki,所以不妨通过这个网址来访问到我们需要的数据.wiki_population
我们通过pandas的read_html方法把数据读下来.
import pandas as pd
url = r'https://encyclopedia.thefreedictionary.com/World+population+estimates'
df = pd.read_html(url , header = 0 , attrs = {"class":"wikitable"})[2] # 这个网页的第二张表格是我们需要的
得到如下数据:
这里我选取美国人口调查局的调查结果作为样例(第二列).
注意这里的df是pandas的DataFrame,这是整张表格.我们把它换成张量并提取我们想要的前两列.
import torch
years = torch.tensor(df.iloc[:,0] , dtype=torch.float32)
populations=torch.tensor(df.iloc[:,1],dtype=torch.float32)
print(years)
print(populations)
如下所示
至此,准备工作轻松完成了.
可以验证一下是否基本呈现线性规律
from matplotlib import pyplot as plt
plt.figure(figsize=(20,8) , dpi = 80)
plt.scatter(years,populations)
plt.show()
因此做线性回归是合理的.
lstsq 方法求解
lstsq是最轻松的方法求解线性回归,代码量短小实用.
import torch
x = torch.stack([years,torch.ones_like(years)],1) # 拼接上多余的‘1’列
y = populations
wr,_ = torch.lstsq(y,x)
slope,intercept = wr[:2,0]
result = 'populations = {:.2e} * year + {:.2e}'.format(slope,intercept)
print("回归结果:"+result)
代码简洁整齐,比较受欢迎.
汇总以上代码
# 数据收集
import torch
import pandas as pd
url = r'https://encyclopedia.thefreedictionary.com/World+population+estimates'
df = pd.read_html(url , header = 0 , attrs = {"class":"wikitable"})[2]
# 处理
years = torch.tensor(df.iloc[:,0] , dtype=torch.float32)
populations=torch.tensor(df.iloc[:,1],dtype=torch.float32)
# 回归
x = torch.stack([years,torch.ones_like(years)],1)
y = populations
wr,_ = torch.lstsq(y,x)
slope,intercept = wr[:2,0]
result = 'populations = {:.2e} * year + {:.2e}'.format(slope,intercept)
print("回归结果:"+result)
当数据集已经过大以至于无法装入内存,不妨push到服务器上去测试,也可以换成求解极值的思想,借助Adam优化器实现,这样可以规避数据过多内存不够的问题.
Adam优化方法
Adam优化器做线性回归,代码稍长,但也比较直观.本处暂略,日后填坑.
– 后记 写分享的时候一定要及时保存下来,我因为没有及时保存导致我又得重写一次,额外花费了时间,所以时间紧张日后补充
???????????? 卧槽 牛逼牛逼牛逼