Compare commits

..

2 Commits

Author SHA1 Message Date
tami-p40
ca1c34b91b negative emmbedings 2024-05-26 08:49:27 +03:00
tami-p40
de2badd5a0 xcfg 2024-05-20 10:29:34 +03:00

124
main.py
View File

@ -23,22 +23,35 @@ IMAGE_PATH = 'images'
# Ensure IMAGE_PATH directory exists # Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True) os.makedirs(IMAGE_PATH, exist_ok=True)
def timestamp(): # Model-specific emmbedings for negative prompts
return datetime.now().strftime("%Y%m%d-%H%M%S") # see civit.ai model page for specific emmbedings recommnded for each model
model_negative_prompts = {
"Anything-Diffusion": "",
"Deliberate": "",
"Dreamshaper": "",
"DreamShaperXL_Lightning": "",
"icbinp": "",
"realisticVisionV60B1_v51VAE": "realisticvision-negative-embedding",
"v1-5-pruned-emaonly": ""
}
def encode_file_to_base64(path): def encode_file_to_base64(path):
with open(path, 'rb') as file: with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8') return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path): def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file: with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str)) file.write(base64.b64decode(base64_str))
def parse_input(input_string):
# Set default payload values # Set default payload values
default_payload = { default_payload = {
"prompt": "", "prompt": "",
"negative_prompt": "ugly, bad face, distorted", "seed": -1, # Random seed
"negative_prompt": "extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs",
"enable_hr": False, "enable_hr": False,
"Sampler": "DPM++ SDE Karras", "Sampler": "DPM++ SDE Karras",
"denoising_strength": 0.35, "denoising_strength": 0.35,
@ -52,16 +65,29 @@ def parse_input(input_string):
"override_settings": {}, "override_settings": {},
"override_settings_restore_afterwards": True, "override_settings_restore_afterwards": True,
} }
def update_negative_prompt(model_name):
if model_name in model_negative_prompts:
suffix = model_negative_prompts[model_name]
default_payload["negative_prompt"] += f", {suffix}"
def parse_input(input_string):
payload = default_payload.copy() payload = default_payload.copy()
prompt = [] prompt = []
matches = re.finditer(r"(\w+):", input_string) matches = re.finditer(r"(\w+):", input_string)
last_index = 0 last_index = 0
# Initialize script_args with default values and placeholder enums
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False] script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None script_name = None
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
slot_mapping = {0: (0, 1), 1: (3, 4), 2: (6, 7)}
slot_index = 0
for match in matches: for match in matches:
key = match.group(1).lower() key = match.group(1).lower()
value_start_index = match.end() value_start_index = match.end()
@ -78,27 +104,34 @@ def parse_input(input_string):
key = "denoising_strength" key = "denoising_strength"
if key == "ng": if key == "ng":
key = "negative_prompt" key = "negative_prompt"
if key == "cfg":
key = "cfg_scale"
if key in default_payload: if key in default_payload:
payload[key] = value payload[key] = value
elif key in ["xsr", "xsteps", "xds"]: elif key in ["xsr", "xsteps", "xds", "xcfg", "nl", "ks", "rs"]:
script_name = "x/y/z plot" script_name = "x/y/z plot"
if key == "xsr": if slot_index < 3:
script_args[0] = 7 # Enum value for xsr script_slot = slot_mapping[slot_index]
script_args[1] = value if key == "xsr":
elif key == "xsteps": script_args[script_slot[0]] = 7 # Enum value for xsr
try: script_args[script_slot[1]] = value
steps_values = [int(x) for x in value.split(',')] elif key == "xsteps":
if all(1 <= x <= 70 for x in steps_values): script_args[script_slot[0]] = 4 # Enum value for xsteps
script_args[3] = 4 # Enum value for xsteps script_args[script_slot[1]] = value
script_args[4] = value elif key == "xds":
else: script_args[script_slot[0]] = 22 # Enum value for xds
raise ValueError("xsteps values must be between 1 and 70.") script_args[script_slot[1]] = value
except ValueError: elif key == "xcfg":
raise ValueError("xsteps must contain only integers.") script_args[script_slot[0]] = 6 # Enum value for CFG Scale
elif key == "xds": script_args[script_slot[1]] = value
script_args[6] = 22 # Enum value for xds slot_index += 1
script_args[7] = value elif key == "nl":
script_args[9] = False # Draw legend
elif key == "ks":
script_args[10] = True # Keep sub images
elif key == "rs":
script_args[11] = True # Set random seed to sub images
else: else:
prompt.append(f"{key}:{value}") prompt.append(f"{key}:{value}")
@ -118,6 +151,24 @@ def parse_input(input_string):
def create_caption(payload, user_name, user_id, info): def create_caption(payload, user_name, user_id, info):
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"] prompt = payload["prompt"]
print(payload["prompt"])
print(info)
# Steps: 3, Sampler: Euler, CFG scale: 7.0, Seed: 4094161400, Size: 512x512, Model hash: 15012c538f, Model: realisticVisionV60B1_v51VAE, Denoising strength: 0.35, Version: v1.8.0-1-g20cdc7c
# Define a regular expression pattern to match the seed value
seed_pattern = r"Seed: (\d+)"
# Search for the pattern in the info string
match = re.search(seed_pattern, info)
# Check if a match was found and extract the seed value
if match:
seed_value = match.group(1)
print(f"Seed value: {seed_value}")
caption += f"**{seed_value}**\n"
else:
print("Seed value not found in the info string.")
caption += f"**{prompt}**\n" caption += f"**{prompt}**\n"
if len(caption) > 1024: if len(caption) > 1024:
@ -125,6 +176,7 @@ def create_caption(payload, user_name, user_id, info):
return caption return caption
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
try: try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
@ -134,6 +186,7 @@ def call_api(api_endpoint, payload):
print(f"API call failed: {e}") print(f"API call failed: {e}")
return None return None
def process_images(images, user_id, user_name): def process_images(images, user_id, user_name):
def generate_unique_name(): def generate_unique_name():
unique_id = str(uuid.uuid4())[:7] unique_id = str(uuid.uuid4())[:7]
@ -163,6 +216,11 @@ def draw(client, message):
payload = parse_input(msgs[1]) payload = parse_input(msgs[1])
print(payload) print(payload)
# Check if xds is used in the payload
if "xds" in msgs[1].lower():
message.reply_text("`xds` key cannot be used in the `/draw` command. Use `/img` instead.")
return
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload) r = call_api('sdapi/v1/txt2img', payload)
@ -176,6 +234,7 @@ def draw(client, message):
message.reply_text("Failed to generate image. Please try again later.") message.reply_text("Failed to generate image. Please try again later.")
K.delete() K.delete()
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
@ -188,7 +247,11 @@ def img2img(client, message):
return return
payload = parse_input(msgs[1]) payload = parse_input(msgs[1])
print(f"input:\n{payload}")
photo = message.reply_to_message.photo photo = message.reply_to_message.photo
# prompt_from_reply = message.reply_to_message.
# orginal_prompt = app.reply_to_message.message
# print(orginal_prompt)
photo_file = app.download_media(photo) photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file) init_image = encode_file_to_base64(photo_file)
os.remove(photo_file) # Clean up downloaded image file os.remove(photo_file) # Clean up downloaded image file
@ -208,13 +271,14 @@ def img2img(client, message):
message.reply_text("Failed to process image. Please try again later.") message.reply_text("Failed to process image. Please try again later.")
K.delete() K.delete()
@app.on_message(filters.command(["getmodels"])) @app.on_message(filters.command(["getmodels"]))
async def get_models(client, message): async def get_models(client, message):
try: try:
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response.raise_for_status() response.raise_for_status()
models_json = response.json() models_json = response.json()
print(models_json)
buttons = [ buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])] [InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json for model in models_json
@ -223,6 +287,7 @@ async def get_models(client, message):
except requests.RequestException as e: except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}") await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query() @app.on_callback_query()
async def process_callback(client, callback_query): async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data sd_model_checkpoint = callback_query.data
@ -231,9 +296,15 @@ async def process_callback(client, callback_query):
try: try:
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response.raise_for_status() response.raise_for_status()
# Update the negative prompt based on the selected model
update_negative_prompt(sd_model_checkpoint)
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}") await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e: except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
print(f"Error setting checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"])) @app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message): async def info(client, message):
@ -244,6 +315,9 @@ currently supported
`xds` - denoise strength, only valid for img2img `xds` - denoise strength, only valid for img2img
`xsteps` - steps `xsteps` - steps
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots **note** limit the overall `steps:` to lower value (10-20) for big xyz plots
""")
aside from that you can use the usual `ng`, `ds`, `cfg`, `steps` for single image generation.
""", disable_web_page_preview=True)
app.run() app.run()