get model
This commit is contained in:
parent
49c9f6337a
commit
ec900759a1
41
main.py
41
main.py
|
@ -39,17 +39,32 @@ model_negative_prompts = {
|
||||||
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
|
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_current_model_name():
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{SD_URL}/sdapi/v1/options")
|
||||||
|
response.raise_for_status()
|
||||||
|
options = response.json()
|
||||||
|
current_model_name = options.get("sd_model_checkpoint", "Unknown")
|
||||||
|
return current_model_name
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"API call failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch the current model name at the start
|
||||||
|
current_model_name = get_current_model_name()
|
||||||
|
if current_model_name:
|
||||||
|
print(f"Current model name: {current_model_name}")
|
||||||
|
else:
|
||||||
|
print("Failed to fetch the current model name.")
|
||||||
|
|
||||||
def encode_file_to_base64(path):
|
def encode_file_to_base64(path):
|
||||||
with open(path, 'rb') as file:
|
with open(path, 'rb') as file:
|
||||||
return base64.b64encode(file.read()).decode('utf-8')
|
return base64.b64encode(file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def decode_and_save_base64(base64_str, save_path):
|
def decode_and_save_base64(base64_str, save_path):
|
||||||
with open(save_path, "wb") as file:
|
with open(save_path, "wb") as file:
|
||||||
file.write(base64.b64decode(base64_str))
|
file.write(base64.b64decode(base64_str))
|
||||||
|
|
||||||
|
|
||||||
# Set default payload values
|
# Set default payload values
|
||||||
default_payload = {
|
default_payload = {
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
|
@ -69,14 +84,12 @@ default_payload = {
|
||||||
"override_settings_restore_afterwards": True,
|
"override_settings_restore_afterwards": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def update_negative_prompt(model_name):
|
def update_negative_prompt(model_name):
|
||||||
"""Update the negative prompt for a given model."""
|
"""Update the negative prompt for a given model."""
|
||||||
if model_name in model_negative_prompts:
|
if model_name in model_negative_prompts:
|
||||||
suffix = model_negative_prompts[model_name]
|
suffix = model_negative_prompts[model_name]
|
||||||
default_payload["negative_prompt"] += f", {suffix}"
|
default_payload["negative_prompt"] += f", {suffix}"
|
||||||
|
|
||||||
|
|
||||||
def update_resolution(model_name):
|
def update_resolution(model_name):
|
||||||
"""Update resolution based on the selected model."""
|
"""Update resolution based on the selected model."""
|
||||||
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
||||||
|
@ -86,7 +99,6 @@ def update_resolution(model_name):
|
||||||
default_payload["width"] = 512
|
default_payload["width"] = 512
|
||||||
default_payload["height"] = 512
|
default_payload["height"] = 512
|
||||||
|
|
||||||
|
|
||||||
def update_cfg_scale(model_name):
|
def update_cfg_scale(model_name):
|
||||||
"""Update CFG scale based on the selected model."""
|
"""Update CFG scale based on the selected model."""
|
||||||
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
||||||
|
@ -94,6 +106,13 @@ def update_cfg_scale(model_name):
|
||||||
else:
|
else:
|
||||||
default_payload["cfg_scale"] = 7
|
default_payload["cfg_scale"] = 7
|
||||||
|
|
||||||
|
# Update configurations based on the current model name
|
||||||
|
if current_model_name:
|
||||||
|
update_negative_prompt(current_model_name)
|
||||||
|
update_resolution(current_model_name)
|
||||||
|
update_cfg_scale(current_model_name)
|
||||||
|
else:
|
||||||
|
print("Failed to update configurations as the current model name is not available.")
|
||||||
|
|
||||||
def parse_input(input_string):
|
def parse_input(input_string):
|
||||||
"""Parse the input string and create a payload."""
|
"""Parse the input string and create a payload."""
|
||||||
|
@ -170,7 +189,6 @@ def parse_input(input_string):
|
||||||
|
|
||||||
return payload, include_info
|
return payload, include_info
|
||||||
|
|
||||||
|
|
||||||
def create_caption(payload, user_name, user_id, info, include_info):
|
def create_caption(payload, user_name, user_id, info, include_info):
|
||||||
"""Create a caption for the generated image."""
|
"""Create a caption for the generated image."""
|
||||||
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
||||||
|
@ -194,7 +212,6 @@ def create_caption(payload, user_name, user_id, info, include_info):
|
||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
|
||||||
def call_api(api_endpoint, payload):
|
def call_api(api_endpoint, payload):
|
||||||
"""Call the API with the provided payload."""
|
"""Call the API with the provided payload."""
|
||||||
try:
|
try:
|
||||||
|
@ -205,12 +222,12 @@ def call_api(api_endpoint, payload):
|
||||||
print(f"API call failed: {e}")
|
print(f"API call failed: {e}")
|
||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def process_images(images, user_id, user_name):
|
def process_images(images, user_id, user_name):
|
||||||
"""Process and save generated images."""
|
"""Process and save generated images."""
|
||||||
def generate_unique_name():
|
def generate_unique_name():
|
||||||
unique_id = str(uuid.uuid4())[:7]
|
unique_id = str(uuid.uuid4())[:7]
|
||||||
return f"{user_name}-{unique_id}"
|
date = datetime.now().strftime("%Y-%m-%d-%H-%M")
|
||||||
|
return f"{date}-{user_name}-{unique_id}"
|
||||||
|
|
||||||
word = generate_unique_name()
|
word = generate_unique_name()
|
||||||
|
|
||||||
|
@ -230,7 +247,6 @@ def process_images(images, user_id, user_name):
|
||||||
|
|
||||||
return word, response2.json().get("info")
|
return word, response2.json().get("info")
|
||||||
|
|
||||||
|
|
||||||
@app.on_message(filters.command(["draw"]))
|
@app.on_message(filters.command(["draw"]))
|
||||||
def draw(client, message):
|
def draw(client, message):
|
||||||
"""Handle /draw command to generate images from text prompts."""
|
"""Handle /draw command to generate images from text prompts."""
|
||||||
|
@ -259,7 +275,6 @@ def draw(client, message):
|
||||||
message.reply_text(error_message)
|
message.reply_text(error_message)
|
||||||
K.delete()
|
K.delete()
|
||||||
|
|
||||||
|
|
||||||
@app.on_message(filters.command(["img"]))
|
@app.on_message(filters.command(["img"]))
|
||||||
def img2img(client, message):
|
def img2img(client, message):
|
||||||
"""Handle /img command to generate images from existing images."""
|
"""Handle /img command to generate images from existing images."""
|
||||||
|
@ -294,7 +309,6 @@ def img2img(client, message):
|
||||||
message.reply_text(error_message)
|
message.reply_text(error_message)
|
||||||
K.delete()
|
K.delete()
|
||||||
|
|
||||||
|
|
||||||
@app.on_message(filters.command(["getmodels"]))
|
@app.on_message(filters.command(["getmodels"]))
|
||||||
async def get_models(client, message):
|
async def get_models(client, message):
|
||||||
"""Handle /getmodels command to list available models."""
|
"""Handle /getmodels command to list available models."""
|
||||||
|
@ -310,7 +324,6 @@ async def get_models(client, message):
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
await message.reply_text(f"Failed to get models: {e}")
|
await message.reply_text(f"Failed to get models: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.on_callback_query()
|
@app.on_callback_query()
|
||||||
async def process_callback(client, callback_query):
|
async def process_callback(client, callback_query):
|
||||||
"""Process model selection from callback queries."""
|
"""Process model selection from callback queries."""
|
||||||
|
@ -330,7 +343,6 @@ async def process_callback(client, callback_query):
|
||||||
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
||||||
print(f"Error setting checkpoint: {e}")
|
print(f"Error setting checkpoint: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.on_message(filters.command(["info_sd_bot"]))
|
@app.on_message(filters.command(["info_sd_bot"]))
|
||||||
async def info(client, message):
|
async def info(client, message):
|
||||||
"""Provide information about the bot's commands and options."""
|
"""Provide information about the bot's commands and options."""
|
||||||
|
@ -391,5 +403,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC
|
||||||
Enjoy creating with Stable Diffusion Bot!
|
Enjoy creating with Stable Diffusion Bot!
|
||||||
""", disable_web_page_preview=True)
|
""", disable_web_page_preview=True)
|
||||||
|
|
||||||
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user