stable-diffusion-telegram-bot/main.py

324 lines
12 KiB
Python
Raw Normal View History

2023-03-18 00:04:54 +02:00
import os
2024-05-18 14:06:29 +03:00
import re
import io
2023-04-28 10:02:47 +03:00
import uuid
2023-01-12 06:37:56 +02:00
import base64
2024-05-18 14:06:29 +03:00
import requests
2024-05-16 01:05:35 +03:00
from datetime import datetime
2023-03-18 00:04:54 +02:00
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
2024-05-16 01:05:35 +03:00
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
2023-03-18 00:24:46 +02:00
from dotenv import load_dotenv
2023-03-18 00:04:54 +02:00
2024-05-16 01:05:35 +03:00
# Load environment variables
2023-03-18 00:24:46 +02:00
load_dotenv()
2024-05-18 14:06:29 +03:00
API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL")
2024-05-16 01:05:35 +03:00
2023-04-22 22:18:10 +03:00
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
2024-05-18 14:06:29 +03:00
IMAGE_PATH = 'images'
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
2024-05-26 08:49:27 +03:00
# Model-specific emmbedings for negative prompts
# see civit.ai model page for specific emmbedings recommnded for each model
model_negative_prompts = {
"Anything-Diffusion": "",
"Deliberate": "",
"Dreamshaper": "",
"DreamShaperXL_Lightning": "",
"icbinp": "",
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
"v1-5-pruned-emaonly": ""
}
2024-05-16 01:05:35 +03:00
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
2024-05-26 08:49:27 +03:00
2024-05-16 01:05:35 +03:00
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
2023-05-06 18:59:02 +03:00
2024-05-20 10:29:34 +03:00
2024-05-18 20:41:30 +03:00
# Set default payload values
2024-05-26 08:49:27 +03:00
default_payload = {
2023-05-06 18:59:02 +03:00
"prompt": "",
2024-05-20 10:29:34 +03:00
"seed": -1, # Random seed
2024-05-26 08:49:27 +03:00
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
2023-05-06 18:59:02 +03:00
"enable_hr": False,
2024-05-17 14:31:15 +03:00
"Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35,
2023-05-06 18:59:02 +03:00
"batch_size": 1,
"n_iter": 1,
2024-05-16 01:05:35 +03:00
"steps": 35,
2023-05-06 18:59:02 +03:00
"cfg_scale": 7,
"width": 512,
"height": 512,
2024-05-17 14:31:15 +03:00
"restore_faces": False,
2023-05-06 18:59:02 +03:00
"override_settings": {},
"override_settings_restore_afterwards": True,
}
2024-05-26 08:49:27 +03:00
def update_negative_prompt(model_name):
if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}"
def parse_input(input_string):
2024-05-17 14:31:15 +03:00
payload = default_payload.copy()
2023-05-06 18:59:02 +03:00
prompt = []
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
2024-05-18 20:41:30 +03:00
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
2024-05-20 10:29:34 +03:00
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
slot_index = 0
2023-05-06 18:59:02 +03:00
for match in matches:
2024-05-16 01:05:35 +03:00
key = match.group(1).lower()
2023-05-06 18:59:02 +03:00
value_start_index = match.end()
if last_index != match.start():
2024-05-16 01:05:35 +03:00
prompt.append(input_string[last_index: match.start()].strip())
2023-05-06 18:59:02 +03:00
last_index = value_start_index
2024-05-18 20:41:30 +03:00
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
if value_end_match:
value_end_index = value_end_match.start() + value_start_index
else:
value_end_index = len(input_string)
value = input_string[value_start_index: value_end_index].strip()
2024-05-17 22:24:12 +03:00
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
2024-05-20 10:29:34 +03:00
if key == "cfg":
key = "cfg_scale"
2024-05-18 20:41:30 +03:00
2023-05-06 18:59:02 +03:00
if key in default_payload:
2024-05-17 22:24:12 +03:00
payload[key] = value
2024-05-20 10:29:34 +03:00
elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
2024-05-18 20:41:30 +03:00
script_name = "x/y/z plot"
2024-05-20 10:29:34 +03:00
if slot_index < 3:
script_slot = slot_mapping[slot_index]
if key == "xsr":
script_args[script_slot[0]] = 7 # Enum value for xsr
script_args[script_slot[1]] = value
elif key == "xsteps":
script_args[script_slot[0]] = 4 # Enum value for xsteps
script_args[script_slot[1]] = value
elif key == "xds":
script_args[script_slot[0]] = 22 # Enum value for xds
script_args[script_slot[1]] = value
elif key == "xcfg":
script_args[script_slot[0]] = 6 # Enum value for CFG Scale
script_args[script_slot[1]] = value
slot_index += 1
elif key == "nl":
script_args[9] = False # Draw legend
elif key == "ks":
script_args[10] = True # Keep sub images
elif key == "rs":
script_args[11] = True # Set random seed to sub images
2023-04-28 10:02:47 +03:00
else:
2024-05-18 20:41:30 +03:00
prompt.append(f"{key}:{value}")
2023-05-06 18:59:02 +03:00
2024-05-18 20:41:30 +03:00
last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip()
2023-05-06 18:59:02 +03:00
if not payload["prompt"]:
2023-05-06 21:09:46 +03:00
payload["prompt"] = input_string.strip()
2023-05-06 18:59:02 +03:00
2024-05-18 20:41:30 +03:00
if script_name:
payload["script_name"] = script_name
payload["script_args"] = script_args
2023-05-06 18:59:02 +03:00
return payload
2024-05-26 08:49:27 +03:00
2024-05-18 20:41:30 +03:00
def create_caption(payload, user_name, user_id, info):
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
2024-05-20 10:29:34 +03:00
print(payload["prompt"])
print(info)
# Steps: 3, Sampler: Euler, CFG scale: 7.0, Seed: 4094161400, Size: 512x512, Model hash: 15012c538f, Model: realisticVisionV60B1_v51VAE, Denoising strength: 0.35, Version: v1.8.0-1-g20cdc7c
# Define a regular expression pattern to match the seed value
seed_pattern = r"Seed: (\d+)"
# Search for the pattern in the info string
match = re.search(seed_pattern, info)
# Check if a match was found and extract the seed value
if match:
seed_value = match.group(1)
print(f"Seed value: {seed_value}")
caption += f"**{seed_value}**\n"
else:
print("Seed value not found in the info string.")
2024-05-18 20:41:30 +03:00
caption += f"**{prompt}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
2024-05-26 08:49:27 +03:00
2024-05-16 01:05:35 +03:00
def call_api(api_endpoint, payload):
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
2024-05-16 01:05:35 +03:00
2024-05-26 08:49:27 +03:00
2024-05-16 01:05:35 +03:00
def process_images(images, user_id, user_name):
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
word = generate_unique_name()
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
return word, response2.json().get("info")
2023-03-18 00:04:54 +02:00
2023-01-12 06:37:56 +02:00
@app.on_message(filters.command(["draw"]))
def draw(client, message):
2023-04-22 22:18:10 +03:00
msgs = message.text.split(" ", 1)
2023-01-12 06:37:56 +02:00
if len(msgs) == 1:
2024-05-16 01:05:35 +03:00
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
2023-01-12 06:37:56 +02:00
return
2023-05-06 18:59:02 +03:00
payload = parse_input(msgs[1])
2023-03-18 12:58:05 +02:00
print(payload)
2023-01-12 06:37:56 +02:00
2024-05-26 08:49:27 +03:00
# Check if xds is used in the payload
if "xds" in msgs[1].lower():
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
return
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload)
2023-04-28 10:02:47 +03:00
2024-05-17 14:31:15 +03:00
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-18 14:06:29 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
2024-05-17 14:31:15 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
message.reply_text("Failed to generate image. Please try again later.")
2024-05-16 01:05:35 +03:00
K.delete()
2023-01-12 06:37:56 +02:00
2024-05-26 08:49:27 +03:00
2024-05-16 01:05:35 +03:00
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
2024-05-18 20:41:30 +03:00
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
2024-05-18 20:41:30 +03:00
message.reply_text("dont FAIL in life")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-18 14:06:29 +03:00
payload = parse_input(msgs[1])
2024-05-20 10:29:34 +03:00
print(f"input:\n{payload}")
2024-05-16 01:05:35 +03:00
photo = message.reply_to_message.photo
2024-05-20 10:29:34 +03:00
# prompt_from_reply = message.reply_to_message.
# orginal_prompt = app.reply_to_message.message
# print(orginal_prompt)
2024-05-16 01:05:35 +03:00
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file
2023-05-06 21:09:46 +03:00
2024-05-16 01:05:35 +03:00
payload["init_images"] = [init_image]
2023-05-06 18:59:02 +03:00
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload)
2023-05-06 18:59:02 +03:00
2024-05-17 14:31:15 +03:00
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-18 14:06:29 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
2024-05-17 14:31:15 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
message.reply_text("Failed to process image. Please try again later.")
2023-01-12 06:37:56 +02:00
K.delete()
2024-05-26 08:49:27 +03:00
2023-03-18 00:04:54 +02:00
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
2024-05-17 14:31:15 +03:00
try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status()
models_json = response.json()
2024-05-26 08:49:27 +03:00
print(models_json)
2024-05-17 14:31:15 +03:00
buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json
]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
2023-04-22 22:18:10 +03:00
2024-05-26 08:49:27 +03:00
2023-03-18 00:04:54 +02:00
@app.on_callback_query()
async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data
2023-04-22 22:18:10 +03:00
options = {"sd_model_checkpoint": sd_model_checkpoint}
2023-03-18 00:04:54 +02:00
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status()
2024-05-26 08:49:27 +03:00
# Update the negative prompt based on the selected model
update_negative_prompt(sd_model_checkpoint)
2024-05-17 14:31:15 +03:00
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
2024-05-26 08:49:27 +03:00
print(f"Error setting checkpoint: {e}")
2024-05-17 14:31:15 +03:00
2024-05-18 20:41:30 +03:00
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
await message.reply_text("""
now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
currently supported
`xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
`xds` - denoise strength, only valid for img2img
`xsteps` - steps
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
2024-05-26 08:49:27 +03:00
aside from that you can use the usual `ng`, `ds`, `cfg`, `steps` for single image generation.
""", disable_web_page_preview=True)
2024-05-18 20:41:30 +03:00
2023-01-12 06:37:56 +02:00
app.run()