get model
This commit is contained in:
parent
49c9f6337a
commit
ec900759a1
41
main.py
41
main.py
|
@ -39,17 +39,32 @@ model_negative_prompts = {
|
|||
"Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark"
|
||||
}
|
||||
|
||||
def get_current_model_name():
|
||||
try:
|
||||
response = requests.get(f"{SD_URL}/sdapi/v1/options")
|
||||
response.raise_for_status()
|
||||
options = response.json()
|
||||
current_model_name = options.get("sd_model_checkpoint", "Unknown")
|
||||
return current_model_name
|
||||
except requests.RequestException as e:
|
||||
print(f"API call failed: {e}")
|
||||
return None
|
||||
|
||||
# Fetch the current model name at the start
|
||||
current_model_name = get_current_model_name()
|
||||
if current_model_name:
|
||||
print(f"Current model name: {current_model_name}")
|
||||
else:
|
||||
print("Failed to fetch the current model name.")
|
||||
|
||||
def encode_file_to_base64(path):
|
||||
with open(path, 'rb') as file:
|
||||
return base64.b64encode(file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def decode_and_save_base64(base64_str, save_path):
|
||||
with open(save_path, "wb") as file:
|
||||
file.write(base64.b64decode(base64_str))
|
||||
|
||||
|
||||
# Set default payload values
|
||||
default_payload = {
|
||||
"prompt": "",
|
||||
|
@ -69,14 +84,12 @@ default_payload = {
|
|||
"override_settings_restore_afterwards": True,
|
||||
}
|
||||
|
||||
|
||||
def update_negative_prompt(model_name):
|
||||
"""Update the negative prompt for a given model."""
|
||||
if model_name in model_negative_prompts:
|
||||
suffix = model_negative_prompts[model_name]
|
||||
default_payload["negative_prompt"] += f", {suffix}"
|
||||
|
||||
|
||||
def update_resolution(model_name):
|
||||
"""Update resolution based on the selected model."""
|
||||
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
||||
|
@ -86,7 +99,6 @@ def update_resolution(model_name):
|
|||
default_payload["width"] = 512
|
||||
default_payload["height"] = 512
|
||||
|
||||
|
||||
def update_cfg_scale(model_name):
|
||||
"""Update CFG scale based on the selected model."""
|
||||
if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2":
|
||||
|
@ -94,6 +106,13 @@ def update_cfg_scale(model_name):
|
|||
else:
|
||||
default_payload["cfg_scale"] = 7
|
||||
|
||||
# Update configurations based on the current model name
|
||||
if current_model_name:
|
||||
update_negative_prompt(current_model_name)
|
||||
update_resolution(current_model_name)
|
||||
update_cfg_scale(current_model_name)
|
||||
else:
|
||||
print("Failed to update configurations as the current model name is not available.")
|
||||
|
||||
def parse_input(input_string):
|
||||
"""Parse the input string and create a payload."""
|
||||
|
@ -170,7 +189,6 @@ def parse_input(input_string):
|
|||
|
||||
return payload, include_info
|
||||
|
||||
|
||||
def create_caption(payload, user_name, user_id, info, include_info):
|
||||
"""Create a caption for the generated image."""
|
||||
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
||||
|
@ -194,7 +212,6 @@ def create_caption(payload, user_name, user_id, info, include_info):
|
|||
|
||||
return caption
|
||||
|
||||
|
||||
def call_api(api_endpoint, payload):
|
||||
"""Call the API with the provided payload."""
|
||||
try:
|
||||
|
@ -205,12 +222,12 @@ def call_api(api_endpoint, payload):
|
|||
print(f"API call failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def process_images(images, user_id, user_name):
|
||||
"""Process and save generated images."""
|
||||
def generate_unique_name():
|
||||
unique_id = str(uuid.uuid4())[:7]
|
||||
return f"{user_name}-{unique_id}"
|
||||
date = datetime.now().strftime("%Y-%m-%d-%H-%M")
|
||||
return f"{date}-{user_name}-{unique_id}"
|
||||
|
||||
word = generate_unique_name()
|
||||
|
||||
|
@ -230,7 +247,6 @@ def process_images(images, user_id, user_name):
|
|||
|
||||
return word, response2.json().get("info")
|
||||
|
||||
|
||||
@app.on_message(filters.command(["draw"]))
|
||||
def draw(client, message):
|
||||
"""Handle /draw command to generate images from text prompts."""
|
||||
|
@ -259,7 +275,6 @@ def draw(client, message):
|
|||
message.reply_text(error_message)
|
||||
K.delete()
|
||||
|
||||
|
||||
@app.on_message(filters.command(["img"]))
|
||||
def img2img(client, message):
|
||||
"""Handle /img command to generate images from existing images."""
|
||||
|
@ -294,7 +309,6 @@ def img2img(client, message):
|
|||
message.reply_text(error_message)
|
||||
K.delete()
|
||||
|
||||
|
||||
@app.on_message(filters.command(["getmodels"]))
|
||||
async def get_models(client, message):
|
||||
"""Handle /getmodels command to list available models."""
|
||||
|
@ -310,7 +324,6 @@ async def get_models(client, message):
|
|||
except requests.RequestException as e:
|
||||
await message.reply_text(f"Failed to get models: {e}")
|
||||
|
||||
|
||||
@app.on_callback_query()
|
||||
async def process_callback(client, callback_query):
|
||||
"""Process model selection from callback queries."""
|
||||
|
@ -330,7 +343,6 @@ async def process_callback(client, callback_query):
|
|||
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
||||
print(f"Error setting checkpoint: {e}")
|
||||
|
||||
|
||||
@app.on_message(filters.command(["info_sd_bot"]))
|
||||
async def info(client, message):
|
||||
"""Provide information about the bot's commands and options."""
|
||||
|
@ -391,5 +403,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC
|
|||
Enjoy creating with Stable Diffusion Bot!
|
||||
""", disable_web_page_preview=True)
|
||||
|
||||
|
||||
app.run()
|
||||
|
|
Loading…
Reference in New Issue
Block a user