Skip to content

Commit

Permalink
Refactor Rule and RuleSet type
Browse files Browse the repository at this point in the history
  • Loading branch information
xkww3n committed Jun 24, 2024
1 parent 9726b07 commit 5ef7736
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 255 deletions.
7 changes: 0 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,3 @@

URL_DOMESTIC_IP_V4 = "https://raw.githubusercontent.com/gaoyifan/china-operator-ip/ip-lists/china.txt"
URL_DOMESTIC_IP_V6 = "https://raw.githubusercontent.com/gaoyifan/china-operator-ip/ip-lists/china6.txt"

RULE_TYPE_CONVERSION = {
"DOMAIN": "DomainFull",
"DOMAIN-SUFFIX": "DomainSuffix",
"IP-CIDR": "IPCIDR",
"IP-CIDR6": "IPCIDR6"
}
61 changes: 23 additions & 38 deletions models/rule.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,59 @@
from enum import Enum
from ipaddress import ip_network, IPv4Network, IPv6Network

from utils.rule import is_domain


class RuleType(Enum):
DomainSuffix = "DOMAIN-SUFFIX"
DomainFull = "DOMAIN"
IPCIDR = "IP-CIDR"
IPCIDR6 = "IP-CIDR6"


class Rule:
_type: str # DomainSuffix / DomainFull / IPCIDR / IPCIDR6
type: RuleType
tag: str
_payload: str
_tag: str

def __init__(self, rule_type: str = "", payload: str = "", tag: str = ""):
self._type = ""
self._payload = ""
self._tag = tag
if rule_type:
self.type = rule_type
if payload:
self.payload = payload
def __init__(self, rule_type: RuleType, payload: str = "", tag: str = ""):
self.type = rule_type
self.tag = tag
self.payload = payload

def __str__(self):
return f'{self._type}: {self._payload}{f" ({self._tag})" if self._tag else ""}'
return f'{self.type.name}: {self._payload}{f" ({self.tag})" if self.tag else ""}'

def __hash__(self):
return hash((self._type, self._payload, self._tag))
return hash((self.type, self.tag, self._payload))

def __eq__(self, other):
# noinspection PyProtectedMember
return self._type == other._type and self._payload == other._payload and self._tag == other._tag

@property
def type(self) -> str:
return self._type

@type.setter
def type(self, rule_type: str):
allowed_types = ("DomainSuffix", "DomainFull", "IPCIDR", "IPCIDR6")
if rule_type not in allowed_types:
raise TypeError(f"Unsupported type: {rule_type}")
self._type = rule_type
return self.type == other.type and self.tag == other.tag and self._payload == other._payload

@property
def payload(self) -> str:
return self._payload

@payload.setter
def payload(self, payload: str):
if "Domain" in self._type:
if self.type in {RuleType.DomainSuffix, RuleType.DomainFull}:
if not is_domain(payload):
raise ValueError(f"Invalid domain: {payload}")
elif "IP" in self._type:
else:
ip_type = ip_network(payload, strict=False)
if self._type == "IPCIDR6" and isinstance(ip_type, IPv4Network):
if self.type == RuleType.IPCIDR6 and isinstance(ip_type, IPv4Network):
raise ValueError(f"IPv4 address stored in IPv6 type: {payload}")
elif self._type == "IPCIDR" and isinstance(ip_type, IPv6Network):
elif self.type == RuleType.IPCIDR and isinstance(ip_type, IPv6Network):
raise ValueError(f"IPv6 address stored in IPv4 type: {payload}")
self._payload = payload

@property
def tag(self) -> str:
return self._tag

@tag.setter
def tag(self, tag: str):
self._tag = tag

def includes(self, other):
if self._type == "DomainSuffix":
if self.type == RuleType.DomainSuffix:
# noinspection PyProtectedMember
if self._payload == other._payload:
return True
# noinspection PyProtectedMember
return other._payload.endswith("." + self._payload)
elif self._type == "DomainFull":
elif self.type == RuleType.DomainFull:
return self == other
64 changes: 28 additions & 36 deletions models/ruleset.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import logging
from enum import Enum

from models.rule import Rule
from models.rule import Rule, RuleType


class RuleSetType(Enum):
Domain = 1
IPCIDR = 2
Combined = 3


class RuleSet:
_type: str # Domain / IPCIDR / Combined
type: RuleSetType
_payload: list[Rule]

def __init__(self, ruleset_type: str, payload: list):
self._type = ""
self._payload = []
if ruleset_type:
self.type = ruleset_type
if payload:
self.payload = payload
def __init__(self, ruleset_type: RuleSetType, payload: list):
self.type = ruleset_type
self._payload = payload

def __hash__(self):
return hash((self._type, tuple(self._payload)))
return hash((self.type, tuple(self._payload)))

def __eq__(self, other):
# noinspection PyProtectedMember
return self._type == other._type and self._payload == other._payload
return self.type == other.type and self._payload == other._payload

def __len__(self):
return len(self._payload)
Expand All @@ -40,41 +43,30 @@ def __contains__(self, item):
def __iter__(self):
return iter(self._payload)

@property
def type(self) -> str:
return self._type

@type.setter
def type(self, ruleset_type: str):
allowed_types = {"Domain", "IPCIDR", "Combined"}
if ruleset_type not in allowed_types:
raise TypeError(f"Unsupported type: {ruleset_type}")
self._type = ruleset_type

@property
def payload(self) -> list:
return self._payload

@payload.setter
def payload(self, payload: list):
if self._type == "Domain":
if self.type == RuleSetType.Domain:
for item in payload:
if "Domain" not in item.type:
raise ValueError(f"{item.type}-type rule found in a domain-type ruleset.")
elif self._type == "IPCIDR":
if item.type not in {RuleType.DomainSuffix, RuleType.DomainFull}:
raise ValueError(f"{item.type.value} rule found in a domain-type ruleset.")
elif self.type == RuleSetType.IPCIDR:
for item in payload:
if "IPCIDR" not in item.type:
raise ValueError(f"{item.type}-type rule found in a IPCIDR-type ruleset.")
if item.type not in {RuleType.IPCIDR, RuleType.IPCIDR6}:
raise ValueError(f"{item.type.value} rule found in a IPCIDR-type ruleset.")
self._payload = payload

def deepcopy(self):
ruleset_copied = RuleSet(self._type, [])
ruleset_copied = RuleSet(self.type, [])
payload_copied = []
for rule in self._payload:
rule_copied = Rule()
rule_copied._type = rule.type
rule_copied._payload = rule.payload
rule_copied._tag = rule.tag
rule_copied_type = rule.type
rule_copied_payload = rule.payload
rule_copied_tag = rule.tag
rule_copied = Rule(rule_copied_type, rule_copied_payload, rule_copied_tag)
payload_copied.append(rule_copied)
ruleset_copied._payload = payload_copied
return ruleset_copied
Expand All @@ -87,7 +79,7 @@ def remove(self, rule):
self._payload.remove(rule)

def sort(self):
if self._type == "Combined":
if self.type == RuleSetType.Combined:
logging.warning("Skipped: Combined-type ruleset shouldn't be sorted as maybe ordered.")
return

Expand All @@ -96,9 +88,9 @@ def sort_key(item: Rule) -> tuple:
# Domain suffixes should always in front of full domains
# Shorter domains should in front of longer domains
# For IPCIDR ruleset, default sort method is ok.
case "DomainSuffix":
case RuleType.DomainSuffix:
sortkey = (0, len(item.payload), item.payload)
case "DomainFull":
case RuleType.DomainFull:
sortkey = (1, len(item.payload), item.payload)
case _:
sortkey = (2, len(item.payload), item.payload)
Expand Down
62 changes: 25 additions & 37 deletions tests/test_0_rule.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,65 @@
from pytest import raises

from models.rule import Rule
from models.rule import Rule, RuleType


class Test:
def test_type_checking_init(self):
with raises(TypeError):
Rule("NotAllowedType", "test_payload")
with raises(ValueError):
Rule("DomainSuffix", "[invalid_domain]")
Rule(RuleType.DomainSuffix, "[invalid_domain]")
with raises(ValueError):
Rule("DomainFull", "[invalid_domain]")
Rule(RuleType.DomainFull, "[invalid_domain]")
with raises(ValueError):
Rule("IPCIDR", "114514")
Rule(RuleType.IPCIDR, "114514")
with raises(ValueError):
Rule("IPCIDR6", "1919810")
Rule(RuleType.IPCIDR6, "1919810")
with raises(ValueError):
Rule("IPCIDR", "fc00:114::514")
Rule(RuleType.IPCIDR, "fc00:114::514")
with raises(ValueError):
Rule("IPCIDR6", "1.14.5.14")
Rule(RuleType.IPCIDR6, "1.14.5.14")

def test_type_checking_runtime(self):
test_rule = Rule()
with raises(TypeError):
test_rule.type = "NotAllowedType"

test_rule.type = "DomainSuffix"
with raises(ValueError):
test_rule.payload = "[invalid_domain]"
Rule(RuleType.DomainSuffix, "[invalid_domain]")

test_rule.type = "DomainFull"
with raises(ValueError):
test_rule.payload = "[invalid_domain]"
Rule(RuleType.DomainFull, "[invalid_domain]")

test_rule.type = "IPCIDR"
with raises(ValueError):
test_rule.payload = "114514"
Rule(RuleType.IPCIDR, "114514")

test_rule.type = "IPCIDR6"
with raises(ValueError):
test_rule.payload = "1919810"
Rule(RuleType.IPCIDR6, "1919810")

test_rule.type = "IPCIDR"
with raises(ValueError):
test_rule.payload = "fc00:114::514"
Rule(RuleType.IPCIDR, "fc00:114::514")

test_rule.type = "IPCIDR6"
with raises(ValueError):
test_rule.payload = "1.14.5.14"
Rule(RuleType.IPCIDR6, "1.14.5.14")

def test_to_str(self):
test_rule = Rule("DomainSuffix", "example.com", "TEST")
test_rule = Rule(RuleType.DomainSuffix, "example.com", "TEST")
assert str(test_rule) == "DomainSuffix: example.com (TEST)"

def test_hash(self):
test_rule_1 = Rule("DomainSuffix", "example.com", "TEST")
test_rule_2 = Rule("DomainFull", "example.com", "TEST2")
test_dict = [Rule("DomainSuffix", "example.com", "TEST")]
test_rule_1 = Rule(RuleType.DomainSuffix, "example.com", "TEST")
test_rule_2 = Rule(RuleType.DomainFull, "example.com", "TEST2")
test_dict = [Rule(RuleType.DomainSuffix, "example.com", "TEST")]
assert test_rule_1 in test_dict
assert test_rule_2 not in test_dict

def test_eq(self):
test_rule_1 = Rule("DomainSuffix", "example.com", "TEST")
test_rule_2 = Rule("DomainSuffix", "example.com", "TEST")
test_rule_1 = Rule(RuleType.DomainSuffix, "example.com", "TEST")
test_rule_2 = Rule(RuleType.DomainSuffix, "example.com", "TEST")
assert test_rule_1 == test_rule_2

def test_include(self):
test_self_rule = Rule("DomainSuffix", "example.com", "TEST")
test_rule_1 = Rule("DomainSuffix", "a.example.com", "TEST")
test_rule_2 = Rule("DomainFull", "b.example.com", "TEST")
test_rule_3 = Rule("DomainFull", "example.com", "TEST")
test_rule_4 = Rule("DomainSuffix", "example.com", "TEST")
test_rule_5 = Rule("DomainFull", "1example.com", "TEST")
test_self_rule = Rule(RuleType.DomainSuffix, "example.com", "TEST")
test_rule_1 = Rule(RuleType.DomainSuffix, "a.example.com", "TEST")
test_rule_2 = Rule(RuleType.DomainFull, "b.example.com", "TEST")
test_rule_3 = Rule(RuleType.DomainFull, "example.com", "TEST")
test_rule_4 = Rule(RuleType.DomainSuffix, "example.com", "TEST")
test_rule_5 = Rule(RuleType.DomainFull, "1example.com", "TEST")
assert test_self_rule.includes(test_rule_1)
assert test_self_rule.includes(test_rule_2)
assert test_self_rule.includes(test_rule_3)
Expand Down
Loading

0 comments on commit 5ef7736

Please sign in to comment.