diff --git a/main.py b/main.py index c405483..57db5ed 100644 --- a/main.py +++ b/main.py @@ -39,17 +39,32 @@ model_negative_prompts = { "Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark" } +def get_current_model_name(): + try: + response = requests.get(f"{SD_URL}/sdapi/v1/options") + response.raise_for_status() + options = response.json() + current_model_name = options.get("sd_model_checkpoint", "Unknown") + return current_model_name + except requests.RequestException as e: + print(f"API call failed: {e}") + return None + +# Fetch the current model name at the start +current_model_name = get_current_model_name() +if current_model_name: + print(f"Current model name: {current_model_name}") +else: + print("Failed to fetch the current model name.") def encode_file_to_base64(path): with open(path, 'rb') as file: return base64.b64encode(file.read()).decode('utf-8') - def decode_and_save_base64(base64_str, save_path): with open(save_path, "wb") as file: file.write(base64.b64decode(base64_str)) - # Set default payload values default_payload = { "prompt": "", @@ -69,14 +84,12 @@ default_payload = { "override_settings_restore_afterwards": True, } - def update_negative_prompt(model_name): """Update the negative prompt for a given model.""" if model_name in model_negative_prompts: suffix = model_negative_prompts[model_name] default_payload["negative_prompt"] += f", {suffix}" - def update_resolution(model_name): """Update resolution based on the selected model.""" if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": @@ -86,7 +99,6 @@ def update_resolution(model_name): default_payload["width"] = 512 default_payload["height"] = 512 - def update_cfg_scale(model_name): """Update CFG scale based on the selected model.""" if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": @@ -94,6 +106,13 @@ def update_cfg_scale(model_name): else: default_payload["cfg_scale"] = 7 +# Update configurations based on the current model name +if current_model_name: + update_negative_prompt(current_model_name) + update_resolution(current_model_name) + update_cfg_scale(current_model_name) +else: + print("Failed to update configurations as the current model name is not available.") def parse_input(input_string): """Parse the input string and create a payload.""" @@ -170,7 +189,6 @@ def parse_input(input_string): return payload, include_info - def create_caption(payload, user_name, user_id, info, include_info): """Create a caption for the generated image.""" caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" @@ -194,7 +212,6 @@ def create_caption(payload, user_name, user_id, info, include_info): return caption - def call_api(api_endpoint, payload): """Call the API with the provided payload.""" try: @@ -205,12 +222,12 @@ def call_api(api_endpoint, payload): print(f"API call failed: {e}") return {"error": str(e)} - def process_images(images, user_id, user_name): """Process and save generated images.""" def generate_unique_name(): unique_id = str(uuid.uuid4())[:7] - return f"{user_name}-{unique_id}" + date = datetime.now().strftime("%Y-%m-%d-%H-%M") + return f"{date}-{user_name}-{unique_id}" word = generate_unique_name() @@ -227,9 +244,8 @@ def process_images(images, user_id, user_name): # Save as JPG jpg_path = f"{IMAGE_PATH}/{word}.jpg" image.convert("RGB").save(jpg_path, "JPEG") - - return word, response2.json().get("info") + return word, response2.json().get("info") @app.on_message(filters.command(["draw"])) def draw(client, message): @@ -259,7 +275,6 @@ def draw(client, message): message.reply_text(error_message) K.delete() - @app.on_message(filters.command(["img"])) def img2img(client, message): """Handle /img command to generate images from existing images.""" @@ -294,7 +309,6 @@ def img2img(client, message): message.reply_text(error_message) K.delete() - @app.on_message(filters.command(["getmodels"])) async def get_models(client, message): """Handle /getmodels command to list available models.""" @@ -310,7 +324,6 @@ async def get_models(client, message): except requests.RequestException as e: await message.reply_text(f"Failed to get models: {e}") - @app.on_callback_query() async def process_callback(client, callback_query): """Process model selection from callback queries.""" @@ -330,7 +343,6 @@ async def process_callback(client, callback_query): await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") print(f"Error setting checkpoint: {e}") - @app.on_message(filters.command(["info_sd_bot"])) async def info(client, message): """Provide information about the bot's commands and options.""" @@ -391,5 +403,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC Enjoy creating with Stable Diffusion Bot! """, disable_web_page_preview=True) - app.run()