修改这个代码,预测2004-2024年的股票价格,并且把2004-2024预测的股票的价格和实际的股票价格画到一张图上进行对比
修改这个代码,预测2004-2024年的股票价格,并且把2004-2024预测的股票的价格和实际的股票价格画到一张图上进行对比
import pandas as pd
def parse_date(date_string):
return pd.Timestamp(date_string.replace('_', '-'))
df = pd.read_csv('D:/LSTMdata.csv', index_col='Date', parse_dates=True, date_parser=parse_date)
df.sort_index(inplace=True)
predict_count = int(len(df)*0.02)
df['label'] = df['Close'].shift(-predict_count)
X = df.drop(['label'],axis=1)
y = df['label'][:-predict_count]
from sklearn.preprocessing import StandardScaler
scale = StandardScaler()
scale.fit(X)
X = scale.transform(X)
X_lately = X[-predict_count:]
X = X[:-predict_count]
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x_train, y_train)
model.score(x_test, y_test)
predict = model.predict(X_lately)
import numpy as np
df['predict'] = np.nan
import datetime
# print(df.index[-1])
last_date_st = df.index[-1].timestamp()
next_date_st = last_date_st + 86400
# print(next_date)
for i in predict:
next_date = datetime.datetime.fromtimestamp(next_date_st)
df.loc[next_date] = [np.nan for _ in range(len(df.columns)-1)] + [i]
next_date_st += 86400
import matplotlib.pyplot as plt
import matplotlib.style as style
style.use('ggplot')
df['Close'].plot()
df['predict'].plot()
plt.show()