diff --git a/src/byteb4rb1e/utils/http/parser.py b/src/byteb4rb1e/utils/http/parser.py index 3f814d6..58f082f 100644 --- a/src/byteb4rb1e/utils/http/parser.py +++ b/src/byteb4rb1e/utils/http/parser.py @@ -1,5 +1,6 @@ from __future__ import annotations from html.parser import HTMLParser +import re from typing import Dict, Iterable, List, Optional, Generator, Union @@ -11,6 +12,10 @@ class Node: :param attrs: Iterable of ``(key, value)`` attribute pairs. :param parent: Parent :class:`Node` instance. :param text: Text content for text nodes. + + .. todo:: + + Mutation APIs (append_child, remove, replace_with) """ def __init__( @@ -113,49 +118,108 @@ class Node: return results[0] if results else None def query_selector_all(self, selector: str) -> List["Node"]: - """ - Return all elements matching a CSS-like selector chain. - - Supports: - - ``tag`` - - ``.class`` - - ``#id`` - - descendant chaining: ``div .item span`` - """ - parts = selector.split() - current: List[Node] = [self] - - for part in parts: - next_nodes: List[Node] = [] - + # Tokenize: split on spaces and > while keeping > + tokens = re.findall(r"[^\s>]+|>", selector) + + # Current working set starts with the context node + current = [self] + + # Helper: match a node against a simple selector + def match_simple(node: "Node", token: str) -> bool: + tag = None + id_ = None + classes = [] + attrs = {} + + # [attr=value] + attr_matches = re.findall( + r"\[([a-zA-Z0-9_-]+)=['\"]?([^'\"]+)['\"]?\]", + token + ) + for k, v in attr_matches: + attrs[k] = v + token = re.sub(r"\[[^\]]+\]", "", token) + + # tag + m = re.match(r"^[a-zA-Z0-9_-]+", token) + if m: + tag = m.group(0) + token = token[len(tag):] + + # #id + m = re.search(r"#([a-zA-Z0-9_-]+)", token) + if m: + id_ = m.group(1) + token = token.replace("#" + id_, "") + + # .classes + classes = [c for c in token.split(".") if c] + + # match + if tag and node.tag != tag: + return False + if id_ and node.attrs.get("id") != id_: + return False + for cls in classes: + if "class" not in node.attrs or cls not in node.attrs["class"].split(): + return False + for k, v in attrs.items(): + if node.attrs.get(k) != v: + return False + + return True + + # ------------------------------------------------------------ + # Main selector evaluation + # ------------------------------------------------------------ + first_token = True + i = 0 + + while i < len(tokens): + token = tokens[i] + + # -------------------------------------------------------- + # Direct child selector + # -------------------------------------------------------- + if token == ">": + i += 1 + next_token = tokens[i] + next_nodes = [] + + for node in current: + for child in node.children: + if match_simple(child, next_token): + next_nodes.append(child) + + current = next_nodes + first_token = False + i += 1 + continue + + # -------------------------------------------------------- + # Descendant selector + # -------------------------------------------------------- + next_nodes = [] + seen = set() + for node in current: - # Tag selector - if not part.startswith(".") and not part.startswith("#"): - if node.tag == part: + # Only include the context node itself if NOT the first token + if not first_token and match_simple(node, token): + if id(node) not in seen: + seen.add(id(node)) next_nodes.append(node) - next_nodes.extend(node.get_elements_by_tag_name(part)) - continue - - # Class selector - if part.startswith("."): - cls = part[1:] - if "class" in node.attrs and cls in node.attrs["class"].split(): - next_nodes.append(node) - next_nodes.extend(node.get_elements_by_class_name(cls)) - continue - - # ID selector - if part.startswith("#"): - ident = part[1:] - if node.attrs.get("id") == ident: - next_nodes.append(node) - found = node.get_element_by_id(ident) - if found: - next_nodes.append(found) - continue - + + # Always include descendants + for desc in node.iter(): + if match_simple(desc, token): + if id(desc) not in seen: + seen.add(id(desc)) + next_nodes.append(desc) + current = next_nodes - + first_token = False + i += 1 + return current def xpath(self, expr: str) -> List["Node"]: @@ -168,6 +232,10 @@ class Node: :param expr: XPath-like expression. :return: List of :class:`Node` objects. + + .. todo:: + + full XPath 1.0 subset """ expr = expr.strip() parts = expr.split("/") diff --git a/tests/unit/byteb4rb1e/utils/http/parser/test_node.py b/tests/unit/byteb4rb1e/utils/http/parser/test_node.py index 721821f..a75de46 100644 --- a/tests/unit/byteb4rb1e/utils/http/parser/test_node.py +++ b/tests/unit/byteb4rb1e/utils/http/parser/test_node.py @@ -7,14 +7,6 @@ from byteb4rb1e.utils.http.parser import Node, TreeBuilder def sample_dom(): """ Build a small DOM tree for testing: - -
Hello
- World -