#!/usr/bin/env python # pylint: disable=unused-argument # This program is dedicated to the public domain under the CC0 license. """ Simple Bot to reply to Telegram messages. First, a few handler functions are defined. Then, those functions are passed to the Application and registered at their respective places. Then, the bot is started and runs until we press Ctrl-C on the command line. Usage: Basic Echobot example, repeats messages. Press Ctrl-C on the command line or send a signal to the process to stop the bot. """ from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model import tokenmonster tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported") import math ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import numpy as np np.set_printoptions(precision=4, suppress=True, linewidth=200) import types, torch, copy, time from typing import List torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch._C._jit_set_autocast_mode(False) import torch.nn as nn from torch.nn import functional as F MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method MyStatic = torch.jit.script ######################################################################################################## args = types.SimpleNamespace() args.MODEL_NAME = 'tami-models/assholerwkv-240.ggml' args.n_layer = 12 args.n_embd = 768 args.vocab_size = 500 args.head_size = 64 #context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען" ######################################################################################################## print(f'\nUsing RWKV cpp. Loading {args.MODEL_NAME} ...') #model = RWKV_RNN(args) # Load the model. library = rwkv_cpp_shared_library.load_rwkv_shared_library() model = rwkv_cpp_model.RWKVModel(library, args.MODEL_NAME) TEMPERATURE = 0.9 NUM_TRIALS = 1 LENGTH_PER_TRIAL = 768 MIN_P = 0.05 EOS_DECAY_DIV = 25 EOS_IGNORE = 100 print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)') # docs for inputs, https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor.__call__ def dry_token_penalty(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: multiplier = 5 #Controls the magnitude of the penalty for the shortest penalized sequences base = 1.75 #Controls how fast the penalty grows with increasing sequence length allowed_length = 2 # Longest sequence that can be repeated without being penalized sequence_breakers = [10,4943,4944] #vocab idxs that break up sequences. aka newline, period, question mark etc etc _range = 0 # repetition penalty range, how long back in tokens do we go # 0 means its the entire supplied sequence if _range > 0: input_ids = input_ids[:, -_range:] for input_ids_row, scores_row in zip(input_ids, scores): # Raw integer must be extracted here to check for set membership. last_token = input_ids_row[-1].item() if last_token in sequence_breakers: #TODO: Patch to use idxs continue # Exclude the last token as it always matches. match_indices = (input_ids_row[:-1] == last_token).nonzero() # Stores the maximum matching sequence length # for each token immediately following the sequence in the input. match_lengths = {} for i in match_indices: next_token = input_ids_row[i+1].item() if next_token in sequence_breakers: continue # We have already found that `last_token` matches at this index, # so the match is at least of length 1. match_length = 1 # Extend the match backwards as far as possible. while True: j = i - match_length if j < 0: # Start of input reached. break previous_token = input_ids_row[-(match_length+1)].item() if input_ids_row[j] != previous_token: # Start of match reached. break if previous_token in sequence_breakers: # Sequence-breaking token reached. break match_length += 1 if next_token in match_lengths: match_lengths[next_token] = max(match_length, match_lengths[next_token]) else: match_lengths[next_token] = match_length # Apply penalties. for token, match_length in match_lengths.items(): if match_length >= allowed_length: penalty = multiplier * base ** (match_length - allowed_length) scores_row[token] -= penalty return scores def generate_text(context_str): context_tokens = tokenizer.tokenize(context_str) print(context_tokens) out, state = model.eval_sequence(context_tokens, None, None, None, use_numpy=False) ######################################################################################################## all_tokens = [] all_tokens += context_tokens.tolist() added_tokens_count = 0 for i in range(LENGTH_PER_TRIAL): added_tokens_count+=1 out = dry_token_penalty(torch.tensor([all_tokens],dtype=torch.long),[out])[0] # returns a tensor which is a list of lists # applying temperature param scores = out / TEMPERATURE min_p = MIN_P # convert logits to probabilities for min_p filtering probs = torch.softmax(scores.float(),dim=-1) # decaying eos token to induce longer generation if (added_tokens_count <= EOS_IGNORE): decay_factor = 0 else: decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV)) probs[0] *= decay_factor # Get the probability of the top token for each sequence in the batch max_prob, _ = (probs).max(dim=-1, keepdim=True) # Calculate the actual min_p threshold by scaling min_p with the top token's probability scaled_min_p = min_p * max_prob # Create a mask for tokens that have a probability less than the scaled min_p indices_to_remove = probs < scaled_min_p scores.masked_fill_(indices_to_remove, float("-Inf")) #print(scores[scores != float("-Inf")]) # convert min_p filtered logits to probabilities probs = torch.softmax(scores.float(),dim=-1) # reapply eos token decay to new probabilities probs[0] *= decay_factor # apply dry here because I think its independent of the probs accross the vocab # aka the probs dont influence other tokens #probs = dry_token_penalty(all_tokens,probs) token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1) #print(token) all_tokens += [token] if len(all_tokens) > 3: # if the last 3 tokens are all the same break if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break if token == 0: break out, state = model.eval_sequence([token], state, None, None, use_numpy=False) #torch.cuda.synchronize() were cpuing now dont need allat # we generated all the tokens we could output_string = tokenizer.decode(all_tokens) logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}") # insert bolding of the input string TODO: wonky #output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):] # replace \n with newline in output_string output_string = output_string.replace("\\n", "\n") return output_string import logging from telegram import ForceReply, Update from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters # Read telegram token from .telegram_token with open(".telegram_token") as f: TOKEN = f.read().strip() # Enable logging logging.basicConfig( filename='tamilm.log', format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) # set higher logging level for httpx to avoid all GET and POST requests being logged logging.getLogger("httpx").setLevel(logging.WARNING) logger = logging.getLogger(__name__) # Define a few command handlers. These usually take the two arguments update and # context. async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /start is issued.""" await update.message.reply_text("use /complete to complete your hebrew text in groupchats") async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" await update.message.reply_text("use /complete to complete your hebrew text") async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Send a message when the command /help is issued.""" message_text = ' '.join(context.args) # if this is a reply to a message include it in the prompt if update.message.reply_to_message != None: message_text = update.message.reply_to_message.text + ' ' + message_text # format input into req structure message_text = f"[REQ]\n{message_text}\n[RES]\n" generated_responce = generate_text(message_text) #logger.info(generated_responce) await update.message.reply_text(generated_responce) async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: message_text = update.message.text if update.message.chat.type == "private": # format input into req structure message_text = f"[REQ]\n{message_text}\n[RES]\n" generated_responce = generate_text(message_text) #(generated_responce) await update.message.reply_text(generated_responce) """Echo the user message.""" #await update.message.reply_text(message_text) def main() -> None: """Start the bot.""" # Create the Application and pass it your bot's token. application = Application.builder().token(TOKEN).build() # on different commands - answer in Telegram application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("help", help_command)) application.add_handler(CommandHandler("complete", complete_command)) # on non command i.e message - echo the message on Telegram application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo)) print("BOT READY") # Run the bot until the user presses Ctrl-C application.run_polling(allowed_updates=Update.ALL_TYPES) if __name__ == "__main__": main()