diff --git a/main.py b/main.py index 36b8c2e..2c42960 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ from dotenv import load_dotenv load_dotenv() API_ID = os.environ.get("API_ID", None) API_HASH = os.environ.get("API_HASH", None) -TOKEN = os.environ.get("TOKEN", None) +TOKEN = os.environ.get("TOKEN_givemtxt2img", None) SD_URL = os.environ.get("SD_URL", None) app = Client("stable", api_id=API_ID, api_hash=API_HASH, bot_token=TOKEN) @@ -24,9 +24,6 @@ IMAGE_PATH = 'images' # Do not leave a trailing / # Ensure IMAGE_PATH directory exists os.makedirs(IMAGE_PATH, exist_ok=True) -# Default params -steps_value_default = 40 - def timestamp(): return datetime.now().strftime("%Y%m%d-%H%M%S") @@ -41,38 +38,21 @@ def decode_and_save_base64(base64_str, save_path): def parse_input(input_string): default_payload = { "prompt": "", - "negative_prompt": "", - "controlnet_input_image": [], - "controlnet_mask": [], - "controlnet_module": "", - "controlnet_model": "", - "controlnet_weight": 1, - "controlnet_resize_mode": "Scale to Fit (Inner Fit)", - "controlnet_lowvram": False, - "controlnet_processor_res": 64, - "controlnet_threshold_a": 64, - "controlnet_threshold_b": 64, - "controlnet_guidance": 1, - "controlnet_guessmode": True, + "negative_prompt": "ugly, bad face, distorted", "enable_hr": False, - "denoising_strength": 0.4, - "hr_scale": 1.5, - "hr_upscale": "Latent", - "seed": -1, - "subseed": -1, - "subseed_strength": -1, - "sampler_index": "", + "Sampler": "DPM++ SDE Karras", + "denoising_strength": 0.35, "batch_size": 1, "n_iter": 1, "steps": 35, "cfg_scale": 7, "width": 512, "height": 512, - "restore_faces": True, + "restore_faces": False, "override_settings": {}, "override_settings_restore_afterwards": True, } - payload = {"prompt": ""} + payload = default_payload.copy() prompt = [] matches = re.finditer(r"(\w+):", input_string) @@ -90,9 +70,15 @@ def parse_input(input_string): value_end_index = re.search(r"(?=\s+\w+:|$)", input_string[value_start_index:]).start() value = input_string[value_start_index: value_start_index + value_end_index].strip() - if isinstance(default_payload[key], int): - if value.isdigit(): - payload[key] = idefault_payloadnt(value) + if key == "denoising_strength": + try: + ds_value = float(value) + if 0.0 <= ds_value <= 1.0: + payload[key] = ds_value + else: + payload[key] = default_payload[key] + except ValueError: + payload[key] = default_payload[key] else: payload[key] = value @@ -101,16 +87,20 @@ def parse_input(input_string): prompt.append(f"{key}:") payload["prompt"] = " ".join(prompt) - if not payload["prompt"]: payload["prompt"] = input_string.strip() return payload + def call_api(api_endpoint, payload): - response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) - response.raise_for_status() - return response.json() + try: + response = requests.post(f'{SD_URL}/{api_endpoint}', json=payload) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + print(f"API call failed: {e}") + return None def process_images(images, user_id, user_name): def generate_unique_name(): @@ -145,37 +135,41 @@ def draw(client, message): K = message.reply_text("Please Wait 10-15 Seconds") r = call_api('sdapi/v1/txt2img', payload) - for i in r["images"]: - word, info = process_images([i], message.from_user.id, message.from_user.first_name) + if r: + for i in r["images"]: + word, info = process_images([i], message.from_user.id, message.from_user.first_name) - seed_value = info.split(", Seed: ")[1].split(",")[0] - caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" - for key, value in payload.items(): - caption += f"{key.capitalize()} - **{value}**\n" - caption += f"Seed - **{seed_value}**\n" + seed_value = info.split(", Seed: ")[1].split(",")[0] + caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" + for key, value in payload.items(): + caption += f"{key.capitalize()} - **{value}**\n" + caption += f"Seed - **{seed_value}**\n" - message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) + # Ensure caption is within the allowed length + if len(caption) > 1024: + caption = caption[:1021] + "..." + + message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) + K.delete() + else: + message.reply_text("Failed to generate image. Please try again later.") K.delete() @app.on_message(filters.command(["img"])) def img2img(client, message): if not message.reply_to_message or not message.reply_to_message.photo: - message.reply_text(""" - reply to an image with - /img ds:0-1.0 - ds is for Denoising_strength. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4 - """) + message.reply_text("reply to an image with \n`/img < prompt > ds:0-1.0`\n\nds stand for `Denoising_strength` parameter. Set that low (like 0.2) if you just want to slightly change things. defaults to 0.4") return msgs = message.text.split(" ", 1) + print(msgs) + if len(msgs) == 1: - message.reply_text(""" - Format :\n/img - force: < 0.1-1.0, default 0.3 > + message.reply_text("""Format :\n/img < prompt >\nforce: < 0.1-1.0, default 0.3 > """) return - payload = parse_input(msgs[1]) + payload = parse_input(" ".join(msgs[1:])) print(payload) photo = message.reply_to_message.photo photo_file = app.download_media(photo) @@ -183,46 +177,65 @@ def img2img(client, message): os.remove(photo_file) # Clean up downloaded image file payload["init_images"] = [init_image] - payload["denoising_strength"] = 0.3 # Set default denoising strength or customize as needed K = message.reply_text("Please Wait 10-15 Seconds") r = call_api('sdapi/v1/img2img', payload) - for i in r["images"]: - word, info = process_images([i], message.from_user.id, message.from_user.first_name) + if r: + for i in r["images"]: + word, info = process_images([i], message.from_user.id, message.from_user.first_name) - caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" - prompt = payload["prompt"] - caption += f"**{prompt}**\n" + caption = f"**[{message.from_user.first_name}](tg://user?id={message.from_user.id})**\n\n" + prompt = payload["prompt"] + caption += f"**{prompt}**\n" - message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) + message.reply_photo(photo=f"{IMAGE_PATH}/{word}.png", caption=caption) + K.delete() + else: + message.reply_text("Failed to process image. Please try again later.") K.delete() @app.on_message(filters.command(["getmodels"])) async def get_models(client, message): - response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") - response.raise_for_status() - models_json = response.json() + try: + response = requests.get(f"{SD_URL}/sdapi/v1/sd-models") + response.raise_for_status() + models_json = response.json() - buttons = [ - [InlineKeyboardButton(model["title"], callback_data=model["model_name"])] - for model in models_json - ] - await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons)) + buttons = [ + [InlineKeyboardButton(model["title"], callback_data=model["model_name"])] + for model in models_json + ] + await message.reply_text("Select a model [checkpoint] to use", reply_markup=InlineKeyboardMarkup(buttons)) + except requests.RequestException as e: + await message.reply_text(f"Failed to get models: {e}") @app.on_callback_query() async def process_callback(client, callback_query): sd_model_checkpoint = callback_query.data options = {"sd_model_checkpoint": sd_model_checkpoint} - response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) - response.raise_for_status() + try: + response = requests.post(f"{SD_URL}/sdapi/v1/options", json=options) + response.raise_for_status() + await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}") + except requests.RequestException as e: + await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") - await callback_query.message.reply_text(f"Checkpoint set to {sd_model_checkpoint}") +# @app.on_message(filters.command(["start"], prefixes=["/", "!"])) +# async def start(client, message): +# buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]] +# await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons)) -@app.on_message(filters.command(["start"], prefixes=["/", "!"])) -async def start(client, message): - buttons = [[InlineKeyboardButton("Add to your group", url="https://t.me/gootmornbot?startgroup=true")]] - await message.reply_text("Hello!\nAsk me to imagine anything\n\n/draw text to image", reply_markup=InlineKeyboardMarkup(buttons)) +user_interactions = {} + +@app.on_message(filters.command(["user_stats"])) +def user_stats(client, message): + stats = "User Interactions:\n\n" + for user_id, info in user_interactions.items(): + stats += f"User: {info['username']} (ID: {user_id})\n" + stats += f"Commands: {', '.join(info['commands'])}\n\n" + + message.reply_text(stats) app.run()