#!/usr/bin/env python # pylint: disable=unused-argument # This program is dedicated to the public domain under the CC0 license. """ Simple Bot to reply to Telegram messages. First, a few handler functions are defined. Then, those functions are passed to the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: Basic Echobot example, repeats messages. Press Ctrl-C on the command line or send a signal to the process to stop the bot. """ import tokenmonster tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported") import math ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np np.set_printoptions(precision=4, suppress=True, linewidth=200) import types, torch, copy, time from typing import List torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch._C._jit_set_autocast_mode(False) import torch.nn as nn from torch.nn import functional as F MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method MyStatic = torch.jit.script ######################################################################################################## args = types.SimpleNamespace() args.MODEL_NAME = 'tami-models/asshole-rwkv-6t' args.n_layer = 12 args.n_embd = 768 args.vocab_size = 500 args.head_size = 64 #context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען" ######################################################################################################## class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() self.args = args self.n_embd = args.n_embd self.n_layer = args.n_layer self.eval() self.z = torch.load(args.MODEL_NAME + '.pth', map_location='cuda') z = self.z z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) keys = list(z.keys()) for k in keys: if '.time_' in k: z[k] = z[k].squeeze() if k.endswith('.time_decay'): z[k] = z[k].float() if k.endswith('.time_faaaa'): z[k] = z[k].unsqueeze(-1).float() for k in keys: if k.endswith('maa_w'): z[k.replace('maa_w','maa_wkvrg')] = torch.concat([z[k],z[k.replace('maa_w','maa_k')],z[k.replace('maa_w','maa_v')],z[k.replace('maa_w','maa_r')],z[k.replace('maa_w','maa_g')]]).clone().reshape(5, -1) del z[k] del z[k.replace('maa_w','maa_k')] del z[k.replace('maa_w','maa_v')] del z[k.replace('maa_w','maa_r')] del z[k.replace('maa_w','maa_g')] self.n_head = z['blocks.0.att.time_faaaa'].shape[0] self.head_size = z['blocks.0.ln1.weight'].shape[0] // self.n_head assert self.head_size == args.head_size @MyFunction def forward(self, token:int, state:List[torch.Tensor]): with torch.no_grad(): z = self.z x = z['emb.weight'][token] for i in range(self.n_layer): bbb = f'blocks.{i}.' att = f'blocks.{i}.att.' ffn = f'blocks.{i}.ffn.' xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) xx, state[i*3+0], state[i*3+1] = time_mixing(self.n_head, self.head_size, xx, state[i*3+0], state[i*3+1], z[att+'time_maa_x'], z[att+'time_maa_wkvrg'], z[att+'time_maa_w1'], z[att+'time_maa_w2'], z[att+'time_decay_w1'], z[att+'time_decay_w2'], z[att+'time_faaaa'], z[att+'time_decay'], z[att+'key.weight'], z[att+'value.weight'], z[att+'receptance.weight'], z[att+'gate.weight'], z[att+'output.weight'], z[att+'ln_x.weight'], z[att+'ln_x.bias']) x = x + xx xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) xx, state[i*3+2] = channel_mixing(xx, state[i*3+2], z[ffn+'time_maa_k'], z[ffn+'time_maa_r'], z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'receptance.weight']) x = x + xx x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) x = z['head.weight'] @ x return x, state ######################################################################################################## def time_mixing__(H:int, N:int, x, x_prev, state, maa_x, maa_wkvrg, tm_w1, tm_w2, td_w1, td_w2, time_faaaa, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): sx = x_prev - x xxx = x + sx * maa_x # C xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) # C @ C*5L => 5L => 5*1*L xxx = torch.bmm(xxx, tm_w2).view(5, -1) # 5*1*L @ 5*L*C => 5*1*C => 5*C xxx = xxx + maa_wkvrg xxx = xxx * sx.expand(5, -1) + x.expand(5, -1) w, k, v, r, g = xxx.unbind(dim=0) w = torch.tanh(w @ td_w1) @ td_w2 w = w.float() + time_decay # assert w.dtype == torch.float w = torch.exp(-torch.exp(w)) k = (kw @ k).view(H, N, 1) v = (vw @ v).view(H, 1, N) r = (rw @ r).view(H, 1, N) g = torch.nn.functional.silu(gw @ g) kv = (k @ v).float() out = r @ (time_faaaa * kv + state).to(torch.bfloat16) state = kv + w.view(H, N, 1) * state out = torch.nn.functional.group_norm(out.view(1, H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N) # same as gn(x/8, eps=1e-5) return ow @ (out * g), x, state try: time_mixing = torch.compile(time_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) except: time_mixing = torch.jit.script(time_mixing__) ######################################################################################################## def channel_mixing__(x, x_prev, time_maa_k, time_maa_r, kw, vw, rw): sx = x_prev - x k = x + sx * time_maa_k r = x + sx * time_maa_r r = torch.sigmoid(rw @ r) k = torch.relu(kw @ k) ** 2 return r * (vw @ k), x try: channel_mixing = torch.compile(channel_mixing__, mode="max-autotune", fullgraph=True, dynamic=False) except: channel_mixing = torch.jit.script(channel_mixing__) ######################################################################################################## print(f'\nUsing CUDA bf16. Loading {args.MODEL_NAME} ...') model = RWKV_RNN(args) TEMPERATURE = 1.5 NUM_TRIALS = 1 LENGTH_PER_TRIAL = 2048 MIN_P = 0.1 EOS_DECAY_DIV = 25 EOS_IGNORE = 25 print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') global_init_state = [None for _ in range(args.n_layer * 3)] for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev global_init_state[i*3+0] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") global_init_state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cuda") global_init_state[i*3+2] = torch.zeros(args.n_embd, dtype=torch.bfloat16, requires_grad=False, device="cuda") def generate_text(context_str): context_tokens = tokenizer.tokenize(context_str) print(context_tokens) state = global_init_state.copy() for token in context_tokens: out, state = model.forward(token, state) ######################################################################################################## all_tokens = [] all_tokens += context_tokens.tolist() added_tokens_count = 0 for i in range(LENGTH_PER_TRIAL): added_tokens_count+=1 # applying temperature param scores = out / TEMPERATURE min_p = MIN_P # convert logits to probabilities for min_p filtering probs = torch.softmax(scores.float(),dim=-1) # decaying eos token to induce longer generation if (added_tokens_count <= EOS_IGNORE): decay_factor = 0 else: decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV)) probs[0] *= decay_factor # Get the probability of the top token for each sequence in the batch max_prob, _ = (probs).max(dim=-1, keepdim=True) # Calculate the actual min_p threshold by scaling min_p with the top token's probability scaled_min_p = min_p * max_prob # Create a mask for tokens that have a probability less than the scaled min_p indices_to_remove = probs < scaled_min_p scores.masked_fill_(indices_to_remove, float("-Inf")) #print(scores[scores != float("-Inf")]) # convert min_p filtered logits to probabilities probs = torch.softmax(scores.float(),dim=-1) # reapply eos token decay to new probabilities probs[0] *= decay_factor token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) #print(token) all_tokens += [token] if len(all_tokens) > 3: # if the last 3 tokens are all the same break if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break # find if any one token is overrepresented in the generated output and then assume the output is repeating if len(all_tokens) > 256: if any(num > len(all_tokens/4) for num in np.unique_counts(all_tokens).counts): break if token == 0: break out, state = model.forward(token, state) torch.cuda.synchronize() # we generated all the tokens we could output_string = tokenizer.decode(all_tokens) logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}") # insert bolding of the input string TODO: wonky #output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):] return output_string import logging from telegram import ForceReply, Update from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters # Read telegram token from .telegram_token with open(".telegram_token") as f: TOKEN = f.read().strip() # Enable logging logging.basicConfig( filename='tamilm.log', format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) # set higher logging level for httpx to avoid all GET and POST requests being logged logging.getLogger("httpx").setLevel(logging.WARNING) logger = logging.getLogger(__name__) # Define a few command handlers. These usually take the two arguments update and # context. async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" await update.message.reply_text("use /complete to complete your hebrew text in groupchats") async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" await update.message.reply_text("use /complete to complete your hebrew text") async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" message_text = ' '.join(context.args) # if this is a reply to a message include it in the prompt if update.message.reply_to_message != None: message_text = update.message.reply_to_message.text + ' ' + message_text # format input into req structure message_text = f"[REQ]\n{message_text}\n[RES]\n" generated_responce = generate_text(message_text) #logger.info(generated_responce) await update.message.reply_text(generated_responce) async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: message_text = update.message.text if update.message.chat.type == "private": # format input into req structure message_text = f"[REQ]\n{message_text}\n[RES]\n" generated_responce = generate_text(message_text) #(generated_responce) await update.message.reply_text(generated_responce) """Echo the user message.""" #await update.message.reply_text(message_text) def main() -> None: """Start the bot.""" # Create the Application and pass it your bot's token. application = Application.builder().token(TOKEN).build() # on different commands - answer in Telegram application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("help", help_command)) application.add_handler(CommandHandler("complete", complete_command)) # on non command i.e message - echo the message on Telegram application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) print("BOT READY") # Run the bot until the user presses Ctrl-C application.run_polling(allowed_updates=Update.ALL_TYPES) if __name__ == "__main__": main()