Compare commits
2 Commits
3fa661a54c
...
ca1c34b91b
Author | SHA1 | Date | |
---|---|---|---|
|
ca1c34b91b | ||
|
de2badd5a0 |
124
main.py
124
main.py
|
@ -23,22 +23,35 @@ IMAGE_PATH = 'images'
|
|||
# Ensure IMAGE_PATH directory exists
|
||||
os.makedirs(IMAGE_PATH, exist_ok=True)
|
||||
|
||||
def timestamp():
|
||||
return datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
# Model-specific emmbedings for negative prompts
|
||||
# see civit.ai model page for specific emmbedings recommnded for each model
|
||||
model_negative_prompts = {
|
||||
"Anything-Diffusion": "",
|
||||
"Deliberate": "",
|
||||
"Dreamshaper": "",
|
||||
"DreamShaperXL_Lightning": "",
|
||||
"icbinp": "",
|
||||
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
|
||||
"v1-5-pruned-emaonly": ""
|
||||
}
|
||||
|
||||
|
||||
def encode_file_to_base64(path):
|
||||
with open(path, 'rb') as file:
|
||||
return base64.b64encode(file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def decode_and_save_base64(base64_str, save_path):
|
||||
with open(save_path, "wb") as file:
|
||||
file.write(base64.b64decode(base64_str))
|
||||
|
||||
def parse_input(input_string):
|
||||
|
||||
|
||||
# Set default payload values
|
||||
default_payload = {
|
||||
default_payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "ugly, bad face, distorted",
|
||||
"seed": -1, # Random seed
|
||||
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
|
||||
"enable_hr": False,
|
||||
"Sampler": "DPM++ SDE Karras",
|
||||
"denoising_strength": 0.35,
|
||||
|
@ -52,16 +65,29 @@ def parse_input(input_string):
|
|||
"override_settings": {},
|
||||
"override_settings_restore_afterwards": True,
|
||||
}
|
||||
|
||||
def update_negative_prompt(model_name):
|
||||
if model_name in model_negative_prompts:
|
||||
suffix = model_negative_prompts[model_name]
|
||||
default_payload["negative_prompt"] += f", {suffix}"
|
||||
|
||||
def parse_input(input_string):
|
||||
payload = default_payload.copy()
|
||||
prompt = []
|
||||
|
||||
matches = re.finditer(r"(\w+):", input_string)
|
||||
last_index = 0
|
||||
|
||||
# Initialize script_args with default values and placeholder enums
|
||||
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
||||
script_name = None
|
||||
|
||||
|
||||
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
|
||||
script_name = None
|
||||
|
||||
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
|
||||
slot_index = 0
|
||||
|
||||
for match in matches:
|
||||
key = match.group(1).lower()
|
||||
value_start_index = match.end()
|
||||
|
@ -78,27 +104,34 @@ def parse_input(input_string):
|
|||
key = "denoising_strength"
|
||||
if key == "ng":
|
||||
key = "negative_prompt"
|
||||
if key == "cfg":
|
||||
key = "cfg_scale"
|
||||
|
||||
if key in default_payload:
|
||||
payload[key] = value
|
||||
elif key in ["xsr", "xsteps", "xds"]:
|
||||
elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
|
||||
script_name = "x/y/z plot"
|
||||
if key == "xsr":
|
||||
script_args[0] = 7 # Enum value for xsr
|
||||
script_args[1] = value
|
||||
elif key == "xsteps":
|
||||
try:
|
||||
steps_values = [int(x) for x in value.split(',')]
|
||||
if all(1 <= x <= 70 for x in steps_values):
|
||||
script_args[3] = 4 # Enum value for xsteps
|
||||
script_args[4] = value
|
||||
else:
|
||||
raise ValueError("xsteps values must be between 1 and 70.")
|
||||
except ValueError:
|
||||
raise ValueError("xsteps must contain only integers.")
|
||||
elif key == "xds":
|
||||
script_args[6] = 22 # Enum value for xds
|
||||
script_args[7] = value
|
||||
if slot_index < 3:
|
||||
script_slot = slot_mapping[slot_index]
|
||||
if key == "xsr":
|
||||
script_args[script_slot[0]] = 7 # Enum value for xsr
|
||||
script_args[script_slot[1]] = value
|
||||
elif key == "xsteps":
|
||||
script_args[script_slot[0]] = 4 # Enum value for xsteps
|
||||
script_args[script_slot[1]] = value
|
||||
elif key == "xds":
|
||||
script_args[script_slot[0]] = 22 # Enum value for xds
|
||||
script_args[script_slot[1]] = value
|
||||
elif key == "xcfg":
|
||||
script_args[script_slot[0]] = 6 # Enum value for CFG Scale
|
||||
script_args[script_slot[1]] = value
|
||||
slot_index += 1
|
||||
elif key == "nl":
|
||||
script_args[9] = False # Draw legend
|
||||
elif key == "ks":
|
||||
script_args[10] = True # Keep sub images
|
||||
elif key == "rs":
|
||||
script_args[11] = True # Set random seed to sub images
|
||||
else:
|
||||
prompt.append(f"{key}:{value}")
|
||||
|
||||
|
@ -118,6 +151,24 @@ def parse_input(input_string):
|
|||
def create_caption(payload, user_name, user_id, info):
|
||||
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
|
||||
prompt = payload["prompt"]
|
||||
print(payload["prompt"])
|
||||
print(info)
|
||||
# Steps: 3, Sampler: Euler, CFG scale: 7.0, Seed: 4094161400, Size: 512x512, Model hash: 15012c538f, Model: realisticVisionV60B1_v51VAE, Denoising strength: 0.35, Version: v1.8.0-1-g20cdc7c
|
||||
|
||||
# Define a regular expression pattern to match the seed value
|
||||
seed_pattern = r"Seed: (\d+)"
|
||||
|
||||
# Search for the pattern in the info string
|
||||
match = re.search(seed_pattern, info)
|
||||
|
||||
# Check if a match was found and extract the seed value
|
||||
if match:
|
||||
seed_value = match.group(1)
|
||||
print(f"Seed value: {seed_value}")
|
||||
caption += f"**{seed_value}**\n"
|
||||
else:
|
||||
print("Seed value not found in the info string.")
|
||||
|
||||
caption += f"**{prompt}**\n"
|
||||
|
||||
if len(caption) > 1024:
|
||||
|
@ -125,6 +176,7 @@ def create_caption(payload, user_name, user_id, info):
|
|||
|
||||
return caption
|
||||
|
||||
|
||||
def call_api(api_endpoint, payload):
|
||||
try:
|
||||
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
|
||||
|
@ -134,6 +186,7 @@ def call_api(api_endpoint, payload):
|
|||
print(f"API call failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_images(images, user_id, user_name):
|
||||
def generate_unique_name():
|
||||
unique_id = str(uuid.uuid4())[:7]
|
||||
|
@ -163,6 +216,11 @@ def draw(client, message):
|
|||
payload = parse_input(msgs[1])
|
||||
print(payload)
|
||||
|
||||
# Check if xds is used in the payload
|
||||
if "xds" in msgs[1].lower():
|
||||
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
|
||||
return
|
||||
|
||||
K = message.reply_text("Please Wait 10-15 Seconds")
|
||||
r = call_api('sdapi/v1/txt2img', payload)
|
||||
|
||||
|
@ -176,6 +234,7 @@ def draw(client, message):
|
|||
message.reply_text("Failed to generate image. Please try again later.")
|
||||
K.delete()
|
||||
|
||||
|
||||
@app.on_message(filters.command(["img"]))
|
||||
def img2img(client, message):
|
||||
if not message.reply_to_message or not message.reply_to_message.photo:
|
||||
|
@ -188,7 +247,11 @@ def img2img(client, message):
|
|||
return
|
||||
|
||||
payload = parse_input(msgs[1])
|
||||
print(f"input:\n{payload}")
|
||||
photo = message.reply_to_message.photo
|
||||
# prompt_from_reply = message.reply_to_message.
|
||||
# orginal_prompt = app.reply_to_message.message
|
||||
# print(orginal_prompt)
|
||||
photo_file = app.download_media(photo)
|
||||
init_image = encode_file_to_base64(photo_file)
|
||||
os.remove(photo_file) # Clean up downloaded image file
|
||||
|
@ -208,13 +271,14 @@ def img2img(client, message):
|
|||
message.reply_text("Failed to process image. Please try again later.")
|
||||
K.delete()
|
||||
|
||||
|
||||
@app.on_message(filters.command(["getmodels"]))
|
||||
async def get_models(client, message):
|
||||
try:
|
||||
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
|
||||
response.raise_for_status()
|
||||
models_json = response.json()
|
||||
|
||||
print(models_json)
|
||||
buttons = [
|
||||
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
|
||||
for model in models_json
|
||||
|
@ -223,6 +287,7 @@ async def get_models(client, message):
|
|||
except requests.RequestException as e:
|
||||
await message.reply_text(f"Failed to get models: {e}")
|
||||
|
||||
|
||||
@app.on_callback_query()
|
||||
async def process_callback(client, callback_query):
|
||||
sd_model_checkpoint = callback_query.data
|
||||
|
@ -231,9 +296,15 @@ async def process_callback(client, callback_query):
|
|||
try:
|
||||
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
|
||||
response.raise_for_status()
|
||||
|
||||
# Update the negative prompt based on the selected model
|
||||
update_negative_prompt(sd_model_checkpoint)
|
||||
|
||||
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
|
||||
except requests.RequestException as e:
|
||||
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
|
||||
print(f"Error setting checkpoint: {e}")
|
||||
|
||||
|
||||
@app.on_message(filters.command(["info_sd_bot"]))
|
||||
async def info(client, message):
|
||||
|
@ -244,6 +315,9 @@ currently supported
|
|||
`xds` - denoise strength, only valid for img2img
|
||||
`xsteps` - steps
|
||||
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
|
||||
""")
|
||||
|
||||
aside from that you can use the usual `ng`, `ds`, `cfg`, `steps` for single image generation.
|
||||
""", disable_web_page_preview=True)
|
||||
|
||||
|
||||
app.run()
|
||||
|
|
Loading…
Reference in New Issue
Block a user