get model
This commit is contained in:
		
							parent
							
								
									49c9f6337a
								
							
						
					
					
						commit
						ec900759a1
					
				
							
								
								
									
										43
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								main.py
									
									
									
									
									
								
							| @ -39,17 +39,32 @@ model_negative_prompts = { | ||||
|     "Juggernaut-XL_v9_RunDiffusionPhoto_v2": "bad eyes, cgi, airbrushed, plastic, watermark" | ||||
| } | ||||
| 
 | ||||
| def get_current_model_name(): | ||||
|     try: | ||||
|         response = requests.get(f"{SD_URL}/sdapi/v1/options") | ||||
|         response.raise_for_status() | ||||
|         options = response.json() | ||||
|         current_model_name = options.get("sd_model_checkpoint", "Unknown") | ||||
|         return current_model_name | ||||
|     except requests.RequestException as e: | ||||
|         print(f"API call failed: {e}") | ||||
|         return None | ||||
| 
 | ||||
| # Fetch the current model name at the start | ||||
| current_model_name = get_current_model_name() | ||||
| if current_model_name: | ||||
|     print(f"Current model name: {current_model_name}") | ||||
| else: | ||||
|     print("Failed to fetch the current model name.") | ||||
| 
 | ||||
| def encode_file_to_base64(path): | ||||
|     with open(path, 'rb') as file: | ||||
|         return base64.b64encode(file.read()).decode('utf-8') | ||||
| 
 | ||||
| 
 | ||||
| def decode_and_save_base64(base64_str, save_path): | ||||
|     with open(save_path, "wb") as file: | ||||
|         file.write(base64.b64decode(base64_str)) | ||||
| 
 | ||||
| 
 | ||||
| # Set default payload values | ||||
| default_payload = { | ||||
|     "prompt": "", | ||||
| @ -69,14 +84,12 @@ default_payload = { | ||||
|     "override_settings_restore_afterwards": True, | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def update_negative_prompt(model_name): | ||||
|     """Update the negative prompt for a given model.""" | ||||
|     if model_name in model_negative_prompts: | ||||
|         suffix = model_negative_prompts[model_name] | ||||
|         default_payload["negative_prompt"] += f", {suffix}" | ||||
| 
 | ||||
| 
 | ||||
| def update_resolution(model_name): | ||||
|     """Update resolution based on the selected model.""" | ||||
|     if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": | ||||
| @ -86,7 +99,6 @@ def update_resolution(model_name): | ||||
|         default_payload["width"] = 512 | ||||
|         default_payload["height"] = 512 | ||||
| 
 | ||||
| 
 | ||||
| def update_cfg_scale(model_name): | ||||
|     """Update CFG scale based on the selected model.""" | ||||
|     if model_name == "Juggernaut-XL_v9_RunDiffusionPhoto_v2": | ||||
| @ -94,6 +106,13 @@ def update_cfg_scale(model_name): | ||||
|     else: | ||||
|         default_payload["cfg_scale"] = 7 | ||||
| 
 | ||||
| # Update configurations based on the current model name | ||||
| if current_model_name: | ||||
|     update_negative_prompt(current_model_name) | ||||
|     update_resolution(current_model_name) | ||||
|     update_cfg_scale(current_model_name) | ||||
| else: | ||||
|     print("Failed to update configurations as the current model name is not available.") | ||||
| 
 | ||||
| def parse_input(input_string): | ||||
|     """Parse the input string and create a payload.""" | ||||
| @ -170,7 +189,6 @@ def parse_input(input_string): | ||||
| 
 | ||||
|     return payload, include_info | ||||
| 
 | ||||
| 
 | ||||
| def create_caption(payload, user_name, user_id, info, include_info): | ||||
|     """Create a caption for the generated image.""" | ||||
|     caption = f"**[{user_name}](tg://user?id={user_id})**\n\n" | ||||
| @ -194,7 +212,6 @@ def create_caption(payload, user_name, user_id, info, include_info): | ||||
| 
 | ||||
|     return caption | ||||
| 
 | ||||
| 
 | ||||
| def call_api(api_endpoint, payload): | ||||
|     """Call the API with the provided payload.""" | ||||
|     try: | ||||
| @ -205,12 +222,12 @@ def call_api(api_endpoint, payload): | ||||
|         print(f"API call failed: {e}") | ||||
|         return {"error": str(e)} | ||||
| 
 | ||||
| 
 | ||||
| def process_images(images, user_id, user_name): | ||||
|     """Process and save generated images.""" | ||||
|     def generate_unique_name(): | ||||
|         unique_id = str(uuid.uuid4())[:7] | ||||
|         return f"{user_name}-{unique_id}" | ||||
|         date = datetime.now().strftime("%Y-%m-%d-%H-%M") | ||||
|         return f"{date}-{user_name}-{unique_id}" | ||||
| 
 | ||||
|     word = generate_unique_name() | ||||
| 
 | ||||
| @ -227,9 +244,8 @@ def process_images(images, user_id, user_name): | ||||
|         # Save as JPG | ||||
|         jpg_path = f"{IMAGE_PATH}/{word}.jpg" | ||||
|         image.convert("RGB").save(jpg_path, "JPEG") | ||||
|          | ||||
|         return word, response2.json().get("info") | ||||
| 
 | ||||
|         return word, response2.json().get("info") | ||||
| 
 | ||||
| @app.on_message(filters.command(["draw"])) | ||||
| def draw(client, message): | ||||
| @ -259,7 +275,6 @@ def draw(client, message): | ||||
|         message.reply_text(error_message) | ||||
|         K.delete() | ||||
| 
 | ||||
| 
 | ||||
| @app.on_message(filters.command(["img"])) | ||||
| def img2img(client, message): | ||||
|     """Handle /img command to generate images from existing images.""" | ||||
| @ -294,7 +309,6 @@ def img2img(client, message): | ||||
|         message.reply_text(error_message) | ||||
|         K.delete() | ||||
| 
 | ||||
| 
 | ||||
| @app.on_message(filters.command(["getmodels"])) | ||||
| async def get_models(client, message): | ||||
|     """Handle /getmodels command to list available models.""" | ||||
| @ -310,7 +324,6 @@ async def get_models(client, message): | ||||
|     except requests.RequestException as e: | ||||
|         await message.reply_text(f"Failed to get models: {e}") | ||||
| 
 | ||||
| 
 | ||||
| @app.on_callback_query() | ||||
| async def process_callback(client, callback_query): | ||||
|     """Process model selection from callback queries.""" | ||||
| @ -330,7 +343,6 @@ async def process_callback(client, callback_query): | ||||
|         await callback_query.message.reply_text(f"Failed to set checkpoint: {e}") | ||||
|         print(f"Error setting checkpoint: {e}") | ||||
| 
 | ||||
| 
 | ||||
| @app.on_message(filters.command(["info_sd_bot"])) | ||||
| async def info(client, message): | ||||
|     """Provide information about the bot's commands and options.""" | ||||
| @ -391,5 +403,4 @@ For more details, visit the [Stable Diffusion Wiki](https://github.com/AUTOMATIC | ||||
| Enjoy creating with Stable Diffusion Bot! | ||||
| """, disable_web_page_preview=True) | ||||
| 
 | ||||
| 
 | ||||
| app.run() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user