stable-diffusion-telegram-bot/main.py
tami-p40 3fa661a54c qa
2024-05-18 20:41:30 +03:00

250 lines
8.9 KiB
Python

import os
import re
import io
import uuid
import base64
import requests
from datetime import datetime
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images'
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
def timestamp():
return datetime.now().strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
def parse_input(input_string):
# Set default payload values
default_payload = {
"prompt": "",
"negative_prompt": "ugly, bad face, distorted",
"enable_hr": False,
"Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35,
"batch_size": 1,
"n_iter": 1,
"steps": 35,
"cfg_scale": 7,
"width": 512,
"height": 512,
"restore_faces": False,
"override_settings": {},
"override_settings_restore_afterwards": True,
}
payload = default_payload.copy()
prompt = []
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
# Initialize script_args with default values and placeholder enums
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
for match in matches:
key = match.group(1).lower()
value_start_index = match.end()
if last_index != match.start():
prompt.append(input_string[last_index: match.start()].strip())
last_index = value_start_index
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
if value_end_match:
value_end_index = value_end_match.start() + value_start_index
else:
value_end_index = len(input_string)
value = input_string[value_start_index: value_end_index].strip()
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
if key in default_payload:
payload[key] = value
elif key in ["xsr", "xsteps", "xds"]:
script_name = "x/y/z plot"
if key == "xsr":
script_args[0] = 7 # Enum value for xsr
script_args[1] = value
elif key == "xsteps":
try:
steps_values = [int(x) for x in value.split(',')]
if all(1 <= x <= 70 for x in steps_values):
script_args[3] = 4 # Enum value for xsteps
script_args[4] = value
else:
raise ValueError("xsteps values must be between 1 and 70.")
except ValueError:
raise ValueError("xsteps must contain only integers.")
elif key == "xds":
script_args[6] = 22 # Enum value for xds
script_args[7] = value
else:
prompt.append(f"{key}:{value}")
last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip()
if not payload["prompt"]:
payload["prompt"] = input_string.strip()
if script_name:
payload["script_name"] = script_name
payload["script_args"] = script_args
return payload
def create_caption(payload, user_name, user_id, info):
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
def call_api(api_endpoint, payload):
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
def process_images(images, user_id, user_name):
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
word = generate_unique_name()
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
return word, response2.json().get("info")
@app.on_message(filters.command(["draw"]))
def draw(client, message):
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
return
payload = parse_input(msgs[1])
print(payload)
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload)
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
message.reply_text("Failed to generate image. Please try again later.")
K.delete()
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
return
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("dont FAIL in life")
return
payload = parse_input(msgs[1])
photo = message.reply_to_message.photo
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file
payload["init_images"] = [init_image]
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload)
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
message.reply_text("Failed to process image. Please try again later.")
K.delete()
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status()
models_json = response.json()
buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json
]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query()
async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data
options = {"sd_model_checkpoint": sd_model_checkpoint}
try:
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status()
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
await message.reply_text("""
now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
currently supported
`xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
`xds` - denoise strength, only valid for img2img
`xsteps` - steps
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
""")
app.run()