­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ ­ """WordPress-specific disabled rules data model. This module provides a separate data model for WordPress disabled rules, independent of the existing DisabledRule/DisabledRuleDomain models used by modsec/ossec plugins. Disable Behavior: Global and domain-level disables are independent and can coexist. A rule is considered effectively disabled for a given WordPress domain if EITHER of these conditions is true: - A global disable exists for the rule (applies to all domains) - A domain-specific disable exists for the rule and that domain Enabling a rule at one scope does not affect disables at the other scope. For example, removing a global disable leaves any domain-specific disables intact, and vice versa. """ from collections.abc import Iterator import logging import time from peewee import ( CharField, FloatField, IntegerField, IntegrityError, PrimaryKeyField, fn, ) from defence360agent.model import Model, instance logger = logging.getLogger(__name__) class WPDisabledRule(Model): """Stores disabled WordPress protection rules. Uses a scope-based design: - scope='global', scope_value=NULL: Rule disabled for all domains (root only) - scope='domain', scope_value='example.com': Rule disabled for specific domain """ class Meta: database = instance.db db_table = "wp_disabled_rules" indexes = ((("rule_id", "scope", "scope_value"), True),) id = PrimaryKeyField() # The rule identifier (e.g., "CVE-2025-001") rule_id = CharField(null=False) # The scope type: "global" or "domain" scope = CharField(null=False) # The scope value: NULL for global, domain name for domain scope scope_value = CharField(null=True) # Unix timestamp when the rule was disabled disabled_at = FloatField(null=False) # Origin of the disable action: "wordpress" (from wordpress admin ui) or "agent" (from CLI/RPC) source = CharField(null=False) # UID of the user who disabled the rule (0 for root) created_by_user_id = IntegerField(null=False) # Scope constants SCOPE_GLOBAL = "global" SCOPE_DOMAIN = "domain" # Source constants SOURCE_WORDPRESS = "wordpress" SOURCE_AGENT = "agent" @classmethod def store( cls, rule_id: str, domains: list[str] | None, source: str, user_id: int, timestamp: float | None = None, ) -> int: """ Disable a rule globally or for specific domains. Args: rule_id: The rule identifier (e.g., "CVE-2025-001") domains: List of domains to disable for, or None/empty for global disable source: Origin of the action ("wordpress" or "agent") user_id: UID of the user performing the action (0 for root) timestamp: Unix timestamp for when the rule was disabled. If None, uses current time. Returns: Number of new entries created (0 if all were no-ops). """ if timestamp is None: timestamp = time.time() if domains: return cls._disable_for_domains( rule_id, domains, timestamp, source, user_id ) return cls._disable_globally(rule_id, timestamp, source, user_id) @classmethod def _disable_globally( cls, rule_id: str, timestamp: float, source: str, user_id: int, ) -> int: """Disable a rule globally (independent of domain-specific entries).""" created = cls._create_if_not_exists( rule_id=rule_id, scope=cls.SCOPE_GLOBAL, scope_value=None, disabled_at=timestamp, source=source, user_id=user_id, ) if created: logger.debug( "Disabled rule %s globally (source=%s, user_id=%s)", rule_id, source, user_id, ) return int(created) @classmethod def _disable_for_domains( cls, rule_id: str, domains: list[str], timestamp: float, source: str, user_id: int, ) -> int: """Disable a rule for specific domains (independent of global state).""" count = 0 for domain in domains: created = cls._create_if_not_exists( rule_id=rule_id, scope=cls.SCOPE_DOMAIN, scope_value=domain, disabled_at=timestamp, source=source, user_id=user_id, ) if created: count += 1 logger.debug( "Disabled rule %s for domain %s (source=%s, user_id=%s)", rule_id, domain, source, user_id, ) return count @classmethod def _create_if_not_exists( cls, rule_id: str, scope: str, scope_value: str | None, disabled_at: float, source: str, user_id: int, ) -> bool: """ Create a new disabled rule entry if it doesn't already exist. Returns: True if a new entry was created, False if it already existed (no-op) """ try: cls.insert( rule_id=rule_id, scope=scope, scope_value=scope_value, disabled_at=disabled_at, source=source, created_by_user_id=user_id, ).execute() return True except IntegrityError: # Rule already disabled for this scope - no-op return False @classmethod def remove(cls, rule_id: str, domains: list[str] | None) -> int: """ Re-enable a rule globally or for specific domains. Args: rule_id: The rule identifier domains: List of domains to enable for, or None/empty to enable globally Returns: Number of rows deleted """ if not domains: # Enable globally - remove ONLY the global entry count = ( cls.delete() .where( cls.rule_id == rule_id, cls.scope == cls.SCOPE_GLOBAL, ) .execute() ) if count: logger.debug("Enabled rule %s globally", rule_id) else: # Enable for specific domains count = ( cls.delete() .where( cls.rule_id == rule_id, cls.scope == cls.SCOPE_DOMAIN, cls.scope_value.in_(domains), ) .execute() ) if count: logger.debug( "Enabled rule %s for %d domain(s)", rule_id, count, ) return count @classmethod def is_rule_disabled(cls, rule_id: str, domain: str | None = None) -> bool: """ Check if a rule is disabled globally or for a specific domain. Args: rule_id: The rule identifier domain: The domain to check. If None, only checks global disable. Returns: True if the rule is disabled, False otherwise """ if domain is None: return ( cls.select() .where( cls.rule_id == rule_id, cls.scope == cls.SCOPE_GLOBAL, ) .exists() ) return ( cls.select() .where( cls.rule_id == rule_id, ( (cls.scope == cls.SCOPE_GLOBAL) | ( (cls.scope == cls.SCOPE_DOMAIN) & (cls.scope_value == domain) ) ), ) .exists() ) @classmethod def get_domain_disabled( cls, domain: str, include_global: bool = False ) -> list[str]: """ Get all rule IDs that are disabled for a specific domain. Args: domain: The domain to get disabled rules for include_global: If True, also include globally disabled rules. If False (default), only return domain-specific disables. Returns: List of rule IDs that are disabled for the domain """ if include_global: query = ( cls.select(cls.rule_id) .where( (cls.scope == cls.SCOPE_GLOBAL) | ( (cls.scope == cls.SCOPE_DOMAIN) & (cls.scope_value == domain) ) ) .distinct() ) else: query = cls.select(cls.rule_id).where( cls.scope == cls.SCOPE_DOMAIN, cls.scope_value == domain, ) return [row.rule_id for row in query] @classmethod def get_global_disabled(cls) -> Iterator[str]: """ Get all rule IDs that are disabled globally. Returns: Iterator of globally disabled rule IDs """ query = cls.select(cls.rule_id).where(cls.scope == cls.SCOPE_GLOBAL) return (row.rule_id for row in query) @classmethod def _build_filter_condition( cls, user_domains: list[str] | None, include_global: bool, ): """ Build the WHERE condition for filtering rules. Returns: A Peewee expression for the WHERE clause, or None if no filter needed. """ if user_domains is not None: domain_match = (cls.scope == cls.SCOPE_DOMAIN) & ( cls.scope_value.in_(user_domains) ) if include_global: return (cls.scope == cls.SCOPE_GLOBAL) | domain_match return domain_match if not include_global: return cls.scope == cls.SCOPE_DOMAIN return None @classmethod def fetch( cls, limit: int, offset: int = 0, user_domains: list[str] | None = None, include_global: bool = False, ) -> tuple[int, list[dict]]: """ List disabled rules with aggregation by rule_id. Multiple domain entries for the same rule are aggregated into a single result with a list of domains. Results are ordered by most recently disabled first (using the latest disabled_at timestamp per rule_id). Uses a two-pass approach for efficiency: 1. First pass: Get rule_ids ordered by latest disabled_at with pagination 2. Second pass: Fetch only rows for the paginated rule_ids Args: limit: Maximum number of rules to return offset: Number of rules to skip user_domains: If provided, only return rules for these domains. If None, return all rules (for root users). include_global: Whether to include global rules in the result Returns: Tuple of (total_count, list of rule dicts) Each dict has: {"rule_id": str, "is_global": bool, "domains": list[str]} is_global is True if rule has a global disable, domains lists domain-specific disables """ # Build filter condition condition = cls._build_filter_condition(user_domains, include_global) # First pass: get rule_ids ordered by latest disabled_at (most recent first) rule_ids_query = ( cls.select(cls.rule_id) .group_by(cls.rule_id) .order_by(fn.MAX(cls.disabled_at).desc()) ) if condition is not None: rule_ids_query = rule_ids_query.where(condition) # Get total count of distinct rule_ids total_count = rule_ids_query.count() # Apply pagination at DB level paginated_rule_ids = [ row.rule_id for row in rule_ids_query.offset(offset).limit(limit) ] if not paginated_rule_ids: return total_count, [] # Second pass: fetch rows for the paginated rule_ids rows_query = cls.select().where(cls.rule_id.in_(paginated_rule_ids)) if condition is not None: rows_query = rows_query.where(condition) # Aggregate domains by rule_id rules_by_id: dict[str, dict] = {} for row in rows_query: if row.rule_id not in rules_by_id: rules_by_id[row.rule_id] = { "rule_id": row.rule_id, "is_global": False, "domains": [], } if row.scope == cls.SCOPE_GLOBAL: rules_by_id[row.rule_id]["is_global"] = True elif row.scope == cls.SCOPE_DOMAIN: rules_by_id[row.rule_id]["domains"].append(row.scope_value) # Build result in order from first query (preserves DB ordering) result = [] for rule_id in paginated_rule_ids: rule_data = rules_by_id[rule_id] rule_data["domains"] = sorted(rule_data["domains"]) result.append(rule_data) return total_count, result