feat: 添加聚类算法和 Q 学习算法
- 新增 clustering_algorithms.py 文件,实现 K 均值和 K 中心点聚类算法 - 新增 route.py 文件,实现基于 Q学习的路径规划算法 - 添加分析聚类结果和可视化功能 - 实现并行 Q学习训练,提高训练效率
This commit is contained in:
commit
161b7a0eea
77
clustering_algorithms.py
Normal file
77
clustering_algorithms.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# 导入必要的库
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.cluster import KMeans
|
||||||
|
from sklearn_extra.cluster import KMedoids
|
||||||
|
from sklearn.metrics import silhouette_score
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
# 设置环境变量以消除Joblib警告
|
||||||
|
os.environ['LOKY_MAX_CPU_COUNT'] = str(psutil.cpu_count(logical=False))
|
||||||
|
|
||||||
|
# 设置Matplotlib支持中文的字体
|
||||||
|
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体
|
||||||
|
matplotlib.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||||
|
|
||||||
|
# 加载Iris数据集
|
||||||
|
iris = load_iris()
|
||||||
|
X = iris.data # 特征数据
|
||||||
|
y = iris.target # 目标标签
|
||||||
|
|
||||||
|
# K均值聚类算法
|
||||||
|
def kmeans_clustering(X, n_clusters=3):
|
||||||
|
"""
|
||||||
|
使用K均值算法对数据进行聚类
|
||||||
|
:param X: 输入数据
|
||||||
|
:param n_clusters: 聚类数量
|
||||||
|
:return: 聚类标签
|
||||||
|
"""
|
||||||
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42) # 初始化KMeans对象
|
||||||
|
kmeans.fit(X) # 拟合数据
|
||||||
|
labels = kmeans.labels_ # 获取聚类标签
|
||||||
|
return labels
|
||||||
|
|
||||||
|
# K中心点聚类算法
|
||||||
|
def kmedoids_clustering(X, n_clusters=3):
|
||||||
|
"""
|
||||||
|
使用K中心点算法对数据进行聚类
|
||||||
|
:param X: 输入数据
|
||||||
|
:param n_clusters: 聚类数量
|
||||||
|
:return: 聚类标签
|
||||||
|
"""
|
||||||
|
kmedoids = KMedoids(n_clusters=n_clusters, random_state=42) # 初始化KMedoids对象
|
||||||
|
kmedoids.fit(X) # 拟合数据
|
||||||
|
labels = kmedoids.labels_ # 获取聚类标签
|
||||||
|
return labels
|
||||||
|
|
||||||
|
# 分析聚类结果
|
||||||
|
def analyze_clustering(X, labels, algorithm_name):
|
||||||
|
"""
|
||||||
|
分析聚类结果并可视化
|
||||||
|
:param X: 输入数据
|
||||||
|
:param labels: 聚类标签
|
||||||
|
:param algorithm_name: 算法名称
|
||||||
|
"""
|
||||||
|
silhouette_avg = silhouette_score(X, labels) # 计算轮廓系数
|
||||||
|
print(f"{algorithm_name} 轮廓系数: {silhouette_avg}")
|
||||||
|
|
||||||
|
# 可视化聚类结果
|
||||||
|
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
|
||||||
|
plt.title(f"{algorithm_name} 聚类结果")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
# 使用K均值算法进行聚类
|
||||||
|
kmeans_labels = kmeans_clustering(X)
|
||||||
|
analyze_clustering(X, kmeans_labels, "K均值")
|
||||||
|
|
||||||
|
# 使用K中心点算法进行聚类
|
||||||
|
kmedoids_labels = kmedoids_clustering(X)
|
||||||
|
analyze_clustering(X, kmedoids_labels, "K中心点")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
BIN
myplot-1.png
Normal file
BIN
myplot-1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
BIN
myplot-2.png
Normal file
BIN
myplot-2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
97
route.py
Normal file
97
route.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import numpy as np # Import NumPy for numerical operations
|
||||||
|
import random # Import random module for exploration
|
||||||
|
import multiprocessing as mp # Import multiprocessing for parallel computing
|
||||||
|
#import cupy as cp
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy as cp # Import CuPy for GPU acceleration
|
||||||
|
|
||||||
|
GPU_AVAILABLE = True # Set flag if CuPy is available
|
||||||
|
except ImportError:
|
||||||
|
GPU_AVAILABLE = False # Set flag if CuPy is not available
|
||||||
|
|
||||||
|
# Define the environment
|
||||||
|
grid_size = (4, 4) # Grid size (4x4)
|
||||||
|
start = (0, 0) # Start position
|
||||||
|
end = (3, 3) # End position
|
||||||
|
obstacles = {(1, 0), (2, 1), (1, 2), (0, 3), (3, 2)} # Set of obstacles
|
||||||
|
|
||||||
|
# Define possible actions and their effects on position
|
||||||
|
actions = {'up': (-1, 0), 'down': (1, 0), 'left': (0, -1), 'right': (0, 1)}
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(state):
|
||||||
|
"""Check if a state is within the grid and not an obstacle."""
|
||||||
|
return (0 <= state[0] < grid_size[0]) and (0 <= state[1] < grid_size[1]) and (state not in obstacles)
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_state(state, action):
|
||||||
|
"""Get the next state based on the current state and action."""
|
||||||
|
new_state = (state[0] + actions[action][0], state[1] + actions[action][1])
|
||||||
|
return new_state if is_valid(new_state) else state
|
||||||
|
|
||||||
|
|
||||||
|
# Q-Learning parameters
|
||||||
|
alpha = 0.5 # Learning rate
|
||||||
|
gamma = 0.9 # Discount factor
|
||||||
|
epsilon = 0.1 # Increased exploration rate for better exploration
|
||||||
|
episodes = 5000 # Increased total training episodes for better learning
|
||||||
|
np.random.seed(42) # Set random seed for reproducibility
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
# Initialize the Q-table with all states and actions
|
||||||
|
grid_states = [(i, j) for i in range(grid_size[0]) for j in range(grid_size[1]) if (i, j) not in obstacles]
|
||||||
|
Q = {state: {action: 0 for action in actions} for state in grid_states}
|
||||||
|
|
||||||
|
# GPU acceleration setup (if available)
|
||||||
|
if GPU_AVAILABLE:
|
||||||
|
Q = {state: {action: 0.0 for action in actions} for state in grid_states} # Initialize Q-table on GPU
|
||||||
|
actions_list = list(actions.keys()) # Store actions as a list
|
||||||
|
|
||||||
|
def train_q_learning(_):
|
||||||
|
"""Function to train Q-learning in parallel using multiple processes."""
|
||||||
|
local_Q = {state: Q[state].copy() for state in grid_states} # Create a local copy of Q-table
|
||||||
|
for _ in range(episodes // mp.cpu_count()): # Each process handles a fraction of episodes
|
||||||
|
state = start # Start at the initial position
|
||||||
|
while state != end: # Run until reaching the goal
|
||||||
|
# Choose an action using ε-greedy policy
|
||||||
|
action = max(local_Q[state], key=local_Q[state].get) if random.uniform(0, 1) > epsilon else random.choice(
|
||||||
|
list(actions))
|
||||||
|
next_state = get_next_state(state, action) # Get the next state
|
||||||
|
reward = 1 if next_state == end else -0.1 # Define rewards
|
||||||
|
# Update Q-value using the Bellman equation
|
||||||
|
local_Q[state][action] += alpha * (
|
||||||
|
reward + gamma * max(local_Q[next_state].values()) - local_Q[state][action])
|
||||||
|
state = next_state # Move to next state
|
||||||
|
return local_Q # Return the updated local Q-table
|
||||||
|
|
||||||
|
# Parallel Q-learning training
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_processes = max(1, mp.cpu_count() // 2) # Use half the available CPU cores
|
||||||
|
with mp.Pool(num_processes) as pool: # Create a process pool with reduced number of CPU cores
|
||||||
|
results = pool.map(train_q_learning, range(num_processes)) # Distribute training across multiple processes
|
||||||
|
|
||||||
|
# Merge Q-tables from all processes
|
||||||
|
for state in grid_states:
|
||||||
|
for action in actions:
|
||||||
|
Q[state][action] = sum(r[state][action] for r in results) / len(results) # Average Q-values
|
||||||
|
|
||||||
|
|
||||||
|
# Compute the optimal path from start to end
|
||||||
|
def get_best_path():
|
||||||
|
"""Find the best path using the learned Q-values."""
|
||||||
|
state = start # Start at the initial position
|
||||||
|
path = [state] # Initialize path
|
||||||
|
visited = set() # Track visited states to avoid loops
|
||||||
|
while state != end:
|
||||||
|
if state in visited:
|
||||||
|
break # Avoid infinite loops
|
||||||
|
visited.add(state) # Mark state as visited
|
||||||
|
action = max(Q[state], key=Q[state].get) # Choose the best action based on Q-values
|
||||||
|
state = get_next_state(state, action) # Move to the next state
|
||||||
|
path.append(state) # Append to path
|
||||||
|
return path # Return the computed path
|
||||||
|
|
||||||
|
|
||||||
|
print(get_best_path()) # Print the optimal path
|
Loading…
Reference in New Issue
Block a user