This commit is contained in:
tami-p40 2024-05-18 20:41:30 +03:00
parent c8e825b247
commit 3fa661a54c
3 changed files with 97 additions and 44 deletions

View File

@ -47,6 +47,11 @@ Set that low (like 0.2) if you just want to slightly change things. defaults to
basicly anything the `/controlnet/img2img` API payload supports
### general
`X/Y/Z script` [link](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot), one powerfull thing
for prompt we use the Serach Replace option (a.k.a `prompt s/r`) [exaplined](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
## Setup
Install requirements using venv

View File

@ -62,29 +62,36 @@ if __name__ == '__main__':
"width": 512,
"height": 512,
"cfg_scale": 7,
"sampler_name": "DPM++ 2M",
"sampler_name": "DPM++ SDE Karras",
"n_iter": 1,
"batch_size": 1,
# example args for x/y/z plot
# "script_name": "x/y/z plot",
# "script_args": [
# 1,
# "10,20",
# [],
# 0,
# "",
# [],
# 0,
# "",
# [],
# True,
# True,
# False,
# False,
# 0,
# False
# ],
#steps 4,"20,30"
#denoising==22
# S/R 7,"X,united states,china",
"script_args": [
4,
"20,30,40",
[],
0,
"",
[],
0,
"",
[],
True,
False,
False,
False,
False,
False,
False,
0,
False
],
"script_name": "x/y/z plot",
# example args for Refiner and ControlNet
# "alwayson_scripts": {

89
main.py
View File

@ -35,6 +35,7 @@ def decode_and_save_base64(base64_str, save_path):
file.write(base64.b64decode(base64_str))
def parse_input(input_string):
# Set default payload values
default_payload = {
"prompt": "",
"negative_prompt": "ugly, bad face, distorted",
@ -57,32 +58,73 @@ def parse_input(input_string):
matches = re.finditer(r"(\w+):", input_string)
last_index = 0
# Initialize script_args with default values and placeholder enums
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
for match in matches:
key = match.group(1).lower()
value_start_index = match.end()
if last_index != match.start():
prompt.append(input_string[last_index: match.start()].strip())
last_index = value_start_index
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
if value_end_match:
value_end_index = value_end_match.start() + value_start_index
else:
value_end_index = len(input_string)
value = input_string[value_start_index: value_end_index].strip()
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
value = input_string[value_start_index: value_start_index + value_end_index.start()].strip()
payload[key] = value
last_index += value_end_index.start()
else:
prompt.append(f"{key}:")
payload["prompt"] = " ".join(prompt)
if key in default_payload:
payload[key] = value
elif key in ["xsr", "xsteps", "xds"]:
script_name = "x/y/z plot"
if key == "xsr":
script_args[0] = 7 # Enum value for xsr
script_args[1] = value
elif key == "xsteps":
try:
steps_values = [int(x) for x in value.split(',')]
if all(1 <= x <= 70 for x in steps_values):
script_args[3] = 4 # Enum value for xsteps
script_args[4] = value
else:
raise ValueError("xsteps values must be between 1 and 70.")
except ValueError:
raise ValueError("xsteps must contain only integers.")
elif key == "xds":
script_args[6] = 22 # Enum value for xds
script_args[7] = value
else:
prompt.append(f"{key}:{value}")
last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip()
if not payload["prompt"]:
payload["prompt"] = input_string.strip()
if script_name:
payload["script_name"] = script_name
payload["script_args"] = script_args
return payload
def create_caption(payload, user_name, user_id, info):
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
def call_api(api_endpoint, payload):
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
@ -111,18 +153,6 @@ def process_images(images, user_id, user_name):
return word, response2.json().get("info")
def create_caption(payload, user_name, user_id, info):
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
@app.on_message(filters.command(["draw"]))
def draw(client, message):
msgs = message.text.split(" ", 1)
@ -149,12 +179,12 @@ def draw(client, message):
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4")
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
return
msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >")
message.reply_text("dont FAIL in life")
return
payload = parse_input(msgs[1])
@ -205,4 +235,15 @@ async def process_callback(client, callback_query):
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
@app.on_message(filters.command(["info_sd_bot"]))
async def info(client, message):
await message.reply_text("""
now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
currently supported
`xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
`xds` - denoise strength, only valid for img2img
`xsteps` - steps
**note** limit the overall `steps:` to lower value (10-20) for big xyz plots
""")
app.run()