diff --git a/cve-fix/analyze.py b/cve-fix/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..52536a534a96f57158d660c58f838ac6f0c5210a --- /dev/null +++ b/cve-fix/analyze.py @@ -0,0 +1,48 @@ +import argparse +from vulnerability.database import CVEDatabase +from vulnerability.tracker import CVETracker +from vulnerability.call import create_comment, make_body + +def main(): + parser = argparse.ArgumentParser(description='查询指定 CVE 名称在不同 OpenEuler 版本中的信息。') + parser.add_argument('cve_name', type=str, help='要查询的 CVE 名称,例如 CVE-2023-1234') + parser.add_argument('access_token', type=str) + parser.add_argument('full_name', type=str) + parser.add_argument('number', type=str) + + args = parser.parse_args() + cve_name = str(args.cve_name).strip() + access_token = str(args.access_token).strip() + full_name = str(args.full_name).strip() + number = str(args.number).strip() + + print("读取数据库") + cve_db = CVEDatabase('fixed_commit_link') + tracker = CVETracker(cve_db) + print("查询CVE") + try: + tracker.update_cve_commit(cve_name) + except Exception as e: + print(f"查询CVE失败:{str(e)}") + print(f"查询每个OS是否引入{cve_name}") + code_dict, intro_commit, fixed_commit = tracker.query_all_os_cve(cve_name) + result_dict = {} + for os_name, code in code_dict.items(): + print(f"检查{cve_name}在{os_name}中的引入及修复情况") + result_dict[os_name] = tracker.get_analyze_result(code) + + body = make_body(result=result_dict, tracker=tracker) + body += "\n" + body += f"引入补丁链接:https://web.git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id={intro_commit}\n" + body += f"修复补丁链接:https://web.git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id={fixed_commit}\n" + + cve_db.close() + parameter = { + "access_token": access_token, + "full_name": full_name, + "number": number, + } + create_comment(parameter, body) + +if __name__ == "__main__": + result = main() diff --git a/cve-fix/create_pr.py b/cve-fix/create_pr.py new file mode 100644 index 0000000000000000000000000000000000000000..35c4d3880c0be1fd73d5c525e44896469c630b96 --- /dev/null +++ b/cve-fix/create_pr.py @@ -0,0 +1,102 @@ +import argparse +from vulnerability.database import CVEDatabase +from vulnerability.tracker import CVETracker +from vulnerability.call import create_pr, create_comment, make_body +import re + +cve_db = CVEDatabase('fixed_commit_link') + +def create_one_pr(cve_name, branch, access_token, full_name_openeuler, head, commit_parameter, user_name, user_email, number): + if(branch != "all" and branch not in cve_db.get_os_versions("all")): + result = "无法识别的操作系统版本名称" + create_comment(commit_parameter, result) + return + + if(branch != "all" and cve_db.check_os_version_is_openeuler(branch)): + result = f"{branch}非内核分支,请使用 /create_pr {cve_db.get_os_parents(branch)}" + create_comment(commit_parameter, result) + return + + tracker = CVETracker(cve_db) + print("Update cve commit for cve: ", cve_name) + tracker.update_cve_commit(cve_name) + + result = {} + return_message = {} + print("Fix cve for branch: ", branch) + code, push_branch = tracker.fix_cve(cve_name, branch, user_name, user_email, number) + return_message[push_branch] = "" + if code == 0: + parameter = { + "access_token": access_token, + "full_name": full_name_openeuler, + "title": cve_name, + "head": f"{head}:{push_branch}", + "base": branch, + "body": "", + "prune_source_branch": True, + } + try: + response = create_pr(parameter) + return_message[branch] = response.json()['html_url'] + print("第一次响应成功") + except: + print("第一次响应失败") + response = create_pr(parameter) + extra_message = response.json()['message'] + urls = re.findall(r'href="(https?://[^"]+)"', extra_message) + if urls: + return_message[branch] = urls[0] + else: + return_message[branch] = extra_message + + + result[branch] = tracker.get_fixed_result(code) + if code == 0: + result[branch] += return_message[branch] + + body = make_body(result, tracker) + create_comment(commit_parameter, body) + +def main(): + parser = argparse.ArgumentParser(description='创建CVE在指定分支下的修复补丁。') + parser.add_argument('cve_name', type=str, help='CVE 名称,例如 CVE-2023-1234') + parser.add_argument('branch', type=str, help='分支名称,all为数据库中所有的openEuler分支') + parser.add_argument('access_token', type=str) + parser.add_argument('full_name', type=str) + parser.add_argument('number', type=str) + parser.add_argument('user_name', type=str) + parser.add_argument('user_email', type=str) + + args = parser.parse_args() + cve_name = str(args.cve_name).strip() + branch = str(args.branch).strip() + access_token = str(args.access_token).strip() + full_name_srcopeneuler = str(args.full_name).strip() + full_name_openeuler = "openeuler/kernel" + number = str(args.number).strip() + head = "ci-robot/kernel" + user_name = str(args.user_name).strip() + user_email = str(args.user_email).strip() + + commit_parameter = { + "access_token": access_token, + "full_name": full_name_srcopeneuler, + "number": number, + } + + if(branch == "all"): + all_os_version = cve_db.get_os_versions(type='stable') + for os_version in all_os_version: + print("Create pr for branch: ", os_version) + create_one_pr(cve_name, os_version, access_token, full_name_openeuler, head, commit_parameter, user_name, user_email, number) + else: + print("Create pr for branch: ", branch) + create_one_pr(cve_name, branch, access_token, full_name_openeuler, head, commit_parameter, user_name, user_email, number) + + cve_db.close() + + +if __name__ == "__main__": + result = main() + \ No newline at end of file diff --git a/cve-fix/fixed_commit_link b/cve-fix/fixed_commit_link new file mode 100644 index 0000000000000000000000000000000000000000..0817870614f001eeb1913218874c165115f42123 Binary files /dev/null and b/cve-fix/fixed_commit_link differ diff --git a/cve-fix/init.py b/cve-fix/init.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5f6469cb2eabee04d7c766b4585e2d403bff62 --- /dev/null +++ b/cve-fix/init.py @@ -0,0 +1,45 @@ +from vulnerability.database import CVEDatabase +from vulnerability.tracker import CVETracker + +import os +import stat +import time +import subprocess + +start_time = time.time() + +cve_db = CVEDatabase('fixed_commit_link') + +cve_db.add_os_version("openEuler-1.0-LTS", version_type="stable", parent="4.19") +cve_db.add_os_version("OLK-5.10", version_type="stable", parent="5.10") +cve_db.add_os_version("OLK-6.6", version_type="stable", parent="6.6") + +cve_db.add_os_version("openEuler-20.03-LTS-SP4", version_type="openEuler", parent="openEuler-1.0-LTS") +cve_db.add_os_version("openEuler-22.03-LTS-SP3", version_type="openEuler", parent="OLK-5.10") +cve_db.add_os_version("openEuler-22.03-LTS-SP4", version_type="openEuler", parent="OLK-5.10") +cve_db.add_os_version("openEuler-24.03-LTS", version_type="openEuler", parent="OLK-6.6") +cve_db.add_os_version("openEuler-24.03-LTS-Next", version_type="openEuler", parent="OLK-6.6") +cve_db.add_os_version("openEuler-24.03-LTS-SP1", version_type="openEuler", parent="OLK-6.6") + +script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vulnerability/scripts") +shell_script_path = os.path.join(script_dir, "init.sh") +st = os.stat(shell_script_path) +if not st.st_mode & stat.S_IEXEC: +# 如果脚本不可执行,则添加执行权限 + os.chmod(shell_script_path, st.st_mode | stat.S_IEXEC) +subprocess.run( + [shell_script_path], + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True + ) + +tracker = CVETracker(cve_db) +tracker.check_immediately() + +cve_db.close() + +end_time = time.time() +estime = end_time - start_time +print("time: ", estime) \ No newline at end of file diff --git a/cve-fix/requirements.txt b/cve-fix/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1040fa6a1d652f23f4d136cffa3acd6dff449de1 --- /dev/null +++ b/cve-fix/requirements.txt @@ -0,0 +1,5 @@ +tqdm +python-dotenv +requests +bs4 +curl_cffi \ No newline at end of file diff --git a/cve-fix/vulnerability/__init__.py b/cve-fix/vulnerability/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cve-fix/vulnerability/analysis.py b/cve-fix/vulnerability/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea1ac2fe280933b12da6c177a8c8a8b94deaa6e --- /dev/null +++ b/cve-fix/vulnerability/analysis.py @@ -0,0 +1,167 @@ +import re +import curl_cffi +from bs4 import BeautifulSoup + +class CommitQuery: + """ + 从 https://lore.kernel.org/linux-cve-announce 中进行commit查询 + """ + def __init__(self): + self.base_url = "https://lore.kernel.org/linux-cve-announce/" + self.headers = { + "X-Requested-With": "XMLHttpRequest" + } + + def get_cve_url(self, cve: str): + # 查询参数 + params = {"q": cve} + # 发送GET请求 + response = curl_cffi.get( + self.base_url, + params=params, + headers=self.headers, + verify=False + ) + # 检查请求是否成功 + if response.status_code != 200: + print(f"Failed to retrieve data. Status code: {response.status_code}") + return None + # 解析HTML内容 + soup = BeautifulSoup(response.text, 'html.parser') + # 查找包含CVE编号的链接 + cve_link = soup.find('a', href=True, string=lambda text: text and cve in text) + if not cve_link: + print("CVE link not found in the response.") + return None + if "REJECTED" in cve_link.get_text().upper(): + return None + # 获取CVE详情页面的相对URL + relative_url = cve_link['href'] + # 将相对URL转换为完整URL + cve_detail_url = self.base_url + relative_url + return cve_detail_url + + def get_all_cve(self): + url = self.base_url + cve_list = [] + # 定义CVE编号匹配模式(包含锚定边界) + cve_pattern = re.compile(r'\b(CVE-\d{4}-\d+)\b') + + while url: + response = curl_cffi.get(url, headers=self.headers, verify=False) + soup = BeautifulSoup(response.text, 'html.parser') + + for link in soup.find_all('a', href=True): + if "REJECTED" in link.get_text().upper(): # 跳过包含REJECTED标识的条目 + continue + link_text = link.text.strip() + # 提取文本中的CVE编号 + cve_match = cve_pattern.search(link_text) + if cve_match: + cve_number = cve_match.group(1).upper() # 统一转为大写 + cve_list.append(cve_number) + # 分页处理 + next_link = soup.find('a', string='next (older)') + url = self.base_url + next_link['href'] if next_link else None + # 去重并保持顺序(Python 3.7+) + return list(dict.fromkeys(cve_list)) + + def get_all_cve_url(self): + url = self.base_url + cve_url_list = [] + while url: + response = curl_cffi.get(url, headers=self.headers, verify=False) + soup = BeautifulSoup(response.text, 'html.parser') + + # Find all CVE links + for link in soup.find_all('a', href=True): + if 'CVE' in link.text: + relative_url = link['href'] + cve_url = self.base_url + relative_url + cve_url_list.append(cve_url) + # You can add code here to visit the CVE link and process its content + next_link = soup.find('a', string='next (older)') + if next_link: + url = self.base_url + next_link['href'] + else: + url = None + return cve_url_list + + def _get_commit_msg(self, url: str): + # 发送请求获取CVE详情页面内容 + detail_response = curl_cffi.get(url, headers=self.headers, verify=False) + if detail_response.status_code != 200: + print(f"Failed to retrieve CVE detail page. Status code: {detail_response.status_code}") + return None + # 解析CVE详情页面内容 + detail_soup = BeautifulSoup(detail_response.text, 'html.parser') + return detail_soup.get_text() + + def _get_commit_info_from_msg(self, msg: str) -> list[dict]: + if not msg: + return None + commits_info = [] + lines = msg.split('\n') + for line in lines: + introduced_version = None + introduced_commit = None + fixed_version = None + fixed_commit = None + + if "introduced in" not in line and "fixed in" not in line and "Fixed in" not in line: + continue + line = line.strip() + pattern = re.compile(r"Issue introduced in ([a-zA-Z0-9-.]+) with commit ([a-f0-9]+) and fixed in ([a-zA-Z0-9-.]+) with commit ([a-f0-9]+)") + matches = pattern.findall(line) + if len(matches) != 0: + introduced_version, introduced_commit, fixed_version, fixed_commit = matches[0] + else: + pattern = re.compile(r"Issue introduced in ([a-zA-Z0-9-.]+) with commit ([a-f0-9]+)") + matches = pattern.findall(line) + if len(matches) != 0: + introduced_version, introduced_commit = matches[0] + else: + pattern = re.compile(r"Fixed in ([a-zA-Z0-9-.]+) with commit ([a-f0-9]+)") + matches = pattern.findall(line) + if len(matches) != 0: + fixed_version, fixed_commit = matches[0] + else: + print(f'The line: {line} cannot be resolved') + + commit_info = { + "introduced_version": introduced_version, + "introduced_commit": introduced_commit, + "fixed_version": fixed_version, + "fixed_commit": fixed_commit + } + commits_info.append(commit_info) + return commits_info + + def get_commit_infos(self, cve: str) -> list[dict]: + url = self.get_cve_url(cve) + if url: + msg = self._get_commit_msg(url) + if msg: + commit_infos = self._get_commit_info_from_msg(msg) + return commit_infos + return [] + + def get_introduced_commit(self, commits_info: list[dict]) -> dict: + introduced_commit_info = {} + for commit_info in commits_info: + version = commit_info.get('introduced_version') + commit = commit_info.get('introduced_commit') + if version is not None and commit is not None: + introduced_commit_info[version] = commit + return introduced_commit_info + + def get_fixed_commit(self, commits_info: list[dict]) -> dict: + fixed_commit_info = {} + for commit_info in commits_info: + version = commit_info.get('fixed_version') + commit = commit_info.get('fixed_commit') + if version is not None and commit is not None: + fixed_commit_info[version] = commit + return fixed_commit_info + + diff --git a/cve-fix/vulnerability/call.py b/cve-fix/vulnerability/call.py new file mode 100644 index 0000000000000000000000000000000000000000..71a129879173fb532cf710dd557bc02c08be1f26 --- /dev/null +++ b/cve-fix/vulnerability/call.py @@ -0,0 +1,104 @@ +import requests +from .tracker import CVETracker +import time + +def make_body(result: dict, tracker: CVETracker) -> str: + """ + 将result转化为str + """ + new_result = {} + for key, value in result.items(): + db = tracker.cve_data_db + parent = None + if db.check_os_version_is_openeuler(key): + parent = db.get_os_parents(db.get_os_parents(key)) + else: + parent = db.get_os_parents(key) + full_key = f"{key} ({parent})" + new_result[full_key] = result[key] + return '\n'.join(f"{key}: {value}" for key, value in new_result.items()) + +def create_comment(parameter: dict, body: str): + """ + 在指定ISSUE下创建评论 + body: 评论内容 + """ + print("创建评论, 参数: ", parameter) + access_token = parameter["access_token"] + full_name = parameter["full_name"] + number = parameter["number"] + # 构造 API 地址 + url = f"https://gitee.com/api/v5/repos/{full_name}/issues/{number}/comments" + + # 请求头 + headers = { + "Content-Type": "application/json;charset=UTF-8" + } + + # 请求参数(包含认证信息) + payload = { + "access_token": access_token, # Gitee 的认证方式 + "body": body + } + + try: + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() # 检查 HTTP 错误 + + print("评论创建成功!响应数据:") + print(response.json()) + return response.json() + except requests.exceptions.RequestException as e: + print(f"请求失败:{str(e)}") + if response: + print(f"错误响应内容:{response.text}") + return None + +def create_pr(parameter: dict): + """ + 创建pr + """ + access_token = parameter["access_token"] + full_name = parameter["full_name"] + title = parameter["title"] + head = parameter["head"] + base = parameter["base"] + body = parameter["body"] + # 构造 API 地址 + url = f"https://gitee.com/api/v5/repos/{full_name}/pulls" + + # 请求头 + headers = { + "Content-Type": "application/json;charset=UTF-8" + } + + # 请求参数(包含认证信息) + payload = { + "access_token": access_token, + "title": title, + "head": head, + "base": base, + "prune_source_branch": "true" + } + + print("head type: ", type(head)) + print("head: ", head) + + print("创建pr:\njson: ", payload) + print("headers: ", headers) + print("url: ", url) + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() # 检查 HTTP 错误 + + print("pr创建成功!响应数据:") + print(response.json()) + return response + except requests.exceptions.RequestException as e: + print(f"Response Code: {response.status_code}") + print("Response Headers:", response.headers) + print("Response Body:", response.json()) + if (response.status_code == 400): + print("Error Message: ", response.json()["message"]) + return response \ No newline at end of file diff --git a/cve-fix/vulnerability/database.py b/cve-fix/vulnerability/database.py new file mode 100644 index 0000000000000000000000000000000000000000..64e671cde6436f08d0d4c29636f966becc993231 --- /dev/null +++ b/cve-fix/vulnerability/database.py @@ -0,0 +1,444 @@ +import sqlite3 +import re +from typing import Optional + +class CVEDatabase: + def __init__(self, db_name='cve.db'): + self.conn = sqlite3.connect(db_name) + self._create_tables() + + def _create_tables(self): + """创建基础表结构""" + self.conn.execute('''CREATE TABLE IF NOT EXISTS CVE ( + cve_id INTEGER PRIMARY KEY, + name TEXT UNIQUE, + intro_mainline_id TEXT, + fixed_mainline_id TEXT + )''') + + self.conn.execute('''CREATE TABLE IF NOT EXISTS OS_Version ( + os_id INTEGER PRIMARY KEY, + version_name TEXT UNIQUE, + version_type TEXT CHECK(version_type IN ('stable', 'openEuler')), + parent TEXT, + CHECK ( + (version_type = 'openEuler' AND parent IS NOT NULL) OR + (version_type = 'stable' AND parent IS NOT NULL) + ) + )''') + + # 创建触发器确保 openEuler 版本关联的父版本存在且为 stable 类型 + self.conn.execute('''CREATE TRIGGER IF NOT EXISTS validate_openeuler_parent_insert + BEFORE INSERT ON OS_Version + FOR EACH ROW + WHEN NEW.version_type = 'openEuler' + BEGIN + SELECT RAISE(ABORT, 'Parent must be an existing stable version') + WHERE NOT EXISTS ( + SELECT 1 FROM OS_Version + WHERE version_name = NEW.parent + AND version_type = 'stable' + ); + END;''') + + self.conn.execute('''CREATE TRIGGER IF NOT EXISTS validate_openeuler_parent_update + BEFORE UPDATE ON OS_Version + FOR EACH ROW + WHEN NEW.version_type = 'openEuler' + BEGIN + SELECT RAISE(ABORT, 'Parent must be an existing stable version') + WHERE NOT EXISTS ( + SELECT 1 FROM OS_Version + WHERE version_name = NEW.parent + AND version_type = 'stable' + ); + END;''') + + self.conn.execute('''CREATE TABLE IF NOT EXISTS CVE_OS_Link ( + cve_id INTEGER, + os_id INTEGER, + commit_id TEXT, + PRIMARY KEY (cve_id, os_id), + FOREIGN KEY (cve_id) REFERENCES CVE(cve_id), + FOREIGN KEY (os_id) REFERENCES OS_Version(os_id) + )''') + self.conn.commit() + + def check_os_version_is_openeuler(self, version_name): + """ + 检查指定的操作系统版本是否为openEuler类型 + :param version_name: 要检查的操作系统版本名称 + :return: + - True: 是openEuler版本 + - False: 不是openEuler版本或版本不存在 + """ + try: + cursor = self.conn.execute( + 'SELECT version_type FROM OS_Version WHERE version_name = ?', + (version_name,) + ) + result = cursor.fetchone() + # 如果记录存在且类型匹配 + return result is not None and result[0] == "openEuler" + except sqlite3.DatabaseError as e: + print(f"数据库查询错误: {e}") + return False + + def get_os_versions(self, type: str = "all") -> list: + """ + 获取操作系统版本(严格类型限制) + :param type: 过滤类型 + - "all": 返回所有记录(默认) + - "openEuler": 仅返回openEuler类型 + - "stable": 仅返回stable类型 + :return: 包含完整信息的字典列表 + :raises ValueError: 当传入无效type参数时 + """ + # 严格参数验证 + if type not in ("all", "openEuler", "stable"): + raise ValueError(f"无效类型参数: {type},只允许 'all'/'openEuler'/'stable'") + + try: + query = 'SELECT version_name FROM OS_Version' + params = () + if type != "all": + query += " WHERE version_type = ?" + params = (type,) + query += " ORDER BY version_name" + # 执行查询 + cursor = self.conn.execute(query, params) + return [row[0] for row in cursor.fetchall()] + + except sqlite3.DatabaseError as e: + print(f"数据库查询失败: {e}") + return [] + + def get_os_parents(self, os_version: str) -> Optional[str]: + """直接返回 parent 字段的值""" + try: + cursor = self.conn.cursor() + cursor.execute('SELECT parent FROM OS_Version WHERE version_name = ?', (os_version,)) + result = cursor.fetchone() + return result[0] if result else None + except sqlite3.Error as e: + print(f"Database error: {str(e)}") + return None + + def get_all_cves(self) -> list[dict]: + """ + 获取数据库中所有CVE记录 + :return: 包含CVE详细信息的字典列, 结构为: + [{ + "cve_id": int, + "name": str, + "intro_mainline_id": str, + "fixed_mainline_id": str + }, ...] + """ + try: + cursor = self.conn.execute( + '''SELECT cve_id, name, intro_mainline_id, fixed_mainline_id FROM CVE ORDER BY cve_id''' + ) + return [{ + "cve_id": row[0], + "name": row[1], + "intro_mainline_id": row[2], + "fixed_mainline_id": row[3] + } + for row in cursor.fetchall() + ] + except sqlite3.DatabaseError as e: + print(f"获取CVE列表失败: {e}") + return [] + + def check_fixed_commit_id(self, cve_name: str, os_version: str) -> Optional[str]: + """ + 获取指定CVE在特定操作系统版本下的修复commit标识 + :param cve_name: CVE名称 + :param os_version: 操作系统版本名称 + :return: + - str: 找到的commit链接标识 + - None: 未找到记录或发生错误 + """ + try: + # 原子性获取所有必要ID + cursor = self.conn.execute(''' + SELECT l.commit_id + FROM CVE_OS_Link l + JOIN CVE c ON l.cve_id = c.cve_id + JOIN OS_Version o ON l.os_id = o.os_id + WHERE c.name = ? AND o.version_name = ? + ''', (cve_name, os_version)) + + result = cursor.fetchone() + return result[0] if result else None + + except sqlite3.DatabaseError as e: + print(f"数据库查询失败: {e}") + return None + + def check_intro_commit_id_mainline(self, cve_name: str) -> Optional[str]: + """ + 获取指定CVE的引入commit标识 + :param cve_name: CVE名称 + :return: + - str: 引入commit标识 + - None: CVE不存在或查询失败 + """ + try: + cursor = self.conn.execute("SELECT intro_mainline_id FROM CVE WHERE name = ?", (cve_name,)) + result = cursor.fetchone() + return result[0] if result else None + except sqlite3.DatabaseError as e: + print(f"查询引入commit失败: {e}") + return None + + def check_fixed_commit_id_mainline(self, cve_name: str) -> Optional[str]: + """ + 获取指定CVE在mainline分支的修复commit标识 (mainline) + :param cve_name: CVE名称 + :return: + - str: 有效修复commit标识 + - None: 无修复记录/CVE不存在/数据异常 + """ + try: + cursor = self.conn.execute("SELECT fixed_mainline_id FROM CVE WHERE name = ?", (cve_name,)) + result = cursor.fetchone() + return result[0] if result else None + except sqlite3.DatabaseError as e: + print(f"数据库查询失败: {e}") + return None + + def check_os_version_is_stable_version(self, version_name: str) -> Optional[str]: + """ + 检查并返回匹配的stable基准版本名称 + :param version_name: 要检查的版本名称 + :return: + - str: 匹配的stable基准版本名称 + - None: 无匹配版本或发生错误 + """ + def parse_version(ver: str) -> list[int]: + """解析版本字符串为数字序列(忽略非数字后缀)""" + parts = [] + for part in ver.split('.'): + match = re.match(r'^(\d+)', part) + if match: + parts.append(int(match.group(1))) + else: + break + return parts + + try: + # 获取所有stable基准版本并解析 + cursor = self.conn.execute( + "SELECT version_name FROM OS_Version " + "WHERE version_type = 'stable'" + ) + + # 存储格式:{"name": "OLK-5.10", "parts": [5,10]} + base_versions = [] + for (db_version,) in cursor.fetchall(): + version_part = self.get_os_parents(db_version) + parts = parse_version(version_part) + if parts: + base_versions.append({ + "name": db_version, + "parts": parts + }) + + # 按版本段长度降序排序(优先匹配更长更精确的版本) + base_versions.sort(key=lambda x: len(x["parts"]), reverse=True) + + # 解析输入版本 + input_parts = parse_version(version_name) + if not input_parts: + return None + + # 寻找最长匹配 + for base in base_versions: + base_parts = base["parts"] + # 输入版本需至少包含基准版本的所有段 + if len(input_parts) < len(base_parts): + continue + # 严格前缀匹配 + if input_parts[:len(base_parts)] == base_parts: + return base["name"] + + return None + + except sqlite3.DatabaseError as e: + print(f"数据库查询失败: {e}") + return None + except Exception as e: + print(f"版本解析异常: {e}") + return None + + def add_os_version(self, version_name: str, version_type: str, parent: str): + """ + 添加新的操作系统版本 + :param version_name: 版本名称 (如 "openEuler-22.03-LTS") + :param version_type: 系统类型 (stable/openEuler) + :param parent: 父版本名称(openEuler类型必须为存在的stable版本名称) + """ + cursor = self.conn.cursor() + try: + # 参数基础验证 + if version_type not in ("stable", "openEuler"): + raise ValueError("version_type 必须为 stable 或 openEuler") + if not parent: + raise ValueError("parent 参数不能为空") + + # 直接插入 parent 字符串(触发器会验证 openEuler 类型的父版本有效性) + cursor.execute(''' + INSERT INTO OS_Version + (version_name, version_type, parent) + VALUES (?, ?, ?) + ''', (version_name, version_type, parent)) + + self.conn.commit() + print(f"成功添加 {version_type} 版本: {version_name}") + + except sqlite3.IntegrityError as e: + self.conn.rollback() + if "UNIQUE constraint failed" in str(e): + print(f"版本名称 '{version_name}' 已存在") + elif "CHECK constraint failed" in str(e): + print(f"父版本参数不符合约束条件") + else: + print(f"数据完整性错误: {e}") + except Exception as e: + self.conn.rollback() + raise RuntimeError(f"添加操作系统版本失败: {str(e)}") from e + + def add_cve(self, cve_name: str, intro_mainline_id: str, fixed_mainline_id: str) -> Optional[int]: + """ + 添加或更新 CVE 到数据库 + 返回新插入或更新的 cve_id, 若失败返回 None + """ + # 参数校验 + if not re.match(r'^CVE-\d{4}-\d+$', cve_name): + raise ValueError(f"Invalid CVE name: {cve_name}. Expected format: CVE-YYYY-NNNN") + try: + cur = self.conn.cursor() + # 使用 UPSERT 语法插入或更新 CVE 记录 + cur.execute(''' + INSERT INTO CVE (name, intro_mainline_id, fixed_mainline_id) + VALUES (?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + intro_mainline_id = excluded.intro_mainline_id, + fixed_mainline_id = excluded.fixed_mainline_id + ''', (cve_name, intro_mainline_id, fixed_mainline_id)) + # 获取受影响的行数,如果是更新操作,则返回现有的 cve_id + self.conn.commit() + if cur.rowcount == 0: + # 新插入的记录 + cve_id = cur.lastrowid + else: + # 更新的记录,获取现有的 cve_id + cur.execute('SELECT cve_id FROM CVE WHERE name = ?', (cve_name,)) + cve_id = cur.fetchone()[0] + return cve_id + except sqlite3.IntegrityError as e: + self.conn.rollback() + # 解析具体违反的约束 + error_msg = str(e) + if "CVE.name" in error_msg: + print(f"CVE {cve_name} 已存在") + elif "CVE.intro_mainline_id" in error_msg: + print(f"主线版本 {intro_mainline_id} 已被其他 CVE 使用") + elif "CVE.fixed_mainline_id" in error_msg: + print(f"主线版本 {fixed_mainline_id} 已被其他 CVE 使用") + else: + raise RuntimeError(f"数据库错误: {error_msg}") from None + except Exception as e: + self.conn.rollback() + raise RuntimeError(f"操作失败: {str(e)}") from None + + def add_cve_commit(self, cve_name: str, os_version: str, commit_id: str) -> bool: + """ + 更新/插入CVE与操作系统版本的关联修复commit + :param cve_name: CVE名称 + :param os_version: 操作系统版本名称 + :param commit_id: 修复commit链接标识 + :return: + - True: 操作成功 + - False: 操作失败 (CVE/版本不存在或数据库错误) + """ + try: + # 获取CVE ID + cve_row = self.conn.execute("SELECT cve_id FROM CVE WHERE name = ?", (cve_name,)).fetchone() + if not cve_row: + print(f"错误: CVE '{cve_name}' 不存在") + return False + cve_id = cve_row[0] + + # 获取OS版本ID + os_row = self.conn.execute("SELECT os_id FROM OS_Version WHERE version_name = ?", (os_version,)).fetchone() + if not os_row: + print(f"错误:操作系统版本 '{os_version}' 不存在") + return False + os_id = os_row[0] + + # 执行UPSERT操作 + self.conn.execute( + """INSERT INTO CVE_OS_Link (cve_id, os_id, commit_id) + VALUES (?, ?, ?) + ON CONFLICT(cve_id, os_id) + DO UPDATE SET commit_id = excluded.commit_id""", + (cve_id, os_id, commit_id) + ) + self.conn.commit() + return True + + except sqlite3.IntegrityError as e: + print(f"数据完整性错误: {e}") + self.conn.rollback() + return False + except sqlite3.DatabaseError as e: + print(f"数据库操作失败: {e}") + self.conn.rollback() + return False + + def update_os_version_name(self, old_version_name: str, new_version_name: str) -> bool: + """ + 更新操作系统版本名称(级联更新 openEuler 子版本的 parent 字段) + """ + try: + with self.conn: + cursor = self.conn.cursor() + + # 检查新名称是否已存在 + cursor.execute("SELECT 1 FROM OS_Version WHERE version_name = ?", (new_version_name,)) + if cursor.fetchone(): + print(f"错误: 版本名称 '{new_version_name}' 已存在") + return False + + # 获取旧版本信息 + cursor.execute("SELECT version_type, parent FROM OS_Version WHERE version_name = ?", (old_version_name,)) + record = cursor.fetchone() + if not record: + print(f"错误: 找不到版本 '{old_version_name}'") + return False + old_version_type, old_parent = record + + # 更新自身版本名称 + cursor.execute( + "UPDATE OS_Version SET version_name = ? WHERE version_name = ?", + (new_version_name, old_version_name) + ) + + # 如果是 stable 版本,级联更新所有 openEuler 子版本的 parent 字段 + if old_version_type == "stable": + cursor.execute( + "UPDATE OS_Version SET parent = ? WHERE parent = ? AND version_type = 'openEuler'", + (new_version_name, old_version_name) + ) + + return True + + except sqlite3.Error as e: + print(f"数据库错误: {str(e)}") + return False + + def close(self): + self.conn.close() diff --git a/cve-fix/vulnerability/scripts/check_commit.sh b/cve-fix/vulnerability/scripts/check_commit.sh new file mode 100755 index 0000000000000000000000000000000000000000..78ef6d2061fe614cef6db79c6df5cdce6ece1dab --- /dev/null +++ b/cve-fix/vulnerability/scripts/check_commit.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# 该脚本提供了从仓库中查询是否引入了某个commit的功能 + +# 0:表示 commit 已合入分支 +# 1:表示 commit 未合入分支 +# 2:表示无法切换到指定分支 +# 3:表示无法克隆或访问仓库 +# 4:表示命令错误 +# 5: 表示其他错误 + +# 设置退出时显示错误信息 +set -euo pipefail + +# 检查参数 +if [ $# -ne 2 ]; then + exit 4 +fi + +branch_name="$1" +commit_id="$2" +repo_path="$(pwd)/stash/kernel" # 持久化仓库路径 +remote_url="https://gitee.com/openeuler/kernel.git" + +mkdir -p $(pwd)/stash + +# 初始化仓库 +init_repo() { + echo "初始化仓库" "$repo_path" + if [ ! -d "$repo_path/.git" ]; then + git clone "$remote_url" "$repo_path" + local clone_status=$? + if [ $clone_status -ne 0 ]; then + return 3 # 无法克隆仓库 + fi + fi + return 0 +} + +# 切换到指定分支 +switch_branch() { + echo "切换到分支$branch_name" + # 检查分支是否存在 + if ! git -C "$repo_path" show-ref --verify --quiet "refs/heads/$branch_name"; then + if ! git -C "$repo_path" fetch origin "$branch_name:$branch_name"; then + return 2 # 无法拉取远程分支 + fi + fi + + # 切换分支 + if ! git -C "$repo_path" checkout "$branch_name"; then + return 2 # 无法切换到分支 + fi + + # 尝试普通拉取 + if ! git -C "$repo_path" pull --ff-only origin "$branch_name"; then + echo "警告:无法快进拉取,尝试强制同步远程分支..." >&2 + # 强制重置本地分支到远程分支 + git -C "$repo_path" fetch origin "$branch_name" || return 3 + git -C "$repo_path" reset --hard "origin/$branch_name" || return 3 + fi + return 0 +} + +# 查询是否存在commit_id +check_commit() { + echo "在 $branch_name 中查询commit $commit_id " + # 步骤1:使用 git show 检查 commit 是否存在于当前分支 + if git -C "$repo_path" show "$commit_id" >/dev/null 2>&1; then + echo "commit $commit_id 已经合入分支 $branch_name" + return 0 # Commit 已合入 + fi + + # 步骤2:如果 git show 失败,使用 git log --grep 进行进一步检查 + if git -C "$repo_path" log --max-count=1 --grep="$commit_id" "$branch_name" | grep -q "$commit_id"; then + echo "commit $commit_id 已经合入分支 $branch_name" + return 0 # Commit 已合入 + else + echo "commit $commit_id 没有合入分支 $branch_name" + return 1 # Commit 未合入 + fi +} + +# 主流程 +init_repo || exit 3 +switch_branch || exit 2 +check_commit +exit $? diff --git a/cve-fix/vulnerability/scripts/check_mainline_commit.sh b/cve-fix/vulnerability/scripts/check_mainline_commit.sh new file mode 100755 index 0000000000000000000000000000000000000000..9c32de9ad12d4d5fc8bc27a7e5af5163f2395920 --- /dev/null +++ b/cve-fix/vulnerability/scripts/check_mainline_commit.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# 检查输入参数 +if [ $# -ne 1 ]; then + echo "Usage: $0 " >&2 + exit 1 +fi + +commit_id="$1" +repo_path="$(pwd)/stash/linux" # 持久化仓库路径 +remote_url="https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git" +max_retries=1 # 最大重试次数 + +mkdir -p $(pwd)/stash + +# 初始化仓库函数 +init_repo() { + # 首次克隆 + if [ ! -d "$repo_path/.git" ]; then + echo "克隆仓库 git clone $remote_url $repo_path" + git clone "$remote_url" "$repo_path" || return 1 + fi + + # 强制切换到master分支 + echo "切换到master分支" + git -C "$repo_path" checkout master --force || return 1 +} + +# 更新仓库函数 +update_repo() { + echo "强制拉取仓库" + git -C "$repo_path" pull --ff-only origin master || return 1 +} + +# 主验证函数 +check_commit() { + # 首次验证 + echo "验证commit" "$commit_id" "是否合法" + if git -C "$repo_path" rev-parse --quiet --verify "$commit_id^{commit}"; then + return 0 + fi + + # 重试逻辑 + for ((i=1; i<=$max_retries; i++)); do + update_repo || return 1 + if git -C "$repo_path" rev-parse --quiet --verify "$commit_id^{commit}"; then + echo "Found commit after update" + return 0 + fi + done + + return 1 +} + +# 执行初始化 +if ! init_repo; then + exit 1 +fi + +# 主检查流程 +if ! check_commit; then + echo "Error: Commit $commit_id not found after retries" + exit 1 +fi + +# 提取mainline信息 +mainline_commit=$( + git -C "$repo_path" show -s --format=%B "$commit_id" | + sed -nE ' + s/.*[Uu]pstream[[:space:]]+commit[[:space:]]+([0-9a-f]{7,}).*/\1/p + s/.*commit[[:space:]]+([0-9a-f]{7,})[[:space:]]+[Uu]pstream.*/\1/p + ' | + grep -m1 -E '^[0-9a-f]{7,}$' +) + +[ -n "$mainline_commit" ] && echo "$mainline_commit" || echo "$commit_id" \ No newline at end of file diff --git a/cve-fix/vulnerability/scripts/init.sh b/cve-fix/vulnerability/scripts/init.sh new file mode 100755 index 0000000000000000000000000000000000000000..582088a442d351605daeb69e83756ebcc3835781 --- /dev/null +++ b/cve-fix/vulnerability/scripts/init.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +kernel_path="$(pwd)/stash/linux" +kernel_url="https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git" + +mkdir -p $(pwd)/stash + +# 初始化内核仓库 +init_repository() { + if [ ! -d "$1" ]; then + echo "Cloning $2 repository to $1..." + git clone "$2" "$1" + else + echo "Repository $1 already exists, skipping clone." + fi +} + +init_repository "$kernel_path" "$kernel_url" diff --git a/cve-fix/vulnerability/scripts/push_patch.sh b/cve-fix/vulnerability/scripts/push_patch.sh new file mode 100755 index 0000000000000000000000000000000000000000..6d4da7f2b93707bbc3a340937ec65ce36286d91e --- /dev/null +++ b/cve-fix/vulnerability/scripts/push_patch.sh @@ -0,0 +1,222 @@ +#!/bin/bash + +# 返回码 +# 0 补丁应用成功 +# 2 补丁存在冲突 + +# 参数校验 +if [ $# -ne 7 ]; then + echo "Usage: $0 " + exit 1 +fi + +commit_id=$1 +origin_branch=$2 +user_name=$3 +user_email=$4 +CVE=$5 +number=$6 +inclusion=$7 + +line=2 +dev_branch="patch_auto_$(date +%s)" + +# 全局变量存储生成的补丁文件名 +generated_patch="" + +linux_path="$(pwd)/stash/linux" # 持久化仓库路径 +linux_url="https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git" + +kernel_path="$(pwd)/stash/kernel" +kernel_url="https://gitee.com/openeuler/kernel.git" + +robot_url="https://ci-robot:03d33936c22f95aa4b47c1cce6caa790@gitee.com/ci-robot/kernel.git" + +mkdir -p $(pwd)/stash + +MSG="commit_msg.txt" +inclusion="mainline" +bugzilla="https://gitee.com/src-openeuler/kernel/issues/$number" +namerev="unknown" + +add_str() +{ + line=$((line + 1)) + if [ -z "$1" ]; then + sed -i "${line}{x;p;x;}" "$MSG" + else + sed -i "${line}i\\$1" "$MSG" + fi +} + +get_namerev() { + namerev=$(git name-rev "$commit_id" 2>/dev/null) + + if ! echo "$namerev" | grep -q "tags/v"; then + echo "没有找到标签" + else + namerev=${namerev#*/} + namerev=${namerev%%~*} + fi +} + +add_mainline_head() { + local ref # 声明局部变量[7](@ref) + + if [ "$inclusion" = "stable" ]; then + add_str "stable inclusion" + add_str "from stable-$namerev" + ref="Reference: https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id=$commit_id" + else + add_str "mainline inclusion" + add_str "from mainline-$namerev" + ref="Reference: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=$commit_id" + fi + + add_str "commit $commit_id" + add_str "category: bugfix" + add_str "bugzilla: ${bugzilla}" + add_str "CVE: ${CVE}" + add_str "" + add_str "$ref" + add_str "-------------------------------------------------" + add_str "" +} + +# 初始化内核仓库 +init_repository() { + local repo_dir="$1" + local repo_url="$2" + + echo "$repo_dir $repo_url" + if [ ! -d "$repo_dir/.git" ]; then + echo "正在克隆仓库……" + git clone "$repo_url" "$repo_dir" + else + echo "仓库 $repo_dir 已存在,跳过克隆。" + fi + git -C "$repo_dir" checkout master +} + +# 生成内核patch +generate_patch() { + cd "$linux_path" || { echo "无法进入 $linux_path"; exit 1; } + + echo "更新 Linux 仓库代码……" + git pull + + echo "获取namerev" + get_namerev + echo "namerev: $namerev" + + # 清理旧补丁 + echo "清理旧补丁文件……" + find . -maxdepth 1 -name "*.patch" -delete + + # 生成新补丁 + echo "为提交 $commit_id 生成补丁..." + # 获取生成的补丁文件名 + generated_patch=$(git format-patch -1 "$commit_id" 2>/dev/null) + if [ -z "$generated_patch" ]; then + echo "补丁生成失败" + exit 1 + fi + echo "补丁生成成功:$generated_patch" + + # 复制到内核仓库 + echo "将补丁 ${generated_patch##*/} 复制到 Kernel 仓库..." + cp "$generated_patch" "$kernel_path/${generated_patch##*/}" + generated_patch="${generated_patch##*/}" # 仅保留文件名 + +} + +# 应用patch并创建PR分支 +apply_patch() { + local patch_name="$1" + cd "$kernel_path" || { echo "无法进入目录 $kernel_path"; exit 1; } + echo "更新 Kernel 仓库代码……" + git fetch origin + echo "切换到分支 $origin_branch 并重置到远程最新状态……" + git checkout "$origin_branch" + git reset --hard origin/"$origin_branch" + + echo "开始尝试应用补丁文件:$patch" + if ! git am "$patch_name" > /dev/null 2>&1; then + echo "补丁应用失败,存在冲突" + git am --abort + exit 2 + else + echo "补丁应用成功" + git log -n 1 --pretty=%B > $MSG + echo "add mainline head" + add_mainline_head + sed -i '$ { /^$/d }' $MSG + echo "添加签名信息:Signed-off-by: $user_name <$user_email>" + echo "Signed-off-by: $user_name <$user_email>" >> $MSG # 追加签名信息 + git commit --amend -F $MSG + fi + echo "执行 git reset --hard HEAD 清理工作区……" + git reset --hard HEAD + git remote add robot "$robot_url" + + # 检查原始目标分支是否存在 + if git ls-remote --exit-code --heads robot "$dev_branch" > /dev/null; then + echo "目标分支 $dev_branch 已存在,自动添加后缀" + dev_branch="${dev_branch}_tmp" + echo "新的目标分支为 $dev_branch" + else + echo "目标分支 $dev_branch 不存在,将直接推送" + fi + + # 推送代码到目标分支(自动创建不存在的分支) + echo "推送代码到分支 $dev_branch..." + git push robot "$origin_branch:$dev_branch" || { + echo "推送失败,请检查权限或分支冲突" + exit 1 + } + + # 新增基于时间的检查逻辑 + timeout=$((60*3)) # 最大等待时间3分钟(180秒) + start_time=$(date +%s) # 记录开始时间戳 + branch_exists=false + + echo "开始等待分支创建(最长3分钟)..." + while true; do + # 检查分支是否存在 + if git ls-remote --exit-code --heads robot "$dev_branch" >/dev/null; then + branch_exists=true + break + fi + + # 计算已用时间 + current_time=$(date +%s) + elapsed=$((current_time - start_time)) + + # 超时检查 + if [ $elapsed -ge $timeout ]; then + echo "等待超时(已等待 ${elapsed} 秒),分支仍未创建" + break + fi + + # 进度提示(动态显示剩余时间) + remaining=$((timeout - elapsed)) + echo "等待分支生效(剩余时间:${remaining}秒)..." + sleep 5 # 保持5秒检查间隔 + done + + if [ "$branch_exists" = false ]; then + echo "错误:推送后分支 $dev_branch 仍未在远程仓库存在" + exit 1 + fi + + # 切换回原分支后删除临时开发分支 + git checkout "$origin_branch" 2>/dev/null || git checkout master + echo "代码已成功推送到分支" + echo $dev_branch +} + +# 主执行流程 +init_repository "$linux_path" "$linux_url" +init_repository "$kernel_path" "$kernel_url" +generate_patch +apply_patch "$generated_patch" diff --git a/cve-fix/vulnerability/tracker.py b/cve-fix/vulnerability/tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..ef0a4cc83090f1056cfa547e233669f07ad51f61 --- /dev/null +++ b/cve-fix/vulnerability/tracker.py @@ -0,0 +1,305 @@ +from .analysis import CommitQuery +from .database import CVEDatabase +from subprocess import Popen, PIPE, STDOUT +from tqdm import tqdm +import os +import stat +import re +import requests + +class CVETracker: + def __init__(self, cve_data_db: CVEDatabase): + # 初始化组件 + self.commit_query = CommitQuery() + self.cve_data_db = cve_data_db + self.commit_url = "https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/commit/?id=" + + def _check_mainline_commit(self, commit_id: str): + try: + # 定义shell脚本的路径 + script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts") + shell_script_path = os.path.join(script_dir, "check_mainline_commit.sh") + + st = os.stat(shell_script_path) + if not st.st_mode & stat.S_IEXEC: + # 如果脚本不可执行,则添加执行权限 + os.chmod(shell_script_path, st.st_mode | stat.S_IEXEC) + + # 执行shell脚本,并传递commit_id作为参数 + output = [] + command = f"bash {shell_script_path} {commit_id}" + print(f"执行脚本:{command}") + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + out = line.decode().strip() + print(out) + output.append(out) + # 捕获并处理标准输出和标准错误 + result = output[-1] + return result + except Exception as e: + print(f"An error occurred while executing the shell script: {e}") + return str(e) + + def _check_mainline_commit_url(self, commit_id: str): + try: + url = self.commit_url + commit_id + # 发送 HTTP 请求(添加浏览器头避免被拦截) + headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/115.0'} + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + + # 匹配图片中的两种 commit 格式 + patterns = [ + # 匹配格式:commit e563b01208f4d1f609bcb23... upstream + r'commit\s+([0-9a-f]{12,40})\b.*?\bupstream', + # 匹配格式:Upstream commit 7828e9363ac4d23b02419fbf + r'[uU]pstream\s+commit\s+([0-9a-f]{12,40})\b' + ] + + # 在页面内容中搜索匹配项 + content = response.text + mainline_commit = None + for pattern in patterns: + match = re.search(pattern, content, re.DOTALL) + if match: + mainline_commit = match.group(1) + # 取第一个有效匹配后退出循环 + if len(mainline_commit) >= 12: + break + + # 返回结果逻辑 + if mainline_commit: + # 确保返回完整 40 位或至少 12 位 hash + return mainline_commit[:40] if len(mainline_commit) > 40 else mainline_commit + else: + return None + + except requests.exceptions.RequestException as e: + return f"HTTP error: {str(e)}" + except Exception as e: + return f"Error: {str(e)}" + + def _check_commit(self, os_version: str, commit_id: str): + try: + # 定义shell脚本的路径 + script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts") + shell_script_path = os.path.join(script_dir, "check_commit.sh") + + st = os.stat(shell_script_path) + if not st.st_mode & stat.S_IEXEC: + # 如果脚本不可执行,则添加执行权限 + os.chmod(shell_script_path, st.st_mode | stat.S_IEXEC) + + # 执行shell脚本,并传递commit_id作为参数 + command = f"bash {shell_script_path} {os_version} {commit_id}" + print(f"执行脚本:{command}") + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + print(line.decode().strip()) + return_code = process.wait() + print("return_code: ", return_code) + return return_code + except Exception as e: + print(f"An error occurred while executing the shell script: {e}") + return None, str(e), -1 + + def update_cve_commit(self, cve: str): + """检查单个CVE在每一个操作系统版本中是否引入, 是否修复""" + # 获取commit信息 + commits_info = self.commit_query.get_commit_infos(cve) + if len(commits_info) == 0: + print(f'Warning! Current CVE: {cve} has no commit info...') + return False + + # 引入commit查询,只查询mainline版本 + intro_commits_info: dict = self.commit_query.get_introduced_commit(commits_info) + intro_commit_mainline = None + if intro_commits_info: + intro_commit_0 = list(intro_commits_info.values())[0] + intro_commit_mainline = self._check_mainline_commit(intro_commit_0) + if not intro_commit_mainline: + intro_commit_mainline = intro_commit_0 + print(f"引入commit: {intro_commit_mainline}") + + # 修补commit查询,这里查询mainline版本 + fixed_commits_info: dict = self.commit_query.get_fixed_commit(commits_info) + fixed_commit_mainline = None + if fixed_commits_info: + fixed_commit_0 = list(fixed_commits_info.values())[0] + fixed_commit_mainline = self._check_mainline_commit(fixed_commit_0) + if not fixed_commit_mainline: + fixed_commit_mainline = fixed_commit_0 + print(f"修复commit: {fixed_commit_mainline}") + + # 创建cve, 同时填入引入commit和修补commit版本 + self.cve_data_db.add_cve(cve, intro_commit_mainline, fixed_commit_mainline) + + # stable版本修补commit查询 + for fixed_version, fixed_commit in fixed_commits_info.items(): + stable_fixed_version = self.cve_data_db.check_os_version_is_stable_version(fixed_version) + if stable_fixed_version: + print(f"修复commit: {fixed_commit} in {stable_fixed_version}") + self.cve_data_db.add_cve_commit(cve, stable_fixed_version, fixed_commit) + + # openEuler版本修补commit查询 + openEuler_os_versions = self.cve_data_db.get_os_versions(type="openEuler") + for os_version in openEuler_os_versions: + parent_os_version = self.cve_data_db.get_os_parents(os_version) + stable_commit = self.cve_data_db.check_fixed_commit_id(cve, parent_os_version) + if stable_commit: + print(f"修复commit: {stable_commit} in {os_version}") + self.cve_data_db.add_cve_commit(cve, os_version, stable_commit) + + def check_immediately(self): + """立刻更新数据库""" + cve_list = self.commit_query.get_all_cve() + for cve in tqdm(cve_list): + try: + self.update_cve_commit(cve) + except Exception as e: + print(f"Error processing {cve}: {str(e)}") + + def query_cve(self, cve_name: str, os_version: str, intro_commit: str, fixed_commit: str) -> tuple[int]: + """ + 查询CVE状态 + + return: CVE的状态, 返回值为tuple, 包含两个int, 分别表示引入和修补的状态 + 引入状态: 0 - 查询不到引入commit 1 - 受影响 2 - 不受影响 3 - 其他错误 + 修复状态: 0 - 查询不到修补commit 1 - 已修复 2 - 未修复 3 - 其他错误 + """ + intro_code = 0 + fixed_code = 0 + + if not intro_commit: + intro_code = 0 + else: + return_code = self._check_commit(os_version, intro_commit) + print("return_code: ", return_code) + if return_code == 0: # commit合入了分支 + intro_code = 1 # 受影响 + elif return_code == 1: # commit没有合入分支中 + intro_code = 2 # 不受影响 + else: + intro_code = 3 + + if not fixed_commit: + fixed_code = 0 + else: + return_code = self._check_commit(os_version, fixed_commit) + print("return_code: ", return_code) + if return_code == 0: # commit合入了分支 + fixed_code = 1 # 已修复 + elif return_code == 1: # commit没有合入分支中 + fixed_code = 2 # 未修复 + else: + fixed_code = 3 + return (intro_code, fixed_code) + + def query_all_os_cve(self, cve_name: str) -> tuple: + code_dict = {} + intro_commit = self.cve_data_db.check_intro_commit_id_mainline(cve_name) + fixed_commit = self.cve_data_db.check_fixed_commit_id_mainline(cve_name) + print(f"引入commit: {intro_commit}, 修复commit: {fixed_commit}") + + for os_version in self.cve_data_db.get_os_versions(type='stable'): + print(f"检查{cve_name}是否被引入了{os_version}中") + code = self.query_cve(cve_name, os_version, intro_commit, fixed_commit) + code_dict[os_version] = code + result = {} + for os_version in self.cve_data_db.get_os_versions(type='openEuler'): + parent_os = self.cve_data_db.get_os_parents(os_version) + result[os_version] = code_dict[parent_os] + return result, intro_commit, fixed_commit + + def fix_cve(self, cve_name: str, os_version: str, user_name: str, user_email: str, number: str): + """ + 修补漏洞 (创建pr) + + return: 修复的状态, 返回值为int + 0: 推送成功,创建pr + 1: 已经修复,无需重新修复 + 2: 没有修复补丁 + 3: 无法修复,存在冲突 + 4: 网络异常,请重新尝试 + """ + fixed_commit = self.cve_data_db.check_fixed_commit_id_mainline(cve_name) + print(f"查询到修复commit: {fixed_commit}") + if not fixed_commit: + return 2, None + else: + return_code = self._check_commit(os_version, fixed_commit) + if return_code == 0: + return 1, None + + # 执行PR创建脚本 + script_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts") + shell_script_path = os.path.join(script_dir, "push_patch.sh") + st = os.stat(shell_script_path) + if not st.st_mode & stat.S_IEXEC: + # 如果脚本不可执行,则添加执行权限 + os.chmod(shell_script_path, st.st_mode | stat.S_IEXEC) + try: + output = [] + command = f"bash {shell_script_path} {fixed_commit} {os_version} {user_name} {user_email} {cve_name} {number} 'mainline'" + print(f"执行脚本:{command}") + process = Popen(command, stdout=PIPE, stderr=STDOUT, shell=True) + with process.stdout: + for line in iter(process.stdout.readline, b''): + out = line.decode().strip() + print(out) + output.append(out) + return_code = process.wait() + branch = output[-1] + if return_code == 0: + return 0, branch + elif return_code == 2: + return 3, None + else: + return 4, None + except FileNotFoundError: + print(f"错误:找不到patch推送脚本 {shell_script_path}") + return 3, None + except Exception as e: + print(f"未知错误:{str(e)}") + return 3, None + + def get_analyze_result(self, code: tuple[int]) -> str: + """ + 将return_code转化为str + code: return code + mode: 仅有analyze create_pr两种模式 + """ + result = "" + if code[0] == 0: + result += "查询不到引入commit," + elif code[0] == 1: + result += "受影响," + elif code[0] == 2: + result += "无影响," + if code[1] == 0: + result += "查询不到修补commit。" + elif code[1] == 1: + result += "已修复。" + elif code[1] == 2: + result += "未修复。" + + if code[0] == 3 or code[1] == 3: + result = "未知错误,请检查日志" + print(result) + return result + + def get_fixed_result(self, code: int) -> str: + if code == 0: + return "pr创建成功," + elif code == 1: + return "已经修复,无需重复修复。" + elif code == 2: + return "没有修复补丁。" + elif code == 3: + return "无法修复,存在冲突。" + elif code == 4: + return "网络异常,请重新尝试。" +