Coverage for tdom / processor.py: 97%

283 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-12 16:43 +0000

1import sys 

2import typing as t 

3from collections.abc import Iterable, Sequence 

4from functools import lru_cache 

5from string.templatelib import Interpolation, Template 

6from dataclasses import dataclass 

7 

8from markupsafe import Markup 

9 

10from .callables import get_callable_info 

11from .format import format_interpolation as base_format_interpolation 

12from .format import format_template 

13from .nodes import Comment, DocumentType, Element, Fragment, Node, Text 

14from .parser import ( 

15 HTMLAttribute, 

16 HTMLAttributesDict, 

17 TAttribute, 

18 TComment, 

19 TComponent, 

20 TDocumentType, 

21 TElement, 

22 TemplateParser, 

23 TFragment, 

24 TInterpolatedAttribute, 

25 TLiteralAttribute, 

26 TNode, 

27 TSpreadAttribute, 

28 TTemplatedAttribute, 

29 TText, 

30) 

31from .placeholders import TemplateRef 

32from .template_utils import template_from_parts 

33from .utils import CachableTemplate, LastUpdatedOrderedDict 

34 

35 

36@t.runtime_checkable 

37class HasHTMLDunder(t.Protocol): 

38 def __html__(self) -> str: ... # pragma: no cover 

39 

40 

41# TODO: in Ian's original PR, this caching was tethered to the 

42# TemplateParser. Here, it's tethered to the processor. I suspect we'll 

43# revisit this soon enough. 

44 

45 

46@lru_cache(maxsize=0 if "pytest" in sys.modules else 512) 

47def _parse_and_cache(cachable: CachableTemplate) -> TNode: 

48 return TemplateParser.parse(cachable.template) 

49 

50 

51type Attribute = tuple[str, object] 

52type AttributesDict = dict[str, object] 

53 

54 

55# -------------------------------------------------------------------------- 

56# Custom formatting for the processor 

57# -------------------------------------------------------------------------- 

58 

59 

60def _format_safe(value: object, format_spec: str) -> str: 

61 """Use Markup() to mark a value as safe HTML.""" 

62 assert format_spec == "safe" 

63 return Markup(value) 

64 

65 

66def _format_unsafe(value: object, format_spec: str) -> str: 

67 """Convert a value to a plain string, forcing it to be treated as unsafe.""" 

68 assert format_spec == "unsafe" 

69 return str(value) 

70 

71 

72CUSTOM_FORMATTERS = (("safe", _format_safe), ("unsafe", _format_unsafe)) 

73 

74 

75def format_interpolation(interpolation: Interpolation) -> object: 

76 return base_format_interpolation( 

77 interpolation, 

78 formatters=CUSTOM_FORMATTERS, 

79 ) 

80 

81 

82# -------------------------------------------------------------------------- 

83# Placeholder Substitution 

84# -------------------------------------------------------------------------- 

85 

86 

87def _expand_aria_attr(value: object) -> t.Iterable[HTMLAttribute]: 

88 """Produce aria-* attributes based on the interpolated value for "aria".""" 

89 if value is None: 

90 return 

91 elif isinstance(value, dict): 

92 for sub_k, sub_v in value.items(): 

93 if sub_v is True: 

94 yield f"aria-{sub_k}", "true" 

95 elif sub_v is False: 

96 yield f"aria-{sub_k}", "false" 

97 elif sub_v is None: 

98 yield f"aria-{sub_k}", None 

99 else: 

100 yield f"aria-{sub_k}", str(sub_v) 

101 else: 

102 raise TypeError( 

103 f"Cannot use {type(value).__name__} as value for aria attribute" 

104 ) 

105 

106 

107def _expand_data_attr(value: object) -> t.Iterable[Attribute]: 

108 """Produce data-* attributes based on the interpolated value for "data".""" 

109 if value is None: 

110 return 

111 elif isinstance(value, dict): 

112 for sub_k, sub_v in value.items(): 

113 if sub_v is True or sub_v is False or sub_v is None: 

114 yield f"data-{sub_k}", sub_v 

115 else: 

116 yield f"data-{sub_k}", str(sub_v) 

117 else: 

118 raise TypeError( 

119 f"Cannot use {type(value).__name__} as value for data attribute" 

120 ) 

121 

122 

123def _substitute_spread_attrs(value: object) -> t.Iterable[Attribute]: 

124 """ 

125 Substitute a spread attribute based on the interpolated value. 

126 

127 A spread attribute is one where the key is a placeholder, indicating that 

128 the entire attribute set should be replaced by the interpolated value. 

129 The value must be a dict or iterable of key-value pairs. 

130 """ 

131 if value is None: 

132 return 

133 elif isinstance(value, dict): 

134 yield from value.items() 

135 else: 

136 raise TypeError( 

137 f"Cannot use {type(value).__name__} as value for spread attributes" 

138 ) 

139 

140 

141ATTR_EXPANDERS = { 

142 "data": _expand_data_attr, 

143 "aria": _expand_aria_attr, 

144} 

145 

146 

147def parse_style_attribute_value(style_str: str) -> list[tuple[str, str | None]]: 

148 """ 

149 Parse the style declarations out of a style attribute string. 

150 """ 

151 props = [p.strip() for p in style_str.split(";")] 

152 styles: list[tuple[str, str | None]] = [] 

153 for prop in props: 

154 if prop: 

155 prop_parts = [p.strip() for p in prop.split(":") if p.strip()] 

156 if len(prop_parts) != 2: 

157 raise ValueError( 

158 f"Invalid number of parts for style property {prop} in {style_str}" 

159 ) 

160 styles.append((prop_parts[0], prop_parts[1])) 

161 return styles 

162 

163 

164def make_style_accumulator(old_value: object) -> StyleAccumulator: 

165 """ 

166 Initialize the style accumulator. 

167 """ 

168 match old_value: 

169 case str(): 

170 styles = { 

171 name: value for name, value in parse_style_attribute_value(old_value) 

172 } 

173 case True: # A bare attribute will just default to {}. 

174 styles = {} 

175 case _: 

176 raise TypeError(f"Unexpected value: {old_value}") 

177 return StyleAccumulator(styles=styles) 

178 

179 

180@dataclass 

181class StyleAccumulator: 

182 styles: dict[str, str | None] 

183 

184 def merge_value(self, value: object) -> None: 

185 """ 

186 Merge in an interpolated style value. 

187 """ 

188 match value: 

189 case str(): 

190 self.styles.update( 

191 {name: value for name, value in parse_style_attribute_value(value)} 

192 ) 

193 case dict(): 

194 self.styles.update( 

195 { 

196 str(pn): str(pv) if pv is not None else None 

197 for pn, pv in value.items() 

198 } 

199 ) 

200 case None: 

201 pass 

202 case _: 

203 raise TypeError( 

204 f"Unknown interpolated style value {value}, use '' to omit." 

205 ) 

206 

207 def to_value(self) -> str | None: 

208 """ 

209 Serialize the special style value back into a string. 

210 

211 @NOTE: If the result would be `''` then use `None` to omit the attribute. 

212 """ 

213 style_value = "; ".join( 

214 [f"{pn}: {pv}" for pn, pv in self.styles.items() if pv is not None] 

215 ) 

216 return style_value if style_value else None 

217 

218 

219def make_class_accumulator(old_value: object) -> ClassAccumulator: 

220 """ 

221 Initialize the class accumulator. 

222 """ 

223 match old_value: 

224 case str(): 

225 toggled_classes = {cn: True for cn in old_value.split()} 

226 case True: 

227 toggled_classes = {} 

228 case _: 

229 raise ValueError(f"Unexpected value {old_value}") 

230 return ClassAccumulator(toggled_classes=toggled_classes) 

231 

232 

233@dataclass 

234class ClassAccumulator: 

235 toggled_classes: dict[str, bool] 

236 

237 def merge_value(self, value: object) -> None: 

238 """ 

239 Merge in an interpolated class value. 

240 """ 

241 if isinstance(value, dict): 

242 self.toggled_classes.update( 

243 {str(cn): bool(toggle) for cn, toggle in value.items()} 

244 ) 

245 else: 

246 if not isinstance(value, str) and isinstance(value, Sequence): 

247 items = value[:] 

248 else: 

249 items = (value,) 

250 for item in items: 

251 match item: 

252 case str(): 

253 self.toggled_classes.update({cn: True for cn in item.split()}) 

254 case None: 

255 pass 

256 case _: 

257 if item == value: 

258 raise TypeError( 

259 f"Unknown interpolated class value: {value}" 

260 ) 

261 else: 

262 raise TypeError( 

263 f"Unknown interpolated class item in {value}: {item}" 

264 ) 

265 

266 def to_value(self) -> str | None: 

267 """ 

268 Serialize the special class value back into a string. 

269 

270 @NOTE: If the result would be `''` then use `None` to omit the attribute. 

271 """ 

272 class_value = " ".join( 

273 [cn for cn, toggle in self.toggled_classes.items() if toggle] 

274 ) 

275 return class_value if class_value else None 

276 

277 

278ATTR_ACCUMULATOR_MAKERS = { 

279 "class": make_class_accumulator, 

280 "style": make_style_accumulator, 

281} 

282 

283 

284type AttributeValueAccumulator = StyleAccumulator | ClassAccumulator 

285 

286 

287def _resolve_t_attrs( 

288 attrs: t.Sequence[TAttribute], interpolations: tuple[Interpolation, ...] 

289) -> AttributesDict: 

290 """ 

291 Replace placeholder values in attributes with their interpolated values. 

292 

293 The values returned are not yet processed for HTML output; that is handled 

294 in a later step. 

295 """ 

296 new_attrs: AttributesDict = LastUpdatedOrderedDict() 

297 attr_accs: dict[str, AttributeValueAccumulator] = {} 

298 for attr in attrs: 

299 match attr: 

300 case TLiteralAttribute(name=name, value=value): 

301 attr_value = True if value is None else value 

302 if name in ATTR_ACCUMULATOR_MAKERS and name in new_attrs: 

303 if name not in attr_accs: 

304 attr_accs[name] = ATTR_ACCUMULATOR_MAKERS[name](new_attrs[name]) 

305 new_attrs[name] = attr_accs[name].merge_value(attr_value) 

306 else: 

307 new_attrs[name] = attr_value 

308 case TInterpolatedAttribute(name=name, value_i_index=i_index): 

309 interpolation = interpolations[i_index] 

310 attr_value = format_interpolation(interpolation) 

311 if name in ATTR_ACCUMULATOR_MAKERS: 

312 if name not in attr_accs: 

313 attr_accs[name] = ATTR_ACCUMULATOR_MAKERS[name]( 

314 new_attrs.get(name, True) 

315 ) 

316 new_attrs[name] = attr_accs[name].merge_value(attr_value) 

317 elif expander := ATTR_EXPANDERS.get(name): 

318 for sub_k, sub_v in expander(attr_value): 

319 new_attrs[sub_k] = sub_v 

320 else: 

321 new_attrs[name] = attr_value 

322 case TTemplatedAttribute(name=name, value_ref=ref): 

323 attr_t = _resolve_ref(ref, interpolations) 

324 attr_value = format_template(attr_t) 

325 if name in ATTR_ACCUMULATOR_MAKERS: 

326 if name not in attr_accs: 

327 attr_accs[name] = ATTR_ACCUMULATOR_MAKERS[name]( 

328 new_attrs.get(name, True) 

329 ) 

330 new_attrs[name] = attr_accs[name].merge_value(attr_value) 

331 elif expander := ATTR_EXPANDERS.get(name): 

332 raise TypeError(f"{name} attributes cannot be templated") 

333 else: 

334 new_attrs[name] = attr_value 

335 case TSpreadAttribute(i_index=i_index): 

336 interpolation = interpolations[i_index] 

337 spread_value = format_interpolation(interpolation) 

338 for sub_k, sub_v in _substitute_spread_attrs(spread_value): 

339 if sub_k in ATTR_ACCUMULATOR_MAKERS: 

340 if sub_k not in attr_accs: 

341 attr_accs[sub_k] = ATTR_ACCUMULATOR_MAKERS[sub_k]( 

342 new_attrs.get(sub_k, True) 

343 ) 

344 new_attrs[sub_k] = attr_accs[sub_k].merge_value(sub_v) 

345 elif expander := ATTR_EXPANDERS.get(sub_k): 

346 for exp_k, exp_v in expander(sub_v): 

347 new_attrs[exp_k] = exp_v 

348 else: 

349 new_attrs[sub_k] = sub_v 

350 case _: 

351 raise ValueError(f"Unknown TAttribute type: {type(attr).__name__}") 

352 for acc_name, acc in attr_accs.items(): 

353 new_attrs[acc_name] = acc.to_value() 

354 return new_attrs 

355 

356 

357def _resolve_html_attrs(attrs: AttributesDict) -> HTMLAttributesDict: 

358 """Resolve attribute values for HTML output.""" 

359 html_attrs: HTMLAttributesDict = {} 

360 for key, value in attrs.items(): 

361 match value: 

362 case True: 

363 html_attrs[key] = None 

364 case False | None: 

365 pass 

366 case _: 

367 html_attrs[key] = str(value) 

368 return html_attrs 

369 

370 

371def _resolve_attrs( 

372 attrs: t.Sequence[TAttribute], interpolations: tuple[Interpolation, ...] 

373) -> HTMLAttributesDict: 

374 """ 

375 Substitute placeholders in attributes for HTML elements. 

376 

377 This is the full pipeline: interpolation + HTML processing. 

378 """ 

379 interpolated_attrs = _resolve_t_attrs(attrs, interpolations) 

380 return _resolve_html_attrs(interpolated_attrs) 

381 

382 

383def _flatten_nodes(nodes: t.Iterable[Node]) -> list[Node]: 

384 """Flatten a list of Nodes, expanding any Fragments.""" 

385 flat: list[Node] = [] 

386 for node in nodes: 

387 if isinstance(node, Fragment): 

388 flat.extend(node.children) 

389 else: 

390 flat.append(node) 

391 return flat 

392 

393 

394def _substitute_and_flatten_children( 

395 children: t.Iterable[TNode], interpolations: tuple[Interpolation, ...] 

396) -> list[Node]: 

397 """Substitute placeholders in a list of children and flatten any fragments.""" 

398 resolved = [_resolve_t_node(child, interpolations) for child in children] 

399 flat = _flatten_nodes(resolved) 

400 return flat 

401 

402 

403def _node_from_value(value: object) -> Node: 

404 """ 

405 Convert an arbitrary value to a Node. 

406 

407 This is the primary action performed when replacing interpolations in child 

408 content positions. 

409 """ 

410 match value: 

411 case str(): 

412 return Text(value) 

413 case Node(): 

414 return value 

415 case Template(): 

416 return html(value) 

417 # Consider: falsey values, not just False and None? 

418 case False | None: 

419 return Fragment(children=[]) 

420 case Iterable(): 

421 children = [_node_from_value(v) for v in value] 

422 return Fragment(children=children) 

423 case HasHTMLDunder(): 

424 # CONSIDER: should we do this lazily? 

425 return Text(Markup(value.__html__())) 

426 case c if callable(c): 

427 # Treat all callable values in child content positions as if 

428 # they are zero-arg functions that return a value to be rendered. 

429 return _node_from_value(c()) 

430 case _: 

431 # CONSIDER: should we do this lazily? 

432 return Text(str(value)) 

433 

434 

435def _kebab_to_snake(name: str) -> str: 

436 """Convert a kebab-case name to snake_case.""" 

437 return name.replace("-", "_").lower() 

438 

439 

440def _invoke_component( 

441 attrs: AttributesDict, 

442 children: list[Node], # TODO: why not TNode, though? 

443 interpolation: Interpolation, 

444) -> Node: 

445 """ 

446 Invoke a component callable with the provided attributes and children. 

447 

448 Components are any callable that meets the required calling signature. 

449 Typically, that's a function, but it could also be the constructor or 

450 __call__() method for a class; dataclass constructors match our expected 

451 invocation style. 

452 

453 We validate the callable's signature and invoke it with keyword-only 

454 arguments, then convert the result to a Node. 

455 

456 Component invocation rules: 

457 

458 1. All arguments are passed as keywords only. Components cannot require 

459 positional arguments. 

460 

461 2. Children are passed via a "children" parameter when: 

462 

463 - Child content exists in the template AND 

464 - The callable accepts "children" OR has **kwargs 

465 

466 If no children exist but the callable accepts "children", we pass an 

467 empty tuple. 

468 

469 3. All other attributes are converted from kebab-case to snake_case 

470 and passed as keyword arguments if the callable accepts them (or has 

471 **kwargs). Attributes that don't match parameters are silently ignored. 

472 """ 

473 value = format_interpolation(interpolation) 

474 if not callable(value): 

475 raise TypeError( 

476 f"Expected a callable for component invocation, got {type(value).__name__}" 

477 ) 

478 callable_info = get_callable_info(value) 

479 

480 if callable_info.requires_positional: 

481 raise TypeError( 

482 "Component callables cannot have required positional arguments." 

483 ) 

484 

485 kwargs: AttributesDict = {} 

486 

487 # Add all supported attributes 

488 for attr_name, attr_value in attrs.items(): 

489 snake_name = _kebab_to_snake(attr_name) 

490 if snake_name in callable_info.named_params or callable_info.kwargs: 

491 kwargs[snake_name] = attr_value 

492 

493 # Add children if appropriate 

494 if "children" in callable_info.named_params or callable_info.kwargs: 

495 kwargs["children"] = tuple(children) 

496 

497 # Check to make sure we've fully satisfied the callable's requirements 

498 missing = callable_info.required_named_params - kwargs.keys() 

499 if missing: 

500 raise TypeError( 

501 f"Missing required parameters for component: {', '.join(missing)}" 

502 ) 

503 

504 result = value(**kwargs) 

505 return _node_from_value(result) 

506 

507 

508def _resolve_ref( 

509 ref: TemplateRef, interpolations: tuple[Interpolation, ...] 

510) -> Template: 

511 resolved = [interpolations[i_index] for i_index in ref.i_indexes] 

512 return template_from_parts(ref.strings, resolved) 

513 

514 

515def _resolve_t_text_ref( 

516 ref: TemplateRef, interpolations: tuple[Interpolation, ...] 

517) -> Text | Fragment: 

518 """Resolve a TText ref into Text or Fragment by processing interpolations.""" 

519 if ref.is_literal: 

520 return Text(ref.strings[0]) 

521 

522 parts = [ 

523 Text(part) 

524 if isinstance(part, str) 

525 else _node_from_value(format_interpolation(part)) 

526 for part in _resolve_ref(ref, interpolations) 

527 ] 

528 flat = _flatten_nodes(parts) 

529 

530 if len(flat) == 1 and isinstance(flat[0], Text): 

531 return flat[0] 

532 

533 return Fragment(children=flat) 

534 

535 

536def _resolve_t_node(t_node: TNode, interpolations: tuple[Interpolation, ...]) -> Node: 

537 """Resolve a TNode tree into a Node tree by processing interpolations.""" 

538 match t_node: 

539 case TText(ref=ref): 

540 return _resolve_t_text_ref(ref, interpolations) 

541 case TComment(ref=ref): 

542 comment_t = _resolve_ref(ref, interpolations) 

543 comment = format_template(comment_t) 

544 return Comment(comment) 

545 case TDocumentType(text=text): 

546 return DocumentType(text) 

547 case TFragment(children=children): 

548 resolved_children = _substitute_and_flatten_children( 

549 children, interpolations 

550 ) 

551 return Fragment(children=resolved_children) 

552 case TElement(tag=tag, attrs=attrs, children=children): 

553 resolved_attrs = _resolve_attrs(attrs, interpolations) 

554 resolved_children = _substitute_and_flatten_children( 

555 children, interpolations 

556 ) 

557 return Element(tag=tag, attrs=resolved_attrs, children=resolved_children) 

558 case TComponent( 

559 start_i_index=start_i_index, 

560 end_i_index=end_i_index, 

561 attrs=t_attrs, 

562 children=children, 

563 ): 

564 start_interpolation = interpolations[start_i_index] 

565 end_interpolation = ( 

566 None if end_i_index is None else interpolations[end_i_index] 

567 ) 

568 resolved_attrs = _resolve_t_attrs(t_attrs, interpolations) 

569 resolved_children = _substitute_and_flatten_children( 

570 children, interpolations 

571 ) 

572 # HERE ALSO BE DRAGONS: validate matching start/end callables, since 

573 # the underlying TemplateParser cannot do that for us. 

574 if ( 

575 end_interpolation is not None 

576 and end_interpolation.value != start_interpolation.value 

577 ): 

578 raise TypeError("Mismatched component start and end callables.") 

579 return _invoke_component( 

580 attrs=resolved_attrs, 

581 children=resolved_children, 

582 interpolation=start_interpolation, 

583 ) 

584 case _: 

585 raise ValueError(f"Unknown TNode type: {type(t_node).__name__}") 

586 

587 

588# -------------------------------------------------------------------------- 

589# Public API 

590# -------------------------------------------------------------------------- 

591 

592 

593def html(template: Template) -> Node: 

594 """Parse an HTML t-string, substitue values, and return a tree of Nodes.""" 

595 cachable = CachableTemplate(template) 

596 t_node = _parse_and_cache(cachable) 

597 return _resolve_t_node(t_node, template.interpolations)