TamiLM-Telegram-Bot/bot_tokenmonster_rwkv_cpp.py

302 lines
11 KiB
Python
Raw Permalink Normal View History

2024-10-24 16:44:37 +03:00
#!/usr/bin/env python
# pylint: disable=unused-argument
# This program is dedicated to the public domain under the CC0 license.
"""
Simple Bot to reply to Telegram messages.
First, a few handler functions are defined. Then, those functions are passed to
the Application and registered at their respective places.
Then, the bot is started and runs until we press Ctrl-C on the command line.
Usage:
Basic Echobot example, repeats messages.
Press Ctrl-C on the command line or send a signal to the process to stop the
bot.
"""
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
import tokenmonster
tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported")
import math
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch, copy, time
from typing import List
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False)
import torch.nn as nn
from torch.nn import functional as F
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
########################################################################################################
args = types.SimpleNamespace()
args.MODEL_NAME = 'tami-models/assholerwkv-240.ggml'
args.n_layer = 12
args.n_embd = 768
args.vocab_size = 500
args.head_size = 64
#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען"
########################################################################################################
print(f'\nUsing RWKV cpp. Loading {args.MODEL_NAME} ...')
#model = RWKV_RNN(args)
# Load the model.
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
model = rwkv_cpp_model.RWKVModel(library, args.MODEL_NAME)
TEMPERATURE = 0.9
NUM_TRIALS = 1
LENGTH_PER_TRIAL = 768
MIN_P = 0.05
EOS_DECAY_DIV = 25
EOS_IGNORE = 100
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
# docs for inputs, https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor.__call__
def dry_token_penalty(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
multiplier = 5 #Controls the magnitude of the penalty for the shortest penalized sequences
base = 1.75 #Controls how fast the penalty grows with increasing sequence length
allowed_length = 2 # Longest sequence that can be repeated without being penalized
sequence_breakers = [10,4943,4944] #vocab idxs that break up sequences. aka newline, period, question mark etc etc
_range = 0 # repetition penalty range, how long back in tokens do we go
# 0 means its the entire supplied sequence
if _range > 0:
input_ids = input_ids[:, -_range:]
for input_ids_row, scores_row in zip(input_ids, scores):
# Raw integer must be extracted here to check for set membership.
last_token = input_ids_row[-1].item()
if last_token in sequence_breakers: #TODO: Patch to use idxs
continue
# Exclude the last token as it always matches.
match_indices = (input_ids_row[:-1] == last_token).nonzero()
# Stores the maximum matching sequence length
# for each token immediately following the sequence in the input.
match_lengths = {}
for i in match_indices:
next_token = input_ids_row[i+1].item()
if next_token in sequence_breakers:
continue
# We have already found that `last_token` matches at this index,
# so the match is at least of length 1.
match_length = 1
# Extend the match backwards as far as possible.
while True:
j = i - match_length
if j < 0:
# Start of input reached.
break
previous_token = input_ids_row[-(match_length+1)].item()
if input_ids_row[j] != previous_token:
# Start of match reached.
break
if previous_token in sequence_breakers:
# Sequence-breaking token reached.
break
match_length += 1
if next_token in match_lengths:
match_lengths[next_token] = max(match_length, match_lengths[next_token])
else:
match_lengths[next_token] = match_length
# Apply penalties.
for token, match_length in match_lengths.items():
if match_length >= allowed_length:
penalty = multiplier * base ** (match_length - allowed_length)
scores_row[token] -= penalty
return scores
def generate_text(context_str):
context_tokens = tokenizer.tokenize(context_str)
print(context_tokens)
out, state = model.eval_sequence(context_tokens, None, None, None, use_numpy=False)
########################################################################################################
all_tokens = []
all_tokens += context_tokens.tolist()
added_tokens_count = 0
for i in range(LENGTH_PER_TRIAL):
added_tokens_count+=1
out = dry_token_penalty(torch.tensor([all_tokens],dtype=torch.long),[out])[0] # returns a tensor which is a list of lists
# applying temperature param
scores = out / TEMPERATURE
min_p = MIN_P
# convert logits to probabilities for min_p filtering
probs = torch.softmax(scores.float(),dim=-1)
# decaying eos token to induce longer generation
if (added_tokens_count <= EOS_IGNORE):
decay_factor = 0
else:
decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV))
probs[0] *= decay_factor
# Get the probability of the top token for each sequence in the batch
max_prob, _ = (probs).max(dim=-1, keepdim=True)
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
scaled_min_p = min_p * max_prob
# Create a mask for tokens that have a probability less than the scaled min_p
indices_to_remove = probs < scaled_min_p
scores.masked_fill_(indices_to_remove, float("-Inf"))
#print(scores[scores != float("-Inf")])
# convert min_p filtered logits to probabilities
probs = torch.softmax(scores.float(),dim=-1)
# reapply eos token decay to new probabilities
probs[0] *= decay_factor
# apply dry here because I think its independent of the probs accross the vocab
# aka the probs dont influence other tokens
#probs = dry_token_penalty(all_tokens,probs)
token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
#print(token)
all_tokens += [token]
if len(all_tokens) > 3:
# if the last 3 tokens are all the same break
if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break
if token == 0: break
out, state = model.eval_sequence([token], state, None, None, use_numpy=False)
#torch.cuda.synchronize() were cpuing now dont need allat
# we generated all the tokens we could
output_string = tokenizer.decode(all_tokens)
logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}")
# insert bolding of the input string TODO: wonky
#output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):]
# replace \n with newline in output_string
output_string = output_string.replace("\\n", "\n")
return output_string
import logging
from telegram import ForceReply, Update
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
# Read telegram token from .telegram_token
with open(".telegram_token") as f:
TOKEN = f.read().strip()
# Enable logging
logging.basicConfig(
filename='tamilm.log',
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
# set higher logging level for httpx to avoid all GET and POST requests being logged
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# Define a few command handlers. These usually take the two arguments update and
# context.
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /start is issued."""
await update.message.reply_text("use /complete to complete your hebrew text in groupchats")
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /help is issued."""
await update.message.reply_text("use /complete to complete your hebrew text")
async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Send a message when the command /help is issued."""
message_text = ' '.join(context.args)
# if this is a reply to a message include it in the prompt
if update.message.reply_to_message != None:
message_text = update.message.reply_to_message.text + ' ' + message_text
# format input into req structure
message_text = f"[REQ]\n{message_text}\n[RES]\n"
generated_responce = generate_text(message_text)
#logger.info(generated_responce)
await update.message.reply_text(generated_responce)
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
message_text = update.message.text
if update.message.chat.type == "private":
# format input into req structure
message_text = f"[REQ]\n{message_text}\n[RES]\n"
generated_responce = generate_text(message_text)
#(generated_responce)
await update.message.reply_text(generated_responce)
"""Echo the user message."""
#await update.message.reply_text(message_text)
def main() -> None:
"""Start the bot."""
# Create the Application and pass it your bot's token.
application = Application.builder().token(TOKEN).build()
# on different commands - answer in Telegram
application.add_handler(CommandHandler("start", start))
application.add_handler(CommandHandler("help", help_command))
application.add_handler(CommandHandler("complete", complete_command))
# on non command i.e message - echo the message on Telegram
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
print("BOT READY")
# Run the bot until the user presses Ctrl-C
application.run_polling(allowed_updates=Update.ALL_TYPES)
if __name__ == "__main__":
main()