AIrefactor

This commit is contained in:
tami-p40 2024-05-18 14:06:29 +03:00
parent 9e73316ab8
commit c8e825b247
2 changed files with 30 additions and 220 deletions

View File

@ -1,163 +0,0 @@
from datetime import datetime
import urllib.request
import base64
import json
import time
import os
url="pop-os.local"
webui_server_url = f'http://{url}:7860'
out_dir = 'api_out'
out_dir_t2i = os.path.join(out_dir, 'txt2img')
out_dir_i2i = os.path.join(out_dir, 'img2img')
os.makedirs(out_dir_t2i, exist_ok=True)
os.makedirs(out_dir_i2i, exist_ok=True)
def timestamp():
return datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
def encode_file_to_base64(path):
with open(path, 'rb') as file:
return base64.b64encode(file.read()).decode('utf-8')
def decode_and_save_base64(base64_str, save_path):
with open(save_path, "wb") as file:
file.write(base64.b64decode(base64_str))
def call_api(api_endpoint, **payload):
data = json.dumps(payload).encode('utf-8')
request = urllib.request.Request(
f'{webui_server_url}/{api_endpoint}',
headers={'Content-Type': 'application/json'},
data=data,
)
response = urllib.request.urlopen(request)
return json.loads(response.read().decode('utf-8'))
def call_txt2img_api(**payload):
response = call_api('sdapi/v1/txt2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_t2i, f'txt2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
def call_img2img_api(**payload):
response = call_api('sdapi/v1/img2img', **payload)
for index, image in enumerate(response.get('images')):
save_path = os.path.join(out_dir_i2i, f'img2img-{timestamp()}-{index}.png')
decode_and_save_base64(image, save_path)
if __name__ == '__main__':
payload = {
"prompt": "masterpiece, (best quality:1.1), 1girl <lora:lora_model:1>", # extra networks also in prompts
"negative_prompt": "",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"cfg_scale": 7,
"sampler_name": "DPM++ 2M",
"n_iter": 1,
"batch_size": 1,
# example args for x/y/z plot
# "script_name": "x/y/z plot",
# "script_args": [
# 1,
# "10,20",
# [],
# 0,
# "",
# [],
# 0,
# "",
# [],
# True,
# True,
# False,
# False,
# 0,
# False
# ],
# example args for Refiner and ControlNet
# "alwayson_scripts": {
# "ControlNet": {
# "args": [
# {
# "batch_images": "",
# "control_mode": "Balanced",
# "enabled": True,
# "guidance_end": 1,
# "guidance_start": 0,
# "image": {
# "image": encode_file_to_base64(r"B:\path\to\control\img.png"),
# "mask": None # base64, None when not need
# },
# "input_mode": "simple",
# "is_ui": True,
# "loopback": False,
# "low_vram": False,
# "model": "control_v11p_sd15_canny [d14c016b]",
# "module": "canny",
# "output_dir": "",
# "pixel_perfect": False,
# "processor_res": 512,
# "resize_mode": "Crop and Resize",
# "threshold_a": 100,
# "threshold_b": 200,
# "weight": 1
# }
# ]
# },
# "Refiner": {
# "args": [
# True,
# "sd_xl_refiner_1.0",
# 0.5
# ]
# }
# },
# "enable_hr": True,
# "hr_upscaler": "R-ESRGAN 4x+ Anime6B",
# "hr_scale": 2,
# "denoising_strength": 0.5,
# "styles": ['style 1', 'style 2'],
# "override_settings": {
# 'sd_model_checkpoint': "sd_xl_base_1.0", # this can use to switch sd model
# },
}
call_txt2img_api(**payload)
init_images = [
encode_file_to_base64(r"../stable-diffusion-webui/output/img2img-images/2024-05-15/00012-357584826.png"),
# encode_file_to_base64(r"B:\path\to\img_2.png"),
# "https://image.can/also/be/a/http/url.png",
]
batch_size = 2
payload = {
"prompt": "1girl, blue hair",
"seed": 1,
"steps": 20,
"width": 512,
"height": 512,
"denoising_strength": 0.5,
"n_iter": 1,
"init_images": init_images,
"batch_size": batch_size if len(init_images) == 1 else len(init_images),
# "mask": encode_file_to_base64(r"B:\path\to\mask.png")
}
# if len(init_images) > 1 then batch_size should be == len(init_images)
# else if len(init_images) == 1 then batch_size can be any value int >= 1
call_img2img_api(**payload)
# there exist a useful extension that allows converting of webui calls to api payload
# particularly useful when you wish setup arguments of extensions and scripts
# https://github.com/huchenlei/sd-webui-api-payload-display

87
main.py
View File

@ -1,10 +1,9 @@
import json
import requests
import io
import re
import os import os
import re
import io
import uuid import uuid
import base64 import base64
import requests
from datetime import datetime from datetime import datetime
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from pyrogram import Client, filters from pyrogram import Client, filters
@ -13,13 +12,13 @@ from dotenv import load_dotenv
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
API_ID = os.environ.get("API_ID", None) API_ID = os.environ.get("API_ID")
API_HASH = os.environ.get("API_HASH", None) API_HASH = os.environ.get("API_HASH")
TOKEN = os.environ.get("TOKEN_givemtxt2img", None) TOKEN = os.environ.get("TOKEN_givemtxt2img")
SD_URL = os.environ.get("SD_URL", None) SD_URL = os.environ.get("SD_URL")
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
IMAGE_PATH = 'images' # Do not leave a trailing / IMAGE_PATH = 'images'
# Ensure IMAGE_PATH directory exists # Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True) os.makedirs(IMAGE_PATH, exist_ok=True)
@ -71,10 +70,10 @@ def parse_input(input_string):
key = "negative_prompt" key = "negative_prompt"
if key in default_payload: if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start() value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:])
value = input_string[value_start_index: value_start_index + value_end_index].strip() value = input_string[value_start_index: value_start_index + value_end_index.start()].strip()
payload[key] = value payload[key] = value
last_index += value_end_index last_index += value_end_index.start()
else: else:
prompt.append(f"{key}:") prompt.append(f"{key}:")
@ -84,7 +83,6 @@ def parse_input(input_string):
return payload return payload
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
try: try:
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
@ -103,7 +101,6 @@ def process_images(images, user_id, user_name):
for i in images: for i in images:
image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0]))) image = Image.open(io.BytesIO(base64.b64decode(i.split(",", 1)[0])))
png_payload = {"image": "data:image/png;base64," + i} png_payload = {"image": "data:image/png;base64," + i}
response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload) response2 = requests.post(f"{SD_URL}/sdapi/v1/png-info", json=png_payload)
response2.raise_for_status() response2.raise_for_status()
@ -114,6 +111,18 @@ def process_images(images, user_id, user_name):
return word, response2.json().get("info") return word, response2.json().get("info")
def create_caption(payload, user_name, user_id, info):
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{user_name}](tg://user?id={user_id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
if len(caption) > 1024:
caption = caption[:1021] + "..."
return caption
@app.on_message(filters.command(["draw"])) @app.on_message(filters.command(["draw"]))
def draw(client, message): def draw(client, message):
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
@ -130,21 +139,9 @@ def draw(client, message):
if r: if r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
# for key, value in payload.items():
# caption += f"{key.capitalize()} - **{value}**\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
caption += f"Seed - **{seed_value}**\n"
# Ensure caption is within the allowed length
if len(caption) > 1024:
caption = caption[:1021] + "..."
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to generate image. Please try again later.") message.reply_text("Failed to generate image. Please try again later.")
K.delete() K.delete()
@ -152,19 +149,15 @@ def draw(client, message):
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("reply to an image with \n`/img < prompt > ds:0-1.0`\n\nds stand for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4") message.reply_text("Reply to an image with\n`/img < prompt > ds:0-1.0`\n\nds stands for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4")
return return
msgs = message.text.split(" ", 1) msgs = message.text.split(" ", 1)
print(msgs)
if len(msgs) == 1: if len(msgs) == 1:
message.reply_text("""Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 > message.reply_text("Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >")
""")
return return
payload = parse_input(" ".join(msgs[1:])) payload = parse_input(msgs[1])
print(payload)
photo = message.reply_to_message.photo photo = message.reply_to_message.photo
photo_file = app.download_media(photo) photo_file = app.download_media(photo)
init_image = encode_file_to_base64(photo_file) init_image = encode_file_to_base64(photo_file)
@ -178,13 +171,9 @@ def img2img(client, message):
if r: if r:
for i in r["images"]: for i in r["images"]:
word, info = process_images([i], message.from_user.id, message.from_user.first_name) word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = create_caption(payload, message.from_user.first_name, message.from_user.id, info)
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
prompt = payload["prompt"]
caption += f"**{prompt}**\n"
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete() K.delete()
else: else:
message.reply_text("Failed to process image. Please try again later.") message.reply_text("Failed to process image. Please try again later.")
K.delete() K.delete()
@ -216,20 +205,4 @@ async def process_callback(client, callback_query):
except requests.RequestException as e: except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
# @app.on_message(filters.command(["start"], prefixes=["/", "!"]))
# async def start(client, message):
# buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]]
# await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons))
user_interactions = {}
@app.on_message(filters.command(["user_stats"]))
def user_stats(client, message):
stats = "User Interactions:\n\n"
for user_id, info in user_interactions.items():
stats += f"User: {info['username']} (ID: {user_id})\n"
stats += f"Commands: {', '.join(info['commands'])}\n\n"
message.reply_text(stats)
app.run() app.run()