ML-exp-4/route.py
fly6516 161b7a0eea feat: 添加聚类算法和 Q 学习算法
- 新增 clustering_algorithms.py 文件,实现 K 均值和 K 中心点聚类算法
- 新增 route.py 文件,实现基于 Q学习的路径规划算法
- 添加分析聚类结果和可视化功能
- 实现并行 Q学习训练,提高训练效率
2025-03-30 02:52:03 +08:00

97 lines
4.2 KiB
Python

import numpy as np # Import NumPy for numerical operations
import random # Import random module for exploration
import multiprocessing as mp # Import multiprocessing for parallel computing
#import cupy as cp
try:
import cupy as cp # Import CuPy for GPU acceleration
GPU_AVAILABLE = True # Set flag if CuPy is available
except ImportError:
GPU_AVAILABLE = False # Set flag if CuPy is not available
# Define the environment
grid_size = (4, 4) # Grid size (4x4)
start = (0, 0) # Start position
end = (3, 3) # End position
obstacles = {(1, 0), (2, 1), (1, 2), (0, 3), (3, 2)} # Set of obstacles
# Define possible actions and their effects on position
actions = {'up': (-1, 0), 'down': (1, 0), 'left': (0, -1), 'right': (0, 1)}
def is_valid(state):
"""Check if a state is within the grid and not an obstacle."""
return (0 <= state[0] < grid_size[0]) and (0 <= state[1] < grid_size[1]) and (state not in obstacles)
def get_next_state(state, action):
"""Get the next state based on the current state and action."""
new_state = (state[0] + actions[action][0], state[1] + actions[action][1])
return new_state if is_valid(new_state) else state
# Q-Learning parameters
alpha = 0.5 # Learning rate
gamma = 0.9 # Discount factor
epsilon = 0.1 # Increased exploration rate for better exploration
episodes = 5000 # Increased total training episodes for better learning
np.random.seed(42) # Set random seed for reproducibility
random.seed(42)
# Initialize the Q-table with all states and actions
grid_states = [(i, j) for i in range(grid_size[0]) for j in range(grid_size[1]) if (i, j) not in obstacles]
Q = {state: {action: 0 for action in actions} for state in grid_states}
# GPU acceleration setup (if available)
if GPU_AVAILABLE:
Q = {state: {action: 0.0 for action in actions} for state in grid_states} # Initialize Q-table on GPU
actions_list = list(actions.keys()) # Store actions as a list
def train_q_learning(_):
"""Function to train Q-learning in parallel using multiple processes."""
local_Q = {state: Q[state].copy() for state in grid_states} # Create a local copy of Q-table
for _ in range(episodes // mp.cpu_count()): # Each process handles a fraction of episodes
state = start # Start at the initial position
while state != end: # Run until reaching the goal
# Choose an action using ε-greedy policy
action = max(local_Q[state], key=local_Q[state].get) if random.uniform(0, 1) > epsilon else random.choice(
list(actions))
next_state = get_next_state(state, action) # Get the next state
reward = 1 if next_state == end else -0.1 # Define rewards
# Update Q-value using the Bellman equation
local_Q[state][action] += alpha * (
reward + gamma * max(local_Q[next_state].values()) - local_Q[state][action])
state = next_state # Move to next state
return local_Q # Return the updated local Q-table
# Parallel Q-learning training
if __name__ == "__main__":
num_processes = max(1, mp.cpu_count() // 2) # Use half the available CPU cores
with mp.Pool(num_processes) as pool: # Create a process pool with reduced number of CPU cores
results = pool.map(train_q_learning, range(num_processes)) # Distribute training across multiple processes
# Merge Q-tables from all processes
for state in grid_states:
for action in actions:
Q[state][action] = sum(r[state][action] for r in results) / len(results) # Average Q-values
# Compute the optimal path from start to end
def get_best_path():
"""Find the best path using the learned Q-values."""
state = start # Start at the initial position
path = [state] # Initialize path
visited = set() # Track visited states to avoid loops
while state != end:
if state in visited:
break # Avoid infinite loops
visited.add(state) # Mark state as visited
action = max(Q[state], key=Q[state].get) # Choose the best action based on Q-values
state = get_next_state(state, action) # Move to the next state
path.append(state) # Append to path
return path # Return the computed path
print(get_best_path()) # Print the optimal path