Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crosscluster: add new insert/update/delete replication statements #143157

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pkg/crosscluster/logical/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ go_library(
"offline_initial_scan_processor.go",
"purgatory.go",
"range_stats.go",
"replication_statements.go",
"udf_row_processor.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/crosscluster/logical",
Expand Down Expand Up @@ -76,6 +77,7 @@ go_library(
"//pkg/sql/sem/catid",
"//pkg/sql/sem/eval",
"//pkg/sql/sem/tree",
"//pkg/sql/sem/tree/treecmp",
"//pkg/sql/sessiondata",
"//pkg/sql/sessiondatapb",
"//pkg/sql/stats",
Expand Down Expand Up @@ -117,9 +119,12 @@ go_test(
"main_test.go",
"purgatory_test.go",
"range_stats_test.go",
"replication_statements_test.go",
"udf_row_processor_test.go",
],
data = ["//c-deps:libgeos"],
data = glob(["testdata/**"]) + [
"//c-deps:libgeos",
],
embed = [":logical"],
exec_properties = {"test.Pool": "large"},
deps = [
Expand Down Expand Up @@ -163,6 +168,7 @@ go_test(
"//pkg/sql/sqltestutils",
"//pkg/sql/stats",
"//pkg/testutils",
"//pkg/testutils/datapathutils",
"//pkg/testutils/jobutils",
"//pkg/testutils/serverutils",
"//pkg/testutils/skip",
Expand All @@ -182,6 +188,7 @@ go_test(
"//pkg/util/timeutil",
"//pkg/util/uuid",
"@com_github_cockroachdb_cockroach_go_v2//crdb",
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_redact//:redact",
"@com_github_lib_pq//:pq",
Expand Down
180 changes: 180 additions & 0 deletions pkg/crosscluster/logical/replication_statements.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Copyright 2025 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package logical

import (
"fmt"

"github.com/cockroachdb/cockroach/pkg/sql/catalog"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treecmp"
)

// getPhysicalColumns returns the list of columns that are part of the table's
// primary key and value.
func getPhysicalColumns(table catalog.TableDescriptor) []catalog.Column {
columns := table.AllColumns()
result := make([]catalog.Column, 0, len(columns))
for _, col := range columns {
if !col.IsComputed() && !col.IsVirtual() && !col.IsSystemColumn() {
result = append(result, col)
}
}
return result
}

// newInsertStatement returns a statement that can be used to insert a row into
// the table.
//
// The statement will have `n` parameters, where `n` is the number of columns
// in the table. Parameters are ordered by column ID.
func newInsertStatement(table catalog.TableDescriptor) (tree.Statement, error) {
columns := getPhysicalColumns(table)

columnNames := make(tree.NameList, len(columns))
for i, col := range columns {
columnNames[i] = tree.Name(col.GetName())
}

parameters := make(tree.Exprs, len(columnNames))
for i := range columnNames {
var err error
parameters[i], err = tree.NewPlaceholder(fmt.Sprintf("%d", i+1))
if err != nil {
return nil, err
}
}

parameterValues := &tree.ValuesClause{
Rows: []tree.Exprs{
parameters,
},
}

rows := &tree.Select{
Select: parameterValues,
}

insert := &tree.Insert{
Table: &tree.TableRef{
TableID: int64(table.GetID()),
As: tree.AliasClause{Alias: "replication_target"},
},
Rows: rows,
Columns: columnNames,
Returning: tree.AbsentReturningClause,
}

return insert, nil
}

// newMatchesLastRow creates a WHERE clause for matching all columns of a row.
// It returns a tree.Expr that compares each column to a placeholder parameter.
// Parameters are ordered by column ID, starting from startParamIdx.
func newMatchesLastRow(columns []catalog.Column, startParamIdx int) (tree.Expr, error) {
var whereClause tree.Expr
for i, col := range columns {
placeholder, err := tree.NewPlaceholder(fmt.Sprintf("%d", startParamIdx+i))
if err != nil {
return nil, err
}
colExpr := &tree.ComparisonExpr{
Operator: treecmp.MakeComparisonOperator(treecmp.EQ),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the one thing I noticed: for nullable columns, this will need to be treecmp.IsNotDistinctFrom rather than treecmp.EQ.

Left: &tree.ColumnItem{ColumnName: tree.Name(col.GetName())},
Right: placeholder,
}

if whereClause == nil {
whereClause = colExpr
} else {
whereClause = &tree.AndExpr{
Left: whereClause,
Right: colExpr,
}
}
}
return whereClause, nil
}

// newUpdateStatement returns a statement that can be used to update a row in
// the table. If a table has `n` columns, the statement will have `2n`
// parameters, where the first `n` parameters are the previous values of the row
// and the last `n` parameters are the new values of the row.
//
// Parameters are ordered by column ID.
func newUpdateStatement(table catalog.TableDescriptor) (tree.Statement, error) {
columns := getPhysicalColumns(table)

// Create WHERE clause for matching the previous row values
whereClause, err := newMatchesLastRow(columns, 1)
if err != nil {
return nil, err
}

exprs := make(tree.UpdateExprs, len(columns))
for i, col := range columns {
nameNode := tree.Name(col.GetName())
names := tree.NameList{nameNode}

// Create a placeholder for the new value (len(columns)+i+1) since we
// use 1-indexed placeholders and the first len(columns) placeholders
// are for the where clause.
placeholder, err := tree.NewPlaceholder(fmt.Sprintf("%d", len(columns)+i+1))
if err != nil {
return nil, err
}

exprs[i] = &tree.UpdateExpr{
Names: names,
Expr: &tree.CastExpr{
Expr: placeholder,
Type: col.GetType(),
SyntaxMode: tree.CastPrepend,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When printing these statements, it's a little easier to read when using tree.CastShort.

random aside: I wonder if we'll want these casts in the where clause as well? In the past we've had some trouble with correct typing of placeholders, and have fixed customer problems by adding casts to placeholders wherever they are used. But for these simple statements where the datum type exactly matches the column type, we should be able to type the placeholders correctly... so this should be fine.

},
}
}

// Create the final update statement
update := &tree.Update{
Table: &tree.TableRef{
TableID: int64(table.GetID()),
As: tree.AliasClause{Alias: "replication_target"},
},
Exprs: exprs,
Where: &tree.Where{Type: tree.AstWhere, Expr: whereClause},
Returning: tree.AbsentReturningClause,
}

return update, nil
}

// newDeleteStatement returns a statement that can be used to delete a row from
// the table. The statement will have `n` parameters, where `n` is the number of
// columns in the table. Parameters are used in the WHERE clause to precisely
// identify the row to delete.
//
// Parameters are ordered by column ID.
func newDeleteStatement(table catalog.TableDescriptor) (tree.Statement, error) {
columns := getPhysicalColumns(table)

// Create WHERE clause for matching the row to delete
whereClause, err := newMatchesLastRow(columns, 1)
if err != nil {
return nil, err
}

// Create the final delete statement
delete := &tree.Delete{
Table: &tree.TableRef{
TableID: int64(table.GetID()),
As: tree.AliasClause{Alias: "replication_target"},
},
Where: &tree.Where{Type: tree.AstWhere, Expr: whereClause},
Returning: tree.AbsentReturningClause,
}

return delete, nil
}
106 changes: 106 additions & 0 deletions pkg/crosscluster/logical/replication_statements_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2025 The Cockroach Authors.
//
// Use of this software is governed by the CockroachDB Software License
// included in the /LICENSE file.

package logical

import (
"context"
"fmt"
"math/rand"
"testing"

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/roachpb"
"github.com/cockroachdb/cockroach/pkg/sql/catalog"
"github.com/cockroachdb/cockroach/pkg/sql/catalog/desctestutils"
"github.com/cockroachdb/cockroach/pkg/testutils/datapathutils"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/datadriven"
"github.com/stretchr/testify/require"
)

func TestReplicationStatements(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()
s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{
Locality: roachpb.Locality{
Tiers: []roachpb.Tier{
{Key: "region", Value: "us-east1"},
},
},
})
defer s.Stopper().Stop(ctx)

getTableDesc := func(tableName string) catalog.TableDescriptor {
return desctestutils.TestingGetTableDescriptor(
s.DB(),
s.Codec(),
"defaultdb",
"public",
tableName,
)
}

datadriven.Walk(t, datapathutils.TestDataPath(t), func(t *testing.T, path string) {
datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string {
switch d.Cmd {
case "exec":
_, err := sqlDB.Exec(d.Input)
if err != nil {
return err.Error()
}
return "ok"
case "show-insert":
var tableName string
d.ScanArgs(t, "table", &tableName)

desc := getTableDesc(tableName)

insertStmt, err := newInsertStatement(desc)
require.NoError(t, err)

// Test preparing the statement to ensure it is valid SQL.
_, err = sqlDB.Exec(fmt.Sprintf("PREPARE stmt_%d AS %s", rand.Int(), insertStmt.String()))
require.NoError(t, err)

return insertStmt.String()
case "show-update":
var tableName string
d.ScanArgs(t, "table", &tableName)

desc := getTableDesc(tableName)

updateStmt, err := newUpdateStatement(desc)
require.NoError(t, err)

// Test preparing the statement to ensure it is valid SQL.
_, err = sqlDB.Exec(fmt.Sprintf("PREPARE stmt_%d AS %s", rand.Int(), updateStmt.String()))
require.NoError(t, err)

return updateStmt.String()
case "show-delete":
var tableName string
d.ScanArgs(t, "table", &tableName)

desc := getTableDesc(tableName)

deleteStmt, err := newDeleteStatement(desc)
require.NoError(t, err)

// Test preparing the statement to ensure it is valid SQL.
_, err = sqlDB.Exec(fmt.Sprintf("PREPARE stmt_%d AS %s", rand.Int(), deleteStmt.String()))
require.NoError(t, err)

return deleteStmt.String()
default:
return "unknown command: " + d.Cmd
}
})
})
}
Loading