登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
登录
注册
代码拉取完成,页面将自动刷新
开源项目
>
人工智能
>
推理框架
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
60
Star
157
Fork
124
MindSpore
/
vllm-mindspore
代码
Issues
21
Pull Requests
122
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
[RPC] Multi-framework support for vllm
置顶
TODO
#IBTNRG
RFC
zichun_ye
成员
创建于
2025-03-16 05:32
# Motivation There are discussion in the community about supporting different frameworks such as [JAX](https://github.com/vllm-project/vllm/issues/11507) and [MindSpore](https://github.com/vllm-project/vllm/issues/11862). Expanding vllm from PyTorch to other DL frameworks will benifit vllm in the way that: - different frameworks have the optimization in different fields. For example TF/Jax and MindSpore both worked a lot on static graphs. Also Jax works well with XLA on tpu. It will help increase inference speed in certain scenarios. - each framework has its own users and ecosystems. Adapting a new framework make vllm more accessible to these users, also help vllm make uses of their models and ecosystems. To support a new framework in vllm, we want to: 1. **Avoid massive changes**. We want to support the new framework without rewriting codes in the project a lot. 2. **Utilize new framework**. Different frameworks have their own way to run and accelerate models, such as `torch.compile`, `jax.jit`, `mindspore.jit`. We want to use them in the project. So we propose to support **framework pluggable**. There are some benefit: - framework decouple make vllm code clean and easy to maintain. This help vllm avoid conflict such as interface incompatibilities when AI framework envolves. - each framework can use its own acceleration optimization and models. # Design ## Key Observation Although vllm uses PyTorch as its default and only AI framework, we find the usage of PyTorch API is not evenly distributed across the project. Codes about scheduling and kv cache management are independent of actual model excution. This part of code includes **LLMEngine base class, scheduler and the Executor except for modelrunner**. There are a few torch apis including: - tesnor and related api; - date type; - device management; - communication; We found such APIs are standard in AI frameworks. All AI framework, include tensorflow, jax, mindspore, paddlepaddle, have API with the same functionalities. The rest part is closely related to actual model execution, including **modelrunner, models, layers and attention backend**. Most of PyTorch API usage is concentrated in the model related codes, and these API is not standard including: - torch.compile - torch._utils - torch.cuda.CUDAGraph/torch.cuda.graph - ... Below is a simple demonstration of distribution of PyTorch usage in the project and we have a detailed summary of all PyTorch usage in the project here in the attachment [vllm接口列表.xlsx](/mind_spore/dashboard/attach_files/2095209/download).  ## Multi-framework Plugin Design Based on this observation, we will have the following stategy to build the plugin - provide a list of common api and mapping for **LLMEngine base class, scheduler and the Executor except for modelrunner**; - provide a full realization of **modelrunner, models, layers and attention backend**. The change to the project include: - refactor the base classes, create a framework adapter and abstract all pytorch api call to the framework adapter call; - create a plugin for new framework. The plugin provides common api mapping for framework adapter and realization for model runner and below.  # Proposed Change. ## Framework Adapter Now vllm import torch as a library and call api from torch directly. To support multi-framework, the first refactor work is to change the use of torch api in an indirect way.  Framework Adapter is in fact a global variable `framework` and a framework class for each framework: - a framework class is an abstraction of all framework api uses, such as tensor and zeros; - a global variable `framework` represent the framework module. All framework api call is from the methods for this global variable, such as `framework.tensor` and `framework.zeros` Notice that this adapter will only deal with torch api calls in **LLMEngine base class, scheduler and the Executor except for modelrunner**. API number is limited including: - tesnor and related api; - date type; - device management; - communication; Thus the size of Framework Adapter would not be large, and all framework will have such apis. When framework plugin exist. This global variable will be directed to new framework class, and api calls will be directed to the new framework via common api mapping in the plugin.  ## Model Execution In a single model forward call, the oringal call stack is: ``` LLM Engine ->Executor->Work->Model Runnner->Model->Layer/Ops/Attn ``` The plugin will take over from Model Runnner and execute the model in new framework.  Let's look at them one by one: - LLM Engine, Executor, Worker: after use the framework adapter, the code should be framework agnostic. The plugin don't need to provide the whole class replacement; - model runner, models(attn, ops): plugin will also provide a full realization. ## Usage The same as before, the only change is to install a new plugin package: ``` pip install vllm pip install vllm-mindspore-plugin ``` ```python # The inference will run in an env with mindspore automatically. from vllm import LLM, SamplingParams # Sample prompts. prompts = ["Hello, my name is",] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. This model is written in mindspore and registered in the plugin llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` # RoadMap The initial version will include the following: * an init version of `Framework Adapter` base class and replace pytorch api calls via framework adapter; * A MindSpore Framework Plugin Demo, use mindspore as vllm underlying framework. # Appendix ## framework class representation and variable The base class FrameworkBase will define all API needed in model independent code as methods in the class. According to our summary, such APIs include tesnor, datatype and device management, so a draft version of such class could be ```python class FrameworkBase(ABC): """Abstract class for frameworks.""" # tensor related method @staticmethod @abstractmethod def Tensor(): raise NotImplementedError @staticmethod @abstractmethod def tensor(): raise NotImplementedError @staticmethod @abstractmethod def zeros(*size, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): raise NotImplementedError # dtype related method @staticmethod @abstractmethod def half(): raise NotImplementedError @staticmethod @abstractmethod def float(): raise NotImplementedError # more methods about tensor, dtype and device management ``` The derived class will realize all methods by API in correspond methods. For example, ```python class FrameworkPytorch(FrameworkBase): """PyTorch.""" # tensor related method @staticmethod def Tensor(): raise torch.Tensor @staticmethod def tensor(): raise torch.tensor @staticmethod def zeros(): raise torch.zeros # dtype related method @staticmethod def half(): raise torch.half @staticmethod def float(): raise torch.float # more methods about tensor, dtype and device management ``` To make the code framework neutral we will introduce a global variable `framework`. vllm select this underlying framework by assigning value to this grobal variable. Below is an example for pytorch. For other framework, some api mapping work is required to fit the api requirement used in vllm (namely the standard of pytorch). ```python global framework if is_pytorch_available(): framework = FrameworkPytorch() elif is_mindspore_available(): framework = FrameworkMindspore() elif is_jax_available(): framework = FrameworkJax() else: raise RuntimeError("No framework available for vllm in env!") ``` Then our framework API can be replaced by method call of grobal variable `framework`. For example, below is the function of allocating kv cache in cache engine. We have already made it framework neutral. ``` def _allocate_kv_cache( self, num_blocks: int, device: str, ) -> List[framework.Tensor]: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[framework.Tensor] = [] if self.align_cache: entry_shape = kv_cache_shape[2:] entry_size = np.prod(entry_shape) alloc_entry_size = align_to_256bytes(entry_size, self.dtype) alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) else: alloc_shape = kv_cache_shape for _ in range(self.num_attention_layers): layer_kv_cache = framework.zeros(alloc_shape, dtype=self.dtype, pin_memory=pin_memory, device=device) if self.align_cache: layer_kv_cache = layer_kv_cache[..., :entry_size] kv_cache.append(layer_kv_cache.view(kv_cache_shape)) return kv_cache ```
# Motivation There are discussion in the community about supporting different frameworks such as [JAX](https://github.com/vllm-project/vllm/issues/11507) and [MindSpore](https://github.com/vllm-project/vllm/issues/11862). Expanding vllm from PyTorch to other DL frameworks will benifit vllm in the way that: - different frameworks have the optimization in different fields. For example TF/Jax and MindSpore both worked a lot on static graphs. Also Jax works well with XLA on tpu. It will help increase inference speed in certain scenarios. - each framework has its own users and ecosystems. Adapting a new framework make vllm more accessible to these users, also help vllm make uses of their models and ecosystems. To support a new framework in vllm, we want to: 1. **Avoid massive changes**. We want to support the new framework without rewriting codes in the project a lot. 2. **Utilize new framework**. Different frameworks have their own way to run and accelerate models, such as `torch.compile`, `jax.jit`, `mindspore.jit`. We want to use them in the project. So we propose to support **framework pluggable**. There are some benefit: - framework decouple make vllm code clean and easy to maintain. This help vllm avoid conflict such as interface incompatibilities when AI framework envolves. - each framework can use its own acceleration optimization and models. # Design ## Key Observation Although vllm uses PyTorch as its default and only AI framework, we find the usage of PyTorch API is not evenly distributed across the project. Codes about scheduling and kv cache management are independent of actual model excution. This part of code includes **LLMEngine base class, scheduler and the Executor except for modelrunner**. There are a few torch apis including: - tesnor and related api; - date type; - device management; - communication; We found such APIs are standard in AI frameworks. All AI framework, include tensorflow, jax, mindspore, paddlepaddle, have API with the same functionalities. The rest part is closely related to actual model execution, including **modelrunner, models, layers and attention backend**. Most of PyTorch API usage is concentrated in the model related codes, and these API is not standard including: - torch.compile - torch._utils - torch.cuda.CUDAGraph/torch.cuda.graph - ... Below is a simple demonstration of distribution of PyTorch usage in the project and we have a detailed summary of all PyTorch usage in the project here in the attachment [vllm接口列表.xlsx](/mind_spore/dashboard/attach_files/2095209/download).  ## Multi-framework Plugin Design Based on this observation, we will have the following stategy to build the plugin - provide a list of common api and mapping for **LLMEngine base class, scheduler and the Executor except for modelrunner**; - provide a full realization of **modelrunner, models, layers and attention backend**. The change to the project include: - refactor the base classes, create a framework adapter and abstract all pytorch api call to the framework adapter call; - create a plugin for new framework. The plugin provides common api mapping for framework adapter and realization for model runner and below.  # Proposed Change. ## Framework Adapter Now vllm import torch as a library and call api from torch directly. To support multi-framework, the first refactor work is to change the use of torch api in an indirect way.  Framework Adapter is in fact a global variable `framework` and a framework class for each framework: - a framework class is an abstraction of all framework api uses, such as tensor and zeros; - a global variable `framework` represent the framework module. All framework api call is from the methods for this global variable, such as `framework.tensor` and `framework.zeros` Notice that this adapter will only deal with torch api calls in **LLMEngine base class, scheduler and the Executor except for modelrunner**. API number is limited including: - tesnor and related api; - date type; - device management; - communication; Thus the size of Framework Adapter would not be large, and all framework will have such apis. When framework plugin exist. This global variable will be directed to new framework class, and api calls will be directed to the new framework via common api mapping in the plugin.  ## Model Execution In a single model forward call, the oringal call stack is: ``` LLM Engine ->Executor->Work->Model Runnner->Model->Layer/Ops/Attn ``` The plugin will take over from Model Runnner and execute the model in new framework.  Let's look at them one by one: - LLM Engine, Executor, Worker: after use the framework adapter, the code should be framework agnostic. The plugin don't need to provide the whole class replacement; - model runner, models(attn, ops): plugin will also provide a full realization. ## Usage The same as before, the only change is to install a new plugin package: ``` pip install vllm pip install vllm-mindspore-plugin ``` ```python # The inference will run in an env with mindspore automatically. from vllm import LLM, SamplingParams # Sample prompts. prompts = ["Hello, my name is",] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. This model is written in mindspore and registered in the plugin llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` # RoadMap The initial version will include the following: * an init version of `Framework Adapter` base class and replace pytorch api calls via framework adapter; * A MindSpore Framework Plugin Demo, use mindspore as vllm underlying framework. # Appendix ## framework class representation and variable The base class FrameworkBase will define all API needed in model independent code as methods in the class. According to our summary, such APIs include tesnor, datatype and device management, so a draft version of such class could be ```python class FrameworkBase(ABC): """Abstract class for frameworks.""" # tensor related method @staticmethod @abstractmethod def Tensor(): raise NotImplementedError @staticmethod @abstractmethod def tensor(): raise NotImplementedError @staticmethod @abstractmethod def zeros(*size, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): raise NotImplementedError # dtype related method @staticmethod @abstractmethod def half(): raise NotImplementedError @staticmethod @abstractmethod def float(): raise NotImplementedError # more methods about tensor, dtype and device management ``` The derived class will realize all methods by API in correspond methods. For example, ```python class FrameworkPytorch(FrameworkBase): """PyTorch.""" # tensor related method @staticmethod def Tensor(): raise torch.Tensor @staticmethod def tensor(): raise torch.tensor @staticmethod def zeros(): raise torch.zeros # dtype related method @staticmethod def half(): raise torch.half @staticmethod def float(): raise torch.float # more methods about tensor, dtype and device management ``` To make the code framework neutral we will introduce a global variable `framework`. vllm select this underlying framework by assigning value to this grobal variable. Below is an example for pytorch. For other framework, some api mapping work is required to fit the api requirement used in vllm (namely the standard of pytorch). ```python global framework if is_pytorch_available(): framework = FrameworkPytorch() elif is_mindspore_available(): framework = FrameworkMindspore() elif is_jax_available(): framework = FrameworkJax() else: raise RuntimeError("No framework available for vllm in env!") ``` Then our framework API can be replaced by method call of grobal variable `framework`. For example, below is the function of allocating kv cache in cache engine. We have already made it framework neutral. ``` def _allocate_kv_cache( self, num_blocks: int, device: str, ) -> List[framework.Tensor]: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[framework.Tensor] = [] if self.align_cache: entry_shape = kv_cache_shape[2:] entry_size = np.prod(entry_shape) alloc_entry_size = align_to_256bytes(entry_size, self.dtype) alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) else: alloc_shape = kv_cache_shape for _ in range(self.num_attention_layers): layer_kv_cache = framework.zeros(alloc_shape, dtype=self.dtype, pin_memory=pin_memory, device=device) if self.align_cache: layer_kv_cache = layer_kv_cache[..., :entry_size] kv_cache.append(layer_kv_cache.view(kv_cache_shape)) return kv_cache ```
附件
vllm接口列表.xlsx
(17.94 KB)
预览
下载
zichun_ye
2025-03-16 11:05
评论 (
0
)
登录
后才可以发表评论
状态
TODO
VALIDATION
CLOSED
TODO
ACCEPTED
WIP
DONE
REJECTED
负责人
未设置
标签
未设置
项目
未立项任务
未立项任务
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
分支 (26)
标签 (2)
master
ms_inferrt
br_feature_glm4v
r0.3.0
r0.3.0.rc1
master_qwen3_next
v0.10.2
develop
refactor-hybrid
r0.3.0_qwen3_next
br_infer_boom
refactor-plugin
duo-dev
0.9.1-dev
develop-bak
develop-refactor
pangu-pro-moe-bak
pangu-pro-moe
cpp-dev
duo-qwq-dev
r0.2
v0.8.3-dev
r0.1_opt
r0.1
r0.1.1
plugin-dev
ms_inferrt_v0.1.0
v0.3.0
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
预计工期
(小时)
参与者(1)
Python
1
https://gitee.com/mindspore/vllm-mindspore.git
[email protected]
:mindspore/vllm-mindspore.git
mindspore
vllm-mindspore
vllm-mindspore
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
评论
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册