Compare commits

...

2 Commits

Author SHA1 Message Date
tami-p40
9f8bff3540 cfg for XL 2024-06-09 10:39:34 +03:00
tami-p40
ec900759a1 get model 2024-06-03 14:14:25 +03:00
2 changed files with 35 additions and 18 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ venv/
*.session-journal *.session-journal
logs/stable_diff_telegram_bot.log logs/stable_diff_telegram_bot.log
*.session *.session
images/

52
main.py
View File

@ -3,6 +3,7 @@ import re
import io import io
import uuid import uuid
import base64 import base64
import json
import requests import requests
from datetime import datetime from datetime import datetime
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
@ -39,17 +40,32 @@ model_negative_prompts = {
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark" "Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
} }
def get_current_model_name():
try:
response = requests.get(f"{SD_URL}/sdapi/v1/options")
response.raise_for_status()
options = response.json()
current_model_name = options.get("sd_model_checkpoint", "Unknown")
return current_model_name
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
# Fetch the current model name at the start
current_model_name = get_current_model_name()
if current_model_name:
print(f"Current model name: {current_model_name}")
else:
print("Failed to fetch the current model name.")
def encode_file_to_base64(path): def encode_file_to_base64(path):
with open(path, 'rb') as file: with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8') return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path): def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file: with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str)) file.write(base64.b64decode(base64_str))
# Set default payload values # Set default payload values
default_payload = { default_payload = {
"prompt": "", "prompt": "",
@ -69,14 +85,12 @@ default_payload = {
"override_settings_restore_afterwards": True, "override_settings_restore_afterwards": True,
} }
def update_negative_prompt(model_name): def update_negative_prompt(model_name):
"""Update the negative prompt for a given model.""" """Update the negative prompt for a given model."""
if model_name in model_negative_prompts: if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name] suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}" default_payload["negative_prompt"] += f", {suffix}"
def update_resolution(model_name): def update_resolution(model_name):
"""Update resolution based on the selected model.""" """Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
@ -86,14 +100,20 @@ def update_resolution(model_name):
default_payload["width"] = 512 default_payload["width"] = 512
default_payload["height"] = 512 default_payload["height"] = 512
def update_cfg_scale(model_name): def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model.""" """Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
default_payload["cfg_scale"] = 1 default_payload["cfg_scale"] = 2.5
else: else:
default_payload["cfg_scale"] = 7 default_payload["cfg_scale"] = 7
# Update configurations based on the current model name
if current_model_name:
update_negative_prompt(current_model_name)
update_resolution(current_model_name)
update_cfg_scale(current_model_name)
else:
print("Failed to update configurations as the current model name is not available.")
def parse_input(input_string): def parse_input(input_string):
"""Parse the input string and create a payload.""" """Parse the input string and create a payload."""
@ -170,7 +190,6 @@ def parse_input(input_string):
return payload, include_info return payload, include_info
def create_caption(payload, user_name, user_id, info, include_info): def create_caption(payload, user_name, user_id, info, include_info):
"""Create a caption for the generated image.""" """Create a caption for the generated image."""
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
@ -194,7 +213,6 @@ def create_caption(payload, user_name, user_id, info, include_info):
return caption return caption
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
"""Call the API with the provided payload.""" """Call the API with the provided payload."""
try: try:
@ -205,12 +223,12 @@ def call_api(api_endpoint, payload):
print(f"API call failed: {e}") print(f"API call failed: {e}")
return {"error": str(e)} return {"error": str(e)}
def process_images(images, user_id, user_name): def process_images(images, user_id, user_name):
"""Process and save generated images.""" """Process and save generated images."""
def generate_unique_name(): def generate_unique_name():
unique_id = str(uuid.uuid4())[:7] unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}" date = datetime.now().strftime("%Y-%m-%d-%H-%M")
return f"{date}-{user_name}-{unique_id}"
word = generate_unique_name() word = generate_unique_name()
@ -219,7 +237,11 @@ def process_images(images, user_id, user_name):
png_payload = {"image": "data:image/png;base64," + i} png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload) response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status() response2.raise_for_status()
# Write response2 json next to the image
with open(f"{IMAGE_PATH}/{word}.json", "w") as json_file:
json.dump(response2.json(), json_file)
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", response2.json().get("info")) pnginfo.add_text("parameters", response2.json().get("info"))
image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo) image.save(f"{IMAGE_PATH}/{word}.png", pnginfo=pnginfo)
@ -227,9 +249,8 @@ def process_images(images, user_id, user_name):
# Save as JPG # Save as JPG
jpg_path = f"{IMAGE_PATH}/{word}.jpg" jpg_path = f"{IMAGE_PATH}/{word}.jpg"
image.convert("RGB").save(jpg_path, "JPEG") image.convert("RGB").save(jpg_path, "JPEG")
return word, response2.json().get("info")
return word, response2.json().get("info")
@app.on_message(filters.command(["draw"])) @app.on_message(filters.command(["draw"]))
def draw(client, message): def draw(client, message):
@ -259,7 +280,6 @@ def draw(client, message):
message.reply_text(error_message) message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
"""Handle /img command to generate images from existing images.""" """Handle /img command to generate images from existing images."""
@ -294,7 +314,6 @@ def img2img(client, message):
message.reply_text(error_message) message.reply_text(error_message)
K.delete() K.delete()
@app.on_message(filters.command(["getmodels"])) @app.on_message(filters.command(["getmodels"]))
async def get_models(client, message): async def get_models(client, message):
"""Handle /getmodels command to list available models.""" """Handle /getmodels command to list available models."""
@ -310,7 +329,6 @@ async def get_models(client, message):
except requests.RequestException as e: except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}") await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query() @app.on_callback_query()
async def process_callback(client, callback_query): async def process_callback(client, callback_query):
"""Process model selection from callback queries.""" """Process model selection from callback queries."""
@ -330,7 +348,6 @@ async def process_callback(client, callback_query):
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}") print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"])) @app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message): async def info(client, message):
"""Provide information about the bot's commands and options.""" """Provide information about the bot's commands and options."""
@ -391,5 +408,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC
Enjoy creating with Stable Diffusion Bot! Enjoy creating with Stable Diffusion Bot!
""", disable_web_page_preview=True) """, disable_web_page_preview=True)
app.run() app.run()