|
| 1 | +# pylint: disable=missing-function-docstring |
1 | 2 | import json
|
2 | 3 | from pathlib import Path
|
3 | 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
@@ -27,67 +28,102 @@ def __init__(self, template_name: str, path: str):
|
27 | 28 | )
|
28 | 29 |
|
29 | 30 |
|
| 31 | +class BadTemplateError(DvcRenderException): |
| 32 | + pass |
| 33 | + |
| 34 | + |
| 35 | +def dict_replace_value(d: dict, name: str, value: Any) -> dict: |
| 36 | + x = {} |
| 37 | + for k, v in d.items(): |
| 38 | + if isinstance(v, dict): |
| 39 | + v = dict_replace_value(v, name, value) |
| 40 | + elif isinstance(v, list): |
| 41 | + v = list_replace_value(v, name, value) |
| 42 | + elif isinstance(v, str): |
| 43 | + if v == name: |
| 44 | + x[k] = value |
| 45 | + continue |
| 46 | + x[k] = v |
| 47 | + return x |
| 48 | + |
| 49 | + |
| 50 | +def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741 |
| 51 | + x = [] |
| 52 | + for e in l: |
| 53 | + if isinstance(e, list): |
| 54 | + e = list_replace_value(e, name, value) |
| 55 | + elif isinstance(e, dict): |
| 56 | + e = dict_replace_value(e, name, value) |
| 57 | + elif isinstance(e, str): |
| 58 | + if e == name: |
| 59 | + e = value |
| 60 | + x.append(e) |
| 61 | + return x |
| 62 | + |
| 63 | + |
| 64 | +def dict_find_value(d: dict, value: str) -> bool: |
| 65 | + for v in d.values(): |
| 66 | + if isinstance(v, dict): |
| 67 | + return dict_find_value(v, value) |
| 68 | + if isinstance(v, str): |
| 69 | + if v == value: |
| 70 | + return True |
| 71 | + return False |
| 72 | + |
| 73 | + |
30 | 74 | class Template:
|
31 |
| - INDENT = 4 |
32 |
| - SEPARATORS = (",", ": ") |
33 | 75 | EXTENSION = ".json"
|
34 | 76 | ANCHOR = "<DVC_METRIC_{}>"
|
35 | 77 |
|
36 |
| - DEFAULT_CONTENT: Optional[Dict[str, Any]] = None |
37 |
| - DEFAULT_NAME: Optional[str] = None |
38 |
| - |
39 |
| - def __init__(self, content=None, name=None): |
40 |
| - if content: |
41 |
| - self.content = content |
42 |
| - else: |
43 |
| - self.content = ( |
44 |
| - json.dumps( |
45 |
| - self.DEFAULT_CONTENT, |
46 |
| - indent=self.INDENT, |
47 |
| - separators=self.SEPARATORS, |
48 |
| - ) |
49 |
| - + "\n" |
50 |
| - ) |
51 |
| - |
| 78 | + DEFAULT_CONTENT: Dict[str, Any] = {} |
| 79 | + DEFAULT_NAME: str = "" |
| 80 | + |
| 81 | + def __init__( |
| 82 | + self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None |
| 83 | + ): |
| 84 | + if ( |
| 85 | + content |
| 86 | + and not isinstance(content, dict) |
| 87 | + or self.DEFAULT_CONTENT |
| 88 | + and not isinstance(self.DEFAULT_CONTENT, dict) |
| 89 | + ): |
| 90 | + raise BadTemplateError() |
| 91 | + self._original_content = content or self.DEFAULT_CONTENT |
| 92 | + self.content: Dict[str, Any] = self._original_content |
52 | 93 | self.name = name or self.DEFAULT_NAME
|
53 |
| - assert self.content and self.name |
54 | 94 | self.filename = Path(self.name).with_suffix(self.EXTENSION)
|
55 | 95 |
|
56 | 96 | @classmethod
|
57 | 97 | def anchor(cls, name):
|
58 | 98 | "Get ANCHOR formatted with name."
|
59 | 99 | return cls.ANCHOR.format(name.upper())
|
60 | 100 |
|
61 |
| - def has_anchor(self, name) -> bool: |
62 |
| - "Check if ANCHOR formatted with name is in content." |
63 |
| - return self.anchor_str(name) in self.content |
64 |
| - |
65 |
| - @classmethod |
66 |
| - def fill_anchor(cls, content, name, value) -> str: |
67 |
| - "Replace anchor `name` with `value` in content." |
68 |
| - value_str = json.dumps( |
69 |
| - value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True |
70 |
| - ) |
71 |
| - return content.replace(cls.anchor_str(name), value_str) |
72 |
| - |
73 | 101 | @classmethod
|
74 | 102 | def escape_special_characters(cls, value: str) -> str:
|
75 | 103 | "Escape special characters in `value`"
|
76 | 104 | for character in (".", "[", "]"):
|
77 | 105 | value = value.replace(character, "\\" + character)
|
78 | 106 | return value
|
79 | 107 |
|
80 |
| - @classmethod |
81 |
| - def anchor_str(cls, name) -> str: |
82 |
| - "Get string wrapping ANCHOR formatted with name." |
83 |
| - return f'"{cls.anchor(name)}"' |
84 |
| - |
85 | 108 | @staticmethod
|
86 | 109 | def check_field_exists(data, field):
|
87 | 110 | "Raise NoFieldInDataError if `field` not in `data`."
|
88 | 111 | if not any(field in row for row in data):
|
89 | 112 | raise NoFieldInDataError(field)
|
90 | 113 |
|
| 114 | + def reset(self): |
| 115 | + """Reset self.content to its original state.""" |
| 116 | + self.content = self._original_content |
| 117 | + |
| 118 | + def has_anchor(self, name) -> bool: |
| 119 | + "Check if ANCHOR formatted with name is in content." |
| 120 | + found = dict_find_value(self.content, self.anchor(name)) |
| 121 | + return found |
| 122 | + |
| 123 | + def fill_anchor(self, name, value) -> None: |
| 124 | + "Replace anchor `name` with `value` in content." |
| 125 | + self.content = dict_replace_value(self.content, self.anchor(name), value) |
| 126 | + |
91 | 127 |
|
92 | 128 | class BarHorizontalSortedTemplate(Template):
|
93 | 129 | DEFAULT_NAME = "bar_horizontal_sorted"
|
@@ -606,7 +642,7 @@ def get_template(
|
606 | 642 | _open = open if fs is None else fs.open
|
607 | 643 | if template_path:
|
608 | 644 | with _open(template_path, encoding="utf-8") as f:
|
609 |
| - content = f.read() |
| 645 | + content = json.load(f) |
610 | 646 | return Template(content, name=template)
|
611 | 647 |
|
612 | 648 | for template_cls in TEMPLATES:
|
@@ -635,6 +671,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
|
635 | 671 | if path.exists():
|
636 | 672 | content = path.read_text(encoding="utf-8")
|
637 | 673 | if content != template.content:
|
638 |
| - raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path) |
| 674 | + raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path)) |
639 | 675 | else:
|
640 |
| - path.write_text(template.content, encoding="utf-8") |
| 676 | + path.write_text(json.dumps(template.content), encoding="utf-8") |
0 commit comments