Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force index for joins #722

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions pypika/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,21 +994,24 @@ def orderby(self, *fields: Any, **kwargs: Any) -> "QueryBuilder":

@builder
def join(
self, item: Union[Table, "QueryBuilder", AliasedQuery, Selectable], how: JoinType = JoinType.inner
self,
item: Union[Table, "QueryBuilder", AliasedQuery, Selectable],
how: JoinType = JoinType.inner,
force_index: Optional[str] = None,
) -> "Joiner":
if isinstance(item, Table):
return Joiner(self, item, how, type_label="table")
return Joiner(self, item, how, type_label="table", force_index=force_index)

elif isinstance(item, QueryBuilder):
if item.alias is None:
self._tag_subquery(item)
return Joiner(self, item, how, type_label="subquery")
return Joiner(self, item, how, type_label="subquery", force_index=force_index)

elif isinstance(item, AliasedQuery):
return Joiner(self, item, how, type_label="table")
return Joiner(self, item, how, type_label="table", force_index=force_index)

elif isinstance(item, Selectable):
return Joiner(self, item, how, type_label="subquery")
return Joiner(self, item, how, type_label="subquery", force_index=force_index)

raise ValueError("Cannot join on type '%s'" % type(item))

Expand Down Expand Up @@ -1539,12 +1542,17 @@ def _set_sql(self, **kwargs: Any) -> str:

class Joiner:
def __init__(
self, query: QueryBuilder, item: Union[Table, "QueryBuilder", AliasedQuery], how: JoinType, type_label: str
self,
query: QueryBuilder,
item: Union[Table, "QueryBuilder", AliasedQuery],
how: JoinType, type_label: str,
force_index: Optional[str] = None,
) -> None:
self.query = query
self.item = item
self.how = how
self.type_label = type_label
self.force_index = force_index

def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QueryBuilder:
if criterion is None:
Expand All @@ -1553,7 +1561,7 @@ def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> Q
"{type} JOIN but was not supplied.".format(type=self.type_label)
)

self.query.do_join(JoinOn(self.item, self.how, criterion, collate))
self.query.do_join(JoinOn(self.item, self.how, criterion, collate, self.force_index))
return self.query

def on_field(self, *fields: Any) -> QueryBuilder:
Expand All @@ -1567,7 +1575,7 @@ def on_field(self, *fields: Any) -> QueryBuilder:
consituent = Field(field, table=self.query._from[0]) == Field(field, table=self.item)
criterion = consituent if criterion is None else criterion & consituent

self.query.do_join(JoinOn(self.item, self.how, criterion))
self.query.do_join(JoinOn(self.item, self.how, criterion, self.force_index))
return self.query

def using(self, *fields: Any) -> QueryBuilder:
Expand All @@ -1579,21 +1587,25 @@ def using(self, *fields: Any) -> QueryBuilder:

def cross(self) -> QueryBuilder:
"""Return cross join"""
self.query.do_join(Join(self.item, JoinType.cross))
self.query.do_join(Join(self.item, JoinType.cross, self.force_index))

return self.query


class Join:
def __init__(self, item: Term, how: JoinType) -> None:
def __init__(self, item: Term, how: JoinType, force_index: Optional[str] = None) -> None:
self.item = item
self.how = how
self.force_index = force_index

def get_sql(self, **kwargs: Any) -> str:
sql = "JOIN {table}".format(
table=self.item.get_sql(subquery=True, with_alias=True, **kwargs),
)

if self.force_index:
sql = sql + " FORCE INDEX ({index})".format(index=self.force_index)

if self.how.value:
return "{type} {join}".format(join=sql, type=self.how.value)
return sql
Expand All @@ -1618,8 +1630,15 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl


class JoinOn(Join):
def __init__(self, item: Term, how: JoinType, criteria: QueryBuilder, collate: Optional[str] = None) -> None:
super().__init__(item, how)
def __init__(
self,
item: Term,
how: JoinType,
criteria: QueryBuilder,
collate: Optional[str] = None,
force_index: Optional[str] = None,
) -> None:
super().__init__(item, how, force_index)
self.criterion = criteria
self.collate = collate

Expand Down Expand Up @@ -1661,8 +1680,8 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl


class JoinUsing(Join):
def __init__(self, item: Term, how: JoinType, fields: Sequence[Field]) -> None:
super().__init__(item, how)
def __init__(self, item: Term, how: JoinType, fields: Sequence[Field], force_index: Optional[str] = None) -> None:
super().__init__(item, how, force_index)
self.fields = fields

def get_sql(self, **kwargs: Any) -> str:
Expand Down
14 changes: 14 additions & 0 deletions pypika/tests/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,20 @@ def test_temporal_join(self):
str(query),
)

def test_join_with_force_index(self):
table_a, table_b = Tables("a", "b")

q1 = (
Query.from_(table_a)
.select(table_b.ouch)
.join(table_b, force_index='PRIMARY')
.on(table_a.foo == table_b.boo)
)

self.assertEqual(
'SELECT "b"."ouch" FROM "a" JOIN "b" FORCE INDEX (PRIMARY) ON "a"."foo"="b"."boo"',
str(q1),
)

class JoinBehaviorTests(unittest.TestCase):
table_abc, table_efg, table_hij, table_klm = Tables("abc", "efg", "hij", "klm")
Expand Down