Compare commits

...

2 Commits

Author SHA1 Message Date
tami-p40
3fa661a54c qa 2024-05-18 20:41:30 +03:00
tami-p40
c8e825b247 AIrefactor 2024-05-18 14:06:29 +03:00
4 changed files with 108 additions and 245 deletions

View File

@ -47,6 +47,11 @@ Set that low (like 0.2) if you just want to slightly change things. defaults to
basicly anything the `/controlnet/img2img` API payload supports basicly anything the `/controlnet/img2img` API payload supports
### general
`X/Y/Z script` [link](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot), one powerfull thing
for prompt we use the Serach Replace option (a.k.a `prompt s/r`) [exaplined](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
## Setup ## Setup
Install requirements using venv Install requirements using venv

View File

@ -1,163 +0,0 @@
from datetime import datetime
import urllib.request
import base64
import json
import time
import os
url="pop-os.local"
webui_server_url = f'http://{url}:7860'
out_dir = 'api_out'
out_dir_t2i = os.path.join(out_dir, 'txt2img')
out_dir_i2i = os.path.join(out_dir, 'img2img')
os.makedirs(out_dir_t2i, exist_ok=True)
os.makedirs(out_dir_i2i, exist_ok=True)
def timestamp():
return datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
def call_api(api_endpoint, **payload):
data = json.dumps(payload).encode('utf-8')
request = urllib.request.Request(
f'{webui_server_url}/{api_endpoint}',
headers={'Content-Type': 'application/json'},
data=data,
)
response = urllib.request.urlopen(request)
return json.loads(response.read().decode('utf-8'))
def call_txt2img_api(**payload):
response = call_api('sdapi/v1/txt2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_t2i, f'txt2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
def call_img2img_api(**payload):
response = call_api('sdapi/v1/img2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_i2i, f'img2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
if __name__ == '__main__':
payload = {
"prompt": "masterpiece, (best quality:1.1), 1girl <lora:lora_model:1>", # extra networks also in prompts
"negative_prompt": "",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"cfg_scale": 7,
"sampler_name": "DPM++ 2M",
"n_iter": 1,
"batch_size": 1,
# example args for x/y/z plot
# "script_name": "x/y/z plot",
# "script_args": [
# 1,
# "10,20",
# [],
# 0,
# "",
# [],
# 0,
# "",
# [],
# True,
# True,
# False,
# False,
# 0,
# False
# ],
# example args for Refiner and ControlNet
# "alwayson_scripts": {
# "ControlNet": {
# "args": [
# {
# "batch_images": "",
# "control_mode": "Balanced",
# "enabled": True,
# "guidance_end": 1,
# "guidance_start": 0,
# "image": {
# "image": encode_file_to_base64(r"B:\path\to\control\img.png"),
# "mask": None # base64, None when not need
# },
# "input_mode": "simple",
# "is_ui": True,
# "loopback": False,
# "low_vram": False,
# "model": "control_v11p_sd15_canny [d14c016b]",
# "module": "canny",
# "output_dir": "",
# "pixel_perfect": False,
# "processor_res": 512,
# "resize_mode": "Crop and Resize",
# "threshold_a": 100,
# "threshold_b": 200,
# "weight": 1
# }
# ]
# },
# "Refiner": {
# "args": [
# True,
# "sd_xl_refiner_1.0",
# 0.5
# ]
# }
# },
# "enable_hr": True,
# "hr_upscaler": "R-ESRGAN 4x+ Anime6B",
# "hr_scale": 2,
# "denoising_strength": 0.5,
# "styles": ['style 1', 'style 2'],
# "override_settings": {
# 'sd_model_checkpoint': "sd_xl_base_1.0", # this can use to switch sd model
# },
}
call_txt2img_api(**payload)
init_images = [
encode_file_to_base64(r"../stable-diffusion-webui/output/img2img-images/2024-05-15/00012-357584826.png"),
# encode_file_to_base64(r"B:\path\to\img_2.png"),
# "https://image.can/also/be/a/http/url.png",
]
batch_size = 2
payload = {
"prompt": "1girl, blue hair",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"denoising_strength": 0.5,
"n_iter": 1,
"init_images": init_images,
"batch_size": batch_size if len(init_images) == 1 else len(init_images),
# "mask": encode_file_to_base64(r"B:\path\to\mask.png")
}
# if len(init_images) > 1 then batch_size should be == len(init_images)
# else if len(init_images) == 1 then batch_size can be any value int >= 1
call_img2img_api(**payload)
# there exist a useful extension that allows converting of webui calls to api payload
# particularly useful when you wish setup arguments of extensions and scripts
# https://github.com/huchenlei/sd-webui-api-payload-display

View File

@ -62,29 +62,36 @@ if __name__ == '__main__':
"width": 512, "width": 512,
"height": 512, "height": 512,
"cfg_scale": 7, "cfg_scale": 7,
"sampler_name": "DPM++ 2M", "sampler_name": "DPM++ SDE Karras",
"n_iter": 1, "n_iter": 1,
"batch_size": 1, "batch_size": 1,
# example args for x/y/z plot # example args for x/y/z plot
# "script_name": "x/y/z plot", #steps 4,"20,30"
# "script_args": [ #denoising==22
# 1, # S/R 7,"X,united states,china",
# "10,20", "script_args": [
# [], 4,
# 0, "20,30,40",
# "", [],
# [], 0,
# 0, "",
# "", [],
# [], 0,
# True, "",
# True, [],
# False, True,
# False, False,
# 0, False,
# False False,
# ], False,
False,
False,
0,
False
],
"script_name": "x/y/z plot",
# example args for Refiner and ControlNet # example args for Refiner and ControlNet
# "alwayson_scripts": { # "alwayson_scripts": {

130
main.py
View File

@ -1,10 +1,9 @@
import json
import requests
import io
import re
import os import os
import re
import io
import uuid import uuid
import base64 import base64
import requests
from datetime import datetime from datetime import datetime
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from pyrogram import Client, filters from pyrogram import Client, filters
@ -13,13 +12,13 @@ from dotenv import load_dotenv
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
API_ID = os.environ.get("API_ID", None) API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH", None) API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img", None) TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL", None) SD_URL = os.environ.get("SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' # Do not leave a trailing / IMAGE_PATH = 'images'
# Ensure IMAGE_PATH directory exists # Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True) os.makedirs(IMAGE_PATH, exist_ok=True)
@ -36,6 +35,7 @@ def decode_and_save_base64(base64_str, save_path):
file.write(base64.b64decode(base64_str)) file.write(base64.b64decode(base64_str))
def parse_input(input_string): def parse_input(input_string):
# Set default payload values
default_payload = { default_payload = {
"prompt": "", "prompt": "",
"negative_prompt": "ugly, bad face, distorted", "negative_prompt": "ugly, bad face, distorted",
@ -58,33 +58,73 @@ def parse_input(input_string):
matches = re.finditer(r"(\w+):", input_string) matches = re.finditer(r"(\w+):", input_string)
last_index = 0 last_index = 0
# Initialize script_args with default values and placeholder enums
script_args = [0, "", [], 0, "", [], 0, "", [], True, False, False, False, False, False, False, 0, False]
script_name = None
for match in matches: for match in matches:
key = match.group(1).lower() key = match.group(1).lower()
value_start_index = match.end() value_start_index = match.end()
if last_index != match.start(): if last_index != match.start():
prompt.append(input_string[last_index: match.start()].strip()) prompt.append(input_string[last_index: match.start()].strip())
last_index = value_start_index last_index = value_start_index
value_end_match = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
if value_end_match:
value_end_index = value_end_match.start() + value_start_index
else:
value_end_index = len(input_string)
value = input_string[value_start_index: value_end_index].strip()
if key == "ds": if key == "ds":
key = "denoising_strength" key = "denoising_strength"
if key == "ng": if key == "ng":
key = "negative_prompt" key = "negative_prompt"
if key in default_payload: if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start()
value = input_string[value_start_index: value_start_index + value_end_index].strip()
payload[key] = value payload[key] = value
last_index += value_end_index elif key in ["xsr", "xsteps", "xds"]:
script_name = "x/y/z plot"
if key == "xsr":
script_args[0] = 7 # Enum value for xsr
script_args[1] = value
elif key == "xsteps":
try:
steps_values = [int(x) for x in value.split(',')]
if all(1 <= x <= 70 for x in steps_values):
script_args[3] = 4 # Enum value for xsteps
script_args[4] = value
else:
raise ValueError("xsteps values must be between 1 and 70.")
except ValueError:
raise ValueError("xsteps must contain only integers.")
elif key == "xds":
script_args[6] = 22 # Enum value for xds
script_args[7] = value
else: else:
prompt.append(f"{key}:") prompt.append(f"{key}:{value}")
payload["prompt"] = " ".join(prompt) last_index = value_end_index
payload["prompt"] = " ".join(prompt).strip()
if not payload["prompt"]: if not payload["prompt"]:
payload["prompt"] = input_string.strip() payload["prompt"] = input_string.strip()
if script_name:
payload["script_name"] = script_name
payload["script_args"] = script_args
return payload return payload
def create_caption(payload, user_name, user_id, info):
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
try: try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
@ -103,7 +143,6 @@ def process_images(images, user_id, user_name):
for i in images: for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0]))) image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i} png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload) response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status() response2.raise_for_status()
@ -130,21 +169,9 @@ def draw(client, message):
if r: if r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
# for key, value in payload.items():
# caption += f"{key.capitalize()} - **{value}**\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
# Ensure caption is within the allowed length
if len(caption) > 1024:
caption = caption[:1021] + "..."
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to generate image. Please try again later.") message.reply_text("Failed to generate image. Please try again later.")
K.delete() K.delete()
@ -152,19 +179,15 @@ def draw(client, message):
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("reply to an image with \n`/img < prompt > ds:0-1.0`\n\nds stand for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4") message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.35\n\nExample: `/img murder on the dance floor ds:0.2`")
return return
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
print(msgs)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("""Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 > message.reply_text("dont FAIL in life")
""")
return return
payload = parse_input(" ".join(msgs[1:])) payload = parse_input(msgs[1])
print(payload)
photo = message.reply_to_message.photo photo = message.reply_to_message.photo
photo_file = app.download_media(photo) photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file) init_image = encode_file_to_base64(photo_file)
@ -178,13 +201,9 @@ def img2img(client, message):
if r: if r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to process image. Please try again later.") message.reply_text("Failed to process image. Please try again later.")
K.delete() K.delete()
@ -216,20 +235,15 @@ async def process_callback(client, callback_query):
except requests.RequestException as e: except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
# @app.on_message(filters.command(["start"], prefixes=["/", "!"])) @app.on_message(filters.command(["info_sd_bot"]))
# async def start(client, message): async def info(client, message):
# buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]] await message.reply_text("""
# await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons)) now support for xyz scripts, see [sd wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#xyz-plot) !
currently supported
user_interactions = {} `xsr` - search replace text/emoji in the prompt, more info [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-sr)
`xds` - denoise strength, only valid for img2img
@app.on_message(filters.command(["user_stats"])) `xsteps` - steps
def user_stats(client, message): **note** limit the overall `steps:` to lower value (10-20) for big xyz plots
stats = "User Interactions:\n\n" """)
for user_id, info in user_interactions.items():
stats += f"User: {info['username']} (ID: {user_id})\n"
stats += f"Commands: {', '.join(info['commands'])}\n\n"
message.reply_text(stats)
app.run() app.run()