Compare commits

...

2 Commits

Author SHA1 Message Date
tami-p40
9f8bff3540 cfg for XL 2024-06-09 10:39:34 +03:00
tami-p40
ec900759a1 get model 2024-06-03 14:14:25 +03:00
2 changed files with 35 additions and 18 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ venv/
*.session-journal
logs/stable_diff_telegram_bot.log
*.session
images/

52
main.py
View File

@ -3,6 +3,7 @@ import re
import io
import uuid
import base64
import json
import requests
from datetime import datetime
from PIL import Image, PngImagePlugin
@ -39,17 +40,32 @@ model_negative_prompts = {
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
}
def get_current_model_name():
try:
response = requests.get(f"{SD_URL}/sdapi/v1/options")
response.raise_for_status()
options = response.json()
current_model_name = options.get("sd_model_checkpoint", "Unknown")
return current_model_name
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
# Fetch the current model name at the start
current_model_name = get_current_model_name()
if current_model_name:
print(f"Current model name: {current_model_name}")
else:
print("Failed to fetch the current model name.")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
# Set default payload values
default_payload = {
"prompt": "",
@ -69,14 +85,12 @@ default_payload = {
"override_settings_restore_afterwards": True,
}
def update_negative_prompt(model_name):
"""Update the negative prompt for a given model."""
if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}"
def update_resolution(model_name):
"""Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
@ -86,14 +100,20 @@ def update_resolution(model_name):
default_payload["width"] = 512
default_payload["height"] = 512
def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["cfg_scale"] = 1
default_payload["cfg_scale"] = 2.5
else:
default_payload["cfg_scale"] = 7
# Update configurations based on the current model name
if current_model_name:
update_negative_prompt(current_model_name)
update_resolution(current_model_name)
update_cfg_scale(current_model_name)
else:
print("Failed to update configurations as the current model name is not available.")
def parse_input(input_string):
"""Parse the input string and create a payload."""
@ -170,7 +190,6 @@ def parse_input(input_string):
return payload, include_info
def create_caption(payload, user_name, user_id, info, include_info):
"""Create a caption for the generated image."""
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
@ -194,7 +213,6 @@ def create_caption(payload, user_name, user_id, info, include_info):
return caption
def call_api(api_endpoint, payload):
"""Call the API with the provided payload."""
try:
@ -205,12 +223,12 @@ def call_api(api_endpoint, payload):
print(f"API call failed: {e}")
return {"error": str(e)}
def process_images(images, user_id, user_name):
"""Process and save generated images."""
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
date = datetime.now().strftime("%Y-%m-%d-%H-%M")
return f"{date}-{user_name}-{unique_id}"
word = generate_unique_name()
@ -219,7 +237,11 @@ def process_images(images, user_id, user_name):
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
# Write response2 json next to the image
with open(f"{IMAGE_PATH}/{word}.json", "w") as json_file:
json.dump(response2.json(), json_file)
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
@ -227,9 +249,8 @@ def process_images(images, user_id, user_name):
# Save as JPG
jpg_path = f"{IMAGE_PATH}/{word}.jpg"
image.convert("RGB").save(jpg_path, "JPEG")
return word, response2.json().get("info")
return word, response2.json().get("info")
@app.on_message(filters.command(["draw"]))
def draw(client, message):
@ -259,7 +280,6 @@ def draw(client, message):
message.reply_text(error_message)
K.delete()
@app.on_message(filters.command(["img"]))
def img2img(client, message):
"""Handle /img command to generate images from existing images."""
@ -294,7 +314,6 @@ def img2img(client, message):
message.reply_text(error_message)
K.delete()
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
"""Handle /getmodels command to list available models."""
@ -310,7 +329,6 @@ async def get_models(client, message):
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query()
async def process_callback(client, callback_query):
"""Process model selection from callback queries."""
@ -330,7 +348,6 @@ async def process_callback(client, callback_query):
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
"""Provide information about the bot's commands and options."""
@ -391,5 +408,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC
Enjoy creating with Stable Diffusion Bot!
""", disable_web_page_preview=True)
app.run()