Initial commit 2
This commit is contained in:
parent
886fb6966f
commit
ee30f68bca
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
.telegram_token
|
||||
tami-models/*
|
||||
!tami-models/MODELS_GO_HERE
|
||||
tokenizer_vocab/*
|
||||
!tokenizer_vocab/VOCABS_GO_HERE
|
||||
librwkv.so
|
||||
tamilm.log
|
||||
rwkv_cpp
|
329
bot_hf_tokenizers_cuda.py
Normal file
329
bot_hf_tokenizers_cuda.py
Normal file
|
@ -0,0 +1,329 @@
|
|||
#!/usr/bin/env python
|
||||
# pylint: disable=unused-argument
|
||||
# This program is dedicated to the public domain under the CC0 license.
|
||||
|
||||
"""
|
||||
Simple Bot to reply to Telegram messages.
|
||||
|
||||
First, a few handler functions are defined. Then, those functions are passed to
|
||||
the Application and registered at their respective places.
|
||||
Then, the bot is started and runs until we press Ctrl-C on the command line.
|
||||
|
||||
Usage:
|
||||
Basic Echobot example, repeats messages.
|
||||
Press Ctrl-C on the command line or send a signal to the process to stop the
|
||||
bot.
|
||||
"""
|
||||
|
||||
import math
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import numpy as np
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
import types, torch, copy, time
|
||||
from typing import List
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch._C._jit_set_autocast_mode(False)
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
MyStatic = torch.jit.script
|
||||
|
||||
########################################################################################################
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers import decoders
|
||||
tokenizer = Tokenizer.from_file("tokenizer_vocab/tokenizer-tami-wordpiece-hedc4.json")
|
||||
tokenizer.decoder = decoders.WordPiece()
|
||||
|
||||
args = types.SimpleNamespace()
|
||||
args.MODEL_NAME = 'tami-models/rwkv-final-maybe-knesset'
|
||||
args.n_layer = 12
|
||||
args.n_embd = 512
|
||||
args.vocab_size = 65536
|
||||
args.head_size = 64
|
||||
|
||||
#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען"
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_RNN(MyModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.n_embd = args.n_embd
|
||||
self.n_layer = args.n_layer
|
||||
self.eval()
|
||||
|
||||
self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
|
||||
z = self.z
|
||||
z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias'])
|
||||
|
||||
keys = list(z.keys())
|
||||
for k in keys:
|
||||
if '.time_' in k: z[k] = z[k].squeeze()
|
||||
if k.endswith('.time_decay'): z[k] = z[k].float()
|
||||
if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float()
|
||||
for k in keys:
|
||||
if k.endswith('maa_w'):
|
||||
z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1)
|
||||
del z[k]
|
||||
del z[k.replace('maa_w','maa_k')]
|
||||
del z[k.replace('maa_w','maa_v')]
|
||||
del z[k.replace('maa_w','maa_r')]
|
||||
del z[k.replace('maa_w','maa_g')]
|
||||
|
||||
self.n_head = z['blocks.0.att.time_faaaa'].shape[0]
|
||||
self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head
|
||||
assert self.head_size == args.head_size
|
||||
|
||||
@MyFunction
|
||||
def forward(self, token:int, state:List[torch.Tensor]):
|
||||
with torch.no_grad():
|
||||
z = self.z
|
||||
x = z['emb.weight'][token]
|
||||
for i in range(self.n_layer):
|
||||
bbb = f'blocks.{i}.'
|
||||
att = f'blocks.{i}.att.'
|
||||
ffn = f'blocks.{i}.ffn.'
|
||||
|
||||
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
|
||||
xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1],
|
||||
z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'],
|
||||
z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'],
|
||||
z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'],
|
||||
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
|
||||
x = x + xx
|
||||
|
||||
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
|
||||
xx, state[i*3+2] = channel_mixing(xx, state[i*3+2],
|
||||
z[ffn+'time_maa_k'], z[ffn+'time_maa_r'],
|
||||
z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight'])
|
||||
x = x + xx
|
||||
|
||||
x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
|
||||
x = z['head.weight'] @ x
|
||||
return x, state
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
|
||||
sx = x_prev - x
|
||||
xxx = x + sx * maa_x # C
|
||||
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L
|
||||
xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C
|
||||
xxx = xxx + maa_wkvrg
|
||||
xxx = xxx * sx.expand(5, -1) + x.expand(5, -1)
|
||||
w, k, v, r, g = xxx.unbind(dim=0)
|
||||
|
||||
w = torch.tanh(w @ td_w1) @ td_w2
|
||||
w = w.float() + time_decay
|
||||
# assert w.dtype == torch.float
|
||||
w = torch.exp(-torch.exp(w))
|
||||
|
||||
k = (kw @ k).view(H, N, 1)
|
||||
v = (vw @ v).view(H, 1, N)
|
||||
r = (rw @ r).view(H, 1, N)
|
||||
g = torch.nn.functional.silu(gw @ g)
|
||||
|
||||
kv = (k @ v).float()
|
||||
out = r @ (time_faaaa * kv + state).to(torch.bfloat16)
|
||||
|
||||
state = kv + w.view(H, N, 1) * state
|
||||
|
||||
out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
|
||||
return ow @ (out * g), x, state
|
||||
try:
|
||||
time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
except:
|
||||
time_mixing = torch.jit.script(time_mixing__)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
|
||||
sx = x_prev - x
|
||||
k = x + sx * time_maa_k
|
||||
r = x + sx * time_maa_r
|
||||
|
||||
r = torch.sigmoid(rw @ r)
|
||||
k = torch.relu(kw @ k) ** 2
|
||||
|
||||
return r * (vw @ k), x
|
||||
try:
|
||||
channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
except:
|
||||
channel_mixing = torch.jit.script(channel_mixing__)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
|
||||
model = RWKV_RNN(args)
|
||||
|
||||
TEMPERATURE = 1.5
|
||||
NUM_TRIALS = 1
|
||||
LENGTH_PER_TRIAL = 2048
|
||||
MIN_P = 0.1
|
||||
EOS_DECAY_DIV = 25
|
||||
EOS_IGNORE = 25
|
||||
|
||||
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
|
||||
|
||||
global_init_state = [None for _ in range(args.n_layer * 3)]
|
||||
for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
|
||||
global_init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
|
||||
global_init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda")
|
||||
global_init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
|
||||
|
||||
def generate_text(context_str):
|
||||
context_tokens = tokenizer.encode(context_str).ids
|
||||
state = global_init_state.copy()
|
||||
for token in context_tokens:
|
||||
out, state = model.forward(token, state)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
all_tokens = []
|
||||
all_tokens += context_tokens
|
||||
|
||||
added_tokens_count = 0
|
||||
for i in range(LENGTH_PER_TRIAL):
|
||||
added_tokens_count+=1
|
||||
|
||||
# applying temperature param
|
||||
scores = out / TEMPERATURE
|
||||
|
||||
min_p = MIN_P
|
||||
|
||||
# convert logits to probabilities for min_p filtering
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
|
||||
# decaying eos token to induce longer generation
|
||||
if (added_tokens_count <= EOS_IGNORE):
|
||||
decay_factor = 0
|
||||
else:
|
||||
decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV))
|
||||
probs[0] *= decay_factor
|
||||
|
||||
# Get the probability of the top token for each sequence in the batch
|
||||
max_prob, _ = (probs).max(dim=-1, keepdim=True)
|
||||
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
|
||||
scaled_min_p = min_p * max_prob
|
||||
# Create a mask for tokens that have a probability less than the scaled min_p
|
||||
indices_to_remove = probs < scaled_min_p
|
||||
|
||||
scores.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
#print(scores[scores != float("-Inf")])
|
||||
|
||||
# convert min_p filtered logits to probabilities
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
# reapply eos token decay to new probabilities
|
||||
probs[0] *= decay_factor
|
||||
|
||||
token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
|
||||
|
||||
#print(token)
|
||||
|
||||
all_tokens += [token]
|
||||
if len(all_tokens) > 3:
|
||||
# if the last 3 tokens are all the same break
|
||||
if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break
|
||||
|
||||
if token == 0: break
|
||||
|
||||
out, state = model.forward(token, state)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# we generated all the tokens we could
|
||||
output_string = tokenizer.decode(all_tokens)
|
||||
logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}")
|
||||
|
||||
# insert bolding of the input string TODO: wonky
|
||||
#output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):]
|
||||
|
||||
return output_string
|
||||
|
||||
import logging
|
||||
|
||||
from telegram import ForceReply, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
# Read telegram token from .telegram_token
|
||||
with open(".telegram_token") as f:
|
||||
TOKEN = f.read().strip()
|
||||
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(
|
||||
filename='tamilm.log',
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
# set higher logging level for httpx to avoid all GET and POST requests being logged
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define a few command handlers. These usually take the two arguments update and
|
||||
# context.
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /start is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text in groupchats")
|
||||
|
||||
|
||||
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text")
|
||||
|
||||
async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
message_text = ' '.join(context.args)
|
||||
|
||||
# if this is a reply to a message include it in the prompt
|
||||
if update.message.reply_to_message != None:
|
||||
message_text = update.message.reply_to_message.text + ' ' + message_text
|
||||
|
||||
generated_responce = generate_text(message_text)
|
||||
#logger.info(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
|
||||
|
||||
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
message_text = update.message.text
|
||||
if update.message.chat.type == "private":
|
||||
generated_responce = generate_text(message_text)
|
||||
#(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
"""Echo the user message."""
|
||||
#await update.message.reply_text(message_text)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the bot."""
|
||||
# Create the Application and pass it your bot's token.
|
||||
application = Application.builder().token(TOKEN).build()
|
||||
|
||||
# on different commands - answer in Telegram
|
||||
application.add_handler(CommandHandler("start", start))
|
||||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("complete", complete_command))
|
||||
|
||||
# on non command i.e message - echo the message on Telegram
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||||
|
||||
print("BOT READY")
|
||||
|
||||
# Run the bot until the user presses Ctrl-C
|
||||
application.run_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
336
bot_tokenmonster_cuda.py
Normal file
336
bot_tokenmonster_cuda.py
Normal file
|
@ -0,0 +1,336 @@
|
|||
#!/usr/bin/env python
|
||||
# pylint: disable=unused-argument
|
||||
# This program is dedicated to the public domain under the CC0 license.
|
||||
|
||||
"""
|
||||
Simple Bot to reply to Telegram messages.
|
||||
|
||||
First, a few handler functions are defined. Then, those functions are passed to
|
||||
the Application and registered at their respective places.
|
||||
Then, the bot is started and runs until we press Ctrl-C on the command line.
|
||||
|
||||
Usage:
|
||||
Basic Echobot example, repeats messages.
|
||||
Press Ctrl-C on the command line or send a signal to the process to stop the
|
||||
bot.
|
||||
"""
|
||||
import tokenmonster
|
||||
tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported")
|
||||
|
||||
import math
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import numpy as np
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
import types, torch, copy, time
|
||||
from typing import List
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch._C._jit_set_autocast_mode(False)
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
MyStatic = torch.jit.script
|
||||
|
||||
########################################################################################################
|
||||
|
||||
args = types.SimpleNamespace()
|
||||
args.MODEL_NAME = 'tami-models/asshole-rwkv-6t'
|
||||
args.n_layer = 12
|
||||
args.n_embd = 768
|
||||
args.vocab_size = 500
|
||||
args.head_size = 64
|
||||
|
||||
#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען"
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_RNN(MyModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.n_embd = args.n_embd
|
||||
self.n_layer = args.n_layer
|
||||
self.eval()
|
||||
|
||||
self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda')
|
||||
z = self.z
|
||||
z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias'])
|
||||
|
||||
keys = list(z.keys())
|
||||
for k in keys:
|
||||
if '.time_' in k: z[k] = z[k].squeeze()
|
||||
if k.endswith('.time_decay'): z[k] = z[k].float()
|
||||
if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float()
|
||||
for k in keys:
|
||||
if k.endswith('maa_w'):
|
||||
z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1)
|
||||
del z[k]
|
||||
del z[k.replace('maa_w','maa_k')]
|
||||
del z[k.replace('maa_w','maa_v')]
|
||||
del z[k.replace('maa_w','maa_r')]
|
||||
del z[k.replace('maa_w','maa_g')]
|
||||
|
||||
self.n_head = z['blocks.0.att.time_faaaa'].shape[0]
|
||||
self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head
|
||||
assert self.head_size == args.head_size
|
||||
|
||||
@MyFunction
|
||||
def forward(self, token:int, state:List[torch.Tensor]):
|
||||
with torch.no_grad():
|
||||
z = self.z
|
||||
x = z['emb.weight'][token]
|
||||
for i in range(self.n_layer):
|
||||
bbb = f'blocks.{i}.'
|
||||
att = f'blocks.{i}.att.'
|
||||
ffn = f'blocks.{i}.ffn.'
|
||||
|
||||
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
|
||||
xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1],
|
||||
z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'],
|
||||
z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'],
|
||||
z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'],
|
||||
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
|
||||
x = x + xx
|
||||
|
||||
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
|
||||
xx, state[i*3+2] = channel_mixing(xx, state[i*3+2],
|
||||
z[ffn+'time_maa_k'], z[ffn+'time_maa_r'],
|
||||
z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight'])
|
||||
x = x + xx
|
||||
|
||||
x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
|
||||
x = z['head.weight'] @ x
|
||||
return x, state
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
|
||||
sx = x_prev - x
|
||||
xxx = x + sx * maa_x # C
|
||||
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L
|
||||
xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C
|
||||
xxx = xxx + maa_wkvrg
|
||||
xxx = xxx * sx.expand(5, -1) + x.expand(5, -1)
|
||||
w, k, v, r, g = xxx.unbind(dim=0)
|
||||
|
||||
w = torch.tanh(w @ td_w1) @ td_w2
|
||||
w = w.float() + time_decay
|
||||
# assert w.dtype == torch.float
|
||||
w = torch.exp(-torch.exp(w))
|
||||
|
||||
k = (kw @ k).view(H, N, 1)
|
||||
v = (vw @ v).view(H, 1, N)
|
||||
r = (rw @ r).view(H, 1, N)
|
||||
g = torch.nn.functional.silu(gw @ g)
|
||||
|
||||
kv = (k @ v).float()
|
||||
out = r @ (time_faaaa * kv + state).to(torch.bfloat16)
|
||||
|
||||
state = kv + w.view(H, N, 1) * state
|
||||
|
||||
out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5)
|
||||
return ow @ (out * g), x, state
|
||||
try:
|
||||
time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
except:
|
||||
time_mixing = torch.jit.script(time_mixing__)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw):
|
||||
sx = x_prev - x
|
||||
k = x + sx * time_maa_k
|
||||
r = x + sx * time_maa_r
|
||||
|
||||
r = torch.sigmoid(rw @ r)
|
||||
k = torch.relu(kw @ k) ** 2
|
||||
|
||||
return r * (vw @ k), x
|
||||
try:
|
||||
channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
except:
|
||||
channel_mixing = torch.jit.script(channel_mixing__)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...')
|
||||
model = RWKV_RNN(args)
|
||||
|
||||
TEMPERATURE = 1.5
|
||||
NUM_TRIALS = 1
|
||||
LENGTH_PER_TRIAL = 2048
|
||||
MIN_P = 0.1
|
||||
EOS_DECAY_DIV = 25
|
||||
EOS_IGNORE = 25
|
||||
|
||||
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
|
||||
|
||||
global_init_state = [None for _ in range(args.n_layer * 3)]
|
||||
for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
|
||||
global_init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
|
||||
global_init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda")
|
||||
global_init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda")
|
||||
|
||||
def generate_text(context_str):
|
||||
context_tokens = tokenizer.tokenize(context_str)
|
||||
print(context_tokens)
|
||||
state = global_init_state.copy()
|
||||
for token in context_tokens:
|
||||
out, state = model.forward(token, state)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
all_tokens = []
|
||||
all_tokens += context_tokens.tolist()
|
||||
|
||||
added_tokens_count = 0
|
||||
for i in range(LENGTH_PER_TRIAL):
|
||||
added_tokens_count+=1
|
||||
|
||||
# applying temperature param
|
||||
scores = out / TEMPERATURE
|
||||
|
||||
min_p = MIN_P
|
||||
|
||||
# convert logits to probabilities for min_p filtering
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
|
||||
# decaying eos token to induce longer generation
|
||||
if (added_tokens_count <= EOS_IGNORE):
|
||||
decay_factor = 0
|
||||
else:
|
||||
decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV))
|
||||
probs[0] *= decay_factor
|
||||
|
||||
# Get the probability of the top token for each sequence in the batch
|
||||
max_prob, _ = (probs).max(dim=-1, keepdim=True)
|
||||
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
|
||||
scaled_min_p = min_p * max_prob
|
||||
# Create a mask for tokens that have a probability less than the scaled min_p
|
||||
indices_to_remove = probs < scaled_min_p
|
||||
|
||||
scores.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
#print(scores[scores != float("-Inf")])
|
||||
|
||||
# convert min_p filtered logits to probabilities
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
# reapply eos token decay to new probabilities
|
||||
probs[0] *= decay_factor
|
||||
|
||||
token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
|
||||
|
||||
#print(token)
|
||||
|
||||
all_tokens += [token]
|
||||
if len(all_tokens) > 3:
|
||||
# if the last 3 tokens are all the same break
|
||||
if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break
|
||||
|
||||
# find if any one token is overrepresented in the generated output and then assume the output is repeating
|
||||
if len(all_tokens) > 256:
|
||||
if any(num > len(all_tokens/4) for num in np.unique_counts(all_tokens).counts): break
|
||||
|
||||
if token == 0: break
|
||||
|
||||
out, state = model.forward(token, state)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# we generated all the tokens we could
|
||||
output_string = tokenizer.decode(all_tokens)
|
||||
logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}")
|
||||
|
||||
# insert bolding of the input string TODO: wonky
|
||||
#output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):]
|
||||
|
||||
return output_string
|
||||
|
||||
import logging
|
||||
|
||||
from telegram import ForceReply, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
# Read telegram token from .telegram_token
|
||||
with open(".telegram_token") as f:
|
||||
TOKEN = f.read().strip()
|
||||
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(
|
||||
filename='tamilm.log',
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
# set higher logging level for httpx to avoid all GET and POST requests being logged
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define a few command handlers. These usually take the two arguments update and
|
||||
# context.
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /start is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text in groupchats")
|
||||
|
||||
|
||||
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text")
|
||||
|
||||
async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
message_text = ' '.join(context.args)
|
||||
|
||||
# if this is a reply to a message include it in the prompt
|
||||
if update.message.reply_to_message != None:
|
||||
message_text = update.message.reply_to_message.text + ' ' + message_text
|
||||
|
||||
# format input into req structure
|
||||
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||||
|
||||
generated_responce = generate_text(message_text)
|
||||
#logger.info(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
|
||||
|
||||
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
message_text = update.message.text
|
||||
if update.message.chat.type == "private":
|
||||
# format input into req structure
|
||||
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||||
generated_responce = generate_text(message_text)
|
||||
#(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
"""Echo the user message."""
|
||||
#await update.message.reply_text(message_text)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the bot."""
|
||||
# Create the Application and pass it your bot's token.
|
||||
application = Application.builder().token(TOKEN).build()
|
||||
|
||||
# on different commands - answer in Telegram
|
||||
application.add_handler(CommandHandler("start", start))
|
||||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("complete", complete_command))
|
||||
|
||||
# on non command i.e message - echo the message on Telegram
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||||
|
||||
print("BOT READY")
|
||||
|
||||
# Run the bot until the user presses Ctrl-C
|
||||
application.run_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
302
bot_tokenmonster_rwkv_cpp.py
Normal file
302
bot_tokenmonster_rwkv_cpp.py
Normal file
|
@ -0,0 +1,302 @@
|
|||
#!/usr/bin/env python
|
||||
# pylint: disable=unused-argument
|
||||
# This program is dedicated to the public domain under the CC0 license.
|
||||
|
||||
"""
|
||||
Simple Bot to reply to Telegram messages.
|
||||
|
||||
First, a few handler functions are defined. Then, those functions are passed to
|
||||
the Application and registered at their respective places.
|
||||
Then, the bot is started and runs until we press Ctrl-C on the command line.
|
||||
|
||||
Usage:
|
||||
Basic Echobot example, repeats messages.
|
||||
Press Ctrl-C on the command line or send a signal to the process to stop the
|
||||
bot.
|
||||
"""
|
||||
|
||||
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
|
||||
|
||||
import tokenmonster
|
||||
tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported")
|
||||
|
||||
import math
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import numpy as np
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
import types, torch, copy, time
|
||||
from typing import List
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch._C._jit_set_autocast_mode(False)
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
MyStatic = torch.jit.script
|
||||
|
||||
########################################################################################################
|
||||
|
||||
args = types.SimpleNamespace()
|
||||
args.MODEL_NAME = 'tami-models/assholerwkv-240.ggml'
|
||||
args.n_layer = 12
|
||||
args.n_embd = 768
|
||||
args.vocab_size = 500
|
||||
args.head_size = 64
|
||||
|
||||
#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען"
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
print(f'\nUsing RWKV cpp. Loading {args.MODEL_NAME} ...')
|
||||
#model = RWKV_RNN(args)
|
||||
|
||||
# Load the model.
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
model = rwkv_cpp_model.RWKVModel(library, args.MODEL_NAME)
|
||||
|
||||
TEMPERATURE = 0.9
|
||||
NUM_TRIALS = 1
|
||||
LENGTH_PER_TRIAL = 768
|
||||
MIN_P = 0.05
|
||||
EOS_DECAY_DIV = 25
|
||||
EOS_IGNORE = 100
|
||||
|
||||
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
|
||||
|
||||
# docs for inputs, https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor.__call__
|
||||
def dry_token_penalty(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
multiplier = 5 #Controls the magnitude of the penalty for the shortest penalized sequences
|
||||
|
||||
base = 1.75 #Controls how fast the penalty grows with increasing sequence length
|
||||
|
||||
allowed_length = 2 # Longest sequence that can be repeated without being penalized
|
||||
|
||||
sequence_breakers = [10,4943,4944] #vocab idxs that break up sequences. aka newline, period, question mark etc etc
|
||||
|
||||
_range = 0 # repetition penalty range, how long back in tokens do we go
|
||||
# 0 means its the entire supplied sequence
|
||||
|
||||
if _range > 0:
|
||||
input_ids = input_ids[:, -_range:]
|
||||
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
# Raw integer must be extracted here to check for set membership.
|
||||
last_token = input_ids_row[-1].item()
|
||||
|
||||
if last_token in sequence_breakers: #TODO: Patch to use idxs
|
||||
continue
|
||||
|
||||
# Exclude the last token as it always matches.
|
||||
match_indices = (input_ids_row[:-1] == last_token).nonzero()
|
||||
|
||||
# Stores the maximum matching sequence length
|
||||
# for each token immediately following the sequence in the input.
|
||||
match_lengths = {}
|
||||
|
||||
for i in match_indices:
|
||||
next_token = input_ids_row[i+1].item()
|
||||
|
||||
if next_token in sequence_breakers:
|
||||
continue
|
||||
|
||||
# We have already found that `last_token` matches at this index,
|
||||
# so the match is at least of length 1.
|
||||
match_length = 1
|
||||
|
||||
# Extend the match backwards as far as possible.
|
||||
while True:
|
||||
j = i - match_length
|
||||
if j < 0:
|
||||
# Start of input reached.
|
||||
break
|
||||
|
||||
previous_token = input_ids_row[-(match_length+1)].item()
|
||||
if input_ids_row[j] != previous_token:
|
||||
# Start of match reached.
|
||||
break
|
||||
|
||||
if previous_token in sequence_breakers:
|
||||
# Sequence-breaking token reached.
|
||||
break
|
||||
|
||||
match_length += 1
|
||||
|
||||
if next_token in match_lengths:
|
||||
match_lengths[next_token] = max(match_length, match_lengths[next_token])
|
||||
else:
|
||||
match_lengths[next_token] = match_length
|
||||
|
||||
# Apply penalties.
|
||||
for token, match_length in match_lengths.items():
|
||||
if match_length >= allowed_length:
|
||||
penalty = multiplier * base ** (match_length - allowed_length)
|
||||
scores_row[token] -= penalty
|
||||
|
||||
return scores
|
||||
|
||||
def generate_text(context_str):
|
||||
context_tokens = tokenizer.tokenize(context_str)
|
||||
print(context_tokens)
|
||||
out, state = model.eval_sequence(context_tokens, None, None, None, use_numpy=False)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
all_tokens = []
|
||||
all_tokens += context_tokens.tolist()
|
||||
|
||||
added_tokens_count = 0
|
||||
for i in range(LENGTH_PER_TRIAL):
|
||||
added_tokens_count+=1
|
||||
|
||||
out = dry_token_penalty(torch.tensor([all_tokens],dtype=torch.long),[out])[0] # returns a tensor which is a list of lists
|
||||
|
||||
# applying temperature param
|
||||
scores = out / TEMPERATURE
|
||||
|
||||
min_p = MIN_P
|
||||
|
||||
# convert logits to probabilities for min_p filtering
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
|
||||
# decaying eos token to induce longer generation
|
||||
if (added_tokens_count <= EOS_IGNORE):
|
||||
decay_factor = 0
|
||||
else:
|
||||
decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV))
|
||||
probs[0] *= decay_factor
|
||||
|
||||
# Get the probability of the top token for each sequence in the batch
|
||||
max_prob, _ = (probs).max(dim=-1, keepdim=True)
|
||||
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
|
||||
scaled_min_p = min_p * max_prob
|
||||
# Create a mask for tokens that have a probability less than the scaled min_p
|
||||
indices_to_remove = probs < scaled_min_p
|
||||
|
||||
scores.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
#print(scores[scores != float("-Inf")])
|
||||
|
||||
# convert min_p filtered logits to probabilities
|
||||
probs = torch.softmax(scores.float(),dim=-1)
|
||||
# reapply eos token decay to new probabilities
|
||||
probs[0] *= decay_factor
|
||||
|
||||
# apply dry here because I think its independent of the probs accross the vocab
|
||||
# aka the probs dont influence other tokens
|
||||
#probs = dry_token_penalty(all_tokens,probs)
|
||||
|
||||
token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
|
||||
|
||||
#print(token)
|
||||
|
||||
all_tokens += [token]
|
||||
if len(all_tokens) > 3:
|
||||
# if the last 3 tokens are all the same break
|
||||
if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break
|
||||
|
||||
if token == 0: break
|
||||
|
||||
out, state = model.eval_sequence([token], state, None, None, use_numpy=False)
|
||||
|
||||
#torch.cuda.synchronize() were cpuing now dont need allat
|
||||
|
||||
# we generated all the tokens we could
|
||||
output_string = tokenizer.decode(all_tokens)
|
||||
logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}")
|
||||
|
||||
# insert bolding of the input string TODO: wonky
|
||||
#output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):]
|
||||
|
||||
# replace \n with newline in output_string
|
||||
output_string = output_string.replace("\\n", "\n")
|
||||
|
||||
return output_string
|
||||
|
||||
import logging
|
||||
|
||||
from telegram import ForceReply, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
# Read telegram token from .telegram_token
|
||||
with open(".telegram_token") as f:
|
||||
TOKEN = f.read().strip()
|
||||
|
||||
|
||||
# Enable logging
|
||||
logging.basicConfig(
|
||||
filename='tamilm.log',
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||||
)
|
||||
# set higher logging level for httpx to avoid all GET and POST requests being logged
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define a few command handlers. These usually take the two arguments update and
|
||||
# context.
|
||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /start is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text in groupchats")
|
||||
|
||||
|
||||
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
await update.message.reply_text("use /complete to complete your hebrew text")
|
||||
|
||||
async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Send a message when the command /help is issued."""
|
||||
message_text = ' '.join(context.args)
|
||||
|
||||
# if this is a reply to a message include it in the prompt
|
||||
if update.message.reply_to_message != None:
|
||||
message_text = update.message.reply_to_message.text + ' ' + message_text
|
||||
|
||||
# format input into req structure
|
||||
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||||
|
||||
generated_responce = generate_text(message_text)
|
||||
#logger.info(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
|
||||
|
||||
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
message_text = update.message.text
|
||||
if update.message.chat.type == "private":
|
||||
# format input into req structure
|
||||
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||||
generated_responce = generate_text(message_text)
|
||||
#(generated_responce)
|
||||
await update.message.reply_text(generated_responce)
|
||||
"""Echo the user message."""
|
||||
#await update.message.reply_text(message_text)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the bot."""
|
||||
# Create the Application and pass it your bot's token.
|
||||
application = Application.builder().token(TOKEN).build()
|
||||
|
||||
# on different commands - answer in Telegram
|
||||
application.add_handler(CommandHandler("start", start))
|
||||
application.add_handler(CommandHandler("help", help_command))
|
||||
application.add_handler(CommandHandler("complete", complete_command))
|
||||
|
||||
# on non command i.e message - echo the message on Telegram
|
||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||||
|
||||
print("BOT READY")
|
||||
|
||||
# Run the bot until the user presses Ctrl-C
|
||||
application.run_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
tami-models/MODELS_GO_HERE
Normal file
0
tami-models/MODELS_GO_HERE
Normal file
0
tokenizer_vocab/VOCABS_GO_HERE
Normal file
0
tokenizer_vocab/VOCABS_GO_HERE
Normal file
Loading…
Reference in New Issue
Block a user