Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 24, 2023
1 parent 6506dcd commit 754a947
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 61 deletions.
25 changes: 9 additions & 16 deletions v2/api_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
os.environ["RWKV_JIT_ON"] = '1'

from rwkv.model import RWKV
from rwkv.utils import TOKENIZER
tokenizer = TOKENIZER("20B_tokenizer.json")
from rwkv.utils import PIPELINE

########################################################################################################
#
Expand Down Expand Up @@ -59,27 +58,21 @@
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy()) # same result as above

########################################################################################################

pipeline = PIPELINE(model, "20B_tokenizer.json")

prompt = "This is the best"
print(prompt, end='')

########################################################################################################
#
# 1. It's slow (not optimized yet) when your prompt is long. Better keep it as short as possible (for now).
# 2. Reuse the state (use deepcopy to clone it) when you are running the same prompt multiple times.
# 3. Use ctx4096 models if you need long ctx.
#
def generate(prompt, max_new_tokens, state=None):
out = ''
all_tokens = []
for i in range(max_new_tokens):
out, state = model.forward(tokenizer.encode(prompt) if i == 0 else [token], state)
token = tokenizer.sample_logits(out, None, None, temperature=1.0, top_p=0.8)
all_tokens += [token]
tmp = tokenizer.decode(all_tokens)
if '\ufffd' not in tmp: # is it a valid utf-8 string?
out = tmp
return out
completion = pipeline.generate(prompt, max_new_tokens=20)

prompt = "What I would like to say is: "
print(prompt, end='')
completion = generate(prompt, max_new_tokens=20)
print(completion)

# input('done. press Ctrl+C to exit')
50 changes: 19 additions & 31 deletions v2/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,8 @@

print(f'\n{CHAT_LANG} - {args.strategy} - QA_PROMPT {QA_PROMPT}')
from rwkv.model import RWKV
from rwkv.utils import TOKENIZER
tokenizer = TOKENIZER("20B_tokenizer.json")

args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
from rwkv.utils import PIPELINE

MODEL_NAME = args.MODEL_NAME

if CHAT_LANG == 'English':
Expand Down Expand Up @@ -238,13 +232,14 @@

print(f'Loading model - {MODEL_NAME}')
model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
pipeline = PIPELINE(model, "20B_tokenizer.json")

model_tokens = []
model_state = None

AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
dd = tokenizer.encode(i)
dd = pipeline.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd

Expand All @@ -257,7 +252,7 @@ def run_rnn(tokens, newline_adj = 0):
model_tokens += tokens
out, model_state = model.forward(tokens, model_state)

# print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]')
# print(f'### model ###\n{tokens}\n[{pipeline.decode(model_tokens)}]')

out[0] = -999999999 # disable <|endoftext|>
out[187] += newline_adj # adjust \n probability
Expand Down Expand Up @@ -287,7 +282,7 @@ def load_all_stat(srv, name):
# Run inference
print(f'\nRun prompt...')

out = run_rnn(tokenizer.encode(init_prompt))
out = run_rnn(pipeline.encode(init_prompt))
save_all_stat('', 'chat_init', out)
gc.collect()
torch.cuda.empty_cache()
Expand All @@ -305,9 +300,6 @@ def on_message(message):
srv = 'dummy_server'

msg = message.replace('\\n','\n').strip()
# if len(msg) > 1000:
# reply_msg('your message is too long (max 1000 tokens)')
# return

x_temp = GEN_TEMP
x_top_p = GEN_TOP_P
Expand Down Expand Up @@ -339,15 +331,15 @@ def on_message(message):
# print(f'### prompt ###\n[{new}]')
model_state = None
model_tokens = []
out = run_rnn(tokenizer.encode(new))
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)

elif msg[:4].lower() == '+qq ':
new = '\nQ: ' + msg[4:].strip() + '\nA:'
# print(f'### prompt ###\n[{new}]')
model_state = None
model_tokens = []
out = run_rnn(tokenizer.encode(new))
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)

elif msg[:4].lower() == '+qa ':
Expand All @@ -357,7 +349,7 @@ def on_message(message):
new = f"{user}{interface} {real_msg}\n\n{bot}{interface}"
# print(f'### qa ###\n[{new}]')

out = run_rnn(tokenizer.encode(new))
out = run_rnn(pipeline.encode(new))
save_all_stat(srv, 'gen_0', out)

elif msg.lower() == '+++':
Expand All @@ -376,10 +368,8 @@ def on_message(message):
begin = len(model_tokens)
out_last = begin
for i in range(FREE_GEN_LEN+100):
token = tokenizer.sample_logits(
token = pipeline.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p=x_top_p,
)
Expand All @@ -388,14 +378,14 @@ def on_message(message):
else:
out = run_rnn([token])

xxx = tokenizer.decode(model_tokens[out_last:])
xxx = pipeline.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
out_last = begin + i + 1
if i >= FREE_GEN_LEN:
break
print('\n')
# send_msg = tokenizer.decode(model_tokens[begin:]).strip()
# send_msg = pipeline.decode(model_tokens[begin:]).strip()
# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
save_all_stat(srv, 'gen_1', out)
Expand All @@ -410,7 +400,7 @@ def on_message(message):
out = load_all_stat(srv, 'chat')
new = f"{user}{interface} {msg}\n\n{bot}{interface}"
# print(f'### add ###\n[{new}]')
out = run_rnn(tokenizer.encode(new), newline_adj=-999999999)
out = run_rnn(pipeline.encode(new), newline_adj=-999999999)
save_all_stat(srv, 'chat_pre', out)

begin = len(model_tokens)
Expand All @@ -425,26 +415,24 @@ def on_message(message):
newline_adj = 0
else:
newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION
token = tokenizer.sample_logits(
token = pipeline.sample_logits(
out,
model_tokens,
args.ctx_len,
temperature=x_temp,
top_p=x_top_p,
)
out = run_rnn([token], newline_adj=newline_adj)

xxx = tokenizer.decode(model_tokens[out_last:])
xxx = pipeline.decode(model_tokens[out_last:])
if '\ufffd' not in xxx: # avoid utf-8 display issues
print(xxx, end='', flush=True)
out_last = begin + i + 1

send_msg = tokenizer.decode(model_tokens[begin:])
send_msg = pipeline.decode(model_tokens[begin:])
if '\n\n' in send_msg:
send_msg = send_msg.strip()
break

# send_msg = tokenizer.decode(model_tokens[begin:]).strip()
# send_msg = pipeline.decode(model_tokens[begin:]).strip()
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
# break
Expand All @@ -453,7 +441,7 @@ def on_message(message):
# break

# print(f'{model_tokens}')
# print(f'[{tokenizer.decode(model_tokens)}]')
# print(f'[{pipeline.decode(model_tokens)}]')

# print(f'### send ###\n[{send_msg}]')
# reply_msg(send_msg)
Expand All @@ -462,7 +450,7 @@ def on_message(message):
print(HELP_MSG)
print(f'{CHAT_LANG} - {args.MODEL_NAME} - {args.strategy}')

print(f'{tokenizer.decode(model_tokens)}'.replace(f'\n\n{bot}',f'\n{bot}'), end='')
print(f'{pipeline.decode(model_tokens)}'.replace(f'\n\n{bot}',f'\n{bot}'), end='')

while True:
msg = prompt(f'{user}{interface} ')
Expand Down
1 change: 0 additions & 1 deletion v2/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def __init__(self, model, strategy):
w[x] = w[x].pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :)
except:
print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.')
pass
elif DEVICE != 'cpu':
w[x] = w[x].to(device=DEVICE)

Expand Down
29 changes: 16 additions & 13 deletions v2/rwkv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,9 @@
from torch.nn import functional as F
from tokenizers import Tokenizer

time_slot = {}
time_ref = time.time_ns()

def record_time(name):
if name not in time_slot:
time_slot[name] = 1e20
tt = (time.time_ns() - time_ref) / 1e9
if tt < time_slot[name]:
time_slot[name] = tt

class TOKENIZER():
def __init__(self, WORD_NAME):
class PIPELINE():
def __init__(self, model, WORD_NAME):
self.model = model
self.tokenizer = Tokenizer.from_file(WORD_NAME)

def refine_context(self, context):
Expand All @@ -38,7 +29,7 @@ def encode(self, x):
def decode(self, x):
return self.tokenizer.decode(x)

def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0):
def sample_logits(self, logits, temperature=1.0, top_p=1.0):
probs = F.softmax(logits.float(), dim=-1)

if probs.device == torch.device('cpu'):
Expand All @@ -61,3 +52,15 @@ def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0):
probs = probs.pow(1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)

def generate(self, prompt, max_new_tokens, state=None):
out = ''
all_tokens = []
for i in range(max_new_tokens):
out, state = self.model.forward(self.encode(prompt) if i == 0 else [token], state)
token = self.sample_logits(out, temperature=1.0, top_p=0.8)
all_tokens += [token]
tmp = self.decode(all_tokens)
if '\ufffd' not in tmp: # is it a valid utf-8 string?
out = tmp
return out

0 comments on commit 754a947

Please sign in to comment.