import torch
import torch.nn as nn
class SpatioTemporalCSIPredictor(nn.Module):
def __init__(self, input_seq_len=8, pred_horizon=3, spatial_dim=64):
super().__init__()
self.temporal_encoder = nn.LSTM(input_size=spatial_dim, hidden_size=128, num_layers=2, batch_first=True)
self.spatial_conv = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=8)
self.decoder = nn.Linear(128, spatial_dim)
def forward(self, x):
# x: [B, T, H, W] CSI tensor over time
B, T, H, W = x.shape
x_flat = x.view(B, T, -1) # Flatten spatial dimensions
# Temporal modeling
lstm_out, _ = self.temporal_encoder(x_flat) # [B, T, 128]
# Self-attention across time steps
attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
# Predict future CSI
pred = self.decoder(attn_out[:, -1, :]) # Use last hidden state
return pred.view(B, H, W)
4. 系统级流程与决策闭环构建
graph TD
A[高速移动终端上报部分CSI] --> B{边缘AI推理节点}
B --> C[加载预训练LSTM-GAN混合模型]
C --> D[输入历史N个时隙CSI序列]
D --> E[输出未来M个时隙CSI预测值]
E --> F[计算瞬时信道容量C(t+τ)]
F --> G[动态AMC策略选择]
G --> H[调整MCS等级与PRB分配]
H --> I[下发调度指令至gNB]
I --> J[实际传输并采集误差反馈]
J --> K[在线微调预测模型参数]
K --> B