324 lines
12 KiB
Python
324 lines
12 KiB
Python
import os
|
|
import re
|
|
import io
|
|
import uuid
|
|
import base64
|
|
import requests
|
|
from datetime import datetime
|
|
from PIL import Image, PngImagePlugin
|
|
from pyrogram import Client, filters
|
|
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
API_ID = os.environ.get("API_ID")
|
|
API_HASH = os.environ.get("API_HASH")
|
|
TOKEN = os.environ.get("TOKEN_givemtxt2img")
|
|
SD_URL = os.environ.get("SD_URL")
|
|
|
|
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
|
|
IMAGE_PATH = 'images'
|
|
|
|
# Ensure IMAGE_PATH directory exists
|
|
os.makedirs(IMAGE_PATH, exist_ok=True)
|
|
|
|
# Model-specific emmbedings for negative prompts
|
|
# see civit.ai model page for specific emmbedings recommnded for each model
|
|
model_negative_prompts = {
|
|
"Anything-Diffusion": "",
|
|
"Deliberate": "",
|
|
"Dreamshaper": "",
|
|
"DreamShaperXL_Lightning": "",
|
|
"icbinp": "",
|
|
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
|
|
"v1-5-pruned-emaonly": ""
|
|
}
|
|
|
|
|
|
def encode_file_to_base64(path):
|
|
with open(path, 'rb') as file:
|
|
return base64.b64encode(file.read()).decode('utf-8')
|
|
|
|
|
|
def decode_and_save_base64(base64_str, save_path):
|
|
with open(save_path, "wb") as file:
|
|
file.write(base64.b64decode(base64_str))
|
|
|
|
|
|
|
|
# Set default payload values
|
|
default_payload = {
|
|
"prompt": "",
|
|
"seed": -1, # Random seed
|
|
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
|
|
"enable_hr": False,
|
|
"Sampler": "DPM++ SDE Karras",
|
|
"denoising_strength": 0.35,
|
|
"batch_size": 1,
|
|
"n_iter": 1,
|
|
"steps": 35,
|
|
"cfg_scale": 7,
|
|
"width": 512,
|
|
"height": 512,
|
|
"restore_faces": False,
|
|
"override_settings": {},
|
|
"override_settings_restore_afterwards": True,
|
|
}
|
|
|
|
def update_negative_prompt(model_name):
|
|
if model_name in model_negative_prompts:
|
|
suffix = model_negative_prompts[model_name]
|
|
default_payload["negative_prompt"] += f", {suffix}"
|
|
|
|
def parse_input(input_string):
|
|
payload = default_payload.copy()
|
|
prompt = []
|
|
|
|
matches = re.finditer(r"(\w+):", input_string)
|
|
last_index = 0
|
|
|
|
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
|
script_name = None
|
|
|
|
|
|
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
|
script_name = None
|
|
|
|
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
|
|
slot_index = 0
|
|
|
|
for match in matches:
|
|
key = match.group(1).lower()
|
|
value_start_index = match.end()
|
|
if last_index != match.start():
|
|
prompt.append(input_string[last_index: match.start()].strip())
|
|
last_index = value_start_index
|
|
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
|
|
if value_end_match:
|
|
value_end_index = value_end_match.start() + value_start_index
|
|
else:
|
|
value_end_index = len(input_string)
|
|
value = input_string[value_start_index: value_end_index].strip()
|
|
if key == "ds":
|
|
key = "denoising_strength"
|
|
if key == "ng":
|
|
key = "negative_prompt"
|
|
if key == "cfg":
|
|
key = "cfg_scale"
|
|
|
|
if key in default_payload:
|
|
payload[key] = value
|
|
elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
|
|
script_name = "x/y/z plot"
|
|
if slot_index < 3:
|
|
script_slot = slot_mapping[slot_index]
|
|
if key == "xsr":
|
|
script_args[script_slot[0]] = 7 # Enum value for xsr
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xsteps":
|
|
script_args[script_slot[0]] = 4 # Enum value for xsteps
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xds":
|
|
script_args[script_slot[0]] = 22 # Enum value for xds
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xcfg":
|
|
script_args[script_slot[0]] = 6 # Enum value for CFG Scale
|
|
script_args[script_slot[1]] = value
|
|
slot_index += 1
|
|
elif key == "nl":
|
|
script_args[9] = False # Draw legend
|
|
elif key == "ks":
|
|
script_args[10] = True # Keep sub images
|
|
elif key == "rs":
|
|
script_args[11] = True # Set random seed to sub images
|
|
else:
|
|
prompt.append(f"{key}:{value}")
|
|
|
|
last_index = value_end_index
|
|
|
|
payload["prompt"] = " ".join(prompt).strip()
|
|
if not payload["prompt"]:
|
|
payload["prompt"] = input_string.strip()
|
|
|
|
if script_name:
|
|
payload["script_name"] = script_name
|
|
payload["script_args"] = script_args
|
|
|
|
return payload
|
|
|
|
|
|
def create_caption(payload, user_name, user_id, info):
|
|
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
|
prompt = payload["prompt"]
|
|
print(payload["prompt"])
|
|
print(info)
|
|
# Steps: 3, Sampler: Euler, CFG scale: 7.0, Seed: 4094161400, Size: 512x512, Model hash: 15012c538f, Model: realisticVisionV60B1_v51VAE, Denoising strength: 0.35, Version: v1.8.0-1-g20cdc7c
|
|
|
|
# Define a regular expression pattern to match the seed value
|
|
seed_pattern = r"Seed: (\d+)"
|
|
|
|
# Search for the pattern in the info string
|
|
match = re.search(seed_pattern, info)
|
|
|
|
# Check if a match was found and extract the seed value
|
|
if match:
|
|
seed_value = match.group(1)
|
|
print(f"Seed value: {seed_value}")
|
|
caption += f"**{seed_value}**\n"
|
|
else:
|
|
print("Seed value not found in the info string.")
|
|
|
|
caption += f"**{prompt}**\n"
|
|
|
|
if len(caption) > 1024:
|
|
caption = caption[:1021] + "..."
|
|
|
|
return caption
|
|
|
|
|
|
def call_api(api_endpoint, payload):
|
|
try:
|
|
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except requests.RequestException as e:
|
|
print(f"API call failed: {e}")
|
|
return None
|
|
|
|
|
|
def process_images(images, user_id, user_name):
|
|
def generate_unique_name():
|
|
unique_id = str(uuid.uuid4())[:7]
|
|
return f"{user_name}-{unique_id}"
|
|
|
|
word = generate_unique_name()
|
|
|
|
for i in images:
|
|
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
|
|
png_payload = {"image": "data:image/png;base64," + i}
|
|
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
|
|
response2.raise_for_status()
|
|
|
|
pnginfo = PngImagePlugin.PngInfo()
|
|
pnginfo.add_text("parameters", response2.json().get("info"))
|
|
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
|
|
|
|
return word, response2.json().get("info")
|
|
|
|
@app.on_message(filters.command(["draw"]))
|
|
def draw(client, message):
|
|
msgs = message.text.split(" ", 1)
|
|
if len(msgs) == 1:
|
|
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
|
|
return
|
|
|
|
payload = parse_input(msgs[1])
|
|
print(payload)
|
|
|
|
# Check if xds is used in the payload
|
|
if "xds" in msgs[1].lower():
|
|
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
|
|
return
|
|
|
|
K = message.reply_text("Please Wait 10-15 Seconds")
|
|
r = call_api('sdapi/v1/txt2img', payload)
|
|
|
|
if r:
|
|
for i in r["images"]:
|
|
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
|
|
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
|
|
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
|
|
K.delete()
|
|
else:
|
|
message.reply_text("Failed to generate image. Please try again later.")
|
|
K.delete()
|
|
|
|
|
|
@app.on_message(filters.command(["img"]))
|
|
def img2img(client, message):
|
|
if not message.reply_to_message or not message.reply_to_message.photo:
|
|
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
|
|
return
|
|
|
|
msgs = message.text.split(" ", 1)
|
|
if len(msgs) == 1:
|
|
message.reply_text("dont FAIL in life")
|
|
return
|
|
|
|
payload = parse_input(msgs[1])
|
|
print(f"input:\n{payload}")
|
|
photo = message.reply_to_message.photo
|
|
# prompt_from_reply = message.reply_to_message.
|
|
# orginal_prompt = app.reply_to_message.message
|
|
# print(orginal_prompt)
|
|
photo_file = app.download_media(photo)
|
|
init_image = encode_file_to_base64(photo_file)
|
|
os.remove(photo_file) # Clean up downloaded image file
|
|
|
|
payload["init_images"] = [init_image]
|
|
|
|
K = message.reply_text("Please Wait 10-15 Seconds")
|
|
r = call_api('sdapi/v1/img2img', payload)
|
|
|
|
if r:
|
|
for i in r["images"]:
|
|
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
|
|
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
|
|
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
|
|
K.delete()
|
|
else:
|
|
message.reply_text("Failed to process image. Please try again later.")
|
|
K.delete()
|
|
|
|
|
|
@app.on_message(filters.command(["getmodels"]))
|
|
async def get_models(client, message):
|
|
try:
|
|
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
|
|
response.raise_for_status()
|
|
models_json = response.json()
|
|
print(models_json)
|
|
buttons = [
|
|
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
|
|
for model in models_json
|
|
]
|
|
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
|
|
except requests.RequestException as e:
|
|
await message.reply_text(f"Failed to get models: {e}")
|
|
|
|
|
|
@app.on_callback_query()
|
|
async def process_callback(client, callback_query):
|
|
sd_model_checkpoint = callback_query.data
|
|
options = {"sd_model_checkpoint": sd_model_checkpoint}
|
|
|
|
try:
|
|
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
|
|
response.raise_for_status()
|
|
|
|
# Update the negative prompt based on the selected model
|
|
update_negative_prompt(sd_model_checkpoint)
|
|
|
|
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
|
|
except requests.RequestException as e:
|
|
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
|
print(f"Error setting checkpoint: {e}")
|
|
|
|
|
|
@app.on_message(filters.command(["info_sd_bot"]))
|
|
async def info(client, message):
|
|
await message.reply_text("""
|
|
now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
|
|
currently supported
|
|
`xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
|
|
`xds` - denoise strength, only valid for img2img
|
|
`xsteps` - steps
|
|
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
|
|
|
|
aside from that you can use the usual `ng`, `ds`, `cfg`, `steps` for single image generation.
|
|
""", disable_web_page_preview=True)
|
|
|
|
|
|
app.run()
|