size limit

This commit is contained in:
tami-p40 2024-06-03 13:35:14 +03:00
parent f62d6a6dbc
commit 49c9f6337a

84
main.py
View File

@ -17,6 +17,10 @@ API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img") TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL") SD_URL = os.environ.get("SD_URL")
# Ensure all required environment variables are loaded
if not all([API_ID, API_HASH, TOKEN, SD_URL]):
raise EnvironmentError("Missing one or more required environment variables: API_ID, API_HASH, TOKEN, SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' IMAGE_PATH = 'images'
@ -25,13 +29,14 @@ os.makedirs(IMAGE_PATH, exist_ok=True)
# Model-specific embeddings for negative prompts # Model-specific embeddings for negative prompts
model_negative_prompts = { model_negative_prompts = {
"coloringPage_v10": "fake",
"Anything-Diffusion": "", "Anything-Diffusion": "",
"Deliberate": "", "Deliberate": "",
"Dreamshaper": "", "Dreamshaper": "",
"DreamShaperXL_Lightning": "", "DreamShaperXL_Lightning": "",
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding", "realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
"v1-5-pruned-emaonly": "", "v1-5-pruned-emaonly": "",
"Juggernaut-XL_v9_RunDiffusionPhoto_v2":"bad eyes, cgi, airbrushed, plastic, watermark" "Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
} }
@ -39,15 +44,17 @@ def encode_file_to_base64(path):
with open(path, 'rb') as file: with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8') return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path): def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file: with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str)) file.write(base64.b64decode(base64_str))
# Set default payload values # Set default payload values
default_payload = { default_payload = {
"prompt": "", "prompt": "",
"seed": -1, # Random seed "seed": -1, # Random seed
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs", "negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
"enable_hr": False, "enable_hr": False,
"Sampler": "DPM++ SDE Karras", "Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35, "denoising_strength": 0.35,
@ -62,12 +69,34 @@ default_payload = {
"override_settings_restore_afterwards": True, "override_settings_restore_afterwards": True,
} }
def update_negative_prompt(model_name): def update_negative_prompt(model_name):
"""Update the negative prompt for a given model."""
if model_name in model_negative_prompts: if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name] suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}" default_payload["negative_prompt"] += f", {suffix}"
def update_resolution(model_name):
"""Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["width"] = 832
default_payload["height"] = 1216
else:
default_payload["width"] = 512
default_payload["height"] = 512
def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["cfg_scale"] = 1
else:
default_payload["cfg_scale"] = 7
def parse_input(input_string): def parse_input(input_string):
"""Parse the input string and create a payload."""
payload = default_payload.copy() payload = default_payload.copy()
prompt = [] prompt = []
include_info = "info:" in input_string include_info = "info:" in input_string
@ -128,12 +157,7 @@ def parse_input(input_string):
script_args[11] = True # Set random seed to sub images script_args[11] = True # Set random seed to sub images
else: else:
prompt.append(f"{key}:{value}") prompt.append(f"{key}:{value}")
# Adjust dimensions for specific model
if "Juggernaut-XL_v9_RunDiffusionPhoto_v2" in input_string:
payload["width"] = 823
payload["height"] = 1216
last_index = value_end_index last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip() payload["prompt"] = " ".join(prompt).strip()
@ -147,18 +171,13 @@ def parse_input(input_string):
return payload, include_info return payload, include_info
def create_caption(payload, user_name, user_id, info, include_info): def create_caption(payload, user_name, user_id, info, include_info):
"""Create a caption for the generated image."""
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"] prompt = payload["prompt"]
# Define a regular expression pattern to match the seed value
seed_pattern = r"Seed: (\d+)" seed_pattern = r"Seed: (\d+)"
# Search for the pattern in the info string
match = re.search(seed_pattern, info) match = re.search(seed_pattern, info)
# Check if a match was found and extract the seed value
if match: if match:
seed_value = match.group(1) seed_value = match.group(1)
caption += f"**{seed_value}**\n" caption += f"**{seed_value}**\n"
@ -177,15 +196,18 @@ def create_caption(payload, user_name, user_id, info, include_info):
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
"""Call the API with the provided payload."""
try: try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.RequestException as e: except requests.RequestException as e:
print(f"API call failed: {e}") print(f"API call failed: {e}")
return None return {"error": str(e)}
def process_images(images, user_id, user_name): def process_images(images, user_id, user_name):
"""Process and save generated images."""
def generate_unique_name(): def generate_unique_name():
unique_id = str(uuid.uuid4())[:7] unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}" return f"{user_name}-{unique_id}"
@ -202,11 +224,16 @@ def process_images(images, user_id, user_name):
pnginfo.add_text("parameters", response2.json().get("info")) pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo) image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
# Save as JPG
jpg_path = f"{IMAGE_PATH}/{word}.jpg"
image.convert("RGB").save(jpg_path, "JPEG")
return word, response2.json().get("info") return word, response2.json().get("info")
@app.on_message(filters.command(["draw"])) @app.on_message(filters.command(["draw"]))
def draw(client, message): def draw(client, message):
"""Handle /draw command to generate images from text prompts."""
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >") message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
@ -214,7 +241,6 @@ def draw(client, message):
payload, include_info = parse_input(msgs[1]) payload, include_info = parse_input(msgs[1])
# Check if xds is used in the payload
if "xds" in msgs[1].lower(): if "xds" in msgs[1].lower():
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.") message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
return return
@ -222,25 +248,28 @@ def draw(client, message):
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload) r = call_api('sdapi/v1/txt2img', payload)
if r: if r and "images" in r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info) caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to generate image. Please try again later.") error_message = r.get("error", "Failed to generate image. Please try again later.")
message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
"""Handle /img command to generate images from existing images."""
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`") message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
return return
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("dont FAIL in life") message.reply_text("Don't FAIL in life")
return return
payload, include_info = parse_input(msgs[1]) payload, include_info = parse_input(msgs[1])
@ -254,19 +283,21 @@ def img2img(client, message):
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload) r = call_api('sdapi/v1/img2img', payload)
if r: if r and "images" in r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info) caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to process image. Please try again later.") error_message = r.get("error", "Failed to process image. Please try again later.")
message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["getmodels"])) @app.on_message(filters.command(["getmodels"]))
async def get_models(client, message): async def get_models(client, message):
"""Handle /getmodels command to list available models."""
try: try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status() response.raise_for_status()
@ -279,8 +310,10 @@ async def get_models(client, message):
except requests.RequestException as e: except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}") await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query() @app.on_callback_query()
async def process_callback(client, callback_query): async def process_callback(client, callback_query):
"""Process model selection from callback queries."""
sd_model_checkpoint = callback_query.data sd_model_checkpoint = callback_query.data
options = {"sd_model_checkpoint": sd_model_checkpoint} options = {"sd_model_checkpoint": sd_model_checkpoint}
@ -288,16 +321,19 @@ async def process_callback(client, callback_query):
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status() response.raise_for_status()
# Update the negative prompt based on the selected model
update_negative_prompt(sd_model_checkpoint) update_negative_prompt(sd_model_checkpoint)
update_resolution(sd_model_checkpoint)
update_cfg_scale(sd_model_checkpoint)
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}") await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e: except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}") print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"])) @app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message): async def info(client, message):
"""Provide information about the bot's commands and options."""
await message.reply_text(""" await message.reply_text("""
**Stable Diffusion Bot Commands and Options:** **Stable Diffusion Bot Commands and Options:**