302 lines
11 KiB
Python
302 lines
11 KiB
Python
|
#!/usr/bin/env python
|
||
|
# pylint: disable=unused-argument
|
||
|
# This program is dedicated to the public domain under the CC0 license.
|
||
|
|
||
|
"""
|
||
|
Simple Bot to reply to Telegram messages.
|
||
|
|
||
|
First, a few handler functions are defined. Then, those functions are passed to
|
||
|
the Application and registered at their respective places.
|
||
|
Then, the bot is started and runs until we press Ctrl-C on the command line.
|
||
|
|
||
|
Usage:
|
||
|
Basic Echobot example, repeats messages.
|
||
|
Press Ctrl-C on the command line or send a signal to the process to stop the
|
||
|
bot.
|
||
|
"""
|
||
|
|
||
|
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
|
||
|
|
||
|
import tokenmonster
|
||
|
tokenizer = tokenmonster.load("tokenizer_vocab/pol_vocab_exported")
|
||
|
|
||
|
import math
|
||
|
########################################################################################################
|
||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||
|
########################################################################################################
|
||
|
|
||
|
import numpy as np
|
||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||
|
import types, torch, copy, time
|
||
|
from typing import List
|
||
|
torch.backends.cudnn.benchmark = True
|
||
|
torch.backends.cudnn.allow_tf32 = True
|
||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||
|
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||
|
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||
|
torch._C._jit_set_autocast_mode(False)
|
||
|
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
MyModule = torch.jit.ScriptModule
|
||
|
MyFunction = torch.jit.script_method
|
||
|
MyStatic = torch.jit.script
|
||
|
|
||
|
########################################################################################################
|
||
|
|
||
|
args = types.SimpleNamespace()
|
||
|
args.MODEL_NAME = 'tami-models/assholerwkv-240.ggml'
|
||
|
args.n_layer = 12
|
||
|
args.n_embd = 768
|
||
|
args.vocab_size = 500
|
||
|
args.head_size = 64
|
||
|
|
||
|
#context = "מחפשת בדחיפות מחשב נייד הכי פשוט עם מטען"
|
||
|
|
||
|
|
||
|
########################################################################################################
|
||
|
|
||
|
print(f'\nUsing RWKV cpp. Loading {args.MODEL_NAME} ...')
|
||
|
#model = RWKV_RNN(args)
|
||
|
|
||
|
# Load the model.
|
||
|
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||
|
model = rwkv_cpp_model.RWKVModel(library, args.MODEL_NAME)
|
||
|
|
||
|
TEMPERATURE = 0.9
|
||
|
NUM_TRIALS = 1
|
||
|
LENGTH_PER_TRIAL = 768
|
||
|
MIN_P = 0.05
|
||
|
EOS_DECAY_DIV = 25
|
||
|
EOS_IGNORE = 100
|
||
|
|
||
|
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
|
||
|
|
||
|
# docs for inputs, https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor.__call__
|
||
|
def dry_token_penalty(input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||
|
multiplier = 5 #Controls the magnitude of the penalty for the shortest penalized sequences
|
||
|
|
||
|
base = 1.75 #Controls how fast the penalty grows with increasing sequence length
|
||
|
|
||
|
allowed_length = 2 # Longest sequence that can be repeated without being penalized
|
||
|
|
||
|
sequence_breakers = [10,4943,4944] #vocab idxs that break up sequences. aka newline, period, question mark etc etc
|
||
|
|
||
|
_range = 0 # repetition penalty range, how long back in tokens do we go
|
||
|
# 0 means its the entire supplied sequence
|
||
|
|
||
|
if _range > 0:
|
||
|
input_ids = input_ids[:, -_range:]
|
||
|
|
||
|
for input_ids_row, scores_row in zip(input_ids, scores):
|
||
|
# Raw integer must be extracted here to check for set membership.
|
||
|
last_token = input_ids_row[-1].item()
|
||
|
|
||
|
if last_token in sequence_breakers: #TODO: Patch to use idxs
|
||
|
continue
|
||
|
|
||
|
# Exclude the last token as it always matches.
|
||
|
match_indices = (input_ids_row[:-1] == last_token).nonzero()
|
||
|
|
||
|
# Stores the maximum matching sequence length
|
||
|
# for each token immediately following the sequence in the input.
|
||
|
match_lengths = {}
|
||
|
|
||
|
for i in match_indices:
|
||
|
next_token = input_ids_row[i+1].item()
|
||
|
|
||
|
if next_token in sequence_breakers:
|
||
|
continue
|
||
|
|
||
|
# We have already found that `last_token` matches at this index,
|
||
|
# so the match is at least of length 1.
|
||
|
match_length = 1
|
||
|
|
||
|
# Extend the match backwards as far as possible.
|
||
|
while True:
|
||
|
j = i - match_length
|
||
|
if j < 0:
|
||
|
# Start of input reached.
|
||
|
break
|
||
|
|
||
|
previous_token = input_ids_row[-(match_length+1)].item()
|
||
|
if input_ids_row[j] != previous_token:
|
||
|
# Start of match reached.
|
||
|
break
|
||
|
|
||
|
if previous_token in sequence_breakers:
|
||
|
# Sequence-breaking token reached.
|
||
|
break
|
||
|
|
||
|
match_length += 1
|
||
|
|
||
|
if next_token in match_lengths:
|
||
|
match_lengths[next_token] = max(match_length, match_lengths[next_token])
|
||
|
else:
|
||
|
match_lengths[next_token] = match_length
|
||
|
|
||
|
# Apply penalties.
|
||
|
for token, match_length in match_lengths.items():
|
||
|
if match_length >= allowed_length:
|
||
|
penalty = multiplier * base ** (match_length - allowed_length)
|
||
|
scores_row[token] -= penalty
|
||
|
|
||
|
return scores
|
||
|
|
||
|
def generate_text(context_str):
|
||
|
context_tokens = tokenizer.tokenize(context_str)
|
||
|
print(context_tokens)
|
||
|
out, state = model.eval_sequence(context_tokens, None, None, None, use_numpy=False)
|
||
|
|
||
|
########################################################################################################
|
||
|
|
||
|
all_tokens = []
|
||
|
all_tokens += context_tokens.tolist()
|
||
|
|
||
|
added_tokens_count = 0
|
||
|
for i in range(LENGTH_PER_TRIAL):
|
||
|
added_tokens_count+=1
|
||
|
|
||
|
out = dry_token_penalty(torch.tensor([all_tokens],dtype=torch.long),[out])[0] # returns a tensor which is a list of lists
|
||
|
|
||
|
# applying temperature param
|
||
|
scores = out / TEMPERATURE
|
||
|
|
||
|
min_p = MIN_P
|
||
|
|
||
|
# convert logits to probabilities for min_p filtering
|
||
|
probs = torch.softmax(scores.float(),dim=-1)
|
||
|
|
||
|
# decaying eos token to induce longer generation
|
||
|
if (added_tokens_count <= EOS_IGNORE):
|
||
|
decay_factor = 0
|
||
|
else:
|
||
|
decay_factor = 1.0 - math.exp(-((added_tokens_count-EOS_IGNORE)/EOS_DECAY_DIV))
|
||
|
probs[0] *= decay_factor
|
||
|
|
||
|
# Get the probability of the top token for each sequence in the batch
|
||
|
max_prob, _ = (probs).max(dim=-1, keepdim=True)
|
||
|
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
|
||
|
scaled_min_p = min_p * max_prob
|
||
|
# Create a mask for tokens that have a probability less than the scaled min_p
|
||
|
indices_to_remove = probs < scaled_min_p
|
||
|
|
||
|
scores.masked_fill_(indices_to_remove, float("-Inf"))
|
||
|
#print(scores[scores != float("-Inf")])
|
||
|
|
||
|
# convert min_p filtered logits to probabilities
|
||
|
probs = torch.softmax(scores.float(),dim=-1)
|
||
|
# reapply eos token decay to new probabilities
|
||
|
probs[0] *= decay_factor
|
||
|
|
||
|
# apply dry here because I think its independent of the probs accross the vocab
|
||
|
# aka the probs dont influence other tokens
|
||
|
#probs = dry_token_penalty(all_tokens,probs)
|
||
|
|
||
|
token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
|
||
|
|
||
|
#print(token)
|
||
|
|
||
|
all_tokens += [token]
|
||
|
if len(all_tokens) > 3:
|
||
|
# if the last 3 tokens are all the same break
|
||
|
if all_tokens[-1] == all_tokens[-2] == all_tokens[-3]: break
|
||
|
|
||
|
if token == 0: break
|
||
|
|
||
|
out, state = model.eval_sequence([token], state, None, None, use_numpy=False)
|
||
|
|
||
|
#torch.cuda.synchronize() were cpuing now dont need allat
|
||
|
|
||
|
# we generated all the tokens we could
|
||
|
output_string = tokenizer.decode(all_tokens)
|
||
|
logger.info(f"Generated {len(all_tokens)} = ctx:{len(context_tokens)} + gen:{added_tokens_count} tokens, Result:{output_string}")
|
||
|
|
||
|
# insert bolding of the input string TODO: wonky
|
||
|
#output_string = output_string[:len(context_str)] + "**" + output_string[len(context_str):]
|
||
|
|
||
|
# replace \n with newline in output_string
|
||
|
output_string = output_string.replace("\\n", "\n")
|
||
|
|
||
|
return output_string
|
||
|
|
||
|
import logging
|
||
|
|
||
|
from telegram import ForceReply, Update
|
||
|
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||
|
# Read telegram token from .telegram_token
|
||
|
with open(".telegram_token") as f:
|
||
|
TOKEN = f.read().strip()
|
||
|
|
||
|
|
||
|
# Enable logging
|
||
|
logging.basicConfig(
|
||
|
filename='tamilm.log',
|
||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
|
||
|
)
|
||
|
# set higher logging level for httpx to avoid all GET and POST requests being logged
|
||
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
# Define a few command handlers. These usually take the two arguments update and
|
||
|
# context.
|
||
|
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||
|
"""Send a message when the command /start is issued."""
|
||
|
await update.message.reply_text("use /complete to complete your hebrew text in groupchats")
|
||
|
|
||
|
|
||
|
async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||
|
"""Send a message when the command /help is issued."""
|
||
|
await update.message.reply_text("use /complete to complete your hebrew text")
|
||
|
|
||
|
async def complete_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||
|
"""Send a message when the command /help is issued."""
|
||
|
message_text = ' '.join(context.args)
|
||
|
|
||
|
# if this is a reply to a message include it in the prompt
|
||
|
if update.message.reply_to_message != None:
|
||
|
message_text = update.message.reply_to_message.text + ' ' + message_text
|
||
|
|
||
|
# format input into req structure
|
||
|
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||
|
|
||
|
generated_responce = generate_text(message_text)
|
||
|
#logger.info(generated_responce)
|
||
|
await update.message.reply_text(generated_responce)
|
||
|
|
||
|
|
||
|
async def echo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||
|
message_text = update.message.text
|
||
|
if update.message.chat.type == "private":
|
||
|
# format input into req structure
|
||
|
message_text = f"[REQ]\n{message_text}\n[RES]\n"
|
||
|
generated_responce = generate_text(message_text)
|
||
|
#(generated_responce)
|
||
|
await update.message.reply_text(generated_responce)
|
||
|
"""Echo the user message."""
|
||
|
#await update.message.reply_text(message_text)
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
"""Start the bot."""
|
||
|
# Create the Application and pass it your bot's token.
|
||
|
application = Application.builder().token(TOKEN).build()
|
||
|
|
||
|
# on different commands - answer in Telegram
|
||
|
application.add_handler(CommandHandler("start", start))
|
||
|
application.add_handler(CommandHandler("help", help_command))
|
||
|
application.add_handler(CommandHandler("complete", complete_command))
|
||
|
|
||
|
# on non command i.e message - echo the message on Telegram
|
||
|
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, echo))
|
||
|
|
||
|
print("BOT READY")
|
||
|
|
||
|
# Run the bot until the user presses Ctrl-C
|
||
|
application.run_polling(allowed_updates=Update.ALL_TYPES)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|