Coverage for mongo/ip_filter.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-11-05 04:22 +0000

1import re 

2from typing import Union 

3 

4__all__ = ('IPFilter', ) 

5 

6 

7class OctetMatcher: 

8 

9 def __init__(self, pattern: str) -> None: 

10 if pattern.isdecimal(): 

11 num = int(pattern) 

12 if not 0 <= num < 256: 

13 raise ValueError(f'A octet must in range [0, 255], got {num}.') 

14 self.ranges = ((num, num), ) 

15 elif pattern == '*': 

16 self.ranges = ((0, 255), ) 

17 else: 

18 pattern = pattern.replace(' ', '').split(',') 

19 if any(not re.match(r'\d+-\d+', r) for r in pattern): 

20 raise ValueError(f'Invalid range pattern.') 

21 ranges = tuple( 

22 (*sorted(map(int, r.split('-'))), ) for r in pattern) 

23 if any(l < 0 or h > 255 for l, h in ranges): 

24 raise ValueError(f'Invalid number range.') 

25 self.ranges = ranges 

26 

27 def __repr__(self) -> str: 

28 return f'OctetMatcher({self.ranges})' 

29 

30 def match(self, num: Union[int, str]) -> bool: 

31 if type(num) == str: 

32 num = int(num) 

33 return any(l <= num <= r for l, r in self.ranges) 

34 

35 

36class IPFilter: 

37 

38 def __init__(self, pattern: str) -> None: 

39 pattern = pattern.split('.') 

40 if len(pattern) != 4: 

41 raise ValueError('Invalid filter pattern.') 

42 self.matchers = [OctetMatcher(p) for p in pattern] 

43 

44 def __repr__(self) -> str: 

45 return f'IPFilter({self.matchers})' 

46 

47 def is_valid_ip(self, ip: str) -> bool: 

48 ip = ip.split('.') 

49 if len(ip) != 4: 

50 return False 

51 if not all(x.isdecimal() for x in ip): 

52 return False 

53 if not all(0 <= int(x) <= 255 for x in ip): 

54 return False 

55 return True 

56 

57 def match(self, ip: str) -> bool: 

58 if not self.is_valid_ip(ip): 

59 return False 

60 return all(m.match(i) for i, m in zip(ip.split('.'), self.matchers))