feat(http/parser): extend query selector

This commit is contained in:
Tiara Rodney 2025-12-31 15:47:05 +01:00
parent db72017810
commit a4e215c69c
No known key found for this signature in database
GPG key ID: 5CD8EC1D46106723
2 changed files with 151 additions and 51 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from html.parser import HTMLParser
import re
from typing import Dict, Iterable, List, Optional, Generator, Union
@ -11,6 +12,10 @@ class Node:
:param attrs: Iterable of ``(key, value)`` attribute pairs.
:param parent: Parent :class:`Node` instance.
:param text: Text content for text nodes.
.. todo::
Mutation APIs (append_child, remove, replace_with)
"""
def __init__(
@ -113,49 +118,108 @@ class Node:
return results[0] if results else None
def query_selector_all(self, selector: str) -> List["Node"]:
"""
Return all elements matching a CSS-like selector chain.
Supports:
- ``tag``
- ``.class``
- ``#id``
- descendant chaining: ``div .item span``
"""
parts = selector.split()
current: List[Node] = [self]
for part in parts:
next_nodes: List[Node] = []
# Tokenize: split on spaces and > while keeping >
tokens = re.findall(r"[^\s>]+|>", selector)
# Current working set starts with the context node
current = [self]
# Helper: match a node against a simple selector
def match_simple(node: "Node", token: str) -> bool:
tag = None
id_ = None
classes = []
attrs = {}
# [attr=value]
attr_matches = re.findall(
r"\[([a-zA-Z0-9_-]+)=['\"]?([^'\"]+)['\"]?\]",
token
)
for k, v in attr_matches:
attrs[k] = v
token = re.sub(r"\[[^\]]+\]", "", token)
# tag
m = re.match(r"^[a-zA-Z0-9_-]+", token)
if m:
tag = m.group(0)
token = token[len(tag):]
# #id
m = re.search(r"#([a-zA-Z0-9_-]+)", token)
if m:
id_ = m.group(1)
token = token.replace("#" + id_, "")
# .classes
classes = [c for c in token.split(".") if c]
# match
if tag and node.tag != tag:
return False
if id_ and node.attrs.get("id") != id_:
return False
for cls in classes:
if "class" not in node.attrs or cls not in node.attrs["class"].split():
return False
for k, v in attrs.items():
if node.attrs.get(k) != v:
return False
return True
# ------------------------------------------------------------
# Main selector evaluation
# ------------------------------------------------------------
first_token = True
i = 0
while i < len(tokens):
token = tokens[i]
# --------------------------------------------------------
# Direct child selector
# --------------------------------------------------------
if token == ">":
i += 1
next_token = tokens[i]
next_nodes = []
for node in current:
for child in node.children:
if match_simple(child, next_token):
next_nodes.append(child)
current = next_nodes
first_token = False
i += 1
continue
# --------------------------------------------------------
# Descendant selector
# --------------------------------------------------------
next_nodes = []
seen = set()
for node in current:
# Tag selector
if not part.startswith(".") and not part.startswith("#"):
if node.tag == part:
# Only include the context node itself if NOT the first token
if not first_token and match_simple(node, token):
if id(node) not in seen:
seen.add(id(node))
next_nodes.append(node)
next_nodes.extend(node.get_elements_by_tag_name(part))
continue
# Class selector
if part.startswith("."):
cls = part[1:]
if "class" in node.attrs and cls in node.attrs["class"].split():
next_nodes.append(node)
next_nodes.extend(node.get_elements_by_class_name(cls))
continue
# ID selector
if part.startswith("#"):
ident = part[1:]
if node.attrs.get("id") == ident:
next_nodes.append(node)
found = node.get_element_by_id(ident)
if found:
next_nodes.append(found)
continue
# Always include descendants
for desc in node.iter():
if match_simple(desc, token):
if id(desc) not in seen:
seen.add(id(desc))
next_nodes.append(desc)
current = next_nodes
first_token = False
i += 1
return current
def xpath(self, expr: str) -> List["Node"]:
@ -168,6 +232,10 @@ class Node:
:param expr: XPath-like expression.
:return: List of :class:`Node` objects.
.. todo::
full XPath 1.0 subset
"""
expr = expr.strip()
parts = expr.split("/")

View file

@ -7,14 +7,6 @@ from byteb4rb1e.utils.http.parser import Node, TreeBuilder
def sample_dom():
"""
Build a small DOM tree for testing:
<div id="root" class="container">
<p class="text">Hello</p>
<span class="text highlight">World</span>
<div class="box">
<span id="inner">Inside</span>
</div>
</div>
"""
html = """
<div id="root" class="container">
@ -22,6 +14,7 @@ def sample_dom():
<span class="text highlight">World</span>
<div class="box">
<span id="inner">Inside</span>
<span id="inner2">Inside Too</span>
</div>
</div>
"""
@ -33,9 +26,10 @@ def sample_dom():
class TestGetElementsByTagName:
def test_find_all_spans(self, sample_dom):
spans = sample_dom.get_elements_by_tag_name("span")
assert len(spans) == 2
assert len(spans) == 3
assert spans[0].tag == "span"
assert spans[1].tag == "span"
assert spans[2].tag == "span"
def test_find_no_matches(self, sample_dom):
assert sample_dom.get_elements_by_tag_name("table") == []
@ -82,15 +76,53 @@ class TestQuerySelectorAll:
assert items[0].inner_content == "Hello"
def test_chained_selector(self, sample_dom):
items = sample_dom.query_selector_all("div .highlight")
items = sample_dom.query_selector_all(".text .highlight")
assert len(items) == 1
assert items[0].inner_content == "World"
def test_direct_child(self, sample_dom):
items = sample_dom.query_selector_all(".box > #inner")
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_direct_child_no_match(self, sample_dom):
items = sample_dom.query_selector_all("div > span.highlight")
# highlight span is NOT a direct child of inner div
assert len(items) == 0
def test_attribute_match(self, sample_dom):
items = sample_dom.query_selector_all('[id="inner"]')
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_attribute_no_match(self, sample_dom):
items = sample_dom.query_selector_all('[data-x="nope"]')
assert items == []
def test_tag_class(self, sample_dom):
items = sample_dom.query_selector_all("span.highlight")
assert len(items) == 1
assert items[0].inner_content == "World"
def test_multiple_classes(self, sample_dom):
items = sample_dom.query_selector_all(".text.highlight")
assert len(items) == 1
assert items[0].inner_content == "World"
def test_tag_id_class(self, sample_dom):
items = sample_dom.query_selector_all("span#inner")
assert len(items) == 1
assert items[0].inner_content == "Inside"
def test_descendant(self, sample_dom):
items = sample_dom.query_selector_all("div span")
assert len(items) == 2
class TestXPath:
def test_simple_tag(self, sample_dom):
spans = sample_dom.xpath("//span")
assert len(spans) == 2
assert len(spans) == 3
def test_attribute_match(self, sample_dom):
nodes = sample_dom.xpath('//span[@id="inner"]')