-
Notifications
You must be signed in to change notification settings - Fork 3
/
LLM4TS_ad.py
119 lines (98 loc) · 4.5 KB
/
LLM4TS_ad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from transformers import AutoModel, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.llama.configuration_llama import LlamaConfig
from einops import rearrange
from layers.Embed import DataEmbedding, DataEmbedding_wo_time
from layers.RevIN import RevIN
class Model(nn.Module):
def __init__(self, configs):
super().__init__()
self.is_llm = configs.is_llm
self.pretrain = configs.pretrain
self.pred_len = configs.pred_len
self.seq_len = configs.seq_len
self.patch_len = configs.patch_len
self.stride = configs.stride
self.patch_num = (configs.seq_len - self.patch_len) // self.stride + 1
self.d_ff = configs.d_ff
self.revin = configs.revin
if self.revin: self.revin_layer = RevIN(configs.enc_in, affine=configs.affine, subtract_last=configs.subtract_last)
if configs.is_llm:
if configs.pretrain:
if "gpt2" in configs.llm:
self.gpt = GPT2Model.from_pretrained(configs.llm, output_attentions=True, output_hidden_states=True)
elif "llama" in configs.llm:
self.gpt = LlamaModel.from_pretrained(configs.llm, output_attentions=True, output_hidden_states=True)
else:
raise NotImplementedError
else:
print("------------------no pretrain------------------")
if "gpt2" in configs.llm:
self.gpt = GPT2Model(GPT2Config())
elif "llama" in configs.llm:
self.gpt = LlamaModel(LlamaConfig())
else:
raise NotImplementedError
if "gpt2" in configs.llm:
self.gpt.h = self.gpt.h[:configs.llm_layers]
print("gpt2 = {}".format(self.gpt))
elif "llama" in configs.llm:
self.gpt.layers = self.gpt.layers[:configs.llm_layers]
print("llama2 = {}".format(self.gpt))
else:
raise NotImplementedError
self.in_layer = nn.Linear(self.patch_len, configs.d_model)
self.ln_proj = nn.LayerNorm(configs.d_ff)
self.out_layer = nn.Linear(
configs.d_ff,
configs.enc_in,
bias=True)
self.enc_embedding = DataEmbedding(configs.enc_in * self.patch_len, configs.d_model, configs.embed, configs.freq,
configs.dropout)
if configs.freeze:
if configs.c_pt:
layers_train = configs.pt_layers.split('_')
elif configs.sft:
layers_train = configs.sft_layers.split('_')
else:
raise NotImplementedError
for i, (name, param) in enumerate(self.gpt.named_parameters()):
tag = 0
for layer_train in layers_train:
if layer_train in name:
tag = 1
if tag:
param.requires_grad = True
else:
param.requires_grad = False
for layer in (self.gpt, self.in_layer, self.out_layer, self.revin_layer):
layer.train()
def forward(self, x_enc):
B, L, M = x_enc.shape
seg_num = 25
x_enc = rearrange(x_enc, 'b (n s) m -> b n s m', s=seg_num)
means = x_enc.mean(2, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=2, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
x_enc = rearrange(x_enc, 'b n s m -> b (n s) m')
enc_out = torch.nn.functional.pad(x_enc, (0, 768-x_enc.shape[-1]))
outputs = self.gpt(inputs_embeds=enc_out).last_hidden_state
outputs = outputs[:, :, :self.d_ff]
dec_out = self.out_layer(outputs)
dec_out = rearrange(dec_out, 'b (n s) m -> b n s m', s=seg_num)
dec_out = dec_out * \
(stdev[:, :, 0, :].unsqueeze(2).repeat(
1, 1, seg_num, 1))
dec_out = dec_out + \
(means[:, :, 0, :].unsqueeze(2).repeat(
1, 1, seg_num, 1))
dec_out = rearrange(dec_out, 'b n s m -> b (n s) m')
return dec_out