size limit

This commit is contained in:
tami-p40 2024-06-03 13:35:14 +03:00
parent f62d6a6dbc
commit 49c9f6337a

80
main.py
View File

@ -17,6 +17,10 @@ API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img") TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL") SD_URL = os.environ.get("SD_URL")
# Ensure all required environment variables are loaded
if not all([API_ID, API_HASH, TOKEN, SD_URL]):
raise EnvironmentError("Missing one or more required environment variables: API_ID, API_HASH, TOKEN, SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' IMAGE_PATH = 'images'
@ -25,13 +29,14 @@ os.makedirs(IMAGE_PATH, exist_ok=True)
# Model-specific embeddings for negative prompts # Model-specific embeddings for negative prompts
model_negative_prompts = { model_negative_prompts = {
"coloringPage_v10": "fake",
"Anything-Diffusion": "", "Anything-Diffusion": "",
"Deliberate": "", "Deliberate": "",
"Dreamshaper": "", "Dreamshaper": "",
"DreamShaperXL_Lightning": "", "DreamShaperXL_Lightning": "",
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding", "realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
"v1-5-pruned-emaonly": "", "v1-5-pruned-emaonly": "",
"Juggernaut-XL_v9_RunDiffusionPhoto_v2":"bad eyes, cgi, airbrushed, plastic, watermark" "Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
} }
@ -39,10 +44,12 @@ def encode_file_to_base64(path):
with open(path, 'rb') as file: with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8') return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path): def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file: with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str)) file.write(base64.b64decode(base64_str))
# Set default payload values # Set default payload values
default_payload = { default_payload = {
"prompt": "", "prompt": "",
@ -62,12 +69,34 @@ default_payload = {
"override_settings_restore_afterwards": True, "override_settings_restore_afterwards": True,
} }
def update_negative_prompt(model_name): def update_negative_prompt(model_name):
"""Update the negative prompt for a given model."""
if model_name in model_negative_prompts: if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name] suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}" default_payload["negative_prompt"] += f", {suffix}"
def update_resolution(model_name):
"""Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["width"] = 832
default_payload["height"] = 1216
else:
default_payload["width"] = 512
default_payload["height"] = 512
def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["cfg_scale"] = 1
else:
default_payload["cfg_scale"] = 7
def parse_input(input_string): def parse_input(input_string):
"""Parse the input string and create a payload."""
payload = default_payload.copy() payload = default_payload.copy()
prompt = [] prompt = []
include_info = "info:" in input_string include_info = "info:" in input_string
@ -129,11 +158,6 @@ def parse_input(input_string):
else: else:
prompt.append(f"{key}:{value}") prompt.append(f"{key}:{value}")
# Adjust dimensions for specific model
if "Juggernaut-XL_v9_RunDiffusionPhoto_v2" in input_string:
payload["width"] = 823
payload["height"] = 1216
last_index = value_end_index last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip() payload["prompt"] = " ".join(prompt).strip()
@ -147,18 +171,13 @@ def parse_input(input_string):
return payload, include_info return payload, include_info
def create_caption(payload, user_name, user_id, info, include_info): def create_caption(payload, user_name, user_id, info, include_info):
"""Create a caption for the generated image."""
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"] prompt = payload["prompt"]
# Define a regular expression pattern to match the seed value
seed_pattern = r"Seed: (\d+)" seed_pattern = r"Seed: (\d+)"
# Search for the pattern in the info string
match = re.search(seed_pattern, info) match = re.search(seed_pattern, info)
# Check if a match was found and extract the seed value
if match: if match:
seed_value = match.group(1) seed_value = match.group(1)
caption += f"**{seed_value}**\n" caption += f"**{seed_value}**\n"
@ -177,15 +196,18 @@ def create_caption(payload, user_name, user_id, info, include_info):
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
"""Call the API with the provided payload."""
try: try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except requests.RequestException as e: except requests.RequestException as e:
print(f"API call failed: {e}") print(f"API call failed: {e}")
return None return {"error": str(e)}
def process_images(images, user_id, user_name): def process_images(images, user_id, user_name):
"""Process and save generated images."""
def generate_unique_name(): def generate_unique_name():
unique_id = str(uuid.uuid4())[:7] unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}" return f"{user_name}-{unique_id}"
@ -202,11 +224,16 @@ def process_images(images, user_id, user_name):
pnginfo.add_text("parameters", response2.json().get("info")) pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo) image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
# Save as JPG
jpg_path = f"{IMAGE_PATH}/{word}.jpg"
image.convert("RGB").save(jpg_path, "JPEG")
return word, response2.json().get("info") return word, response2.json().get("info")
@app.on_message(filters.command(["draw"])) @app.on_message(filters.command(["draw"]))
def draw(client, message): def draw(client, message):
"""Handle /draw command to generate images from text prompts."""
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >") message.reply_text("Format :\n/draw < text to image >\nng: < negative (optional) >\nsteps: < steps value (1-70, optional) >")
@ -214,7 +241,6 @@ def draw(client, message):
payload, include_info = parse_input(msgs[1]) payload, include_info = parse_input(msgs[1])
# Check if xds is used in the payload
if "xds" in msgs[1].lower(): if "xds" in msgs[1].lower():
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.") message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
return return
@ -222,25 +248,28 @@ def draw(client, message):
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload) r = call_api('sdapi/v1/txt2img', payload)
if r: if r and "images" in r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info) caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to generate image. Please try again later.") error_message = r.get("error", "Failed to generate image. Please try again later.")
message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
"""Handle /img command to generate images from existing images."""
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`") message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
return return
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("dont FAIL in life") message.reply_text("Don't FAIL in life")
return return
payload, include_info = parse_input(msgs[1]) payload, include_info = parse_input(msgs[1])
@ -254,19 +283,21 @@ def img2img(client, message):
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload) r = call_api('sdapi/v1/img2img', payload)
if r: if r and "images" in r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info) caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info, include_info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.jpg", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to process image. Please try again later.") error_message = r.get("error", "Failed to process image. Please try again later.")
message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["getmodels"])) @app.on_message(filters.command(["getmodels"]))
async def get_models(client, message): async def get_models(client, message):
"""Handle /getmodels command to list available models."""
try: try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status() response.raise_for_status()
@ -279,8 +310,10 @@ async def get_models(client, message):
except requests.RequestException as e: except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}") await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query() @app.on_callback_query()
async def process_callback(client, callback_query): async def process_callback(client, callback_query):
"""Process model selection from callback queries."""
sd_model_checkpoint = callback_query.data sd_model_checkpoint = callback_query.data
options = {"sd_model_checkpoint": sd_model_checkpoint} options = {"sd_model_checkpoint": sd_model_checkpoint}
@ -288,16 +321,19 @@ async def process_callback(client, callback_query):
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status() response.raise_for_status()
# Update the negative prompt based on the selected model
update_negative_prompt(sd_model_checkpoint) update_negative_prompt(sd_model_checkpoint)
update_resolution(sd_model_checkpoint)
update_cfg_scale(sd_model_checkpoint)
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}") await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e: except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}") print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"])) @app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message): async def info(client, message):
"""Provide information about the bot's commands and options."""
await message.reply_text(""" await message.reply_text("""
**Stable Diffusion Bot Commands and Options:** **Stable Diffusion Bot Commands and Options:**