Compare commits

..

No commits in common. "1aacd87547bafb3b6fb24c4e2573cb240a86a6d6" and "31907abe334da7a59bf6726ac096a0078a7d5546" have entirely different histories.

3 changed files with 83 additions and 91 deletions

5
.gitignore vendored
View File

@ -1,8 +1,5 @@
*.png *.png
.env .env
*.session .session
vscode/ vscode/
venv/ venv/
*.session-journal
logs/stable_diff_telegram_bot.log
*.session

167
main.py
View File

@ -15,7 +15,7 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
API_ID = os.environ.get("API_ID", None) API_ID = os.environ.get("API_ID", None)
API_HASH = os.environ.get("API_HASH", None) API_HASH = os.environ.get("API_HASH", None)
TOKEN = os.environ.get("TOKEN_givemtxt2img", None) TOKEN = os.environ.get("TOKEN", None)
SD_URL = os.environ.get("SD_URL", None) SD_URL = os.environ.get("SD_URL", None)
app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN)
@ -24,6 +24,9 @@ IMAGE_PATH = 'images' # Do not leave a trailing /
# Ensure IMAGE_PATH directory exists # Ensure IMAGE_PATH directory exists
os.makedirs(IMAGE_PATH, exist_ok=True) os.makedirs(IMAGE_PATH, exist_ok=True)
# Default params
steps_value_default = 40
def timestamp(): def timestamp():
return datetime.now().strftime("%Y%m%d-%H%M%S") return datetime.now().strftime("%Y%m%d-%H%M%S")
@ -38,21 +41,38 @@ def decode_and_save_base64(base64_str, save_path):
def parse_input(input_string): def parse_input(input_string):
default_payload = { default_payload = {
"prompt": "", "prompt": "",
"negative_prompt": "ugly, bad face, distorted", "negative_prompt": "",
"controlnet_input_image": [],
"controlnet_mask": [],
"controlnet_module": "",
"controlnet_model": "",
"controlnet_weight": 1,
"controlnet_resize_mode": "Scale to Fit (Inner Fit)",
"controlnet_lowvram": False,
"controlnet_processor_res": 64,
"controlnet_threshold_a": 64,
"controlnet_threshold_b": 64,
"controlnet_guidance": 1,
"controlnet_guessmode": True,
"enable_hr": False, "enable_hr": False,
"Sampler": "DPM++ SDE Karras", "denoising_strength": 0.4,
"denoising_strength": 0.35, "hr_scale": 1.5,
"hr_upscale": "Latent",
"seed": -1,
"subseed": -1,
"subseed_strength": -1,
"sampler_index": "",
"batch_size": 1, "batch_size": 1,
"n_iter": 1, "n_iter": 1,
"steps": 35, "steps": 35,
"cfg_scale": 7, "cfg_scale": 7,
"width": 512, "width": 512,
"height": 512, "height": 512,
"restore_faces": False, "restore_faces": True,
"override_settings": {}, "override_settings": {},
"override_settings_restore_afterwards": True, "override_settings_restore_afterwards": True,
} }
payload = default_payload.copy() payload = {"prompt": ""}
prompt = [] prompt = []
matches = re.finditer(r"(\w+):", input_string) matches = re.finditer(r"(\w+):", input_string)
@ -65,34 +85,32 @@ def parse_input(input_string):
if last_index != match.start(): if last_index != match.start():
prompt.append(input_string[last_index: match.start()].strip()) prompt.append(input_string[last_index: match.start()].strip())
last_index = value_start_index last_index = value_start_index
if key == "ds":
key = "denoising_strength"
if key == "ng":
key = "negative_prompt"
if key in default_payload: if key in default_payload:
value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start() value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start()
value = input_string[value_start_index: value_start_index + value_end_index].strip() value = input_string[value_start_index: value_start_index + value_end_index].strip()
payload[key] = value
if isinstance(default_payload[key], int):
if value.isdigit():
payload[key] = idefault_payloadnt(value)
else:
payload[key] = value
last_index += value_end_index last_index += value_end_index
else: else:
prompt.append(f"{key}:") prompt.append(f"{key}:")
payload["prompt"] = " ".join(prompt) payload["prompt"] = " ".join(prompt)
if not payload["prompt"]: if not payload["prompt"]:
payload["prompt"] = input_string.strip() payload["prompt"] = input_string.strip()
return payload return payload
def call_api(api_endpoint, payload): def call_api(api_endpoint, payload):
try: response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload)
response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) response.raise_for_status()
response.raise_for_status() return response.json()
return response.json()
except requests.RequestException as e:
print(f"API call failed: {e}")
return None
def process_images(images, user_id, user_name): def process_images(images, user_id, user_name):
def generate_unique_name(): def generate_unique_name():
@ -127,41 +145,37 @@ def draw(client, message):
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/txt2img', payload) r = call_api('sdapi/v1/txt2img', payload)
if r: for i in r["images"]:
for i in r["images"]: word, info = process_images([i], message.from_user.id, message.from_user.first_name)
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
seed_value = info.split(", Seed: ")[1].split(",")[0] seed_value = info.split(", Seed: ")[1].split(",")[0]
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
for key, value in payload.items(): for key, value in payload.items():
caption += f"{key.capitalize()} - **{value}**\n" caption += f"{key.capitalize()} - **{value}**\n"
caption += f"Seed - **{seed_value}**\n" caption += f"Seed - **{seed_value}**\n"
# Ensure caption is within the allowed length message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
if len(caption) > 1024:
caption = caption[:1021] + "..."
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
message.reply_text("Failed to generate image. Please try again later.")
K.delete() K.delete()
@app.on_message(filters.command(["img"])) @app.on_message(filters.command(["img"]))
def img2img(client, message): def img2img(client, message):
if not message.reply_to_message or not message.reply_to_message.photo: if not message.reply_to_message or not message.reply_to_message.photo:
message.reply_text("reply to an image with \n`/img < prompt > ds:0-1.0`\n\nds stand for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4") message.reply_text("""
return reply to an image with
/img <prompt> ds:0-1.0
msgs = message.text.split(" ", 1) ds is for Denoising_strength. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4
print(msgs)
if len(msgs) == 1:
message.reply_text("""Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 >
""") """)
return return
payload = parse_input(" ".join(msgs[1:])) msgs = message.text.split(" ", 1)
if len(msgs) == 1:
message.reply_text("""
Format :\n/img <prompt>
force: < 0.1-1.0, default 0.3 >
""")
return
payload = parse_input(msgs[1])
print(payload) print(payload)
photo = message.reply_to_message.photo photo = message.reply_to_message.photo
photo_file = app.download_media(photo) photo_file = app.download_media(photo)
@ -169,65 +183,46 @@ def img2img(client, message):
os.remove(photo_file) # Clean up downloaded image file os.remove(photo_file) # Clean up downloaded image file
payload["init_images"] = [init_image] payload["init_images"] = [init_image]
payload["denoising_strength"] = 0.3 # Set default denoising strength or customize as needed
K = message.reply_text("Please Wait 10-15 Seconds") K = message.reply_text("Please Wait 10-15 Seconds")
r = call_api('sdapi/v1/img2img', payload) r = call_api('sdapi/v1/img2img', payload)
if r: for i in r["images"]:
for i in r["images"]: word, info = process_images([i], message.from_user.id, message.from_user.first_name)
word, info = process_images([i], message.from_user.id, message.from_user.first_name)
caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n"
prompt = payload["prompt"] prompt = payload["prompt"]
caption += f"**{prompt}**\n" caption += f"**{prompt}**\n"
message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption)
K.delete()
else:
message.reply_text("Failed to process image. Please try again later.")
K.delete() K.delete()
@app.on_message(filters.command(["getmodels"])) @app.on_message(filters.command(["getmodels"]))
async def get_models(client, message): async def get_models(client, message):
try: response = requests.get(f"{SD_URL}/sdapi/v1/sd-models")
response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") response.raise_for_status()
response.raise_for_status() models_json = response.json()
models_json = response.json()
buttons = [ buttons = [
[InlineKeyboardButton(model["title"], callback_data=model["model_name"])] [InlineKeyboardButton(model["title"], callback_data=model["model_name"])]
for model in models_json for model in models_json
] ]
await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons)) await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons))
except requests.RequestException as e:
await message.reply_text(f"Failed to get models: {e}")
@app.on_callback_query() @app.on_callback_query()
async def process_callback(client, callback_query): async def process_callback(client, callback_query):
sd_model_checkpoint = callback_query.data sd_model_checkpoint = callback_query.data
options = {"sd_model_checkpoint": sd_model_checkpoint} options = {"sd_model_checkpoint": sd_model_checkpoint}
try: response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options)
response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) response.raise_for_status()
response.raise_for_status()
await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
except requests.RequestException as e:
await callback_query.message.reply_text(f"Failed to set checkpoint: {e}")
# @app.on_message(filters.command(["start"], prefixes=["/", "!"])) await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}")
# async def start(client, message):
# buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]]
# await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons))
user_interactions = {} @app.on_message(filters.command(["start"], prefixes=["/", "!"]))
async def start(client, message):
@app.on_message(filters.command(["user_stats"])) buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]]
def user_stats(client, message): await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons))
stats = "User Interactions:\n\n"
for user_id, info in user_interactions.items():
stats += f"User: {info['username']} (ID: {user_id})\n"
stats += f"Commands: {', '.join(info['commands'])}\n\n"
message.reply_text(stats)
app.run() app.run()

View File

@ -1,4 +1,4 @@
pyrogram pyrogram==1.4.16
requests requests
tgcrypto==1.2.2 tgcrypto==1.2.2
Pillow Pillow