get model

This commit is contained in:
tami-p40 2024-06-03 14:14:25 +03:00
parent 49c9f6337a
commit ec900759a1

41
main.py
View File

@ -39,17 +39,32 @@ model_negative_prompts = {
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
}
def get_current_model_name():
try:
response = requests.get(f"{SD_URL}/sdapi/v1/options")
response.raise_for_status()
options = response.json()
current_model_name = options.get("sd_model_checkpoint", "Unknown")
return current_model_name
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
# Fetch the current model name at the start
current_model_name = get_current_model_name()
if current_model_name:
print(f"Current model name: {current_model_name}")
else:
print("Failed to fetch the current model name.")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
# Set default payload values
default_payload = {
"prompt": "",
@ -69,14 +84,12 @@ default_payload = {
"override_settings_restore_afterwards": True,
}
def update_negative_prompt(model_name):
"""Update the negative prompt for a given model."""
if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}"
def update_resolution(model_name):
"""Update resolution based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
@ -86,7 +99,6 @@ def update_resolution(model_name):
default_payload["width"] = 512
default_payload["height"] = 512
def update_cfg_scale(model_name):
"""Update CFG scale based on the selected model."""
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
@ -94,6 +106,13 @@ def update_cfg_scale(model_name):
else:
default_payload["cfg_scale"] = 7
# Update configurations based on the current model name
if current_model_name:
update_negative_prompt(current_model_name)
update_resolution(current_model_name)
update_cfg_scale(current_model_name)
else:
print("Failed to update configurations as the current model name is not available.")
def parse_input(input_string):
"""Parse the input string and create a payload."""
@ -170,7 +189,6 @@ def parse_input(input_string):
return payload, include_info
def create_caption(payload, user_name, user_id, info, include_info):
"""Create a caption for the generated image."""
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
@ -194,7 +212,6 @@ def create_caption(payload, user_name, user_id, info, include_info):
return caption
def call_api(api_endpoint, payload):
"""Call the API with the provided payload."""
try:
@ -205,12 +222,12 @@ def call_api(api_endpoint, payload):
print(f"API call failed: {e}")
return {"error": str(e)}
def process_images(images, user_id, user_name):
"""Process and save generated images."""
def generate_unique_name():
unique_id = str(uuid.uuid4())[:7]
return f"{user_name}-{unique_id}"
date = datetime.now().strftime("%Y-%m-%d-%H-%M")
return f"{date}-{user_name}-{unique_id}"
word = generate_unique_name()
@ -230,7 +247,6 @@ def process_images(images, user_id, user_name):
return word, response2.json().get("info")
@app.on_message(filters.command(["draw"]))
def draw(client, message):
"""Handle /draw command to generate images from text prompts."""
@ -259,7 +275,6 @@ def draw(client, message):
message.reply_text(error_message)
K.delete()
@app.on_message(filters.command(["img"]))
def img2img(client, message):
"""Handle /img command to generate images from existing images."""
@ -294,7 +309,6 @@ def img2img(client, message):
message.reply_text(error_message)
K.delete()
@app.on_message(filters.command(["getmodels"]))
async def get_models(client, message):
"""Handle /getmodels command to list available models."""
@ -310,7 +324,6 @@ async def get_models(client, message):
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query()
async def process_callback(client, callback_query):
"""Process model selection from callback queries."""
@ -330,7 +343,6 @@ async def process_callback(client, callback_query):
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
"""Provide information about the bot's commands and options."""
@ -391,5 +403,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC
Enjoy creating with Stable Diffusion Bot!
""", disable_web_page_preview=True)
app.run()