diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..19e23ee --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.telegram_token +tami-models/* +!tami-models/MODELS_GO_HERE +tokenizer_vocab/* +!tokenizer_vocab/VOCABS_GO_HERE +librwkv.so +tamilm.log +rwkv_cpp \ No newline at end of file diff --git a/bot_hf_tokenizers_cuda.py b/bot_hf_tokenizers_cuda.py new file mode 100644 index 0000000..e8fb063 --- /dev/null +++ b/bot_hf_tokenizers_cuda.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python +# pylint: disable=unused-argument +# This program is dedicated to the public domain under the CC0 license. + +""" +Simple Bot to reply to Telegram messages. + +First, a few handler functions are defined. Then, those functions are passed to +the Application and registered at their respective places. +Then, the bot is started and runs until we press Ctrl-C on the command line. + +Usage: +Basic Echobot example, repeats messages. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +import math +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import numpy as np +np.set_printoptions(precision=4, suppress=True, linewidth=200) +import types, torch, copy, time +from typing import List +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True +torch._C._jit_set_autocast_mode(False) + +import torch.nn as nn +from torch.nn import functional as F + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +MyStatic = torch.jit.script + +######################################################################################################## + +from tokenizers import Tokenizer +from tokenizers import decoders +tokenizer = Tokenizer.from_file("tokenizer_vocab/tokenizer-tami-wordpiece-hedc4.json") +tokenizer.decoder = decoders.WordPiece() + +args = types.SimpleNamespace() +args.MODEL_NAME = 'tami-models/rwkv-final-maybe-knesset' +args.n_layer = 12 +args.n_embd = 512 +args.vocab_size = 65536 +args.head_size = 64 + +#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען" + + +######################################################################################################## + +class RWKV_RNN(MyModule): + def __init__(self, args): + super().__init__() + self.args = args + self.n_embd = args.n_embd + self.n_layer = args.n_layer + self.eval() + + self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') + z = self.z + z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) + + keys = list(z.keys()) + for k in keys: + if '.time_' in k: z[k] = z[k].squeeze() + if k.endswith('.time_decay'): z[k] = z[k].float() + if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float() + for k in keys: + if k.endswith('maa_w'): + z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) + del z[k] + del z[k.replace('maa_w','maa_k')] + del z[k.replace('maa_w','maa_v')] + del z[k.replace('maa_w','maa_r')] + del z[k.replace('maa_w','maa_g')] + + self.n_head = z['blocks.0.att.time_faaaa'].shape[0] + self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head + assert self.head_size == args.head_size + + @MyFunction + def forward(self, token:int, state:List[torch.Tensor]): + with torch.no_grad(): + z = self.z + x = z['emb.weight'][token] + for i in range(self.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) + xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1], + z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'], + z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'], + z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'], + z[att+'ln_x.weight'], z[att+'ln_x.bias']) + x = x + xx + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) + xx, state[i*3+2] = channel_mixing(xx, state[i*3+2], + z[ffn+'time_maa_k'], z[ffn+'time_maa_r'], + z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight']) + x = x + xx + + x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) + x = z['head.weight'] @ x + return x, state + +######################################################################################################## + +def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): + sx = x_prev - x + xxx = x + sx * maa_x # C + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L + xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C + xxx = xxx + maa_wkvrg + xxx = xxx * sx.expand(5, -1) + x.expand(5, -1) + w, k, v, r, g = xxx.unbind(dim=0) + + w = torch.tanh(w @ td_w1) @ td_w2 + w = w.float() + time_decay + # assert w.dtype == torch.float + w = torch.exp(-torch.exp(w)) + + k = (kw @ k).view(H, N, 1) + v = (vw @ v).view(H, 1, N) + r = (rw @ r).view(H, 1, N) + g = torch.nn.functional.silu(gw @ g) + + kv = (k @ v).float() + out = r @ (time_faaaa * kv + state).to(torch.bfloat16) + + state = kv + w.view(H, N, 1) * state + + out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5) + return ow @ (out * g), x, state +try: + time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + time_mixing = torch.jit.script(time_mixing__) + +######################################################################################################## + +def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): + sx = x_prev - x + k = x + sx * time_maa_k + r = x + sx * time_maa_r + + r = torch.sigmoid(rw @ r) + k = torch.relu(kw @ k) ** 2 + + return r * (vw @ k), x +try: + channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + channel_mixing = torch.jit.script(channel_mixing__) + +######################################################################################################## + +print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...') +model = RWKV_RNN(args) + +TEMPERATURE = 1.5 +NUM_TRIALS = 1 +LENGTH_PER_TRIAL = 2048 +MIN_P = 0.1 +EOS_DECAY_DIV = 25 +EOS_IGNORE = 25 + +print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') + +global_init_state = [None for _ in range(args.n_layer * 3)] +for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev + global_init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + global_init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda") + global_init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + +def generate_text(context_str): + context_tokens = tokenizer.encode(context_str).ids + state = global_init_state.copy() + for token in context_tokens: + out, state = model.forward(token, state) + + ######################################################################################################## + + all_tokens = [] + all_tokens += context_tokens + + added_tokens_count = 0 + for i in range(LENGTH_PER_TRIAL): + added_tokens_count+=1 + + # applying temperature param + scores = out / TEMPERATURE + + min_p = MIN_P + + # convert logits to probabilities for min_p filtering + probs = torch.softmax(scores.float(),dim=-1) + + # decaying eos token to induce longer generation + if (added_tokens_count <= EOS_IGNORE): + decay_factor = 0 + else: + decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV)) + probs[0] *= decay_factor + + # Get the probability of the top token for each sequence in the batch + max_prob, _ = (probs).max(dim=-1, keepdim=True) + # Calculate the actual min_p threshold by scaling min_p with the top token's probability + scaled_min_p = min_p * max_prob + # Create a mask for tokens that have a probability less than the scaled min_p + indices_to_remove = probs < scaled_min_p + + scores.masked_fill_(indices_to_remove, float("-Inf")) + #print(scores[scores != float("-Inf")]) + + # convert min_p filtered logits to probabilities + probs = torch.softmax(scores.float(),dim=-1) + # reapply eos token decay to new probabilities + probs[0] *= decay_factor + + token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + + #print(token) + + all_tokens += [token] + if len(all_tokens) > 3: + # if the last 3 tokens are all the same break + if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break + + if token == 0: break + + out, state = model.forward(token, state) + + torch.cuda.synchronize() + + # we generated all the tokens we could + output_string = tokenizer.decode(all_tokens) + logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}") + + # insert bolding of the input string TODO: wonky + #output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):] + + return output_string + +import logging + +from telegram import ForceReply, Update +from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters +# Read telegram token from .telegram_token +with open(".telegram_token") as f: + TOKEN = f.read().strip() + + +# Enable logging +logging.basicConfig( + filename='tamilm.log', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +# set higher logging level for httpx to avoid all GET and POST requests being logged +logging.getLogger("httpx").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + + +# Define a few command handlers. These usually take the two arguments update and +# context. +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /start is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text in groupchats") + + +async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text") + +async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + message_text = ' '.join(context.args) + + # if this is a reply to a message include it in the prompt + if update.message.reply_to_message != None: + message_text = update.message.reply_to_message.text + ' ' + message_text + + generated_responce = generate_text(message_text) + #logger.info(generated_responce) + await update.message.reply_text(generated_responce) + + +async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + message_text = update.message.text + if update.message.chat.type == "private": + generated_responce = generate_text(message_text) + #(generated_responce) + await update.message.reply_text(generated_responce) + """Echo the user message.""" + #await update.message.reply_text(message_text) + + +def main() -> None: + """Start the bot.""" + # Create the Application and pass it your bot's token. + application = Application.builder().token(TOKEN).build() + + # on different commands - answer in Telegram + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", help_command)) + application.add_handler(CommandHandler("complete", complete_command)) + + # on non command i.e message - echo the message on Telegram + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) + + print("BOT READY") + + # Run the bot until the user presses Ctrl-C + application.run_polling(allowed_updates=Update.ALL_TYPES) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/bot_tokenmonster_cuda.py b/bot_tokenmonster_cuda.py new file mode 100644 index 0000000..2e1a664 --- /dev/null +++ b/bot_tokenmonster_cuda.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python +# pylint: disable=unused-argument +# This program is dedicated to the public domain under the CC0 license. + +""" +Simple Bot to reply to Telegram messages. + +First, a few handler functions are defined. Then, those functions are passed to +the Application and registered at their respective places. +Then, the bot is started and runs until we press Ctrl-C on the command line. + +Usage: +Basic Echobot example, repeats messages. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" +import tokenmonster +tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported") + +import math +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import numpy as np +np.set_printoptions(precision=4, suppress=True, linewidth=200) +import types, torch, copy, time +from typing import List +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True +torch._C._jit_set_autocast_mode(False) + +import torch.nn as nn +from torch.nn import functional as F + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +MyStatic = torch.jit.script + +######################################################################################################## + +args = types.SimpleNamespace() +args.MODEL_NAME = 'tami-models/asshole-rwkv-6t' +args.n_layer = 12 +args.n_embd = 768 +args.vocab_size = 500 +args.head_size = 64 + +#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען" + + +######################################################################################################## + +class RWKV_RNN(MyModule): + def __init__(self, args): + super().__init__() + self.args = args + self.n_embd = args.n_embd + self.n_layer = args.n_layer + self.eval() + + self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') + z = self.z + z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) + + keys = list(z.keys()) + for k in keys: + if '.time_' in k: z[k] = z[k].squeeze() + if k.endswith('.time_decay'): z[k] = z[k].float() + if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float() + for k in keys: + if k.endswith('maa_w'): + z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) + del z[k] + del z[k.replace('maa_w','maa_k')] + del z[k.replace('maa_w','maa_v')] + del z[k.replace('maa_w','maa_r')] + del z[k.replace('maa_w','maa_g')] + + self.n_head = z['blocks.0.att.time_faaaa'].shape[0] + self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head + assert self.head_size == args.head_size + + @MyFunction + def forward(self, token:int, state:List[torch.Tensor]): + with torch.no_grad(): + z = self.z + x = z['emb.weight'][token] + for i in range(self.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) + xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1], + z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'], + z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'], + z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'], + z[att+'ln_x.weight'], z[att+'ln_x.bias']) + x = x + xx + + xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) + xx, state[i*3+2] = channel_mixing(xx, state[i*3+2], + z[ffn+'time_maa_k'], z[ffn+'time_maa_r'], + z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight']) + x = x + xx + + x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) + x = z['head.weight'] @ x + return x, state + +######################################################################################################## + +def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): + sx = x_prev - x + xxx = x + sx * maa_x # C + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L + xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C + xxx = xxx + maa_wkvrg + xxx = xxx * sx.expand(5, -1) + x.expand(5, -1) + w, k, v, r, g = xxx.unbind(dim=0) + + w = torch.tanh(w @ td_w1) @ td_w2 + w = w.float() + time_decay + # assert w.dtype == torch.float + w = torch.exp(-torch.exp(w)) + + k = (kw @ k).view(H, N, 1) + v = (vw @ v).view(H, 1, N) + r = (rw @ r).view(H, 1, N) + g = torch.nn.functional.silu(gw @ g) + + kv = (k @ v).float() + out = r @ (time_faaaa * kv + state).to(torch.bfloat16) + + state = kv + w.view(H, N, 1) * state + + out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5) + return ow @ (out * g), x, state +try: + time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + time_mixing = torch.jit.script(time_mixing__) + +######################################################################################################## + +def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): + sx = x_prev - x + k = x + sx * time_maa_k + r = x + sx * time_maa_r + + r = torch.sigmoid(rw @ r) + k = torch.relu(kw @ k) ** 2 + + return r * (vw @ k), x +try: + channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) +except: + channel_mixing = torch.jit.script(channel_mixing__) + +######################################################################################################## + +print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...') +model = RWKV_RNN(args) + +TEMPERATURE = 1.5 +NUM_TRIALS = 1 +LENGTH_PER_TRIAL = 2048 +MIN_P = 0.1 +EOS_DECAY_DIV = 25 +EOS_IGNORE = 25 + +print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') + +global_init_state = [None for _ in range(args.n_layer * 3)] +for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev + global_init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + global_init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda") + global_init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") + +def generate_text(context_str): + context_tokens = tokenizer.tokenize(context_str) + print(context_tokens) + state = global_init_state.copy() + for token in context_tokens: + out, state = model.forward(token, state) + + ######################################################################################################## + + all_tokens = [] + all_tokens += context_tokens.tolist() + + added_tokens_count = 0 + for i in range(LENGTH_PER_TRIAL): + added_tokens_count+=1 + + # applying temperature param + scores = out / TEMPERATURE + + min_p = MIN_P + + # convert logits to probabilities for min_p filtering + probs = torch.softmax(scores.float(),dim=-1) + + # decaying eos token to induce longer generation + if (added_tokens_count <= EOS_IGNORE): + decay_factor = 0 + else: + decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV)) + probs[0] *= decay_factor + + # Get the probability of the top token for each sequence in the batch + max_prob, _ = (probs).max(dim=-1, keepdim=True) + # Calculate the actual min_p threshold by scaling min_p with the top token's probability + scaled_min_p = min_p * max_prob + # Create a mask for tokens that have a probability less than the scaled min_p + indices_to_remove = probs < scaled_min_p + + scores.masked_fill_(indices_to_remove, float("-Inf")) + #print(scores[scores != float("-Inf")]) + + # convert min_p filtered logits to probabilities + probs = torch.softmax(scores.float(),dim=-1) + # reapply eos token decay to new probabilities + probs[0] *= decay_factor + + token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + + #print(token) + + all_tokens += [token] + if len(all_tokens) > 3: + # if the last 3 tokens are all the same break + if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break + + # find if any one token is overrepresented in the generated output and then assume the output is repeating + if len(all_tokens) > 256: + if any(num > len(all_tokens/4) for num in np.unique_counts(all_tokens).counts): break + + if token == 0: break + + out, state = model.forward(token, state) + + torch.cuda.synchronize() + + # we generated all the tokens we could + output_string = tokenizer.decode(all_tokens) + logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}") + + # insert bolding of the input string TODO: wonky + #output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):] + + return output_string + +import logging + +from telegram import ForceReply, Update +from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters +# Read telegram token from .telegram_token +with open(".telegram_token") as f: + TOKEN = f.read().strip() + + +# Enable logging +logging.basicConfig( + filename='tamilm.log', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +# set higher logging level for httpx to avoid all GET and POST requests being logged +logging.getLogger("httpx").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + + +# Define a few command handlers. These usually take the two arguments update and +# context. +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /start is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text in groupchats") + + +async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text") + +async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + message_text = ' '.join(context.args) + + # if this is a reply to a message include it in the prompt + if update.message.reply_to_message != None: + message_text = update.message.reply_to_message.text + ' ' + message_text + + # format input into req structure + message_text = f"[REQ]\n{message_text}\n[RES]\n" + + generated_responce = generate_text(message_text) + #logger.info(generated_responce) + await update.message.reply_text(generated_responce) + + +async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + message_text = update.message.text + if update.message.chat.type == "private": + # format input into req structure + message_text = f"[REQ]\n{message_text}\n[RES]\n" + generated_responce = generate_text(message_text) + #(generated_responce) + await update.message.reply_text(generated_responce) + """Echo the user message.""" + #await update.message.reply_text(message_text) + + +def main() -> None: + """Start the bot.""" + # Create the Application and pass it your bot's token. + application = Application.builder().token(TOKEN).build() + + # on different commands - answer in Telegram + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", help_command)) + application.add_handler(CommandHandler("complete", complete_command)) + + # on non command i.e message - echo the message on Telegram + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) + + print("BOT READY") + + # Run the bot until the user presses Ctrl-C + application.run_polling(allowed_updates=Update.ALL_TYPES) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/bot_tokenmonster_rwkv_cpp.py b/bot_tokenmonster_rwkv_cpp.py new file mode 100644 index 0000000..d2a7171 --- /dev/null +++ b/bot_tokenmonster_rwkv_cpp.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python +# pylint: disable=unused-argument +# This program is dedicated to the public domain under the CC0 license. + +""" +Simple Bot to reply to Telegram messages. + +First, a few handler functions are defined. Then, those functions are passed to +the Application and registered at their respective places. +Then, the bot is started and runs until we press Ctrl-C on the command line. + +Usage: +Basic Echobot example, repeats messages. +Press Ctrl-C on the command line or send a signal to the process to stop the +bot. +""" + +from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model + +import tokenmonster +tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported") + +import math +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import numpy as np +np.set_printoptions(precision=4, suppress=True, linewidth=200) +import types, torch, copy, time +from typing import List +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True +# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True +torch._C._jit_set_autocast_mode(False) + +import torch.nn as nn +from torch.nn import functional as F + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +MyStatic = torch.jit.script + +######################################################################################################## + +args = types.SimpleNamespace() +args.MODEL_NAME = 'tami-models/assholerwkv-240.ggml' +args.n_layer = 12 +args.n_embd = 768 +args.vocab_size = 500 +args.head_size = 64 + +#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען" + + +######################################################################################################## + +print(f'\nUsing RWKV cpp. Loading {args.MODEL_NAME} ...') +#model = RWKV_RNN(args) + +# Load the model. +library = rwkv_cpp_shared_library.load_rwkv_shared_library() +model = rwkv_cpp_model.RWKVModel(library, args.MODEL_NAME) + +TEMPERATURE = 0.9 +NUM_TRIALS = 1 +LENGTH_PER_TRIAL = 768 +MIN_P = 0.05 +EOS_DECAY_DIV = 25 +EOS_IGNORE = 100 + +print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') + +# docs for inputs, https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor.__call__ +def dry_token_penalty(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + multiplier = 5 #Controls the magnitude of the penalty for the shortest penalized sequences + + base = 1.75 #Controls how fast the penalty grows with increasing sequence length + + allowed_length = 2 # Longest sequence that can be repeated without being penalized + + sequence_breakers = [10,4943,4944] #vocab idxs that break up sequences. aka newline, period, question mark etc etc + + _range = 0 # repetition penalty range, how long back in tokens do we go + # 0 means its the entire supplied sequence + + if _range > 0: + input_ids = input_ids[:, -_range:] + + for input_ids_row, scores_row in zip(input_ids, scores): + # Raw integer must be extracted here to check for set membership. + last_token = input_ids_row[-1].item() + + if last_token in sequence_breakers: #TODO: Patch to use idxs + continue + + # Exclude the last token as it always matches. + match_indices = (input_ids_row[:-1] == last_token).nonzero() + + # Stores the maximum matching sequence length + # for each token immediately following the sequence in the input. + match_lengths = {} + + for i in match_indices: + next_token = input_ids_row[i+1].item() + + if next_token in sequence_breakers: + continue + + # We have already found that `last_token` matches at this index, + # so the match is at least of length 1. + match_length = 1 + + # Extend the match backwards as far as possible. + while True: + j = i - match_length + if j < 0: + # Start of input reached. + break + + previous_token = input_ids_row[-(match_length+1)].item() + if input_ids_row[j] != previous_token: + # Start of match reached. + break + + if previous_token in sequence_breakers: + # Sequence-breaking token reached. + break + + match_length += 1 + + if next_token in match_lengths: + match_lengths[next_token] = max(match_length, match_lengths[next_token]) + else: + match_lengths[next_token] = match_length + + # Apply penalties. + for token, match_length in match_lengths.items(): + if match_length >= allowed_length: + penalty = multiplier * base ** (match_length - allowed_length) + scores_row[token] -= penalty + + return scores + +def generate_text(context_str): + context_tokens = tokenizer.tokenize(context_str) + print(context_tokens) + out, state = model.eval_sequence(context_tokens, None, None, None, use_numpy=False) + + ######################################################################################################## + + all_tokens = [] + all_tokens += context_tokens.tolist() + + added_tokens_count = 0 + for i in range(LENGTH_PER_TRIAL): + added_tokens_count+=1 + + out = dry_token_penalty(torch.tensor([all_tokens],dtype=torch.long),[out])[0] # returns a tensor which is a list of lists + + # applying temperature param + scores = out / TEMPERATURE + + min_p = MIN_P + + # convert logits to probabilities for min_p filtering + probs = torch.softmax(scores.float(),dim=-1) + + # decaying eos token to induce longer generation + if (added_tokens_count <= EOS_IGNORE): + decay_factor = 0 + else: + decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV)) + probs[0] *= decay_factor + + # Get the probability of the top token for each sequence in the batch + max_prob, _ = (probs).max(dim=-1, keepdim=True) + # Calculate the actual min_p threshold by scaling min_p with the top token's probability + scaled_min_p = min_p * max_prob + # Create a mask for tokens that have a probability less than the scaled min_p + indices_to_remove = probs < scaled_min_p + + scores.masked_fill_(indices_to_remove, float("-Inf")) + #print(scores[scores != float("-Inf")]) + + # convert min_p filtered logits to probabilities + probs = torch.softmax(scores.float(),dim=-1) + # reapply eos token decay to new probabilities + probs[0] *= decay_factor + + # apply dry here because I think its independent of the probs accross the vocab + # aka the probs dont influence other tokens + #probs = dry_token_penalty(all_tokens,probs) + + token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) + + #print(token) + + all_tokens += [token] + if len(all_tokens) > 3: + # if the last 3 tokens are all the same break + if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break + + if token == 0: break + + out, state = model.eval_sequence([token], state, None, None, use_numpy=False) + + #torch.cuda.synchronize() were cpuing now dont need allat + + # we generated all the tokens we could + output_string = tokenizer.decode(all_tokens) + logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}") + + # insert bolding of the input string TODO: wonky + #output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):] + + # replace \n with newline in output_string + output_string = output_string.replace("\\n", "\n") + + return output_string + +import logging + +from telegram import ForceReply, Update +from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters +# Read telegram token from .telegram_token +with open(".telegram_token") as f: + TOKEN = f.read().strip() + + +# Enable logging +logging.basicConfig( + filename='tamilm.log', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) +# set higher logging level for httpx to avoid all GET and POST requests being logged +logging.getLogger("httpx").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + + +# Define a few command handlers. These usually take the two arguments update and +# context. +async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /start is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text in groupchats") + + +async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + await update.message.reply_text("use /complete to complete your hebrew text") + +async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Send a message when the command /help is issued.""" + message_text = ' '.join(context.args) + + # if this is a reply to a message include it in the prompt + if update.message.reply_to_message != None: + message_text = update.message.reply_to_message.text + ' ' + message_text + + # format input into req structure + message_text = f"[REQ]\n{message_text}\n[RES]\n" + + generated_responce = generate_text(message_text) + #logger.info(generated_responce) + await update.message.reply_text(generated_responce) + + +async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + message_text = update.message.text + if update.message.chat.type == "private": + # format input into req structure + message_text = f"[REQ]\n{message_text}\n[RES]\n" + generated_responce = generate_text(message_text) + #(generated_responce) + await update.message.reply_text(generated_responce) + """Echo the user message.""" + #await update.message.reply_text(message_text) + + +def main() -> None: + """Start the bot.""" + # Create the Application and pass it your bot's token. + application = Application.builder().token(TOKEN).build() + + # on different commands - answer in Telegram + application.add_handler(CommandHandler("start", start)) + application.add_handler(CommandHandler("help", help_command)) + application.add_handler(CommandHandler("complete", complete_command)) + + # on non command i.e message - echo the message on Telegram + application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) + + print("BOT READY") + + # Run the bot until the user presses Ctrl-C + application.run_polling(allowed_updates=Update.ALL_TYPES) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tami-models/MODELS_GO_HERE b/tami-models/MODELS_GO_HERE new file mode 100644 index 0000000..e69de29 diff --git a/tokenizer_vocab/VOCABS_GO_HERE b/tokenizer_vocab/VOCABS_GO_HERE new file mode 100644 index 0000000..e69de29