Skip to content

Commit ba87a2e

Browse files
authored
vega_templates: Handle content as dict instead of string. (#124)
Prevent unnecessary `dumps`/`loads` calls. Closes #23
1 parent 14f91b2 commit ba87a2e

File tree

4 files changed

+100
-72
lines changed

4 files changed

+100
-72
lines changed

src/dvc_render/vega.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
from copy import deepcopy
1+
import json
22
from pathlib import Path
3-
from typing import List, Optional
3+
from typing import Any, Dict, List, Optional
44
from warnings import warn
55

66
from .base import Renderer
7-
from .exceptions import DvcRenderException
87
from .utils import list_dict_to_dict_list
9-
from .vega_templates import LinearTemplate, get_template
10-
11-
12-
class BadTemplateError(DvcRenderException):
13-
pass
8+
from .vega_templates import BadTemplateError, LinearTemplate, get_template
149

1510

1611
class VegaRenderer(Renderer):
@@ -44,16 +39,15 @@ def __init__(self, datapoints: List, name: str, **properties):
4439

4540
def get_filled_template(
4641
self, skip_anchors: Optional[List[str]] = None, strict: bool = True
47-
) -> str:
42+
) -> Dict[str, Any]:
4843
"""Returns a functional vega specification"""
44+
self.template.reset()
4945
if not self.datapoints:
50-
return ""
46+
return {}
5147

5248
if skip_anchors is None:
5349
skip_anchors = []
5450

55-
content = deepcopy(self.template.content)
56-
5751
if strict:
5852
if self.properties.get("x"):
5953
self.template.check_field_exists(
@@ -76,20 +70,20 @@ def get_filled_template(
7670
if value is None:
7771
continue
7872
if name == "data":
79-
if self.template.anchor_str(name) not in self.template.content:
73+
if not self.template.has_anchor(name):
8074
anchor = self.template.anchor(name)
8175
raise BadTemplateError(
8276
f"Template '{self.template.name}' "
8377
f"is not using '{anchor}' anchor"
8478
)
8579
elif name in {"x", "y"}:
8680
value = self.template.escape_special_characters(value)
87-
content = self.template.fill_anchor(content, name, value)
81+
self.template.fill_anchor(name, value)
8882

89-
return content
83+
return self.template.content
9084

9185
def partial_html(self, **kwargs) -> str:
92-
return self.get_filled_template()
86+
return json.dumps(self.get_filled_template())
9387

9488
def generate_markdown(self, report_path=None) -> str:
9589
if not isinstance(self.template, LinearTemplate):

src/dvc_render/vega_templates.py

+75-39
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=missing-function-docstring
12
import json
23
from pathlib import Path
34
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -27,67 +28,102 @@ def __init__(self, template_name: str, path: str):
2728
)
2829

2930

31+
class BadTemplateError(DvcRenderException):
32+
pass
33+
34+
35+
def dict_replace_value(d: dict, name: str, value: Any) -> dict:
36+
x = {}
37+
for k, v in d.items():
38+
if isinstance(v, dict):
39+
v = dict_replace_value(v, name, value)
40+
elif isinstance(v, list):
41+
v = list_replace_value(v, name, value)
42+
elif isinstance(v, str):
43+
if v == name:
44+
x[k] = value
45+
continue
46+
x[k] = v
47+
return x
48+
49+
50+
def list_replace_value(l: list, name: str, value: str) -> list: # noqa: E741
51+
x = []
52+
for e in l:
53+
if isinstance(e, list):
54+
e = list_replace_value(e, name, value)
55+
elif isinstance(e, dict):
56+
e = dict_replace_value(e, name, value)
57+
elif isinstance(e, str):
58+
if e == name:
59+
e = value
60+
x.append(e)
61+
return x
62+
63+
64+
def dict_find_value(d: dict, value: str) -> bool:
65+
for v in d.values():
66+
if isinstance(v, dict):
67+
return dict_find_value(v, value)
68+
if isinstance(v, str):
69+
if v == value:
70+
return True
71+
return False
72+
73+
3074
class Template:
31-
INDENT = 4
32-
SEPARATORS = (",", ": ")
3375
EXTENSION = ".json"
3476
ANCHOR = "<DVC_METRIC_{}>"
3577

36-
DEFAULT_CONTENT: Optional[Dict[str, Any]] = None
37-
DEFAULT_NAME: Optional[str] = None
38-
39-
def __init__(self, content=None, name=None):
40-
if content:
41-
self.content = content
42-
else:
43-
self.content = (
44-
json.dumps(
45-
self.DEFAULT_CONTENT,
46-
indent=self.INDENT,
47-
separators=self.SEPARATORS,
48-
)
49-
+ "\n"
50-
)
51-
78+
DEFAULT_CONTENT: Dict[str, Any] = {}
79+
DEFAULT_NAME: str = ""
80+
81+
def __init__(
82+
self, content: Optional[Dict[str, Any]] = None, name: Optional[str] = None
83+
):
84+
if (
85+
content
86+
and not isinstance(content, dict)
87+
or self.DEFAULT_CONTENT
88+
and not isinstance(self.DEFAULT_CONTENT, dict)
89+
):
90+
raise BadTemplateError()
91+
self._original_content = content or self.DEFAULT_CONTENT
92+
self.content: Dict[str, Any] = self._original_content
5293
self.name = name or self.DEFAULT_NAME
53-
assert self.content and self.name
5494
self.filename = Path(self.name).with_suffix(self.EXTENSION)
5595

5696
@classmethod
5797
def anchor(cls, name):
5898
"Get ANCHOR formatted with name."
5999
return cls.ANCHOR.format(name.upper())
60100

61-
def has_anchor(self, name) -> bool:
62-
"Check if ANCHOR formatted with name is in content."
63-
return self.anchor_str(name) in self.content
64-
65-
@classmethod
66-
def fill_anchor(cls, content, name, value) -> str:
67-
"Replace anchor `name` with `value` in content."
68-
value_str = json.dumps(
69-
value, indent=cls.INDENT, separators=cls.SEPARATORS, sort_keys=True
70-
)
71-
return content.replace(cls.anchor_str(name), value_str)
72-
73101
@classmethod
74102
def escape_special_characters(cls, value: str) -> str:
75103
"Escape special characters in `value`"
76104
for character in (".", "[", "]"):
77105
value = value.replace(character, "\\" + character)
78106
return value
79107

80-
@classmethod
81-
def anchor_str(cls, name) -> str:
82-
"Get string wrapping ANCHOR formatted with name."
83-
return f'"{cls.anchor(name)}"'
84-
85108
@staticmethod
86109
def check_field_exists(data, field):
87110
"Raise NoFieldInDataError if `field` not in `data`."
88111
if not any(field in row for row in data):
89112
raise NoFieldInDataError(field)
90113

114+
def reset(self):
115+
"""Reset self.content to its original state."""
116+
self.content = self._original_content
117+
118+
def has_anchor(self, name) -> bool:
119+
"Check if ANCHOR formatted with name is in content."
120+
found = dict_find_value(self.content, self.anchor(name))
121+
return found
122+
123+
def fill_anchor(self, name, value) -> None:
124+
"Replace anchor `name` with `value` in content."
125+
self.content = dict_replace_value(self.content, self.anchor(name), value)
126+
91127

92128
class BarHorizontalSortedTemplate(Template):
93129
DEFAULT_NAME = "bar_horizontal_sorted"
@@ -606,7 +642,7 @@ def get_template(
606642
_open = open if fs is None else fs.open
607643
if template_path:
608644
with _open(template_path, encoding="utf-8") as f:
609-
content = f.read()
645+
content = json.load(f)
610646
return Template(content, name=template)
611647

612648
for template_cls in TEMPLATES:
@@ -635,6 +671,6 @@ def dump_templates(output: "StrPath", targets: Optional[List] = None) -> None:
635671
if path.exists():
636672
content = path.read_text(encoding="utf-8")
637673
if content != template.content:
638-
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME or "", path)
674+
raise TemplateContentDoesNotMatch(template.DEFAULT_NAME, str(path))
639675
else:
640-
path.write_text(template.content, encoding="utf-8")
676+
path.write_text(json.dumps(template.content), encoding="utf-8")

tests/test_templates.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23

34
import pytest
@@ -38,8 +39,9 @@ def test_raise_on_no_template():
3839
],
3940
)
4041
def test_get_template_from_dir(tmp_dir, template_path, target_name):
41-
tmp_dir.gen(template_path, "template_content")
42-
assert get_template(target_name, ".dvc/plots").content == "template_content"
42+
template_content = {"template_content": "foo"}
43+
tmp_dir.gen(template_path, json.dumps(template_content))
44+
assert get_template(target_name, ".dvc/plots").content == template_content
4345

4446

4547
def test_get_template_exact_match(tmp_dir):
@@ -51,13 +53,16 @@ def test_get_template_exact_match(tmp_dir):
5153

5254

5355
def test_get_template_from_file(tmp_dir):
54-
tmp_dir.gen("foo/bar.json", "template_content")
55-
assert get_template("foo/bar.json").content == "template_content"
56+
template_content = {"template_content": "foo"}
57+
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
58+
assert get_template("foo/bar.json").content == template_content
5659

5760

5861
def test_get_template_fs(tmp_dir, mocker):
59-
tmp_dir.gen("foo/bar.json", "template_content")
62+
template_content = {"template_content": "foo"}
63+
tmp_dir.gen("foo/bar.json", json.dumps(template_content))
6064
fs = mocker.MagicMock()
65+
mocker.patch("json.load", return_value={})
6166
get_template("foo/bar.json", fs=fs)
6267
fs.open.assert_called()
6368
fs.exists.assert_called()

tests/test_vega.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
31
import pytest
42

53
from dvc_render.vega import BadTemplateError, VegaRenderer
@@ -33,7 +31,6 @@ def test_init_empty():
3331
assert renderer.name == ""
3432
assert renderer.properties == {}
3533

36-
assert renderer.generate_html() == ""
3734
assert renderer.generate_markdown("foo") == ""
3835

3936

@@ -43,7 +40,7 @@ def test_default_template_mark():
4340
{"first_val": 200, "second_val": 300, "val": 3},
4441
]
4542

46-
plot_content = json.loads(VegaRenderer(datapoints, "foo").partial_html())
43+
plot_content = VegaRenderer(datapoints, "foo").get_filled_template()
4744

4845
assert plot_content["layer"][0]["mark"] == "line"
4946

@@ -60,7 +57,7 @@ def test_choose_axes():
6057
{"first_val": 200, "second_val": 300, "val": 3},
6158
]
6259

63-
plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
60+
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()
6461

6562
assert plot_content["data"]["values"] == [
6663
{
@@ -85,7 +82,7 @@ def test_confusion():
8582
]
8683
props = {"template": "confusion", "x": "predicted", "y": "actual"}
8784

88-
plot_content = json.loads(VegaRenderer(datapoints, "foo", **props).partial_html())
85+
plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template()
8986

9087
assert plot_content["data"]["values"] == [
9188
{"predicted": "B", "actual": "A"},
@@ -100,12 +97,8 @@ def test_confusion():
10097

10198

10299
def test_bad_template():
103-
datapoints = [{"val": 2}, {"val": 3}]
104-
props = {"template": Template("name", "content")}
105-
renderer = VegaRenderer(datapoints, "foo", **props)
106100
with pytest.raises(BadTemplateError):
107-
renderer.get_filled_template()
108-
renderer.get_filled_template(skip_anchors=["data"])
101+
Template("name", "content")
109102

110103

111104
def test_raise_on_wrong_field():
@@ -177,7 +170,7 @@ def test_escape_special_characters():
177170
]
178171
props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"}
179172
renderer = VegaRenderer(datapoints, "foo", **props)
180-
filled = json.loads(renderer.get_filled_template())
173+
filled = renderer.get_filled_template()
181174
# data is not escaped
182175
assert filled["data"]["values"][0] == datapoints[0]
183176
# field and title yes

0 commit comments

Comments
 (0)