feat(http/parser): extend query selector

This commit is contained in:
Tiara Rodney 2025-12-31 15:47:05 +01:00
parent db72017810
commit a4e215c69c
No known key found for this signature in database
GPG key ID: 5CD8EC1D46106723
2 changed files with 151 additions and 51 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from html.parser import HTMLParser from html.parser import HTMLParser
import re
from typing import Dict, Iterable, List, Optional, Generator, Union from typing import Dict, Iterable, List, Optional, Generator, Union
@ -11,6 +12,10 @@ class Node:
:param attrs: Iterable of ``(key, value)`` attribute pairs. :param attrs: Iterable of ``(key, value)`` attribute pairs.
:param parent: Parent :class:`Node` instance. :param parent: Parent :class:`Node` instance.
:param text: Text content for text nodes. :param text: Text content for text nodes.
.. todo::
Mutation APIs (append_child, remove, replace_with)
""" """
def __init__( def __init__(
@ -113,49 +118,108 @@ class Node:
return results[0] if results else None return results[0] if results else None
def query_selector_all(self, selector: str) -> List["Node"]: def query_selector_all(self, selector: str) -> List["Node"]:
""" # Tokenize: split on spaces and > while keeping >
Return all elements matching a CSS-like selector chain. tokens = re.findall(r"[^\s>]+|>", selector)
Supports: # Current working set starts with the context node
- ``tag`` current = [self]
- ``.class``
- ``#id`` # Helper: match a node against a simple selector
- descendant chaining: ``div .item span`` def match_simple(node: "Node", token: str) -> bool:
""" tag = None
parts = selector.split() id_ = None
current: List[Node] = [self] classes = []
attrs = {}
for part in parts:
next_nodes: List[Node] = [] # [attr=value]
attr_matches = re.findall(
r"\[([a-zA-Z0-9_-]+)=['\"]?([^'\"]+)['\"]?\]",
token
)
for k, v in attr_matches:
attrs[k] = v
token = re.sub(r"\[[^\]]+\]", "", token)
# tag
m = re.match(r"^[a-zA-Z0-9_-]+", token)
if m:
tag = m.group(0)
token = token[len(tag):]
# #id
m = re.search(r"#([a-zA-Z0-9_-]+)", token)
if m:
id_ = m.group(1)
token = token.replace("#" + id_, "")
# .classes
classes = [c for c in token.split(".") if c]
# match
if tag and node.tag != tag:
return False
if id_ and node.attrs.get("id") != id_:
return False
for cls in classes:
if "class" not in node.attrs or cls not in node.attrs["class"].split():
return False
for k, v in attrs.items():
if node.attrs.get(k) != v:
return False
return True
# ------------------------------------------------------------
# Main selector evaluation
# ------------------------------------------------------------
first_token = True
i = 0
while i < len(tokens):
token = tokens[i]
# --------------------------------------------------------
# Direct child selector
# --------------------------------------------------------
if token == ">":
i += 1
next_token = tokens[i]
next_nodes = []
for node in current:
for child in node.children:
if match_simple(child, next_token):
next_nodes.append(child)
current = next_nodes
first_token = False
i += 1
continue
# --------------------------------------------------------
# Descendant selector
# --------------------------------------------------------
next_nodes = []
seen = set()
for node in current: for node in current:
# Tag selector # Only include the context node itself if NOT the first token
if not part.startswith(".") and not part.startswith("#"): if not first_token and match_simple(node, token):
if node.tag == part: if id(node) not in seen:
seen.add(id(node))
next_nodes.append(node) next_nodes.append(node)
next_nodes.extend(node.get_elements_by_tag_name(part))
continue # Always include descendants
for desc in node.iter():
# Class selector if match_simple(desc, token):
if part.startswith("."): if id(desc) not in seen:
cls = part[1:] seen.add(id(desc))
if "class" in node.attrs and cls in node.attrs["class"].split(): next_nodes.append(desc)
next_nodes.append(node)
next_nodes.extend(node.get_elements_by_class_name(cls))
continue
# ID selector
if part.startswith("#"):
ident = part[1:]
if node.attrs.get("id") == ident:
next_nodes.append(node)
found = node.get_element_by_id(ident)
if found:
next_nodes.append(found)
continue
current = next_nodes current = next_nodes
first_token = False
i += 1
return current return current
def xpath(self, expr: str) -> List["Node"]: def xpath(self, expr: str) -> List["Node"]:
@ -168,6 +232,10 @@ class Node:
:param expr: XPath-like expression. :param expr: XPath-like expression.
:return: List of :class:`Node` objects. :return: List of :class:`Node` objects.
.. todo::
full XPath 1.0 subset
""" """
expr = expr.strip() expr = expr.strip()
parts = expr.split("/") parts = expr.split("/")

View file

@ -7,14 +7,6 @@ from byteb4rb1e.utils.http.parser import Node, TreeBuilder
def sample_dom(): def sample_dom():
""" """
Build a small DOM tree for testing: Build a small DOM tree for testing:
<div id="root" class="container">
<p class="text">Hello</p>
<span class="text highlight">World</span>
<div class="box">
<span id="inner">Inside</span>
</div>
</div>
""" """
html = """ html = """
<div id="root" class="container"> <div id="root" class="container">
@ -22,6 +14,7 @@ def sample_dom():
<span class="text highlight">World</span> <span class="text highlight">World</span>
<div class="box"> <div class="box">
<span id="inner">Inside</span> <span id="inner">Inside</span>
<span id="inner2">Inside Too</span>
</div> </div>
</div> </div>
""" """
@ -33,9 +26,10 @@ def sample_dom():
class TestGetElementsByTagName: class TestGetElementsByTagName:
def test_find_all_spans(self, sample_dom): def test_find_all_spans(self, sample_dom):
spans = sample_dom.get_elements_by_tag_name("span") spans = sample_dom.get_elements_by_tag_name("span")
assert len(spans) == 2 assert len(spans) == 3
assert spans[0].tag == "span" assert spans[0].tag == "span"
assert spans[1].tag == "span" assert spans[1].tag == "span"
assert spans[2].tag == "span"
def test_find_no_matches(self, sample_dom): def test_find_no_matches(self, sample_dom):
assert sample_dom.get_elements_by_tag_name("table") == [] assert sample_dom.get_elements_by_tag_name("table") == []
@ -82,15 +76,53 @@ class TestQuerySelectorAll:
assert items[0].inner_content == "Hello" assert items[0].inner_content == "Hello"
def test_chained_selector(self, sample_dom): def test_chained_selector(self, sample_dom):
items = sample_dom.query_selector_all("div .highlight") items = sample_dom.query_selector_all(".text .highlight")
assert len(items) == 1 assert len(items) == 1
assert items[0].inner_content == "World" assert items[0].inner_content == "World"
def test_direct_child(self, sample_dom):
items = sample_dom.query_selector_all(".box > #inner")
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_direct_child_no_match(self, sample_dom):
items = sample_dom.query_selector_all("div > span.highlight")
# highlight span is NOT a direct child of inner div
assert len(items) == 0
def test_attribute_match(self, sample_dom):
items = sample_dom.query_selector_all('[id="inner"]')
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_attribute_no_match(self, sample_dom):
items = sample_dom.query_selector_all('[data-x="nope"]')
assert items == []
def test_tag_class(self, sample_dom):
items = sample_dom.query_selector_all("span.highlight")
assert len(items) == 1
assert items[0].inner_content == "World"
def test_multiple_classes(self, sample_dom):
items = sample_dom.query_selector_all(".text.highlight")
assert len(items) == 1
assert items[0].inner_content == "World"
def test_tag_id_class(self, sample_dom):
items = sample_dom.query_selector_all("span#inner")
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_descendant(self, sample_dom):
items = sample_dom.query_selector_all("div span")
assert len(items) == 2
class TestXPath: class TestXPath:
def test_simple_tag(self, sample_dom): def test_simple_tag(self, sample_dom):
spans = sample_dom.xpath("//span") spans = sample_dom.xpath("//span")
assert len(spans) == 2 assert len(spans) == 3
def test_attribute_match(self, sample_dom): def test_attribute_match(self, sample_dom):
nodes = sample_dom.xpath('//span[@id="inner"]') nodes = sample_dom.xpath('//span[@id="inner"]')