#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import json
import argparse
import traceback

from typing import Any
from xml.etree import ElementTree
from subprocess import getstatusoutput,getoutput

class MetaData(object):
    # json 中的字段
    rules_field = "rules"
    modules_field = "modules"
    subject_field = "subject"
    object_field = "object"
    variables_field = "variables"
    name_field = "name"
    value_field = "value"
    type_field = "type"
    # 变量存储的数据类型
    variable_type = [
        "path",
        "package",
    ]
    # rule规则中忽略的字段值
    exclude_variables_values = [
        "",
        "*",
    ]
    exclude_modules = [
        "devctl",
        "kmod"
    ]
    exclude_variables = [
        "secadm_caps",
        "auditadm_caps"
    ]

    result_file = "/etc/kysec-scene/kysec-scene-origin-resource"

class JsonHandler(MetaData):
    """
    json handler
    """

    @staticmethod
    def read_json(file: str) -> dict:
        """
        获取json文件中的数据
        :param file:
        :return:
        """
        try:
            with open(file, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception as e:
            print(f"json load failed {e}")
            return {}

    @classmethod
    def find_module_field_data(cls, json_data:dict, module_name:str, field_name:str) -> tuple[bool,list[Any]]:
        """
        找指定模块的数据 例如找netctl模块的variable数据
        :param json_data: 完整的数据
        :param module_name: 模块名称
        :param field_name: 字段名称 如rules varialbes
        :return 返回对应的数据
        """
        for m, v in json_data.items():
            if m == module_name:
                return True, v.get(field_name, [])
            
            if not isinstance(v, dict): continue
            data = v.get(cls.modules_field)
            if isinstance(data, dict) and data:
                found, result = cls.find_module_field_data(data, module_name, field_name)
                if found: return found, result

        return False, []

    @classmethod
    def get_variable_content(cls, json_data:dict, module_name:str, variable_name:str) -> tuple[bool, list[str]]:
        """
        :param json_data: json 数据
        :param module_name: 模块功能名称
        :param variable_name: 变量名称
        """
        if not json_data: raise ValueError("invalid json data format")

        data = json_data.get(cls.modules_field)
        found, variable_list = cls.find_module_field_data(data, module_name, cls.variables_field)
        if not found: return False, []

        # 找到的变量list为空 也视为出错的一种
        if not variable_list: return False, []

        for variable in variable_list:
            if variable.get(cls.name_field) == variable_name:
                return True, variable.get(cls.value_field, {}).get(cls.value_field, [])
        return False, []

    @classmethod
    def save_data(cls, json_file, json_data):
        with open(json_file, "w", encoding="utf-8") as f:
            json.dump(json_data, f, indent=4, ensure_ascii=False)

class ResourceParser(MetaData):
    """json数据处理类"""

    # 存储subject变量名称
    subject_variables: list = []
    # 存储object变量名称
    object_variables: list = []

    # 存储subject变量包含的内容
    subject_content: list = []
    # 存储object变量包含的内容
    object_content: list = []

    def __init__(self, *args):
        self.json_files = args

    def parse_rules(self, rules: list):
        """
        解析subject 和object变量名称, 并分别存储到 sub_variables obj_variables
        :param rules: 规则list
        :return:
        """
        for rule in rules:
            for key in [self.subject_field, self.object_field]:
                val = rule.get(key)
                if not val or val in self.exclude_variables_values:
                    continue
                if key == self.subject_field:
                    self.subject_variables.append(val)
                elif key == self.object_field:
                    self.object_variables.append(val)

    def parse_var_content(self, variables: list):
        """
        获取变量中内容，根据变量类型 分别存储到sub_content和obj_content中
        :param variables: 变量list
        :return:
        """
        for val in variables:
            v_name = val.get(self.name_field)
            v_type = val.get(self.value_field).get(self.type_field)
            if not v_name or v_name in self.exclude_variables: continue
            if not v_type or v_type not in self.variable_type: continue

            if v_name in self.subject_variables:
                self.subject_content.extend(val.get(self.value_field, {}).get(self.value_field, []))
            if v_name in self.object_variables:
                self.object_content.extend(val.get(self.value_field, {}).get(self.value_field, []))

    def parse_data(self, json_data: dict):
        """
        解析json
        :param json_data:
        :return:
        """
        for m, v in json_data.items():
            if m in self.exclude_modules: continue

            if v.get(self.modules_field):
                self.parse_data(v.get(self.modules_field))

            # 解析rules
            if v.get(self.rules_field):
                self.parse_rules(v.get(self.rules_field))

            # 获取变量内容
            if v.get(self.variables_field):
                self.parse_var_content(v.get(self.variables_field))

    def parse_json_data(self):
        """
        解析json数据
        :return:
        """
        json_data = {}
        for file in self.json_files:
            if not os.path.exists(file): continue
            tmp_data = JsonHandler.read_json(file)
            if not tmp_data.get(self.modules_field): continue
            json_data.update(tmp_data.get(self.modules_field))

        self.parse_data(json_data)

    def write_result(self):
        """
        保存解析的策略
        :return:
        """
        if not os.path.exists("/etc/kysec-scene"):
            os.makedirs("/etc/kysec-scene", exist_ok=True)
        self.subject_content = list(set(self.subject_content))
        self.object_content = list(set(self.object_content))
        result = """[subject]
{sc}
[object]
{oc}""".format(sc="\n".join(self.subject_content), oc="\n".join(self.object_content))
        with open(self.result_file, "w") as f: f.write(result)

class KysecNetctlPkgNode(object):
    def __init__(self, pkg_name ="", type="", ico="", desktop="", server_type="") -> None:
        """
        :param pkg_name: 软件包名称
        :param type: 权限 1代表放行 2代表禁止 3代表认证(桌面平台)
        :param ico: 图标文件路径
        :param desktop: desktop文件路径
        :param server_type: 软件包类型 1代表系统类
        """
        self.pkg_name = pkg_name
        self.type = type
        self.ico = ico
        self.desktop = desktop
        self.server_type = server_type
        self.node_value = [("type", self.type), ("ico",self.ico), ("desktop", self.desktop), ("server_type", self.server_type)]

    def create(self):
        """
        创建 kysec_node 节点数据
        """
        new_node = ElementTree.Element("kysec_node", attrib={"pkg": self.pkg_name})
        for k,v in self.node_value:
            ElementTree.SubElement(new_node,k).text = v
        return new_node

class XmlHandler(object):
    
    def __init__(self, xml_file) -> None: 
        # xml文件路径
        self.xml_file = xml_file
        # 根节点
        self.root = None
        # xml解析器
        self.doc = None

        self.init_parser()

    def init_parser(self):
        """
        初始化xml解析器
        需要子类 主动调用
        """
        if not os.path.exists(self.xml_file):
            raise FileNotFoundError(f"{self.xml_file}")
        self.doc = ElementTree.parse(self.xml_file)
        self.root = self.doc.getroot()

    def create_node(self):
        """
        需要子类实现
        """
        raise NotImplementedError("crete_node func need child implement")

    def save_data(self):
        """
        保存xml数据
        """
        ElementTree.indent(self.doc, space="    ")
        self.doc.write(self.xml_file, encoding="utf-8", xml_declaration=True)

class NetCtlXmlHandler(XmlHandler):
    """
    xml data handler
    """

    def __init__(self, file) -> None:
        """
        :param file: xml文件
        """
        super(NetCtlXmlHandler, self).__init__(file)
        self.current_node_pkg_list = []
        self.init_variable()

    def get_current_node_pkg_info(self):
        """
        获取xml中已有包数据
        """

        for user in self.root.findall("kysec_user"):
            for node in user.findall("kysec_node"):
                pkg = node.get("pkg")
                if not pkg: continue
                self.current_node_pkg_list.append(pkg)

    def create_node(self, pkg_name):
        """
        xml 中创建kysec_node节点
        :param pkg_name: 软件包名称
        """
        if pkg_name in self.current_node_pkg_list: return
        node = KysecNetctlPkgNode(pkg_name,"1","","","1")
        child = node.create()
        self.kysec_user_nodes.append(child)


    def init_variable(self):
        # 获取属性未0的kysec_user节点
        self.kysec_user_nodes = self.root.find(".//kysec_user[@userID='0']")
        if self.kysec_user_nodes is None: raise ValueError("get kysec user node failed")
        # TODO 这里本来想优化一下的，但是已经确定了数据最多存在1万条左右
        # 所以当前方式无性能方式
        self.get_current_node_pkg_info()

class NetPolicyHandler(object):
    """
    针对网络管控策略等一系列的处理
    """
    netctl_xml = "/etc/kysec/netctl/netctl_pkg.xml"
    json_file = "/etc/kysec-scene/scene/ksc-defender.json"
    module_name = "netctl"
    pkg_var_name = "netctl_package_list"

    def __init__(self) -> None:
        self.system_pkg_list = []
        self.netctl_xml_handler = NetCtlXmlHandler(self.netctl_xml)
        self.json_data = None

    def get_all_system_pkg(self):
        """
        获取系统下所有已安装的deb kaiming kare包
        """
        # deb pkg list
        status, output = getstatusoutput("/usr/bin/awk  '/^Package:/ {print $2}' /var/lib/dpkg/status 2> /dev/null")
        if status !=0:
            raise ValueError("get all system pkg failed")

        self.system_pkg_list = output.splitlines()
        search_pkg_cmd_list = ["/usr/bin/kaiming list", "/usr/bin/kare -l"]

        for cmd in search_pkg_cmd_list:
            if not os.path.exists(cmd.split()[0].strip()): continue
            # 匹配以字母开头的行
            cmd_str = "%s |grep '^[a-zA-Z]' |awk '{print $1}'  2> /dev/null" %cmd
            output = getoutput(cmd_str)
            if output:
                self.system_pkg_list.extend(output.splitlines())
            else:
                print("warning: %s no pkg data"%cmd_str)
        # 特殊情况下会出现脏数据(例如系统语言时英文时 获取kaiming包时 会把title获取进来)
        if "Name" in self.system_pkg_list:
            self.system_pkg_list.remove("Name")


    def update_json_data(self):
        """
        将软件包信息写入到json文件中
        """
        self.json_data = JsonHandler.read_json(self.json_file)
        ok, content = JsonHandler.get_variable_content(self.json_data, self.module_name, self.pkg_var_name)
        if not ok: raise ValueError(f"{self.pkg_var_name} not found")

        for pkg in self.system_pkg_list:
            if pkg in content or pkg == "Name": continue

            content.append(pkg)


    def update_xml_data(self):
        """
        将软件包信息写入到xml文件中
        """
        for pkg in self.system_pkg_list:
            if not pkg: continue

            if pkg not in self.netctl_xml_handler.current_node_pkg_list:
                self.netctl_xml_handler.create_node(pkg)

    def run(self):
        self.get_all_system_pkg()
        if not self.system_pkg_list: raise ValueError("get system pkg data failed")

        self.update_json_data()
        self.update_xml_data()

        JsonHandler.save_data(self.json_file, self.json_data)
        self.netctl_xml_handler.save_data()




class CmdParser(object):
    """
    命令解析类
    """
    def __init__(self) -> None:
        self.parser = argparse.ArgumentParser(
            "kysec-scene-parser-resource",
            add_help = False
        )
        self.init_args()

    def init_args(self) -> None:
        self.parser.add_argument("-g", "--gen-resource", action="store_true", help="解析场景策略,生成/etc/kysec-scene/kysec-scene-origin-resource")
        self.parser.add_argument("-n", "--init-pkg-netpolicy", action="store_true", help="初始化软件包的网络策略")
        self.parser.add_argument("-f", "--file", type=str, metavar='', nargs="+", help="文件路径")
        self.parser.add_argument("-h", "--help",  action="help", default=argparse.SUPPRESS, help="帮助信息")
        self.parse_args = self.parser.parse_args()

    def run(self):
        if self.parse_args.gen_resource:
            if self.parse_args.file:
                handler = ResourceParser(*self.parse_args.file)
            else:
                handler = ResourceParser("/etc/kysec-scene/scene/ksc-defender.json", "/etc/kysec-scene/scene/ksc-defender.pblk.json", "/etc/kysec-scene/scene/kysec-global.json", "/etc/kysec-scene/scene/ksc-defender.nblk.json")
            handler.parse_json_data()
            handler.write_result()

        elif self.parse_args.init_pkg_netpolicy:
            net_handler = NetPolicyHandler()
            net_handler.run()
        else:
            self.help()


    def exit(self, msg):
        self.parser.error(f"{msg}")
        sys.exit(1)

    def help(self):
        self.parser.print_help()
        sys.exit(1)


if __name__ == "__main__":
    try:
        cmd_parser = CmdParser().run()
    except Exception as e:
        print("Exception: %s" % e)
        traceback.print_exc()
