stable-diffusion-telegram-bot/main.py
2024-05-16 01:05:35 +03:00

229 lines
7.6 KiB
Python

import json
import requests
import io
import re
import os
import uuid
import base64
from datetime import datetime
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
API_ID = os.environ.get("API_ID", None)
API_HASH = os.environ.get("API_HASH", None)
TOKEN = os.environ.get("TOKEN", None)
SD_URL = os.environ.get("SD_URL", None)
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' # Do not leave a trailing /
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
# Default params
steps_value_default = 40
def timestamp():
return datetime.now().strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
def parse_input(input_string):
default_payload = {
"prompt": "",
"negative_prompt": "",
"controlnet_input_image": [],
"controlnet_mask": [],
"controlnet_module": "",
"controlnet_model": "",
"controlnet_weight": 1,
"controlnet_resize_mode": "Scale to Fit (Inner Fit)",
"controlnet_lowvram": False,
"controlnet_processor_res": 64,
"controlnet_threshold_a": 64,
"controlnet_threshold_b": 64,
"controlnet_guidance": 1,
"controlnet_guessmode": True,
"enable_hr": False,
"denoising_strength": 0.4,
"hr_scale": 1.5,
"hr_upscale": "Latent",
"seed": -1,
"subseed": -1,
"subseed_strength": -1,
"sampler_index": "",
"batch_size": 1,
"n_iter": 1,
"steps": 35,
"cfg_scale": 7,
"width": 512,
"height": 512,
"restore_faces": True,
"override_settings": {},
"override_settings_restore_afterwards": True,
}
payload = {"prompt": ""}
prompt = []
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
for match in matches:
key = match.group(1).lower()
value_start_index = match.end()
if last_index != match.start():
prompt.append(input_string[last_index: match.start()].strip())
last_index = value_start_index
if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start()
value = input_string[value_start_index: value_start_index + value_end_index].strip()
if isinstance(default_payload[key], int):
if value.isdigit():
payload[key] = idefault_payloadnt(value)
else:
payload[key] = value
last_index += value_end_index
else:
prompt.append(f"{key}:")
payload["prompt"] = " ".join(prompt)
if not payload["prompt"]:
payload["prompt"] = input_string.strip()
return payload
def call_api(api_endpoint, payload):
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status()
return response.json()
def process_images(images, user_id, user_name):
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
word = generate_unique_name()
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
return word, response2.json().get("info")
@app.on_message(filters.command(["draw"]))
def draw(client, message):
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
return
payload = parse_input(msgs[1])
print(payload)
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload)
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
for key, value in payload.items():
caption += f"{key.capitalize()} - **{value}**\n"
caption += f"Seed - **{seed_value}**\n"
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("""
reply to an image with
/img <prompt> ds:0-1.0
ds is for Denoising_strength. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4
""")
return
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("""
Format :\n/img <prompt>
force: < 0.1-1.0, default 0.3 >
""")
return
payload = parse_input(msgs[1])
print(payload)
photo = message.reply_to_message.photo
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file
payload["init_images"] = [init_image]
payload["denoising_strength"] = 0.3 # Set default denoising strength or customize as needed
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload)
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status()
models_json = response.json()
buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json
]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
@app.on_callback_query()
async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data
options = {"sd_model_checkpoint": sd_model_checkpoint}
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status()
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
@app.on_message(filters.command(["start"], prefixes=["/", "!"]))
async def start(client, message):
buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]]
await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons))
app.run()