stable-diffusion-telegram-bot/main.py

209 lines
7.1 KiB
Python
Raw Normal View History

2023-03-18 00:04:54 +02:00
import os
2024-05-18 14:06:29 +03:00
import re
import io
2023-04-28 10:02:47 +03:00
import uuid
2023-01-12 06:37:56 +02:00
import base64
2024-05-18 14:06:29 +03:00
import requests
2024-05-16 01:05:35 +03:00
from datetime import datetime
2023-03-18 00:04:54 +02:00
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
2024-05-16 01:05:35 +03:00
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
2023-03-18 00:24:46 +02:00
from dotenv import load_dotenv
2023-03-18 00:04:54 +02:00
2024-05-16 01:05:35 +03:00
# Load environment variables
2023-03-18 00:24:46 +02:00
load_dotenv()
2024-05-18 14:06:29 +03:00
API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL")
2024-05-16 01:05:35 +03:00
2023-04-22 22:18:10 +03:00
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
2024-05-18 14:06:29 +03:00
IMAGE_PATH = 'images'
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
def timestamp():
return datetime.now().strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
2023-05-06 18:59:02 +03:00
def parse_input(input_string):
default_payload = {
"prompt": "",
2024-05-17 14:31:15 +03:00
"negative_prompt": "ugly, bad face, distorted",
2023-05-06 18:59:02 +03:00
"enable_hr": False,
2024-05-17 14:31:15 +03:00
"Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35,
2023-05-06 18:59:02 +03:00
"batch_size": 1,
"n_iter": 1,
2024-05-16 01:05:35 +03:00
"steps": 35,
2023-05-06 18:59:02 +03:00
"cfg_scale": 7,
"width": 512,
"height": 512,
2024-05-17 14:31:15 +03:00
"restore_faces": False,
2023-05-06 18:59:02 +03:00
"override_settings": {},
"override_settings_restore_afterwards": True,
}
2024-05-17 14:31:15 +03:00
payload = default_payload.copy()
2023-05-06 18:59:02 +03:00
prompt = []
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
for match in matches:
2024-05-16 01:05:35 +03:00
key = match.group(1).lower()
2023-05-06 18:59:02 +03:00
value_start_index = match.end()
if last_index != match.start():
2024-05-16 01:05:35 +03:00
prompt.append(input_string[last_index: match.start()].strip())
2023-05-06 18:59:02 +03:00
last_index = value_start_index
2024-05-17 22:24:12 +03:00
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
2023-05-06 18:59:02 +03:00
if key in default_payload:
2024-05-18 14:06:29 +03:00
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
value = input_string[value_start_index: value_start_index + value_end_index.start()].strip()
2024-05-17 22:24:12 +03:00
payload[key] = value
2024-05-18 14:06:29 +03:00
last_index += value_end_index.start()
2023-04-28 10:02:47 +03:00
else:
2023-05-06 18:59:02 +03:00
prompt.append(f"{key}:")
payload["prompt"] = " ".join(prompt)
if not payload["prompt"]:
2023-05-06 21:09:46 +03:00
payload["prompt"] = input_string.strip()
2023-05-06 18:59:02 +03:00
return payload
2024-05-16 01:05:35 +03:00
def call_api(api_endpoint, payload):
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
2024-05-16 01:05:35 +03:00
def process_images(images, user_id, user_name):
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
word = generate_unique_name()
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
return word, response2.json().get("info")
2023-03-18 00:04:54 +02:00
2024-05-18 14:06:29 +03:00
def create_caption(payload, user_name, user_id, info):
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
2023-01-12 06:37:56 +02:00
@app.on_message(filters.command(["draw"]))
def draw(client, message):
2023-04-22 22:18:10 +03:00
msgs = message.text.split(" ", 1)
2023-01-12 06:37:56 +02:00
if len(msgs) == 1:
2024-05-16 01:05:35 +03:00
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
2023-01-12 06:37:56 +02:00
return
2023-05-06 18:59:02 +03:00
payload = parse_input(msgs[1])
2023-03-18 12:58:05 +02:00
print(payload)
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload)
2023-04-28 10:02:47 +03:00
2024-05-17 14:31:15 +03:00
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-18 14:06:29 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
2024-05-17 14:31:15 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
message.reply_text("Failed to generate image. Please try again later.")
2024-05-16 01:05:35 +03:00
K.delete()
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
2024-05-18 14:06:29 +03:00
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
2024-05-18 14:06:29 +03:00
message.reply_text("Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-18 14:06:29 +03:00
payload = parse_input(msgs[1])
2024-05-16 01:05:35 +03:00
photo = message.reply_to_message.photo
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file
2023-05-06 21:09:46 +03:00
2024-05-16 01:05:35 +03:00
payload["init_images"] = [init_image]
2023-05-06 18:59:02 +03:00
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload)
2023-05-06 18:59:02 +03:00
2024-05-17 14:31:15 +03:00
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-18 14:06:29 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
2024-05-17 14:31:15 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
message.reply_text("Failed to process image. Please try again later.")
2023-01-12 06:37:56 +02:00
K.delete()
2023-03-18 00:04:54 +02:00
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
2024-05-17 14:31:15 +03:00
try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status()
models_json = response.json()
buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json
]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
2023-04-22 22:18:10 +03:00
2023-03-18 00:04:54 +02:00
@app.on_callback_query()
async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data
2023-04-22 22:18:10 +03:00
options = {"sd_model_checkpoint": sd_model_checkpoint}
2023-03-18 00:04:54 +02:00
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status()
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
2023-01-12 06:37:56 +02:00
app.run()