AIrefactor

This commit is contained in:
tami-p40 2024-05-18 14:06:29 +03:00
parent 9e73316ab8
commit c8e825b247
2 changed files with 30 additions and 220 deletions

View File

@ -1,163 +0,0 @@
from datetime import datetime
import urllib.request
import base64
import json
import time
import os
url="pop-os.local"
webui_server_url = f'http://{url}:7860'
out_dir = 'api_out'
out_dir_t2i = os.path.join(out_dir, 'txt2img')
out_dir_i2i = os.path.join(out_dir, 'img2img')
os.makedirs(out_dir_t2i, exist_ok=True)
os.makedirs(out_dir_i2i, exist_ok=True)
def timestamp():
return datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
def call_api(api_endpoint, **payload):
data = json.dumps(payload).encode('utf-8')
request = urllib.request.Request(
f'{webui_server_url}/{api_endpoint}',
headers={'Content-Type': 'application/json'},
data=data,
)
response = urllib.request.urlopen(request)
return json.loads(response.read().decode('utf-8'))
def call_txt2img_api(**payload):
response = call_api('sdapi/v1/txt2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_t2i, f'txt2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
def call_img2img_api(**payload):
response = call_api('sdapi/v1/img2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_i2i, f'img2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
if __name__ == '__main__':
payload = {
"prompt": "masterpiece, (best quality:1.1), 1girl <lora:lora_model:1>", # extra networks also in prompts
"negative_prompt": "",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"cfg_scale": 7,
"sampler_name": "DPM++ 2M",
"n_iter": 1,
"batch_size": 1,
# example args for x/y/z plot
# "script_name": "x/y/z plot",
# "script_args": [
# 1,
# "10,20",
# [],
# 0,
# "",
# [],
# 0,
# "",
# [],
# True,
# True,
# False,
# False,
# 0,
# False
# ],
# example args for Refiner and ControlNet
# "alwayson_scripts": {
# "ControlNet": {
# "args": [
# {
# "batch_images": "",
# "control_mode": "Balanced",
# "enabled": True,
# "guidance_end": 1,
# "guidance_start": 0,
# "image": {
# "image": encode_file_to_base64(r"B:\path\to\control\img.png"),
# "mask": None # base64, None when not need
# },
# "input_mode": "simple",
# "is_ui": True,
# "loopback": False,
# "low_vram": False,
# "model": "control_v11p_sd15_canny [d14c016b]",
# "module": "canny",
# "output_dir": "",
# "pixel_perfect": False,
# "processor_res": 512,
# "resize_mode": "Crop and Resize",
# "threshold_a": 100,
# "threshold_b": 200,
# "weight": 1
# }
# ]
# },
# "Refiner": {
# "args": [
# True,
# "sd_xl_refiner_1.0",
# 0.5
# ]
# }
# },
# "enable_hr": True,
# "hr_upscaler": "R-ESRGAN 4x+ Anime6B",
# "hr_scale": 2,
# "denoising_strength": 0.5,
# "styles": ['style 1', 'style 2'],
# "override_settings": {
# 'sd_model_checkpoint': "sd_xl_base_1.0", # this can use to switch sd model
# },
}
call_txt2img_api(**payload)
init_images = [
encode_file_to_base64(r"../stable-diffusion-webui/output/img2img-images/2024-05-15/00012-357584826.png"),
# encode_file_to_base64(r"B:\path\to\img_2.png"),
# "https://image.can/also/be/a/http/url.png",
]
batch_size = 2
payload = {
"prompt": "1girl, blue hair",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"denoising_strength": 0.5,
"n_iter": 1,
"init_images": init_images,
"batch_size": batch_size if len(init_images) == 1 else len(init_images),
# "mask": encode_file_to_base64(r"B:\path\to\mask.png")
}
# if len(init_images) > 1 then batch_size should be == len(init_images)
# else if len(init_images) == 1 then batch_size can be any value int >= 1
call_img2img_api(**payload)
# there exist a useful extension that allows converting of webui calls to api payload
# particularly useful when you wish setup arguments of extensions and scripts
# https://github.com/huchenlei/sd-webui-api-payload-display

83
main.py
View File

@ -1,10 +1,9 @@
import json
import requests
import io
import re
import os
import re
import io
import uuid
import base64
import requests
from datetime import datetime
from PIL import Image, PngImagePlugin
from pyrogram import Client, filters
@ -13,13 +12,13 @@ from dotenv import load_dotenv
# Load environment variables
load_dotenv()
API_ID = os.environ.get("API_ID", None)
API_HASH = os.environ.get("API_HASH", None)
TOKEN = os.environ.get("TOKEN_givemtxt2img", None)
SD_URL = os.environ.get("SD_URL", None)
API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' # Do not leave a trailing /
IMAGE_PATH = 'images'
# Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True)
@ -71,10 +70,10 @@ def parse_input(input_string):
key = "negative_prompt"
if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start()
value = input_string[value_start_index: value_start_index + value_end_index].strip()
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
value = input_string[value_start_index: value_start_index + value_end_index.start()].strip()
payload[key] = value
last_index += value_end_index
last_index += value_end_index.start()
else:
prompt.append(f"{key}:")
@ -84,7 +83,6 @@ def parse_input(input_string):
return payload
def call_api(api_endpoint, payload):
try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
@ -103,7 +101,6 @@ def process_images(images, user_id, user_name):
for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status()
@ -114,6 +111,18 @@ def process_images(images, user_id, user_name):
return word, response2.json().get("info")
def create_caption(payload, user_name, user_id, info):
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
@app.on_message(filters.command(["draw"]))
def draw(client, message):
msgs = message.text.split(" ", 1)
@ -130,19 +139,7 @@ def draw(client, message):
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
# for key, value in payload.items():
# caption += f"{key.capitalize()} - **{value}**\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
# Ensure caption is within the allowed length
if len(caption) > 1024:
caption = caption[:1021] + "..."
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
@ -152,19 +149,15 @@ def draw(client, message):
@app.on_message(filters.command(["img"]))
def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("reply to an image with \n`/img < prompt > ds:0-1.0`\n\nds stand for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4")
message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4")
return
msgs = message.text.split(" ", 1)
print(msgs)
if len(msgs) == 1:
message.reply_text("""Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >
""")
message.reply_text("Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >")
return
payload = parse_input(" ".join(msgs[1:]))
print(payload)
payload = parse_input(msgs[1])
photo = message.reply_to_message.photo
photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file)
@ -178,11 +171,7 @@ def img2img(client, message):
if r:
for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
@ -216,20 +205,4 @@ async def process_callback(client, callback_query):
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
# @app.on_message(filters.command(["start"], prefixes=["/", "!"]))
# async def start(client, message):
# buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]]
# await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons))
user_interactions = {}
@app.on_message(filters.command(["user_stats"]))
def user_stats(client, message):
stats = "User Interactions:\n\n"
for user_id, info in user_interactions.items():
stats += f"User: {info['username']} (ID: {user_id})\n"
stats += f"Commands: {', '.join(info['commands'])}\n\n"
message.reply_text(stats)
app.run()