-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enhancement: use scikit-learn to get the most distinct color for the …
…embed
- Loading branch information
Showing
2 changed files
with
28 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,46 @@ | ||
import requests | ||
from PIL import Image | ||
import requests | ||
from io import BytesIO | ||
from colorthief import ColorThief | ||
import numpy as np | ||
from sklearn.cluster import KMeans | ||
|
||
def get_discord_color(image_url, border_percentage=0.2): | ||
def get_discord_color(image_url, num_colors=5, crop_percentage=0.5): | ||
""" | ||
Get the most common color from an image for Discord usage. | ||
Get the most distinct color from the center of an image for Discord usage. | ||
Args: | ||
image_url (str): The URL of the image to analyze. | ||
border_percentage (float): The percentage of border to exclude when cropping the image (default is 0.2). | ||
num_colors (int): Number of dominant colors to extract (default is 5). | ||
crop_percentage (float): Percentage of the image to keep in the center (default is 0.5). | ||
Returns: | ||
int: The most common color in hexadecimal format. | ||
int: The most distinct color in hexadecimal format. | ||
Raises: | ||
(Exception): If there is an issue with fetching or processing the image. | ||
""" | ||
response = requests.get(image_url) | ||
img = Image.open(BytesIO(response.content)).convert('RGB') | ||
|
||
# Crop the image to exclude borders | ||
img = Image.open(BytesIO(response.content)) | ||
# Calculate the crop dimensions | ||
width, height = img.size | ||
img = img.crop((width * border_percentage, height * border_percentage, width * (1 - border_percentage), height * (1 - border_percentage))) | ||
crop_width = int(width * crop_percentage) | ||
crop_height = int(height * crop_percentage) | ||
left = (width - crop_width) // 2 | ||
top = (height - crop_height) // 2 | ||
right = left + crop_width | ||
bottom = top + crop_height | ||
|
||
img = img.crop((left, top, right, bottom)) | ||
|
||
img = img.resize((img.width // 2, img.height // 2)) # Resize for faster processing | ||
|
||
color_thief = ColorThief(BytesIO(response.content)) | ||
palette = color_thief.get_palette(color_count=4, quality=4) | ||
img_array = np.array(img) | ||
img_flattened = img_array.reshape(-1, 3) | ||
|
||
img_arr = np.array(img) | ||
color_counts = {tuple(color): np.sum(np.all(img_arr == color, axis=-1)) for color in palette} | ||
kmeans = KMeans(n_clusters=num_colors) | ||
kmeans.fit(img_flattened) | ||
|
||
# Find the most common color | ||
most_common_color = max(color_counts, key=color_counts.get) | ||
dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))] | ||
|
||
return int('0x{:02x}{:02x}{:02x}'.format(*most_common_color), 16) | ||
return int('0x{:02x}{:02x}{:02x}'.format(*dominant_color.astype(int)), 16) |