Skip to content

Commit f49c2f4

Browse files
committed
Enable querying related flexible attributes
1 parent 0712a2a commit f49c2f4

File tree

6 files changed

+107
-54
lines changed

6 files changed

+107
-54
lines changed

beets/dbcore/db.py

+1-28
Original file line numberDiff line numberDiff line change
@@ -743,33 +743,6 @@ def set_parse(self, key, string: str):
743743
"""Set the object's key to a value represented by a string."""
744744
self[key] = self._parse(key, string)
745745

746-
# Convenient queries.
747-
748-
@classmethod
749-
def field_query(
750-
cls,
751-
field,
752-
pattern,
753-
query_cls: Type[FieldQuery] = MatchQuery,
754-
) -> FieldQuery:
755-
"""Get a `FieldQuery` for this model."""
756-
return query_cls(field, pattern, field in cls._fields)
757-
758-
@classmethod
759-
def all_fields_query(
760-
cls: Type["Model"],
761-
pats: Mapping,
762-
query_cls: Type[FieldQuery] = MatchQuery,
763-
):
764-
"""Get a query that matches many fields with different patterns.
765-
766-
`pats` should be a mapping from field names to patterns. The
767-
resulting query is a conjunction ("and") of per-field queries
768-
for all of these field/pattern pairs.
769-
"""
770-
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
771-
return AndQuery(subqueries)
772-
773746

774747
# Database controller and supporting interfaces.
775748

@@ -1246,7 +1219,7 @@ def _fetch(
12461219

12471220
relation_fields = query.model_fields - set(model_cls._fields)
12481221
_from = model_cls.table_with_flex_attrs
1249-
if relation_fields:
1222+
if relation_fields or query.flex_fields:
12501223
_from += f" {model_cls.relation_join}"
12511224

12521225
table = model_cls._table_name

beets/dbcore/query.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ def model_fields(self) -> Set[str]:
9292
"""Return query fields that are (any) model attributes."""
9393
return {f for f, fast in self.fields_info if fast}
9494

95+
@property
96+
def flex_fields(self) -> Set[str]:
97+
"""Return query fields that are (any) model attributes."""
98+
return {f for f, fast in self.fields_info if not fast}
99+
95100
def clause(self) -> Tuple[Optional[str], Sequence[Any]]:
96101
"""Generate an SQLite expression implementing the query.
97102
@@ -149,7 +154,10 @@ def fields_info(self) -> Set[Tuple[str, bool]]:
149154
@property
150155
def col_name(self) -> str:
151156
if not self.fast:
152-
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
157+
column = '"flex_attrs [json_str]"'
158+
if self.table:
159+
column = f"{self.table}.{column}"
160+
return f'json_extract({column}, "$.{self.field}")'
153161

154162
return f"{self.table}.{self.field}" if self.table else self.field
155163

beets/dbcore/queryparse.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@
1616

1717
import itertools
1818
import re
19-
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
19+
from typing import (
20+
TYPE_CHECKING,
21+
Collection,
22+
Dict,
23+
List,
24+
Optional,
25+
Sequence,
26+
Tuple,
27+
Type,
28+
)
2029

21-
from .. import library
2230
from . import Model, query
2331
from .query import Sort
2432

33+
if TYPE_CHECKING:
34+
from ..library import LibModel
35+
2536
PARSE_QUERY_PART_REGEX = re.compile(
2637
# Non-capturing optional segment for the keyword.
2738
r"(-|\^)?" # Negation prefixes.
@@ -105,7 +116,7 @@ def parse_query_part(
105116

106117

107118
def construct_query_part(
108-
model_cls: Type[Model],
119+
model_cls: Type["LibModel"],
109120
prefixes: Dict,
110121
query_part: str,
111122
) -> query.Query:
@@ -153,17 +164,7 @@ def construct_query_part(
153164
# Field queries get constructed according to the name of the field
154165
# they are querying.
155166
else:
156-
key = key.lower()
157-
album_fields = library.Album._fields.keys()
158-
item_fields = library.Item._fields.keys()
159-
fast = key in album_fields | item_fields
160-
161-
if key in album_fields & item_fields:
162-
# This field exists in both tables, so SQLite will encounter
163-
# an OperationalError. Using an explicit table name resolves this.
164-
key = f"{model_cls._table_name}.{key}"
165-
166-
out_query = query_class(key, pattern, fast)
167+
out_query = model_cls.field_query(key.lower(), pattern, query_class)
167168

168169
# Apply negation.
169170
if negate:

beets/library.py

+55-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sys
2323
import time
2424
import unicodedata
25+
from typing import Mapping, Set, Type
2526

2627
from mediafile import MediaFile, UnreadableFileError
2728

@@ -387,6 +388,14 @@ class LibModel(dbcore.Model):
387388
# Config key that specifies how an instance should be formatted.
388389
_format_config_key = None
389390

391+
@cached_classproperty
392+
def all_model_db_fields(cls) -> Set[str]:
393+
return Album._fields.keys() | Item._fields.keys()
394+
395+
@cached_classproperty
396+
def shared_model_db_fields(cls) -> Set[str]:
397+
return Album._fields.keys() & Item._fields.keys()
398+
390399
def _template_funcs(self):
391400
funcs = DefaultTemplateFunctions(self, self._db).functions()
392401
funcs.update(plugins.template_funcs())
@@ -416,6 +425,46 @@ def __str__(self):
416425
def __bytes__(self):
417426
return self.__str__().encode("utf-8")
418427

428+
# Convenient queries.
429+
430+
@classmethod
431+
def field_query(
432+
cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery]
433+
) -> dbcore.Query:
434+
"""Get a `FieldQuery` for this model."""
435+
if field not in cls.all_model_db_fields:
436+
# this is a flexible attribute. Since we do not know which entity
437+
# it belongs, we check them both
438+
tables = [Item._table_name, Album._table_name]
439+
return dbcore.OrQuery(
440+
[query_cls(f"{t}.{field}", pattern, False) for t in tables]
441+
)
442+
443+
if field in cls.shared_model_db_fields:
444+
# This field exists in both tables, so SQLite will encounter
445+
# an OperationalError if we try to use it in a query.
446+
# Using an explicit table name resolves this.
447+
field = f"{cls._table_name}.{field}"
448+
449+
return query_cls(field, pattern, True)
450+
451+
@classmethod
452+
def all_fields_query(
453+
cls, pattern_by_field: Mapping[str, str]
454+
) -> dbcore.AndQuery:
455+
"""Get a query that matches many fields with different patterns.
456+
457+
`pattern_by_field` should be a mapping from field names to patterns.
458+
The resulting query is a conjunction ("and") of per-field queries
459+
for all of these field/pattern pairs.
460+
"""
461+
return dbcore.AndQuery(
462+
[
463+
cls.field_query(f, p, dbcore.MatchQuery)
464+
for f, p in pattern_by_field.items()
465+
]
466+
)
467+
419468

420469
class FormattedItemMapping(dbcore.db.FormattedMapping):
421470
"""Add lookup for album-level fields.
@@ -648,8 +697,7 @@ def relation_join(cls) -> str:
648697
We need to use a LEFT JOIN here, otherwise items that are not part of
649698
an album (e.g. singletons) would be left out.
650699
"""
651-
other_table = Album._table_name
652-
return f"LEFT JOIN {other_table} ON {cls._table_name}.album_id = {other_table}.id"
700+
return f"LEFT JOIN {Album.table_with_flex_attrs} ON {cls._table_name}.album_id = {Album._table_name}.id"
653701

654702
@property
655703
def _cached_album(self):
@@ -1258,8 +1306,7 @@ def relation_join(cls) -> str:
12581306
Here we can use INNER JOIN (which is more performant than LEFT JOIN),
12591307
since we only want to see albums that
12601308
"""
1261-
other_table = Item._table_name
1262-
return f"INNER JOIN {other_table} ON {cls._table_name}.id = {other_table}.album_id"
1309+
return f"INNER JOIN {Item.table_with_flex_attrs} ON {cls._table_name}.id = {Item._table_name}.album_id"
12631310

12641311
@classmethod
12651312
def _getters(cls):
@@ -1949,9 +1996,10 @@ def _tmpl_unique(
19491996
subqueries.extend(initial_subqueries)
19501997
for key in keys:
19511998
value = db_item.get(key, "")
1952-
# Use slow queries for flexible attributes.
1953-
fast = key in item_keys
1954-
subqueries.append(dbcore.MatchQuery(key, value, fast))
1999+
subqueries.append(
2000+
db_item.field_query(key, value, dbcore.MatchQuery)
2001+
)
2002+
19552003
query = dbcore.AndQuery(subqueries)
19562004
ambigous_items = (
19572005
self.lib.items(query)

test/test_dbcore.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tempfile import mkstemp
2222

2323
from beets import dbcore
24+
from beets.library import LibModel
2425
from beets.test import _common
2526

2627
# Fixture: concrete database and model classes. For migration tests, we
@@ -42,7 +43,7 @@ def match(self):
4243
return True
4344

4445

45-
class ModelFixture1(dbcore.Model):
46+
class ModelFixture1(LibModel):
4647
_table_name = "test"
4748
_flex_table = "testflex"
4849
_fields = {
@@ -590,15 +591,17 @@ def test_two_parts(self):
590591
self.assertIsInstance(q, dbcore.query.AndQuery)
591592
self.assertEqual(len(q.subqueries), 2)
592593
self.assertIsInstance(q.subqueries[0], dbcore.query.AnyFieldQuery)
593-
self.assertIsInstance(q.subqueries[1], dbcore.query.SubstringQuery)
594+
self.assertIsInstance(q.subqueries[1], dbcore.query.OrQuery)
595+
self.assertIsInstance(q.subqueries[1][0], dbcore.query.SubstringQuery)
594596

595597
def test_parse_fixed_type_query(self):
596598
q = self.qfs(["field_one:2..3"])
597-
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
599+
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
600+
self.assertIsInstance(q.subqueries[0][0], dbcore.query.NumericQuery)
598601

599602
def test_parse_flex_type_query(self):
600603
q = self.qfs(["some_float_field:2..3"])
601-
self.assertIsInstance(q.subqueries[0], dbcore.query.NumericQuery)
604+
self.assertIsInstance(q.subqueries[0], dbcore.query.OrQuery)
602605

603606
def test_empty_query_part(self):
604607
q = self.qfs([""])

test/test_query.py

+20
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,26 @@ def test_get_albums_filter_by_album_flex(self):
11521152
results = self.lib.albums(q)
11531153
self.assert_albums_matched(results, ["Album1"])
11541154

1155+
def test_get_albums_filter_by_track_flex(self):
1156+
q = "item_flex1:Album1"
1157+
results = self.lib.albums(q)
1158+
self.assert_albums_matched(results, ["Album1"])
1159+
1160+
def test_get_items_filter_by_album_flex(self):
1161+
q = "album_flex:Album1"
1162+
results = self.lib.items(q)
1163+
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])
1164+
1165+
def test_filter_by_flex(self):
1166+
q = "item_flex1:'Item1 Flex1'"
1167+
results = self.lib.items(q)
1168+
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])
1169+
1170+
def test_filter_by_many_flex(self):
1171+
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
1172+
results = self.lib.items(q)
1173+
self.assert_items_matched(results, ["Album1 Item1"])
1174+
11551175

11561176
def suite():
11571177
return unittest.TestLoader().loadTestsFromName(__name__)

0 commit comments

Comments
 (0)