stable-diffusion-telegram-bot/main.py

425 lines
16 KiB
Python
Raw Normal View History

2023-03-18 00:04:54 +02:00
import os
2024-05-18 14:06:29 +03:00
import re
import io
2023-04-28 10:02:47 +03:00
import uuid
2023-01-12 06:37:56 +02:00
import base64
2024-06-09 10:39:34 +03:00
import json
2024-05-18 14:06:29 +03:00
import requests
2024-05-16 01:05:35 +03:00
from datetime import datetime
2023-03-18 00:04:54 +02:00
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
2024-05-16 01:05:35 +03:00
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
2023-03-18 00:24:46 +02:00
from dotenv import load_dotenv
2023-03-18 00:04:54 +02:00
2024-05-16 01:05:35 +03:00
# Load environment variables
2023-03-18 00:24:46 +02:00
load_dotenv()
2024-05-18 14:06:29 +03:00
API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL")
2024-05-16 01:05:35 +03:00
2024-06-03 13:35:14 +03:00
# Ensure all required environment variables are loaded
if not all([API_ID, API_HASH, TOKEN, SD_URL]):
raise EnvironmentError("Missing one or more required environment variables: API_ID, API_HASH, TOKEN, SD_URL")
2023-04-22 22:18:10 +03:00
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
2024-05-18 14:06:29 +03:00
IMAGE_PATH = 'images'
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
2024-05-26 08:49:27 +03:00
2024-06-03 14:14:25 +03:00
def get_current_model_name():
try:
response = requests.get(f"{SD_URL}/sdapi/v1/options")
response.raise_for_status()
options = response.json()
current_model_name = options.get("sd_model_checkpoint", "Unknown")
return current_model_name
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
# Fetch the current model name at the start
current_model_name = get_current_model_name()
if current_model_name:
print(f"Current model name: {current_model_name}")
else:
print("Failed to fetch the current model name.")
2024-05-29 09:28:23 +03:00
2024-05-16 01:05:35 +03:00
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
2023-05-06 18:59:02 +03:00
2024-05-26 09:00:32 +03:00
# Set default payload values
2024-05-26 08:49:27 +03:00
default_payload = {
2024-05-26 09:00:32 +03:00
"prompt": "",
"seed": -1, # Random seed
2024-06-03 13:35:14 +03:00
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
2024-05-26 09:00:32 +03:00
"enable_hr": False,
"Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35,
"batch_size": 1,
"n_iter": 1,
"steps": 35,
"cfg_scale": 7,
"width": 512,
"height": 512,
"restore_faces": False,
"override_settings": {},
"override_settings_restore_afterwards": True,
}
2024-05-26 08:49:27 +03:00
2024-06-09 11:33:02 +03:00
# Model-specific embeddings for negative prompts
model_negative_prompts = {
"coloringPage_v10": "fake",
"Anything-Diffusion": "",
"Deliberate": "",
"Dreamshaper": "",
"DreamShaperXL_Lightning": "",
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
"v1-5-pruned-emaonly": "",
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
}
2024-05-26 08:49:27 +03:00
def update_negative_prompt(model_name):
2024-06-03 13:35:14 +03:00
"""Update the negative prompt for a given model."""
2024-05-26 08:49:27 +03:00
if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}"
2024-06-09 11:33:02 +03:00
print(f"Updated negative prompt to: {default_payload['negative_prompt']}")
2024-05-26 08:49:27 +03:00
2024-06-03 13:35:14 +03:00
def update_resolution(model_name):
"""Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["width"] = 832
default_payload["height"] = 1216
else:
default_payload["width"] = 512
default_payload["height"] = 512
2024-06-09 11:33:02 +03:00
print(f"Updated resolution to {default_payload['width']}x{default_payload['height']}")
def update_steps(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["steps"] = 15
else:
default_payload["steps"] = 35
print(f"Updated steps to {default_payload['cfg_scale']}")
2024-06-03 13:35:14 +03:00
def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
2024-06-09 10:39:34 +03:00
default_payload["cfg_scale"] = 2.5
2024-06-03 13:35:14 +03:00
else:
default_payload["cfg_scale"] = 7
2024-06-09 11:33:02 +03:00
print(f"Updated CFG scale to {default_payload['cfg_scale']}")
2024-06-03 13:35:14 +03:00
2024-06-03 14:14:25 +03:00
# Update configurations based on the current model name
if current_model_name:
update_negative_prompt(current_model_name)
update_resolution(current_model_name)
update_cfg_scale(current_model_name)
2024-06-09 11:33:02 +03:00
update_steps(current_model_name)
2024-06-03 14:14:25 +03:00
else:
print("Failed to update configurations as the current model name is not available.")
2024-06-03 13:35:14 +03:00
2024-05-26 08:49:27 +03:00
def parse_input(input_string):
2024-06-03 13:35:14 +03:00
"""Parse the input string and create a payload."""
2024-05-17 14:31:15 +03:00
payload = default_payload.copy()
2023-05-06 18:59:02 +03:00
prompt = []
2024-05-26 09:00:32 +03:00
include_info = "info:" in input_string
input_string = input_string.replace("info:", "").strip()
2023-05-06 18:59:02 +03:00
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
2024-05-20 10:29:34 +03:00
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
slot_index = 0
2023-05-06 18:59:02 +03:00
for match in matches:
2024-05-16 01:05:35 +03:00
key = match.group(1).lower()
2023-05-06 18:59:02 +03:00
value_start_index = match.end()
if last_index != match.start():
2024-05-16 01:05:35 +03:00
prompt.append(input_string[last_index: match.start()].strip())
2023-05-06 18:59:02 +03:00
last_index = value_start_index
2024-05-18 20:41:30 +03:00
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
if value_end_match:
value_end_index = value_end_match.start() + value_start_index
else:
value_end_index = len(input_string)
value = input_string[value_start_index: value_end_index].strip()
2024-05-17 22:24:12 +03:00
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
2024-05-20 10:29:34 +03:00
if key == "cfg":
key = "cfg_scale"
2024-05-18 20:41:30 +03:00
2023-05-06 18:59:02 +03:00
if key in default_payload:
2024-05-17 22:24:12 +03:00
payload[key] = value
2024-05-20 10:29:34 +03:00
elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
2024-05-18 20:41:30 +03:00
script_name = "x/y/z plot"
2024-05-20 10:29:34 +03:00
if slot_index < 3:
script_slot = slot_mapping[slot_index]
if key == "xsr":
script_args[script_slot[0]] = 7 # Enum value for xsr
script_args[script_slot[1]] = value
elif key == "xsteps":
script_args[script_slot[0]] = 4 # Enum value for xsteps
script_args[script_slot[1]] = value
elif key == "xds":
script_args[script_slot[0]] = 22 # Enum value for xds
script_args[script_slot[1]] = value
elif key == "xcfg":
script_args[script_slot[0]] = 6 # Enum value for CFG Scale
script_args[script_slot[1]] = value
slot_index += 1
elif key == "nl":
script_args[9] = False # Draw legend
elif key == "ks":
script_args[10] = True # Keep sub images
elif key == "rs":
script_args[11] = True # Set random seed to sub images
2023-04-28 10:02:47 +03:00
else:
2024-05-18 20:41:30 +03:00
prompt.append(f"{key}:{value}")
2024-06-03 13:35:14 +03:00
2024-05-18 20:41:30 +03:00
last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip()
2023-05-06 18:59:02 +03:00
if not payload["prompt"]:
2023-05-06 21:09:46 +03:00
payload["prompt"] = input_string.strip()
2023-05-06 18:59:02 +03:00
2024-05-18 20:41:30 +03:00
if script_name:
payload["script_name"] = script_name
payload["script_args"] = script_args
2024-06-09 11:33:02 +03:00
print(f"Generated payload: {payload}")
2024-05-26 09:00:32 +03:00
return payload, include_info
2024-05-26 08:49:27 +03:00
2024-05-26 09:00:32 +03:00
def create_caption(payload, user_name, user_id, info, include_info):
2024-06-03 13:35:14 +03:00
"""Create a caption for the generated image."""
2024-05-18 20:41:30 +03:00
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
2024-05-20 10:29:34 +03:00
seed_pattern = r"Seed: (\d+)"
match = re.search(seed_pattern, info)
if match:
seed_value = match.group(1)
caption += f"**{seed_value}**\n"
else:
print("Seed value not found in the info string.")
2024-05-18 20:41:30 +03:00
caption += f"**{prompt}**\n"
2024-05-26 09:00:32 +03:00
if include_info:
caption += f"\nFull Payload:\n`{payload}`\n"
2024-05-18 20:41:30 +03:00
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
2024-05-16 01:05:35 +03:00
def call_api(api_endpoint, payload):
2024-06-03 13:35:14 +03:00
"""Call the API with the provided payload."""
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"API call failed: {e}")
2024-06-03 13:35:14 +03:00
return {"error": str(e)}
2024-05-16 01:05:35 +03:00
def process_images(images, user_id, user_name):
2024-06-03 13:35:14 +03:00
"""Process and save generated images."""
2024-05-16 01:05:35 +03:00
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
2024-06-03 14:14:25 +03:00
date = datetime.now().strftime("%Y-%m-%d-%H-%M")
return f"{date}-{user_name}-{unique_id}"
2024-05-16 01:05:35 +03:00
word = generate_unique_name()
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
2024-06-09 10:39:34 +03:00
# Write response2 json next to the image
with open(f"{IMAGE_PATH}/{word}.json", "w") as json_file:
json.dump(response2.json(), json_file)
2024-05-16 01:05:35 +03:00
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
2024-06-03 13:35:14 +03:00
# Save as JPG
jpg_path = f"{IMAGE_PATH}/{word}.jpg"
image.convert("RGB").save(jpg_path, "JPEG")
2023-03-18 00:04:54 +02:00
2024-06-03 14:14:25 +03:00
return word, response2.json().get("info")
2024-05-26 09:00:32 +03:00
2023-01-12 06:37:56 +02:00
@app.on_message(filters.command(["draw"]))
def draw(client, message):
2024-06-03 13:35:14 +03:00
"""Handle /draw command to generate images from text prompts."""
2023-04-22 22:18:10 +03:00
msgs = message.text.split(" ", 1)
2023-01-12 06:37:56 +02:00
if len(msgs) == 1:
2024-05-16 01:05:35 +03:00
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
2023-01-12 06:37:56 +02:00
return
2024-05-26 09:00:32 +03:00
payload, include_info = parse_input(msgs[1])
2023-01-12 06:37:56 +02:00
2024-05-26 08:49:27 +03:00
if "xds" in msgs[1].lower():
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
return
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload)
2023-04-28 10:02:47 +03:00
2024-06-03 13:35:14 +03:00
if r and "images" in r:
2024-05-17 14:31:15 +03:00
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-26 09:00:32 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
2024-06-03 13:35:14 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
2024-06-03 13:35:14 +03:00
error_message = r.get("error", "Failed to generate image. Please try again later.")
message.reply_text(error_message)
2024-05-16 01:05:35 +03:00
K.delete()
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
@app.on_message(filters.command(["img"]))
def img2img(client, message):
2024-06-03 13:35:14 +03:00
"""Handle /img command to generate images from existing images."""
2024-05-16 01:05:35 +03:00
if not message.reply_to_message or not message.reply_to_message.photo:
2024-05-18 20:41:30 +03:00
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-16 01:05:35 +03:00
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
2024-06-03 13:35:14 +03:00
message.reply_text("Don't FAIL in life")
2024-05-16 01:05:35 +03:00
return
2023-01-12 06:37:56 +02:00
2024-05-26 09:00:32 +03:00
payload, include_info = parse_input(msgs[1])
2024-05-16 01:05:35 +03:00
photo = message.reply_to_message.photo
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file
2023-05-06 21:09:46 +03:00
2024-05-16 01:05:35 +03:00
payload["init_images"] = [init_image]
2023-05-06 18:59:02 +03:00
2024-05-16 01:05:35 +03:00
K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload)
2023-05-06 18:59:02 +03:00
2024-06-03 13:35:14 +03:00
if r and "images" in r:
2024-05-17 14:31:15 +03:00
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
2024-05-26 09:00:32 +03:00
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
2024-06-03 13:35:14 +03:00
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
2024-05-18 14:06:29 +03:00
K.delete()
2024-05-17 14:31:15 +03:00
else:
2024-06-03 13:35:14 +03:00
error_message = r.get("error", "Failed to process image. Please try again later.")
message.reply_text(error_message)
2023-01-12 06:37:56 +02:00
K.delete()
2023-03-18 00:04:54 +02:00
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
2024-06-03 13:35:14 +03:00
"""Handle /getmodels command to list available models."""
2024-05-17 14:31:15 +03:00
try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status()
models_json = response.json()
buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json
]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
2023-04-22 22:18:10 +03:00
2023-03-18 00:04:54 +02:00
@app.on_callback_query()
async def process_callback(client, callback_query):
2024-06-03 13:35:14 +03:00
"""Process model selection from callback queries."""
2023-03-18 00:04:54 +02:00
sd_model_checkpoint = callback_query.data
2023-04-22 22:18:10 +03:00
options = {"sd_model_checkpoint": sd_model_checkpoint}
2023-03-18 00:04:54 +02:00
2024-05-17 14:31:15 +03:00
try:
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status()
2024-05-26 08:49:27 +03:00
update_negative_prompt(sd_model_checkpoint)
2024-06-03 13:35:14 +03:00
update_resolution(sd_model_checkpoint)
update_cfg_scale(sd_model_checkpoint)
2024-05-26 08:49:27 +03:00
2024-05-17 14:31:15 +03:00
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
2024-05-26 08:49:27 +03:00
print(f"Error setting checkpoint: {e}")
2024-05-18 20:41:30 +03:00
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
2024-06-03 13:35:14 +03:00
"""Provide information about the bot's commands and options."""
2024-05-18 20:41:30 +03:00
await message.reply_text("""
2024-05-26 09:00:32 +03:00
**Stable Diffusion Bot Commands and Options:**
1. **/draw <prompt> [options]**
- Generates an image based on the provided text prompt.
- **Options:**
- `ng:<negative_prompt>` - Add a negative prompt to avoid specific features.
- `steps:<value>` - Number of steps for generation (1-70).
- `ds:<value>` - Denoising strength (0-1.0).
- `cfg:<value>` - CFG scale (1-30).
2024-05-26 09:05:27 +03:00
- `width:<value>` - Width of the generated image.
- `height:<value>` - Height of the generated image.
2024-05-26 09:00:32 +03:00
- `info:` - Include full payload information in the caption.
**Example:** `/draw beautiful sunset ng:ugly steps:30 ds:0.5 info:`
2. **/img <prompt> [options]**
- Generates an image based on an existing image and the provided text prompt.
- **Options:**
- `ds:<value>` - Denoising strength (0-1.0).
- `steps:<value>` - Number of steps for generation (1-70).
- `cfg:<value>` - CFG scale (1-30).
2024-05-26 09:05:27 +03:00
- `width:<value>` - Width of the generated image.
- `height:<value>` - Height of the generated image.
2024-05-26 09:00:32 +03:00
- `info:` - Include full payload information in the caption.
**Example:** Reply to an image with `/img modern art ds:0.2 info:`
3. **/getmodels**
- Retrieves and lists all available models for the user to select.
- User can then choose a model to set as the current model for image generation.
4. **/info_sd_bot**
- Provides detailed information about the bot's commands and options.
**Additional Options for Advanced Users:**
- **x/y/z plot options** for advanced generation:
- `xsr:<value>` - Search and replace text/emoji in the prompt.
- `xsteps:<value>` - Steps value for x/y/z plot.
- `xds:<value>` - Denoising strength for x/y/z plot.
- `xcfg:<value>` - CFG scale for x/y/z plot.
2024-05-26 09:05:27 +03:00
- `nl:` - No legend in x/y/z plot.
- `ks:` - Keep sub-images in x/y/z plot.
- `rs:` - Set random seed for sub-images in x/y/z plot.
2024-05-26 09:00:32 +03:00
**Notes:**
- Use lower step values (10-20) for large x/y/z plots to avoid long processing times.
- Use `info:` option to include full payload details in the caption of generated images for better troubleshooting and analysis.
2024-05-26 09:05:27 +03:00
**Example for Advanced Users:** `/draw beautiful landscape xsteps:10 xds:0.5 xcfg:7 nl: ks: rs: info:`
2024-05-26 09:00:32 +03:00
2024-05-26 09:05:27 +03:00
For the bot code visit: [Stable Diffusion Bot](https://git.telavivmakers.space/ro/stable-diffusion-telegram-bot)
2024-05-26 09:00:32 +03:00
For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot).
Enjoy creating with Stable Diffusion Bot!
2024-05-26 08:49:27 +03:00
""", disable_web_page_preview=True)
2023-01-12 06:37:56 +02:00
app.run()