diff --git a/aikg/examples/run_evolve_island_example.py b/aikg/examples/run_evolve_island_example.py new file mode 100644 index 0000000000000000000000000000000000000000..609efb03801ce2f2673039bea5fe3a47c49a3959 --- /dev/null +++ b/aikg/examples/run_evolve_island_example.py @@ -0,0 +1,302 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from pathlib import Path +import json +from datetime import datetime +from ai_kernel_generator.core.evolve_improved import evolve_with_islands +from ai_kernel_generator.core.async_pool.task_pool import TaskPool +from ai_kernel_generator.core.async_pool.device_pool import DevicePool +from ai_kernel_generator.config.config_validator import load_config +from ai_kernel_generator.utils.environment_check import check_env_for_task + + +def get_op_name(): + """获取算子名称""" + return "aikg_matmul" + + +def get_task_desc(): + """获取任务描述""" + return ''' +import torch +import torch.nn as nn + + +class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.matmul(a, b) + + +M, N, K = 512, 512, 512 # 矩阵维度 + + +def get_inputs(): + a = torch.randn(M, K) + b = torch.randn(K, N) + return [a, b] + + +def get_init_inputs(): + return [] # No special initialization inputs needed +''' + + +async def run_evolve_example(): + """运行带岛屿模型的进化式算子生成示例""" + # 基本参数配置 + op_name = get_op_name() + task_desc = get_task_desc() + dsl = "triton" # 可选: "triton", "swft" + framework = "torch" # 可选: "mindspore", "torch", "numpy" + backend = "cuda" # 可选: "ascend", "cuda" + arch = "a100" # 根据backend选择对应架构 + + # 进化参数配置 + max_rounds = 5 # 进化轮数 + parallel_num = 4 # 每轮并行任务数 + + # 岛屿模型配置(设置为0或1可禁用岛屿模型) + num_islands = 2 # 岛屿数量 + migration_interval = 2 # 迁移间隔(设置为0可禁用迁移) + elite_size = 5 # 精英保留数量(设置为0可禁用精英机制) + parent_selection_prob = 0.7 # 父代来源概率 (0.7概率在当前岛屿选择,0.3概率在精英机制选择) + + print("="*80) + print("AI KERNEL GENERATOR - 带岛屿模型的进化式算子生成示例") + print("="*80) + print(f"算子名称: {op_name}") + print(f"实现类型: {dsl}") + print(f"框架: {framework}") + print(f"后端: {backend}") + print(f"架构: {arch}") + print(f"进化轮数: {max_rounds}") + print(f"并行任务数: {parallel_num}") + print(f"进化轮数: {max_rounds}") + print(f"并行任务数: {parallel_num}") + + # 岛屿模型配置说明 + if num_islands <= 1: + print("岛屿模型: 禁用(简单进化模式)") + else: + print(f"岛屿数量: {num_islands}") + if migration_interval <= 0: + print("迁移: 禁用") + else: + print(f"迁移间隔: {migration_interval}") + + if elite_size <= 0: + print("精英机制: 禁用") + else: + print(f"精英数量: {elite_size}") + + if num_islands > 1 and elite_size > 0: + print(f"父代选择概率: {parent_selection_prob}") + print("="*80) + + # 初始化资源 + task_pool = TaskPool(max_concurrency=parallel_num) + device_pool = DevicePool([6]) # 使用设备6 + + config = load_config(config_path="./python/ai_kernel_generator/config/vllm_triton_evolve_config.yaml") + check_env_for_task(framework, backend, dsl, config) + + # 运行进化过程 + print("开始进化过程...") + evolution_result = await evolve_with_islands( + op_name=op_name, + task_desc=task_desc, + dsl=dsl, + framework=framework, + backend=backend, + arch=arch, + config=config, + device_pool=device_pool, + task_pool=task_pool, + max_rounds=max_rounds, + parallel_num=parallel_num, + num_islands=num_islands, + migration_interval=migration_interval, + elite_size=elite_size, + parent_selection_prob=parent_selection_prob # 新增参数 + ) + + # 检查进化结果是否有效 + if not evolution_result: + print("\n❌ 进化过程返回空结果") + return None + + # 输出进化结果 + print("\n" + "="*80) + print("进化完成!最终结果汇总:") + print("="*80) + print(f"算子名称: {evolution_result.get('op_name', 'Unknown')}") + print(f"总轮数: {evolution_result.get('total_rounds', 0)}") + print(f"总任务数: {evolution_result.get('total_tasks', 0)}") + print(f"成功任务数: {evolution_result.get('successful_tasks', 0)}") + print(f"最终成功率: {evolution_result.get('final_success_rate', 0.0):.2%}") + print(f"最佳成功率: {evolution_result.get('best_success_rate', 0.0):.2%}") + print(f"实现类型: {evolution_result.get('implementation_type', 'Unknown')}") + print(f"框架: {evolution_result.get('framework', 'Unknown')}") + print(f"后端: {evolution_result.get('backend', 'Unknown')}") + print(f"架构: {evolution_result.get('architecture', 'Unknown')}") + + # 岛屿信息 + island_info = evolution_result.get('island_info', {}) + if island_info: + num_islands_used = island_info.get('num_islands', 'N/A') + if num_islands_used <= 1: + print("进化模式: 简单进化(无岛屿模型)") + else: + print(f"岛屿数量: {num_islands_used}") + migration_interval_used = island_info.get('migration_interval', 'N/A') + if migration_interval_used <= 0: + print("迁移: 禁用") + else: + print(f"迁移间隔: {migration_interval_used}") + + elite_size_used = island_info.get('elite_size', 'N/A') + if elite_size_used <= 0: + print("精英机制: 禁用") + else: + print(f"精英数量: {elite_size_used}") + + # 显示存储目录信息 + storage_dir = evolution_result.get('storage_dir', '') + if storage_dir: + print(f"存储目录: {storage_dir}") + + # 显示最佳实现 + best_implementations = evolution_result.get('best_implementations', []) + if best_implementations: + print(f"\n最佳实现 (前{len(best_implementations)}个):") + for i, impl in enumerate(best_implementations, 1): + profile_data = impl.get('profile', float('inf')) + + # 处理profile信息,支持三元组格式 + if isinstance(profile_data, (list, tuple)) and len(profile_data) >= 3: + gen_time, base_time, speedup = profile_data[0], profile_data[1], profile_data[2] + profile_str = f"生成代码: {gen_time:.4f}s, 基准代码: {base_time:.4f}s, 加速比: {speedup:.2f}x" + elif isinstance(profile_data, (list, tuple)) and len(profile_data) >= 1: + profile_str = f"执行时间: {profile_data[0]:.4f}s" + elif profile_data != float('inf'): + profile_str = f"执行时间: {profile_data:.4f}s" + else: + profile_str = "性能: N/A" + + print(f" {i}. {impl.get('op_name', 'Unknown')} (轮次 {impl.get('round', 'N/A')}, {profile_str})") + else: + print("\n⚠️ 没有找到成功的实现") + + # 显示每轮详细结果 + round_results = evolution_result.get('round_results', []) + if round_results: + print(f"\n每轮详细结果:") + for round_result in round_results: + round_num = round_result.get('round', 'N/A') + success_rate = round_result.get('success_rate', 0.0) + successful = round_result.get('successful_tasks', 0) + total = round_result.get('total_tasks', 0) + print(f" 轮次 {round_num}: {successful}/{total} 成功 ({success_rate:.2%})") + + print("="*80) + + # 保存结果到文件 + timestamp_str = datetime.now().strftime("%Y%m%d%H%M") # 获取当前时间,并格式化为 "YYYYMMDDHHMM" + file_name = f"evolve_result_{evolution_result.get('op_name', 'unknown')}_{dsl}_{framework}_{timestamp_str}.json" + result_file = Path(config.get("log_dir", "")) / file_name + + # 为了JSON序列化,需要处理可能包含不可序列化对象的task_info字段 + serializable_result = evolution_result.copy() + if 'best_implementations' in serializable_result: + serializable_implementations = [] + for impl in serializable_result['best_implementations']: + serializable_impl = impl.copy() + # 从task_info中提取关键代码信息,然后移除整个task_info字段 + if 'task_info' in serializable_impl: + task_info = serializable_impl['task_info'] + # 提取关键代码字段 + serializable_impl['designer_code'] = task_info.get('designer_code', '') + serializable_impl['coder_code'] = task_info.get('coder_code', '') + serializable_impl['task_desc'] = task_info.get('task_desc', '') + serializable_impl['verifier_result'] = task_info.get('verifier_result', False) + serializable_impl['verifier_error'] = task_info.get('verifier_error', '') + # 移除复杂的task_info对象 + del serializable_impl['task_info'] + + # 确保profile三元组可以JSON序列化 + if 'profile' in serializable_impl and isinstance(serializable_impl['profile'], tuple): + serializable_impl['profile'] = list(serializable_impl['profile']) + serializable_implementations.append(serializable_impl) + serializable_result['best_implementations'] = serializable_implementations + + # 处理round_results中的implementations + if 'round_results' in serializable_result: + serializable_rounds = [] + for round_result in serializable_result['round_results']: + serializable_round = round_result.copy() + if 'implementations' in serializable_round: + serializable_impls = [] + for impl in serializable_round['implementations']: + serializable_impl = impl.copy() + # 从task_info中提取关键代码信息,然后移除整个task_info字段 + if 'task_info' in serializable_impl: + task_info = serializable_impl['task_info'] + # 提取关键代码字段 + serializable_impl['designer_code'] = task_info.get('designer_code', '') + serializable_impl['coder_code'] = task_info.get('coder_code', '') + serializable_impl['task_desc'] = task_info.get('task_desc', '') + serializable_impl['verifier_result'] = task_info.get('verifier_result', False) + serializable_impl['verifier_error'] = task_info.get('verifier_error', '') + # 移除复杂的task_info对象 + del serializable_impl['task_info'] + + # 确保profile三元组可以JSON序列化 + if 'profile' in serializable_impl and isinstance(serializable_impl['profile'], tuple): + serializable_impl['profile'] = list(serializable_impl['profile']) + serializable_impls.append(serializable_impl) + serializable_round['implementations'] = serializable_impls + serializable_rounds.append(serializable_round) + serializable_result['round_results'] = serializable_rounds + + with open(result_file, 'w', encoding='utf-8') as f: + json.dump(serializable_result, f, indent=2, ensure_ascii=False) + print(f"结果已保存到: {result_file}") + + return evolution_result + + +def main(): + """主函数""" + # 运行异步进化过程 + result = asyncio.run(run_evolve_example()) + + if result: + print("\n🎉 进化式算子生成成功完成!") + successful_tasks = result.get('successful_tasks', 0) + if successful_tasks > 0: + print(f"✅ 成功生成了 {successful_tasks} 个有效的算子实现") + else: + print("⚠️ 未能生成成功的算子实现,请检查配置和任务描述") + else: + print("\n❌ 进化过程失败,请检查日志获取详细信息") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/aikg/python/ai_kernel_generator/core/evolve_improved.py b/aikg/python/ai_kernel_generator/core/evolve_improved.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b7e04b8b22fbc0d6b43d598b6e5d82b570f2ab --- /dev/null +++ b/aikg/python/ai_kernel_generator/core/evolve_improved.py @@ -0,0 +1,766 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import random +import json +import uuid +from functools import partial +from typing import List, Dict, Any, Tuple +from pathlib import Path +from ai_kernel_generator.core.task import Task +from ai_kernel_generator.core.async_pool.task_pool import TaskPool +from ai_kernel_generator.core.async_pool.device_pool import DevicePool +from ai_kernel_generator.core.sketch import Sketch +from ai_kernel_generator import get_project_root +from ai_kernel_generator.utils.collector import get_collector + + +os.environ['AIKG_DATA_COLLECT'] = 'on' +logger = logging.getLogger(__name__) + + +def generate_unique_id() -> str: + """生成唯一ID""" + return str(uuid.uuid4()) + + +def pretty_print_results(results: List[Tuple[str, bool]]): + """打印进化结果 + + Args: + results: 任务执行结果列表 + """ + logger.info("=" * 60) + logger.info("EVOLVE ROUND RESULTS") + logger.info("=" * 60) + + success_count = 0 + total_count = len(results) + + for op_name, success in results: + status = "✓ SUCCESS" if success else "✗ FAILED" + logger.info(f"{op_name}: {status}") + if success: + success_count += 1 + + success_rate = success_count / total_count if total_count > 0 else 0 + logger.info("-" * 60) + logger.info(f"Success Rate: {success_count}/{total_count} ({success_rate:.2%})") + logger.info("=" * 60) + + +def load_meta_prompts(parallel_num: int) -> list[str]: + """ + 返回长度为 parallel_num 的 meta prompt 列表。 + + Args: + parallel_num: 并行任务数 + + Returns: + list[str]: meta prompts 字符串列表 + ▪ 当parallel_num <= n时:随机不重复选择 + + ▪ 当parallel_num > n时:随机重复选择,保证parallel_num条数据 + + """ + try: + from ai_kernel_generator.resources.docs.triton_docs.meta_prompts import ( + triton_meta_prompts, + ) + + assert triton_meta_prompts + assert isinstance( + triton_meta_prompts, list + ), "triton_meta_prompts should be a list" + + n = len(triton_meta_prompts) + + if parallel_num <= n: + # 随机不重复选择parallel_num个 + return random.sample(triton_meta_prompts, parallel_num) + else: + # 需要重复选择,保证parallel_num条数据 + result = [] + while len(result) < parallel_num: + # 每轮随机打乱所有prompts + shuffled_prompts = triton_meta_prompts.copy() + random.shuffle(shuffled_prompts) + + # 取需要的数量 + remaining = parallel_num - len(result) + result.extend(shuffled_prompts[:min(remaining, n)]) + + return result + + except Exception as e: + logger.error(f"Failed to load meta prompts: {e}") + return [""] * parallel_num + + +def save_implementation(impl_data: Dict[str, Any], storage_dir: str) -> None: + """保存实现到本地文件 + + Args: + impl_data: 实现数据字典 + storage_dir: 存储目录 + """ + try: + os.makedirs(storage_dir, exist_ok=True) + + # 确保有唯一ID + if 'id' not in impl_data: + impl_data['id'] = generate_unique_id() + + # 生成唯一文件名 + round_idx = impl_data.get('round', 0) + task_id = impl_data.get('task_id', 'unknown') + impl_id = impl_data.get('id', 'unknown')[:8] # 取ID前8位作为文件名的一部分 + filename = f"impl_{round_idx}_{task_id}_{impl_id}.json" + filepath = os.path.join(storage_dir, filename) + + # 保存数据 + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(impl_data, f, ensure_ascii=False, indent=2) + + logger.debug(f"Saved implementation to {filepath}") + except Exception as e: + logger.error(f"Failed to save implementation: {e}") + + +def load_best_implementations(storage_dir: str, max_count: int = 5) -> List[Dict[str, Any]]: + """从本地文件加载最佳实现 + + Args: + storage_dir: 存储目录 + max_count: 最大加载数量 + + Returns: + 按性能排序的最佳实现列表 + """ + implementations = [] + + try: + if not os.path.exists(storage_dir): + return implementations + + for filename in os.listdir(storage_dir): + if filename.endswith('.json'): + filepath = os.path.join(storage_dir, filename) + try: + with open(filepath, 'r', encoding='utf-8') as f: + impl_data = json.load(f) + # 确保每个实现都有唯一ID + if 'id' not in impl_data: + impl_data['id'] = generate_unique_id() + implementations.append(impl_data) + except Exception as e: + logger.warning(f"Failed to load {filepath}: {e}") + + # 按性能排序(gen_time越小越好) + implementations.sort(key=lambda x: x.get('profile', (float('inf'), 0.0, 0.0))[0]) + + logger.info(f"Loaded {len(implementations)} implementations from {storage_dir}") + return implementations[:max_count] + + except Exception as e: + logger.error(f"Failed to load implementations from {storage_dir}: {e}") + return implementations + + +def classify_implementations_by_performance(implementations: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]: + """按性能将实现分为三层:差、中等、好 + + Args: + implementations: 实现列表(已按性能排序,gen_time越小越好) + + Returns: + 分层后的实现字典,包含'good', 'medium', 'poor'三个层级 + """ + if not implementations: + return {'good': [], 'medium': [], 'poor': []} + + # 过滤出有效的加速比数据 + valid_impls = [] + for impl in implementations: + profile = impl.get('profile', (float('inf'), 0.0, 0.0)) + if len(profile) >= 3 and profile[2] != float('inf') and profile[2] > 0: + valid_impls.append(impl) + + if not valid_impls: + return {'good': [], 'medium': [], 'poor': []} + + total_count = len(valid_impls) + + # 按加速比排序(从高到低) + valid_impls.sort(key=lambda x: x['profile'][2], reverse=True) + + # 分层策略:前30%为好,中间40%为中等,后30%为差 + good_count = max(1, int(total_count * 0.3)) + medium_count = max(1, int(total_count * 0.4)) + + classified = { + 'good': valid_impls[:good_count], + 'medium': valid_impls[good_count:good_count + medium_count], + 'poor': valid_impls[good_count + medium_count:] + } + + logger.info(f"Performance classification: good={len(classified['good'])}, " + f"medium={len(classified['medium'])}, poor={len(classified['poor'])}") + + return classified + + +def sample_inspirations(implementations: List[Dict[str, Any]], sample_num: int = 2, use_all: bool = False, use_tiered_sampling: bool = False, parent_implementations: List[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + """从实现列表中采样inspiration格式的数据 + + Args: + implementations: 实现列表 + sample_num: 采样数量(当use_all=False时生效) + use_all: 是否使用所有数据,如果True则按性能排序返回所有数据 + use_tiered_sampling: 是否使用分层采样策略 + parent_implementations: 父代实现列表,用于避免重复采样 + + Returns: + inspiration格式的数据列表 + """ + if not implementations: + return [] + + # 收集父代ID,用于排除重复 + parent_ids = set() + if parent_implementations: + for parent in parent_implementations: + parent_id = parent.get('id') + if parent_id: + parent_ids.add(parent_id) + + # 排除父代实现 + filtered_implementations = [] + for impl in implementations: + impl_id = impl.get('id') + if impl_id and impl_id not in parent_ids: + filtered_implementations.append(impl) + elif not impl_id and impl not in parent_implementations: + # 没有ID的实现也添加进去,但排除父代 + filtered_implementations.append(impl) + + implementations = filtered_implementations + + if use_all: + # 使用所有数据,按性能排序 + selected = implementations # implementations已经按性能排序 + else: + # 检查是否有足够数据进行分层采样 + if use_tiered_sampling and len(implementations) >= 3: # 至少需要3个实现才进行分层 + # 分层采样:从好、中等、差三个层级各选一个 + classified = classify_implementations_by_performance(implementations) + + selected = [] + # 从每个层级选择一个最佳的 + for tier in ['good', 'medium', 'poor']: + if classified[tier]: + selected.append(classified[tier][0]) # 选择该层级最佳的 + + # 如果需要更多样本,从最佳层级补充 + while len(selected) < sample_num and classified['good']: + remaining_good = [impl for impl in classified['good'] if impl not in selected] + if remaining_good: + selected.append(remaining_good[0]) + else: + break + + logger.info(f"Tiered sampling selected {len(selected)} inspirations from different performance tiers") + else: + # 传统采样策略 + if len(implementations) <= sample_num: + selected = implementations + else: + # 50%概率选择最佳的,50%概率随机选择 + best_count = max(1, sample_num // 2) + random_count = sample_num - best_count + + selected = implementations[:best_count] # 最佳的几个 + if random_count > 0 and len(implementations) > best_count: + remaining = implementations[best_count:] + selected.extend(random.sample(remaining, min(random_count, len(remaining)))) + + # 转换为inspiration格式 + inspirations = [] + for impl in selected: + profile_tuple = impl.get('profile', (float('inf'), 0.0, 0.0)) + + # 优先使用sketch,如果没有sketch则使用原始代码 + sketch = impl.get('sketch', '') + impl_code = impl.get('impl_code', '') + + inspiration = { + 'id': impl.get('id'), # 保留ID信息 + 'sketch': sketch, # 使用sketch作为inspiration内容 + 'impl_code': impl_code, # 使用原始代码作为inspiration内容 + 'profile': profile_tuple, # 保持完整的三元组 + 'strategy_mode': 'evolution' + } + inspirations.append(inspiration) + + return inspirations + + +def migrate_elites(islands: List[List[Dict[str, Any]]], migration_size: int = 1) -> List[List[Dict[str, Any]]]: + """在岛屿间迁移精英个体 + + Args: + islands: 所有岛屿的实现列表 + migration_size: 每个岛屿迁移的个体数 + + Returns: + 更新后的岛屿列表 + """ + if len(islands) < 2: + return islands + + updated_islands = [island.copy() for island in islands] + + # 收集所有岛屿的精英 + elites = [] + for island in islands: + # 每个岛屿选择最好的几个个体 + sorted_island = sorted(island, key=lambda x: x.get('profile', (float('inf'), 0.0, 0.0))[0]) + elites.extend(sorted_island[:migration_size]) + + # 随机打乱精英列表 + random.shuffle(elites) + + # 将精英分配给其他岛屿 + for i, island in enumerate(updated_islands): + # 获取当前岛屿已有的实现ID集合,避免重复 + existing_ids = {impl.get('id') for impl in island if impl.get('id')} + + # 从其他岛屿的精英中选择 + other_elites = [elite for j, elite_list in enumerate(islands) + for elite in elite_list[:migration_size] if j != i] + + if other_elites: + # 顺延选择,跳过重复的实现 + selected_elites = [] + for elite in other_elites: + if len(selected_elites) >= migration_size: + break + elite_id = elite.get('id') + if elite_id not in existing_ids: + selected_elites.append(elite) + existing_ids.add(elite_id) + # 如果有重复,继续检查下一个 + + # 如果选择的数量不够,继续从剩余的精英中选择 + if len(selected_elites) < migration_size: + for elite in other_elites: + if len(selected_elites) >= migration_size: + break + if elite not in selected_elites: + selected_elites.append(elite) + + island.extend(selected_elites) + + return updated_islands + + +def select_parent_from_elite(current_island_idx: int, num_islands: int, elite_pool: List[Dict[str, Any]]) -> int: + """从精英机制中选择父代所在的岛屿索引 + + Args: + current_island_idx: 当前岛屿索引 + num_islands: 总岛屿数量 + elite_pool: 精英池 + + Returns: + 父代所在的岛屿索引 + """ + if not elite_pool: + # 如果精英池为空,返回当前岛屿 + return current_island_idx + + # 从精英池中随机选择一个精英个体 + selected_elite = random.choice(elite_pool) + + # 根据精英个体的来源岛屿决定父代岛屿 + source_island = selected_elite.get('source_island', None) + if source_island is not None and source_island != current_island_idx and 0 <= source_island < num_islands: + return source_island + else: + # 选择除当前岛屿外的随机岛屿 + other_islands = [i for i in range(num_islands) if i != current_island_idx] + if other_islands: + return random.choice(other_islands) + else: + return current_island_idx + + +async def evolve_with_islands( + op_name: str, + task_desc: str, + dsl: str, + framework: str, + backend: str, + arch: str, + config: dict, + device_pool: DevicePool, + task_pool: TaskPool, + max_rounds: int = 1, + parallel_num: int = 1, + num_islands: int = 2, + migration_interval: int = 2, + elite_size: int = 1, + parent_selection_prob: float = 0.5, +) -> Dict[str, Any]: + """带有岛屿模型和精英机制的进化式算子生成主函数 + + Args: + op_name: 算子名称 + task_desc: 任务描述 + dsl: 实现类型(如"triton", "swft") + framework: 框架名称(如"mindspore", "torch", "numpy") + backend: 后端名称(如"ascend", "cuda") + arch: 架构名称(如"ascend910b4", "a100") + config: 配置字典 + device_pool: 设备池 + task_pool: 任务池 + max_rounds: 最大进化轮数 + parallel_num: 每轮并行任务数 + num_islands: 岛屿数量 + migration_interval: 迁移间隔(每隔多少轮进行一次迁移) + elite_size: 精英池大小 + parent_selection_prob: 父代来源概率 + + Returns: + 进化结果字典 + """ + logger.info(f"Starting evolve process with islands for {op_name}") + logger.info(f"Configuration: {dsl} on {backend}/{arch} using {framework}") + logger.info(f"Islands: {num_islands}, Migration interval: {migration_interval}, Elite size: {elite_size}") + logger.info(f"Parent selection probability: {parent_selection_prob}") + + # 本地存储路径配置 + import uuid + random_hash = uuid.uuid4().hex[:8] + storage_dir = os.path.expanduser(f"~/aikg_evolve/{op_name}_{dsl}_{framework}_{backend}_{arch}/{random_hash}/") + os.makedirs(storage_dir, exist_ok=True) + + all_results = [] + best_success_rate = 0.0 + round_results = [] + best_implementations = [] + total_tasks = 0 + total_successful_tasks = 0 + + # 初始化岛屿 + tasks_per_island = max(1, parallel_num // num_islands) + islands_storage_dirs = [] + for i in range(num_islands): + island_storage_dir = os.path.join(storage_dir, f"island_{i}") + os.makedirs(island_storage_dir, exist_ok=True) + islands_storage_dirs.append(island_storage_dir) + + # 每个岛屿的历史实现 + island_implementations = [[] for _ in range(num_islands)] + # 精英库 + elite_pool = [] + # 当前岛屿索引(用于某些操作的偏好选择) + current_island = 0 + # 岛屿切换计数器 + current_island_counter = 0 + # 每处理多少任务后切换岛屿 + tasks_per_island_switch = max(1, tasks_per_island) + + # meta_prompts路径 + meta_prompt_path = None + if dsl == "triton": + root_dir = get_project_root() + meta_prompt_path = Path(root_dir) / "resources" / "docs" / f"{dsl}_docs" / "meta_prompts.py" + + for round_idx in range(1, max_rounds + 1): + logger.info(f"Evolve round {round_idx}/{max_rounds} started") + + # 每隔migration_interval轮进行迁移 + if round_idx > 1 and migration_interval > 0 and round_idx % migration_interval == 1 and num_islands > 1: + logger.info("Performing migration between islands") + island_implementations = migrate_elites(island_implementations, elite_size) + + # 为所有岛屿生成任务(确保所有岛屿在每轮都参与) + island_inspirations = [[] for _ in range(num_islands)] + island_meta_prompts = [[] for _ in range(num_islands)] + + if round_idx == 1: + # 第一轮:为所有岛屿初始化空的灵感列表 + for island_idx in range(num_islands): + island_inspirations[island_idx] = [[] for _ in range(tasks_per_island)] + if dsl == "triton" and meta_prompt_path and meta_prompt_path.exists(): + island_meta_prompts[island_idx] = load_meta_prompts(tasks_per_island) + else: + island_meta_prompts[island_idx] = [""] * tasks_per_island + else: + # 后续轮次:为所有岛屿生成灵感 + for island_idx in range(num_islands): + island_inspirations[island_idx] = [] + for pid in range(tasks_per_island): + # 根据概率p选择父代来源 + if random.random() < parent_selection_prob: + # 在当前岛屿中随机选择父代(保持岛屿隔离) + parent_island_idx = island_idx + else: + # 在精英机制中选择父代 + parent_island_idx = select_parent_from_elite(island_idx, num_islands, elite_pool) + + # 从父代所在的岛屿中选取灵感(严格限制在指定岛屿内) + if num_islands == 1: + stored_implementations = load_best_implementations(island_storage_dir) + else: + stored_implementations = load_best_implementations(islands_storage_dirs[parent_island_idx], + max_count=tasks_per_island * 2) + + # 如果有实现可用,先选择一个作为父代 + parent_implementation = None + if stored_implementations: + # 随机选择一个作为父代 + parent_implementation = random.choice(stored_implementations) + + # 使用分层采样策略来增加多样性,排除父代实现 + sampled = sample_inspirations(stored_implementations, sample_num=min(tasks_per_island, 3), use_tiered_sampling=True, parent_implementations=[parent_implementation] if parent_implementation else None) + island_inspirations[island_idx].append(sampled) + + if dsl == "triton" and meta_prompt_path and meta_prompt_path.exists(): + island_meta_prompts[island_idx] = load_meta_prompts(tasks_per_island) + else: + island_meta_prompts[island_idx] = [""] * tasks_per_island + + # 创建所有岛屿的任务 + all_tasks = [] + task_mapping = [] # 记录任务到岛屿的映射 + + # 为所有岛屿创建任务 + for island_idx in range(num_islands): + # 确保island_inspirations[island_idx]有足够的元素 + while len(island_inspirations[island_idx]) < tasks_per_island: + island_inspirations[island_idx].append([]) + + # 确保island_meta_prompts[island_idx]有足够的元素 + while len(island_meta_prompts[island_idx]) < tasks_per_island: + island_meta_prompts[island_idx].append("") + + for pid in range(tasks_per_island): + task_id = f"{round_idx}_{island_idx}_{pid}" + + task = Task( + op_name=op_name, + task_desc=task_desc, + task_id=task_id, + backend=backend, + arch=arch, + dsl=dsl, + config=config, + device_pool=device_pool, + framework=framework, + task_type="profile", + workflow="default_workflow", + inspirations=island_inspirations[island_idx][pid], + meta_prompts=island_meta_prompts[island_idx][pid] if island_meta_prompts[island_idx] else None, + ) + + task_pool.create_task(partial(task.run,)) + all_tasks.append(task) + task_mapping.append(island_idx) + + # 等待所有任务完成 + results = await task_pool.wait_all() + task_pool.tasks.clear() + + # 处理所有岛屿的结果 + round_implementations = [] + + # 按岛屿分组结果 + island_results = [[] for _ in range(num_islands)] + for i, result in enumerate(results): + island_idx = task_mapping[i] + island_results[island_idx].append(result) + + # 处理每个岛屿的结果 + for island_idx in range(num_islands): + current_island_results = island_results[island_idx] + + # 统计当前岛屿结果 + island_success_count = 0 + island_implementations_list = [] + + # 创建sketch agent(复用config) + sketch_agent = Sketch( + op_name=op_name, + task_desc=task_desc, + dsl=dsl, + backend=backend, + arch=arch, + config=config + ) + + # 收集成功任务信息 + successful_impls = [] + + for task_op_name, success, task_info in current_island_results: + total_tasks += 1 + + if success: + total_successful_tasks += 1 + island_success_count += 1 + + # 获取完整的profile三元组 + profile_res = task_info.get("profile_res", (float('inf'), 0.0, 0.0)) + + # 收集成功的实现信息 + impl_info = { + 'id': generate_unique_id(), # 添加唯一ID + 'op_name': task_op_name, + 'round': round_idx, + 'task_id': task_info.get('task_id', ''), + 'task_info': task_info, + 'profile': profile_res, + 'impl_code': task_info.get("coder_code", ""), + 'framework_code': task_desc, + 'backend': backend, + 'arch': arch, + 'dsl': dsl, + 'framework': framework, + 'sketch': '', + 'source_island': island_idx # 明确记录来源岛屿 + } + successful_impls.append(impl_info) + + # 使用task_pool异步执行sketch生成 + if successful_impls: + sketch_tasks = [] + for impl_info in successful_impls: + if impl_info['impl_code']: + sketch_task = partial(sketch_agent.run, impl_info['task_info']) + task_pool.create_task(sketch_task) + sketch_tasks.append(impl_info) + + if sketch_tasks: + sketch_results = await task_pool.wait_all() + task_pool.tasks.clear() + + # 处理sketch结果并更新impl_info + for i, impl_info in enumerate(sketch_tasks): + if impl_info['impl_code'] and i < len(sketch_results): + sketch_content = sketch_results[i] + impl_info['sketch'] = sketch_content if not isinstance(sketch_content, Exception) else "" + + island_implementations_list.append(impl_info) + round_implementations.append(impl_info) + + # 保存到对应岛屿的本地文件 + save_implementation(impl_info, islands_storage_dirs[island_idx]) + + # 添加到全局最佳实现列表 + best_implementations.append(impl_info) + + # 添加到当前岛屿实现列表 + island_implementations[island_idx].append(impl_info) + + # 更新精英库 + if island_implementations_list: + # 添加来源岛屿信息到每个实现 + for impl in island_implementations_list: + impl['source_island'] = island_idx + + # 添加到精英池(新生成的实现都有唯一ID,无需去重) + elite_pool.extend(island_implementations_list) + # 按性能排序精英池 + elite_pool.sort(key=lambda x: x['profile'][0] if isinstance(x['profile'], (list, tuple)) else x['profile']) + # 保持精英库大小限制 + elite_pool = elite_pool[:elite_size * num_islands] + + # 更新岛屿切换计数器 + if num_islands > 1: + current_island_counter += len(current_island_results) + if current_island_counter >= tasks_per_island_switch: + current_island = (current_island + 1) % num_islands + current_island_counter = 0 + logger.debug(f"Switched to island {current_island}") + + # 计算当前轮次成功率 + round_total_count = len(results) + round_success_count = sum(1 for island_results in island_results for result in island_results if result[1]) # 统计所有成功任务 + round_success_rate = round_success_count / round_total_count if round_total_count > 0 else 0.0 + cumulative_success_rate = total_successful_tasks / total_tasks if total_tasks > 0 else 0.0 + if cumulative_success_rate > best_success_rate: + best_success_rate = cumulative_success_rate + + # 记录轮次结果 + round_result = { + 'round': round_idx, + 'total_tasks': round_total_count, + 'successful_tasks': round_success_count, # 当前轮次成功任务数 + 'success_rate': round_success_rate, # 当前轮次成功率 + 'implementations': round_implementations + } + round_results.append(round_result) + all_results.extend([(impl['op_name'], True) for impl in round_implementations]) + + if os.getenv("AIKG_DATA_COLLECT", "off").lower() == "on": + try: + collector = await get_collector() + collector.set_config(config) + saved_files = await collector.prepare_and_remove_data() + except Exception as e: + logger.error(f"Failed to prepare data for transmission in evolve round {round_idx}: {e}") + + # 打印轮次结果 + pretty_print_results([(impl['op_name'], True) for impl in round_implementations] + + [(f"failed_task_{i}", False) for i in range(round_total_count - round_success_count)]) + + # 按性能排序最佳实现(gen_time越小越好) + best_implementations.sort(key=lambda x: x['profile'][0] if isinstance( + x['profile'], (list, tuple)) else x['profile']) + + # 计算最终成功率 + final_success_rate = total_successful_tasks / total_tasks if total_tasks > 0 else 0.0 + + # 构建返回结果 + evolution_result = { + 'op_name': op_name, + 'total_rounds': max_rounds, + 'total_tasks': total_tasks, + 'successful_tasks': total_successful_tasks, + 'final_success_rate': final_success_rate, + 'best_success_rate': best_success_rate, + 'implementation_type': dsl, + 'framework': framework, + 'backend': backend, + 'architecture': arch, + 'best_implementations': best_implementations[:5], # 只返回前5个最佳实现 + 'round_results': round_results, + 'storage_dir': storage_dir, # 添加存储目录信息 + 'island_info': { + 'num_islands': num_islands, + 'migration_interval': migration_interval, + 'elite_size': elite_size, + 'parent_selection_prob': parent_selection_prob + } + } + + logger.info(f"Evolution completed for {op_name}") + logger.info(f"Total tasks: {total_tasks}, Successful: {total_successful_tasks}") + logger.info(f"Final success rate: {final_success_rate:.2%}") + logger.info(f"Results stored in: {storage_dir}") + + return evolution_result \ No newline at end of file