286 lines
10 KiB
Python
286 lines
10 KiB
Python
import os
|
|
import re
|
|
import io
|
|
import uuid
|
|
import base64
|
|
import requests
|
|
from datetime import datetime
|
|
from PIL import Image, PngImagePlugin
|
|
from pyrogram import Client, filters
|
|
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
API_ID = os.environ.get("API_ID")
|
|
API_HASH = os.environ.get("API_HASH")
|
|
TOKEN = os.environ.get("TOKEN_givemtxt2img")
|
|
SD_URL = os.environ.get("SD_URL")
|
|
|
|
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
|
|
IMAGE_PATH = 'images'
|
|
|
|
# Ensure IMAGE_PATH directory exists
|
|
os.makedirs(IMAGE_PATH, exist_ok=True)
|
|
|
|
def timestamp():
|
|
return datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
|
|
def encode_file_to_base64(path):
|
|
with open(path, 'rb') as file:
|
|
return base64.b64encode(file.read()).decode('utf-8')
|
|
|
|
def decode_and_save_base64(base64_str, save_path):
|
|
with open(save_path, "wb") as file:
|
|
file.write(base64.b64decode(base64_str))
|
|
|
|
|
|
|
|
def parse_input(input_string):
|
|
# Set default payload values
|
|
default_payload = {
|
|
"prompt": "",
|
|
"seed": -1, # Random seed
|
|
"negative_prompt": "ugly, bad face, distorted",
|
|
"enable_hr": False,
|
|
"Sampler": "DPM++ SDE Karras",
|
|
"denoising_strength": 0.35,
|
|
"batch_size": 1,
|
|
"n_iter": 1,
|
|
"steps": 35,
|
|
"cfg_scale": 7,
|
|
"width": 512,
|
|
"height": 512,
|
|
"restore_faces": False,
|
|
"override_settings": {},
|
|
"override_settings_restore_afterwards": True,
|
|
}
|
|
payload = default_payload.copy()
|
|
prompt = []
|
|
|
|
matches = re.finditer(r"(\w+):", input_string)
|
|
last_index = 0
|
|
|
|
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
|
script_name = None
|
|
|
|
|
|
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
|
script_name = None
|
|
|
|
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
|
|
slot_index = 0
|
|
|
|
for match in matches:
|
|
key = match.group(1).lower()
|
|
value_start_index = match.end()
|
|
if last_index != match.start():
|
|
prompt.append(input_string[last_index: match.start()].strip())
|
|
last_index = value_start_index
|
|
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
|
|
if value_end_match:
|
|
value_end_index = value_end_match.start() + value_start_index
|
|
else:
|
|
value_end_index = len(input_string)
|
|
value = input_string[value_start_index: value_end_index].strip()
|
|
if key == "ds":
|
|
key = "denoising_strength"
|
|
if key == "ng":
|
|
key = "negative_prompt"
|
|
if key == "cfg":
|
|
key = "cfg_scale"
|
|
|
|
if key in default_payload:
|
|
payload[key] = value
|
|
elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
|
|
script_name = "x/y/z plot"
|
|
if slot_index < 3:
|
|
script_slot = slot_mapping[slot_index]
|
|
if key == "xsr":
|
|
script_args[script_slot[0]] = 7 # Enum value for xsr
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xsteps":
|
|
script_args[script_slot[0]] = 4 # Enum value for xsteps
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xds":
|
|
script_args[script_slot[0]] = 22 # Enum value for xds
|
|
script_args[script_slot[1]] = value
|
|
elif key == "xcfg":
|
|
script_args[script_slot[0]] = 6 # Enum value for CFG Scale
|
|
script_args[script_slot[1]] = value
|
|
slot_index += 1
|
|
elif key == "nl":
|
|
script_args[9] = False # Draw legend
|
|
elif key == "ks":
|
|
script_args[10] = True # Keep sub images
|
|
elif key == "rs":
|
|
script_args[11] = True # Set random seed to sub images
|
|
else:
|
|
prompt.append(f"{key}:{value}")
|
|
|
|
last_index = value_end_index
|
|
|
|
payload["prompt"] = " ".join(prompt).strip()
|
|
if not payload["prompt"]:
|
|
payload["prompt"] = input_string.strip()
|
|
|
|
if script_name:
|
|
payload["script_name"] = script_name
|
|
payload["script_args"] = script_args
|
|
|
|
return payload
|
|
def create_caption(payload, user_name, user_id, info):
|
|
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
|
prompt = payload["prompt"]
|
|
print(payload["prompt"])
|
|
print(info)
|
|
# Steps: 3, Sampler: Euler, CFG scale: 7.0, Seed: 4094161400, Size: 512x512, Model hash: 15012c538f, Model: realisticVisionV60B1_v51VAE, Denoising strength: 0.35, Version: v1.8.0-1-g20cdc7c
|
|
|
|
# Define a regular expression pattern to match the seed value
|
|
seed_pattern = r"Seed: (\d+)"
|
|
|
|
# Search for the pattern in the info string
|
|
match = re.search(seed_pattern, info)
|
|
|
|
# Check if a match was found and extract the seed value
|
|
if match:
|
|
seed_value = match.group(1)
|
|
print(f"Seed value: {seed_value}")
|
|
caption += f"**{seed_value}**\n"
|
|
else:
|
|
print("Seed value not found in the info string.")
|
|
|
|
caption += f"**{prompt}**\n"
|
|
|
|
if len(caption) > 1024:
|
|
caption = caption[:1021] + "..."
|
|
|
|
return caption
|
|
|
|
def call_api(api_endpoint, payload):
|
|
try:
|
|
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except requests.RequestException as e:
|
|
print(f"API call failed: {e}")
|
|
return None
|
|
|
|
def process_images(images, user_id, user_name):
|
|
def generate_unique_name():
|
|
unique_id = str(uuid.uuid4())[:7]
|
|
return f"{user_name}-{unique_id}"
|
|
|
|
word = generate_unique_name()
|
|
|
|
for i in images:
|
|
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
|
|
png_payload = {"image": "data:image/png;base64," + i}
|
|
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
|
|
response2.raise_for_status()
|
|
|
|
pnginfo = PngImagePlugin.PngInfo()
|
|
pnginfo.add_text("parameters", response2.json().get("info"))
|
|
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
|
|
|
|
return word, response2.json().get("info")
|
|
|
|
@app.on_message(filters.command(["draw"]))
|
|
def draw(client, message):
|
|
msgs = message.text.split(" ", 1)
|
|
if len(msgs) == 1:
|
|
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
|
|
return
|
|
|
|
payload = parse_input(msgs[1])
|
|
print(payload)
|
|
|
|
K = message.reply_text("Please Wait 10-15 Seconds")
|
|
r = call_api('sdapi/v1/txt2img', payload)
|
|
|
|
if r:
|
|
for i in r["images"]:
|
|
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
|
|
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
|
|
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
|
|
K.delete()
|
|
else:
|
|
message.reply_text("Failed to generate image. Please try again later.")
|
|
K.delete()
|
|
|
|
@app.on_message(filters.command(["img"]))
|
|
def img2img(client, message):
|
|
if not message.reply_to_message or not message.reply_to_message.photo:
|
|
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
|
|
return
|
|
|
|
msgs = message.text.split(" ", 1)
|
|
if len(msgs) == 1:
|
|
message.reply_text("dont FAIL in life")
|
|
return
|
|
|
|
payload = parse_input(msgs[1])
|
|
print(f"input:\n{payload}")
|
|
photo = message.reply_to_message.photo
|
|
# prompt_from_reply = message.reply_to_message.
|
|
# orginal_prompt = app.reply_to_message.message
|
|
# print(orginal_prompt)
|
|
photo_file = app.download_media(photo)
|
|
init_image = encode_file_to_base64(photo_file)
|
|
os.remove(photo_file) # Clean up downloaded image file
|
|
|
|
payload["init_images"] = [init_image]
|
|
|
|
K = message.reply_text("Please Wait 10-15 Seconds")
|
|
r = call_api('sdapi/v1/img2img', payload)
|
|
|
|
if r:
|
|
for i in r["images"]:
|
|
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
|
|
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
|
|
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
|
|
K.delete()
|
|
else:
|
|
message.reply_text("Failed to process image. Please try again later.")
|
|
K.delete()
|
|
|
|
@app.on_message(filters.command(["getmodels"]))
|
|
async def get_models(client, message):
|
|
try:
|
|
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
|
|
response.raise_for_status()
|
|
models_json = response.json()
|
|
|
|
buttons = [
|
|
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
|
|
for model in models_json
|
|
]
|
|
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
|
|
except requests.RequestException as e:
|
|
await message.reply_text(f"Failed to get models: {e}")
|
|
|
|
@app.on_callback_query()
|
|
async def process_callback(client, callback_query):
|
|
sd_model_checkpoint = callback_query.data
|
|
options = {"sd_model_checkpoint": sd_model_checkpoint}
|
|
|
|
try:
|
|
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
|
|
response.raise_for_status()
|
|
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
|
|
except requests.RequestException as e:
|
|
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
|
|
|
@app.on_message(filters.command(["info_sd_bot"]))
|
|
async def info(client, message):
|
|
await message.reply_text("""
|
|
now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
|
|
currently supported
|
|
`xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
|
|
`xds` - denoise strength, only valid for img2img
|
|
`xsteps` - steps
|
|
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
|
|
""")
|
|
|
|
app.run()
|